From ebf90ac0a04e37876f328c66ef6a336fc813a35c Mon Sep 17 00:00:00 2001 From: octo-patch Date: Sun, 26 Apr 2026 10:11:13 +0800 Subject: [PATCH] fix(fns): handle BinaryImage/UserContent types as attachments in extract() and classify() When passing pydantic-ai UserContent types (BinaryImage, ImageUrl, AudioUrl, etc.) to marvin.extract() or marvin.classify(), the binary data was serialized into the task context as a string, causing token explosion (200K+ tokens for moderately-sized images). This is the same issue as #1246 which was already fixed for cast(). Fixes the same class of bug in extract() and classify() by detecting UserContent attachment types and passing them as task attachments instead of context values. --- src/marvin/fns/classify.py | 22 +++++++- src/marvin/fns/extract.py | 23 ++++++++- tests/basic/fns/__init__.py | 0 tests/basic/fns/test_classify.py | 87 ++++++++++++++++++++++++++++++++ tests/basic/fns/test_extract.py | 87 ++++++++++++++++++++++++++++++++ 5 files changed, 217 insertions(+), 2 deletions(-) create mode 100644 tests/basic/fns/__init__.py create mode 100644 tests/basic/fns/test_classify.py create mode 100644 tests/basic/fns/test_extract.py diff --git a/src/marvin/fns/classify.py b/src/marvin/fns/classify.py index efde54cef..a660b4ce9 100644 --- a/src/marvin/fns/classify.py +++ b/src/marvin/fns/classify.py @@ -2,6 +2,14 @@ from collections.abc import Sequence from typing import Any, Literal, TypeVar, overload +from pydantic_ai.messages import ( + AudioUrl, + BinaryContent, + DocumentUrl, + ImageUrl, + VideoUrl, +) + import marvin from marvin.agents.agent import Agent from marvin.handlers.handlers import AsyncHandler, Handler @@ -9,6 +17,9 @@ from marvin.utilities.asyncio import run_sync from marvin.utilities.types import Labels, issubclass_safe +# Non-string UserContent types that should be passed as attachments +_ATTACHMENT_TYPES = (ImageUrl, AudioUrl, DocumentUrl, VideoUrl, BinaryContent) + T = TypeVar("T") DEFAULT_PROMPT = """ @@ -131,7 +142,15 @@ async def classify_async( """ task_context = context or {} - task_context.update({"Data to classify": data}) + attachments = [] + + # Handle non-string UserContent types (images, audio, etc.) as attachments + # to avoid serializing binary data into the prompt text + if isinstance(data, _ATTACHMENT_TYPES): + attachments.append(data) + task_context.update({"Data to classify": "(provided as attachment)"}) + else: + task_context.update({"Data to classify": data}) prompt = prompt or PROMPT if instructions: @@ -150,6 +169,7 @@ async def classify_async( task = marvin.Task[ReturnType]( name="Classification Task", instructions=prompt, + attachments=attachments, context=task_context, result_type=result_type, agents=[agent] if agent else None, diff --git a/src/marvin/fns/extract.py b/src/marvin/fns/extract.py index 11016cc86..12d8a5fe5 100644 --- a/src/marvin/fns/extract.py +++ b/src/marvin/fns/extract.py @@ -1,5 +1,13 @@ from typing import Any, TypeVar +from pydantic_ai.messages import ( + AudioUrl, + BinaryContent, + DocumentUrl, + ImageUrl, + VideoUrl, +) + import marvin from marvin.agents.agent import Agent from marvin.handlers.handlers import AsyncHandler, Handler @@ -7,6 +15,9 @@ from marvin.utilities.asyncio import run_sync from marvin.utilities.types import TargetType +# Non-string UserContent types that should be passed as attachments +_ATTACHMENT_TYPES = (ImageUrl, AudioUrl, DocumentUrl, VideoUrl, BinaryContent) + T = TypeVar("T") DEFAULT_PROMPT = """ @@ -68,7 +79,16 @@ async def extract_async( raise ValueError("Instructions are required when extracting string values.") task_context = context or {} - task_context["Data to extract"] = data + attachments = [] + + # Handle non-string UserContent types (images, audio, etc.) as attachments + # to avoid serializing binary data into the prompt text + if isinstance(data, _ATTACHMENT_TYPES): + attachments.append(data) + task_context["Data to extract"] = "(provided as attachment)" + else: + task_context["Data to extract"] = data + prompt = prompt or PROMPT if instructions: prompt += f"\n\nYou must follow these instructions for your extraction:\n{instructions}" @@ -76,6 +96,7 @@ async def extract_async( task = marvin.Task[list[target]]( name="Extraction Task", instructions=prompt, + attachments=attachments, context=task_context, result_type=list[target], agents=[agent] if agent else None, diff --git a/tests/basic/fns/__init__.py b/tests/basic/fns/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/basic/fns/test_classify.py b/tests/basic/fns/test_classify.py new file mode 100644 index 000000000..a855ba5b5 --- /dev/null +++ b/tests/basic/fns/test_classify.py @@ -0,0 +1,87 @@ +"""Tests for classify function - basic unit tests.""" + +from pydantic_ai.messages import BinaryImage, ImageUrl + +import marvin +from marvin.tasks.task import Task + + +class TestClassifyWithAttachments: + """Test that classify properly handles attachment types like images.""" + + def test_binary_image_passed_as_attachment(self, test_model): + """Test that BinaryImage is passed as attachment, not in context.""" + binary_image = BinaryImage(data=b"fake image data", media_type="image/png") + + original_task_init = Task.__init__ + captured_task = None + + def capture_task(self, *args, **kwargs): + nonlocal captured_task + original_task_init(self, *args, **kwargs) + captured_task = self + + Task.__init__ = capture_task + + try: + marvin.classify( + binary_image, + labels=["cat", "dog", "other"], + instructions="Classify the animal in the image", + ) + finally: + Task.__init__ = original_task_init + + assert captured_task is not None + assert len(captured_task.attachments) == 1 + assert captured_task.attachments[0] is binary_image + assert captured_task.context["Data to classify"] == "(provided as attachment)" + + def test_image_url_passed_as_attachment(self, test_model): + """Test that ImageUrl is passed as attachment, not in context.""" + image_url = ImageUrl(url="https://example.com/image.png") + + original_task_init = Task.__init__ + captured_task = None + + def capture_task(self, *args, **kwargs): + nonlocal captured_task + original_task_init(self, *args, **kwargs) + captured_task = self + + Task.__init__ = capture_task + + try: + marvin.classify( + image_url, + labels=["safe", "unsafe"], + instructions="Classify whether the image is safe", + ) + finally: + Task.__init__ = original_task_init + + assert captured_task is not None + assert len(captured_task.attachments) == 1 + assert captured_task.attachments[0] is image_url + assert captured_task.context["Data to classify"] == "(provided as attachment)" + + def test_string_data_not_treated_as_attachment(self, test_model): + """Test that string data is still passed in context, not as attachment.""" + original_task_init = Task.__init__ + captured_task = None + + def capture_task(self, *args, **kwargs): + nonlocal captured_task + original_task_init(self, *args, **kwargs) + captured_task = self + + Task.__init__ = capture_task + + try: + marvin.classify("I love this product!", labels=["positive", "negative", "neutral"]) + finally: + Task.__init__ = original_task_init + + assert captured_task is not None + assert len(captured_task.attachments) == 0 + assert captured_task.context["Data to classify"] == "I love this product!" diff --git a/tests/basic/fns/test_extract.py b/tests/basic/fns/test_extract.py new file mode 100644 index 000000000..20373d28f --- /dev/null +++ b/tests/basic/fns/test_extract.py @@ -0,0 +1,87 @@ +"""Tests for extract function - basic unit tests.""" + +from pydantic_ai.messages import BinaryImage, ImageUrl + +import marvin +from marvin.tasks.task import Task + + +class TestExtractWithAttachments: + """Test that extract properly handles attachment types like images.""" + + def test_binary_image_passed_as_attachment(self, test_model): + """Test that BinaryImage is passed as attachment, not in context.""" + binary_image = BinaryImage(data=b"fake image data", media_type="image/png") + + original_task_init = Task.__init__ + captured_task = None + + def capture_task(self, *args, **kwargs): + nonlocal captured_task + original_task_init(self, *args, **kwargs) + captured_task = self + + Task.__init__ = capture_task + + try: + marvin.extract( + binary_image, + target=str, + instructions="List all objects in the image", + ) + finally: + Task.__init__ = original_task_init + + assert captured_task is not None + assert len(captured_task.attachments) == 1 + assert captured_task.attachments[0] is binary_image + assert captured_task.context["Data to extract"] == "(provided as attachment)" + + def test_image_url_passed_as_attachment(self, test_model): + """Test that ImageUrl is passed as attachment, not in context.""" + image_url = ImageUrl(url="https://example.com/image.png") + + original_task_init = Task.__init__ + captured_task = None + + def capture_task(self, *args, **kwargs): + nonlocal captured_task + original_task_init(self, *args, **kwargs) + captured_task = self + + Task.__init__ = capture_task + + try: + marvin.extract( + image_url, + target=str, + instructions="List all objects in the image", + ) + finally: + Task.__init__ = original_task_init + + assert captured_task is not None + assert len(captured_task.attachments) == 1 + assert captured_task.attachments[0] is image_url + assert captured_task.context["Data to extract"] == "(provided as attachment)" + + def test_string_data_not_treated_as_attachment(self, test_model): + """Test that string data is still passed in context, not as attachment.""" + original_task_init = Task.__init__ + captured_task = None + + def capture_task(self, *args, **kwargs): + nonlocal captured_task + original_task_init(self, *args, **kwargs) + captured_task = self + + Task.__init__ = capture_task + + try: + marvin.extract("apple, banana, cherry", target=str, instructions="Extract the fruits") + finally: + Task.__init__ = original_task_init + + assert captured_task is not None + assert len(captured_task.attachments) == 0 + assert captured_task.context["Data to extract"] == "apple, banana, cherry"