Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/marvin/fns/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,24 @@
from collections.abc import Sequence
from typing import Any, Literal, TypeVar, overload

from pydantic_ai.messages import (
AudioUrl,
BinaryContent,
DocumentUrl,
ImageUrl,
VideoUrl,
)

import marvin
from marvin.agents.agent import Agent
from marvin.handlers.handlers import AsyncHandler, Handler
from marvin.thread import Thread
from marvin.utilities.asyncio import run_sync
from marvin.utilities.types import Labels, issubclass_safe

# Non-string UserContent types that should be passed as attachments
_ATTACHMENT_TYPES = (ImageUrl, AudioUrl, DocumentUrl, VideoUrl, BinaryContent)

T = TypeVar("T")

DEFAULT_PROMPT = """
Expand Down Expand Up @@ -131,7 +142,15 @@ async def classify_async(

"""
task_context = context or {}
task_context.update({"Data to classify": data})
attachments = []

# Handle non-string UserContent types (images, audio, etc.) as attachments
# to avoid serializing binary data into the prompt text
if isinstance(data, _ATTACHMENT_TYPES):
attachments.append(data)
task_context.update({"Data to classify": "(provided as attachment)"})
else:
task_context.update({"Data to classify": data})

prompt = prompt or PROMPT
if instructions:
Expand All @@ -150,6 +169,7 @@ async def classify_async(
task = marvin.Task[ReturnType](
name="Classification Task",
instructions=prompt,
attachments=attachments,
context=task_context,
result_type=result_type,
agents=[agent] if agent else None,
Expand Down
23 changes: 22 additions & 1 deletion src/marvin/fns/extract.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
from typing import Any, TypeVar

from pydantic_ai.messages import (
AudioUrl,
BinaryContent,
DocumentUrl,
ImageUrl,
VideoUrl,
)

import marvin
from marvin.agents.agent import Agent
from marvin.handlers.handlers import AsyncHandler, Handler
from marvin.thread import Thread
from marvin.utilities.asyncio import run_sync
from marvin.utilities.types import TargetType

# Non-string UserContent types that should be passed as attachments
_ATTACHMENT_TYPES = (ImageUrl, AudioUrl, DocumentUrl, VideoUrl, BinaryContent)

T = TypeVar("T")

DEFAULT_PROMPT = """
Expand Down Expand Up @@ -68,14 +79,24 @@ async def extract_async(
raise ValueError("Instructions are required when extracting string values.")

task_context = context or {}
task_context["Data to extract"] = data
attachments = []

# Handle non-string UserContent types (images, audio, etc.) as attachments
# to avoid serializing binary data into the prompt text
if isinstance(data, _ATTACHMENT_TYPES):
attachments.append(data)
task_context["Data to extract"] = "(provided as attachment)"
else:
task_context["Data to extract"] = data

prompt = prompt or PROMPT
if instructions:
prompt += f"\n\nYou must follow these instructions for your extraction:\n{instructions}"

task = marvin.Task[list[target]](
name="Extraction Task",
instructions=prompt,
attachments=attachments,
context=task_context,
result_type=list[target],
agents=[agent] if agent else None,
Expand Down
Empty file added tests/basic/fns/__init__.py
Empty file.
87 changes: 87 additions & 0 deletions tests/basic/fns/test_classify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Tests for classify function - basic unit tests."""

from pydantic_ai.messages import BinaryImage, ImageUrl

import marvin
from marvin.tasks.task import Task


class TestClassifyWithAttachments:
"""Test that classify properly handles attachment types like images."""

def test_binary_image_passed_as_attachment(self, test_model):
"""Test that BinaryImage is passed as attachment, not in context."""
binary_image = BinaryImage(data=b"fake image data", media_type="image/png")

original_task_init = Task.__init__
captured_task = None

def capture_task(self, *args, **kwargs):
nonlocal captured_task
original_task_init(self, *args, **kwargs)
captured_task = self

Task.__init__ = capture_task

try:
marvin.classify(
binary_image,
labels=["cat", "dog", "other"],
instructions="Classify the animal in the image",
)
finally:
Task.__init__ = original_task_init

assert captured_task is not None
assert len(captured_task.attachments) == 1
assert captured_task.attachments[0] is binary_image
assert captured_task.context["Data to classify"] == "(provided as attachment)"

def test_image_url_passed_as_attachment(self, test_model):
"""Test that ImageUrl is passed as attachment, not in context."""
image_url = ImageUrl(url="https://example.com/image.png")

original_task_init = Task.__init__
captured_task = None

def capture_task(self, *args, **kwargs):
nonlocal captured_task
original_task_init(self, *args, **kwargs)
captured_task = self

Task.__init__ = capture_task

try:
marvin.classify(
image_url,
labels=["safe", "unsafe"],
instructions="Classify whether the image is safe",
)
finally:
Task.__init__ = original_task_init

assert captured_task is not None
assert len(captured_task.attachments) == 1
assert captured_task.attachments[0] is image_url
assert captured_task.context["Data to classify"] == "(provided as attachment)"

def test_string_data_not_treated_as_attachment(self, test_model):
"""Test that string data is still passed in context, not as attachment."""
original_task_init = Task.__init__
captured_task = None

def capture_task(self, *args, **kwargs):
nonlocal captured_task
original_task_init(self, *args, **kwargs)
captured_task = self

Task.__init__ = capture_task

try:
marvin.classify("I love this product!", labels=["positive", "negative", "neutral"])
finally:
Task.__init__ = original_task_init

assert captured_task is not None
assert len(captured_task.attachments) == 0
assert captured_task.context["Data to classify"] == "I love this product!"
87 changes: 87 additions & 0 deletions tests/basic/fns/test_extract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Tests for extract function - basic unit tests."""

from pydantic_ai.messages import BinaryImage, ImageUrl

import marvin
from marvin.tasks.task import Task


class TestExtractWithAttachments:
"""Test that extract properly handles attachment types like images."""

def test_binary_image_passed_as_attachment(self, test_model):
"""Test that BinaryImage is passed as attachment, not in context."""
binary_image = BinaryImage(data=b"fake image data", media_type="image/png")

original_task_init = Task.__init__
captured_task = None

def capture_task(self, *args, **kwargs):
nonlocal captured_task
original_task_init(self, *args, **kwargs)
captured_task = self

Task.__init__ = capture_task

try:
marvin.extract(
binary_image,
target=str,
instructions="List all objects in the image",
)
finally:
Task.__init__ = original_task_init

assert captured_task is not None
assert len(captured_task.attachments) == 1
assert captured_task.attachments[0] is binary_image
assert captured_task.context["Data to extract"] == "(provided as attachment)"

def test_image_url_passed_as_attachment(self, test_model):
"""Test that ImageUrl is passed as attachment, not in context."""
image_url = ImageUrl(url="https://example.com/image.png")

original_task_init = Task.__init__
captured_task = None

def capture_task(self, *args, **kwargs):
nonlocal captured_task
original_task_init(self, *args, **kwargs)
captured_task = self

Task.__init__ = capture_task

try:
marvin.extract(
image_url,
target=str,
instructions="List all objects in the image",
)
finally:
Task.__init__ = original_task_init

assert captured_task is not None
assert len(captured_task.attachments) == 1
assert captured_task.attachments[0] is image_url
assert captured_task.context["Data to extract"] == "(provided as attachment)"

def test_string_data_not_treated_as_attachment(self, test_model):
"""Test that string data is still passed in context, not as attachment."""
original_task_init = Task.__init__
captured_task = None

def capture_task(self, *args, **kwargs):
nonlocal captured_task
original_task_init(self, *args, **kwargs)
captured_task = self

Task.__init__ = capture_task

try:
marvin.extract("apple, banana, cherry", target=str, instructions="Extract the fruits")
finally:
Task.__init__ = original_task_init

assert captured_task is not None
assert len(captured_task.attachments) == 0
assert captured_task.context["Data to extract"] == "apple, banana, cherry"
Loading