diff --git a/cli/serve/app.py b/cli/serve/app.py index 583b28c01..859d37045 100644 --- a/cli/serve/app.py +++ b/cli/serve/app.py @@ -7,6 +7,7 @@ import sys import time import uuid +from typing import Any try: import typer @@ -14,6 +15,7 @@ from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse + from pydantic import BaseModel, create_model except ImportError as e: raise ImportError( "The 'm serve' command requires extra dependencies. " @@ -90,6 +92,58 @@ def create_openai_error_response( ) +def _json_schema_to_pydantic( + schema: dict[str, Any], model_name: str = "DynamicModel" +) -> type[BaseModel]: + """Convert a JSON Schema to a Pydantic model dynamically. + + Args: + schema: JSON Schema definition (must have 'properties' and 'type': 'object'). + model_name: Name for the generated Pydantic model. + + Returns: + A dynamically created Pydantic model class. + + Raises: + ValueError: If the schema is invalid or unsupported. + """ + if not isinstance(schema, dict): + raise ValueError("Schema must be a dictionary") + + if schema.get("type") != "object": + raise ValueError("Only object-type schemas are supported") + + properties = schema.get("properties", {}) + required = schema.get("required", []) + + if not properties: + raise ValueError("Schema must have 'properties' field") + + # Map JSON Schema types to Python types + type_mapping = { + "string": str, + "integer": int, + "number": float, + "boolean": bool, + "array": list, + "object": dict, + } + + # Build field definitions for create_model + field_definitions: dict[str, Any] = {} + for field_name, field_schema in properties.items(): + field_type = field_schema.get("type", "string") + python_type = type_mapping.get(field_type, str) + + # Handle optional fields + if field_name in required: + field_definitions[field_name] = (python_type, ...) + else: + field_definitions[field_name] = (python_type | None, None) + + return create_model(model_name, **field_definitions) + + def _build_model_options(request: ChatCompletionRequest) -> dict: """Build model_options dict from OpenAI-compatible request parameters.""" excluded_fields = { @@ -108,7 +162,7 @@ def _build_model_options(request: ChatCompletionRequest) -> dict: "presence_penalty", # Presence penalty - not yet implemented "frequency_penalty", # Frequency penalty - not yet implemented "logit_bias", # Logit bias - not yet implemented - "response_format", # Response format (json_object) - not yet implemented + "response_format", # Response format - handled separately "functions", # Legacy function calling - not yet implemented "function_call", # Legacy function calling - not yet implemented "tools", # Tool calling - not yet implemented @@ -154,22 +208,71 @@ async def endpoint(request: ChatCompletionRequest): model_options = _build_model_options(request) + # Handle response_format + format_model: type[BaseModel] | None = None + if request.response_format is not None: + if request.response_format.type == "json_schema": + if request.response_format.json_schema is None: + return create_openai_error_response( + status_code=400, + message="json_schema field is required when response_format.type is 'json_schema'", + error_type="invalid_request_error", + param="response_format.json_schema", + ) + try: + format_model = _json_schema_to_pydantic( + request.response_format.json_schema.schema_, + request.response_format.json_schema.name, + ) + except ValueError as e: + return create_openai_error_response( + status_code=400, + message=f"Invalid JSON schema: {e!s}", + error_type="invalid_request_error", + param="response_format.json_schema.schema", + ) + elif request.response_format.type == "json_object": + # For json_object, we don't enforce a specific schema + # The backend will handle JSON mode if supported + pass + + # Check if serve function accepts format parameter + serve_sig = inspect.signature(module.serve) + accepts_format = "format" in serve_sig.parameters + # Detect if serve is async or sync and handle accordingly if inspect.iscoroutinefunction(module.serve): # It's async, await it directly - output = await module.serve( - input=request.messages, - requirements=request.requirements, - model_options=model_options, - ) + if accepts_format: + output = await module.serve( + input=request.messages, + requirements=request.requirements, + model_options=model_options, + format=format_model, + ) + else: + output = await module.serve( + input=request.messages, + requirements=request.requirements, + model_options=model_options, + ) else: # It's sync, run in thread pool to avoid blocking event loop - output = await asyncio.to_thread( - module.serve, - input=request.messages, - requirements=request.requirements, - model_options=model_options, - ) + if accepts_format: + output = await asyncio.to_thread( + module.serve, + input=request.messages, + requirements=request.requirements, + model_options=model_options, + format=format_model, + ) + else: + output = await asyncio.to_thread( + module.serve, + input=request.messages, + requirements=request.requirements, + model_options=model_options, + ) # system_fingerprint represents backend config hash, not model name # The model name is already in response.model (line 73) @@ -186,6 +289,7 @@ async def endpoint(request: ChatCompletionRequest): created=created_timestamp, stream_options=request.stream_options, system_fingerprint=system_fingerprint, + format_model=format_model, ), media_type="text/event-stream", ) diff --git a/cli/serve/models.py b/cli/serve/models.py index 7e738730e..b68588417 100644 --- a/cli/serve/models.py +++ b/cli/serve/models.py @@ -29,8 +29,26 @@ class ToolFunction(BaseModel): function: FunctionDefinition +class JsonSchemaFormat(BaseModel): + """JSON Schema definition for structured output.""" + + name: str + """Name of the schema.""" + + schema_: dict[str, Any] = Field(alias="schema") + """JSON Schema definition.""" + + strict: bool | None = None + """Whether to enforce strict schema validation.""" + + model_config = {"populate_by_name": True} + + class ResponseFormat(BaseModel): - type: Literal["text", "json_object"] + type: Literal["text", "json_object", "json_schema"] + + json_schema: JsonSchemaFormat | None = None + """JSON Schema definition when type is 'json_schema'.""" class StreamOptions(BaseModel): diff --git a/cli/serve/streaming.py b/cli/serve/streaming.py index 51ff33c3c..15298a98a 100644 --- a/cli/serve/streaming.py +++ b/cli/serve/streaming.py @@ -1,7 +1,10 @@ """Streaming utilities for OpenAI-compatible server responses.""" +import json from collections.abc import AsyncGenerator +from pydantic import BaseModel, ValidationError + from mellea.core.base import ModelOutputThunk from mellea.core.utils import MelleaLogger from mellea.helpers.openai_compatible_helpers import build_completion_usage @@ -23,6 +26,7 @@ async def stream_chat_completion_chunks( created: int, stream_options: StreamOptions | None = None, system_fingerprint: str | None = None, + format_model: type[BaseModel] | None = None, ) -> AsyncGenerator[str, None]: """Generate OpenAI-compatible SSE chat completion chunks from a model output. @@ -36,6 +40,9 @@ async def stream_chat_completion_chunks( ``include_usage`` field. system_fingerprint: Backend configuration fingerprint to include in chunks. Defaults to ``None``. + format_model: Optional Pydantic model for validating structured output. + When provided, the complete streamed output will be validated against + this schema before the final chunk is sent. Yields: Server-sent event payload strings representing OpenAI-compatible chat @@ -98,6 +105,45 @@ async def stream_chat_completion_chunks( ) yield f"data: {chunk.model_dump_json()}\n\n" + # Validate format if format_model is provided + if format_model is not None: + if output.value is None: + error_response = OpenAIErrorResponse( + error=OpenAIError( + message="Output value is None, cannot validate format", + type="invalid_response_error", + ) + ) + yield f"data: {error_response.model_dump_json()}\n\n" + yield "data: [DONE]\n\n" + return + + try: + # Parse the complete output as JSON + output_json = json.loads(output.value) + # Validate against the Pydantic model + format_model.model_validate(output_json) + except json.JSONDecodeError as e: + error_response = OpenAIErrorResponse( + error=OpenAIError( + message=f"Output is not valid JSON: {e!s}", + type="invalid_response_error", + ) + ) + yield f"data: {error_response.model_dump_json()}\n\n" + yield "data: [DONE]\n\n" + return + except ValidationError as e: + error_response = OpenAIErrorResponse( + error=OpenAIError( + message=f"Output does not match required schema: {e!s}", + type="invalid_response_error", + ) + ) + yield f"data: {error_response.model_dump_json()}\n\n" + yield "data: [DONE]\n\n" + return + # Include usage in final chunk only if explicitly requested via stream_options # Per OpenAI spec: usage is only included when stream_options.include_usage=True include_usage = stream_options is not None and stream_options.include_usage diff --git a/docs/examples/m_serve/README.md b/docs/examples/m_serve/README.md index 70fcb5f5e..c65ba8819 100644 --- a/docs/examples/m_serve/README.md +++ b/docs/examples/m_serve/README.md @@ -19,6 +19,14 @@ A dedicated streaming example for `m serve` that supports both modes: - `stream=True` returns an uncomputed thunk so the server can emit incremental Server-Sent Events (SSE) chunks +### m_serve_example_response_format.py +Example demonstrating structured output with the `response_format` parameter. + +**Key Features:** +- Supporting the `format` parameter in serve functions +- Structured output validation with JSON schemas +- Three format types: `text`, `json_object`, `json_schema` + ### pii_serve.py Example of serving a PII (Personally Identifiable Information) detection service. @@ -29,6 +37,9 @@ Client code for testing the served API endpoints with non-streaming requests. Client code demonstrating streaming responses using Server-Sent Events (SSE) against `m_serve_example_streaming.py`. +### client_response_format.py +Client code demonstrating all three `response_format` types with examples. + ## Concepts Demonstrated - **API Deployment**: Exposing Mellea programs as REST APIs @@ -37,6 +48,7 @@ against `m_serve_example_streaming.py`. - **Validation in Production**: Using requirements in deployed services - **Model Options**: Passing model configuration through API - **Streaming Responses**: Real-time token streaming via Server-Sent Events (SSE) +- **Structured Output**: Using `response_format` for JSON schema validation ## Basic Pattern @@ -84,6 +96,85 @@ m serve docs/examples/m_serve/m_serve_example_streaming.py python docs/examples/m_serve/client_streaming.py ``` +### Response Format + +```bash +# Start the response_format example server +m serve docs/examples/m_serve/m_serve_example_response_format.py + +# In another terminal, test with the response_format client +python docs/examples/m_serve/client_response_format.py +``` + +## Response Format Support + +The server supports structured output via the `response_format` parameter, which allows you to control the format of the model's response. This is compatible with OpenAI's response format API. + +**Three Format Types:** + +1. **`text`** (default): Plain text output +2. **`json_object`**: Unstructured JSON output (model decides the schema) +3. **`json_schema`**: Structured output validated against a JSON schema + +**Key Features:** +- Automatic JSON schema to Pydantic model conversion +- Schema validation for structured outputs +- OpenAI-compatible API +- Works with the `format` parameter in serve functions + +**Example - JSON Schema:** +```python +import openai + +client = openai.OpenAI(api_key="na", base_url="http://0.0.0.0:8080/v1") + +# Define a schema for structured output +person_schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "email": {"type": "string"}, + }, + "required": ["name", "age", "email"], +} + +response = client.chat.completions.create( + messages=[{"role": "user", "content": "Generate a person named Alice"}], + model="granite4:micro-h", + response_format={ + "type": "json_schema", + "json_schema": { + "name": "Person", + "schema": person_schema, + "strict": True, + }, + }, +) + +# Response will be valid JSON matching the schema +print(response.choices[0].message.content) +``` + +**Server Implementation:** +Your serve function must accept a `format` parameter to support `json_schema`: + +```python +def serve( + input: list[ChatMessage], + requirements: list[str] | None = None, + model_options: dict | None = None, + format: type | None = None, # Add this parameter +) -> ModelOutputThunk: + result = session.instruct( + description=input[-1].content, + requirements=requirements, + model_options=model_options, + format=format, # Pass to instruct() + ) + return result +``` + ## Streaming Support The server supports streaming responses via Server-Sent Events (SSE) when the diff --git a/docs/examples/m_serve/client_response_format.py b/docs/examples/m_serve/client_response_format.py new file mode 100644 index 000000000..a51f371b1 --- /dev/null +++ b/docs/examples/m_serve/client_response_format.py @@ -0,0 +1,254 @@ +# pytest: skip_always +"""Client demonstrating response_format parameter with m serve. + +This example shows how to use the three response_format types: +1. text - Plain text output (default) +2. json_object - Unstructured JSON output +3. json_schema - Structured output validated against a JSON schema + +Prerequisites: + Start the server first: + m serve docs/examples/m_serve/m_serve_example_response_format.py + + Then run this client: + python docs/examples/m_serve/client_response_format.py +""" + +import json + +import openai + +PORT = 8080 +BASE_URL = f"http://0.0.0.0:{PORT}/v1" + +# Create OpenAI client pointing to our m serve endpoint +client = openai.OpenAI(api_key="not-needed", base_url=BASE_URL) + + +def example_text_format(): + """Example 1: Plain text output (default behavior).""" + print("\n" + "=" * 60) + print("Example 1: Text Format (default)") + print("=" * 60) + + response = client.chat.completions.create( + model="granite4:micro-h", + messages=[{"role": "user", "content": "Write a haiku about programming."}], + response_format={"type": "text"}, + ) + + print(f"Response: {response.choices[0].message.content}") + + +def example_json_object(): + """Example 2: Unstructured JSON output. + + Note: json_object format requests JSON but doesn't enforce it strictly. + The model may wrap JSON in markdown or add explanatory text. + For strict JSON validation, use json_schema instead. + """ + print("\n" + "=" * 60) + print("Example 2: JSON Object Format") + print("=" * 60) + + response = client.chat.completions.create( + model="granite4:micro-h", + messages=[ + { + "role": "user", + "content": "Generate a JSON object with information about a fictional person. Include name, age, and city. Return ONLY the JSON, no markdown formatting.", + } + ], + response_format={"type": "json_object"}, + ) + + content = response.choices[0].message.content or "" + print(f"Response: {content}") + + # First, try to parse as-is (valid JSON) + try: + data = json.loads(content) + print("\n✓ Valid JSON received") + print(f"\nParsed JSON:\n{json.dumps(data, indent=2)}") + return + except json.JSONDecodeError: + # Not valid JSON, try to extract from markdown + print("\n⚠ Response is not valid JSON, attempting to extract from markdown...") + + # Fallback: Try to extract JSON from markdown code blocks + json_content = content + if "```json" in content: + # Extract JSON from markdown code block + start = content.find("```json") + 7 + end = content.find("```", start) + if end > start: + json_content = content[start:end].strip() + print("Extracted from ```json block") + elif "```" in content: + # Generic code block + start = content.find("```") + 3 + end = content.find("```", start) + if end > start: + json_content = content[start:end].strip() + print("Extracted from ``` block") + + # Try parsing the extracted content + try: + data = json.loads(json_content) + print( + f"\n✓ Successfully extracted and parsed JSON:\n{json.dumps(data, indent=2)}" + ) + except json.JSONDecodeError as e: + print("\n✗ Failed to parse JSON even after extraction") + print("Note: json_object format doesn't enforce strict JSON.") + print("For guaranteed JSON output, use json_schema format instead.") + print(f"Parse error: {e}") + + +def example_json_schema_person(): + """Example 3: Structured output with JSON schema validation.""" + print("\n" + "=" * 60) + print("Example 3: JSON Schema Format - Person") + print("=" * 60) + + # Define a JSON schema for a person + person_schema = { + "type": "object", + "properties": { + "name": {"type": "string", "description": "The person's full name"}, + "age": {"type": "integer", "description": "The person's age in years"}, + "email": {"type": "string", "description": "The person's email address"}, + "city": { + "type": "string", + "description": "The city where the person lives", + }, + }, + "required": ["name", "age", "email"], + "additionalProperties": False, + } + + response = client.chat.completions.create( + model="granite4:micro-h", + messages=[ + { + "role": "user", + "content": "Generate information about a software engineer named Alice.", + } + ], + response_format={ + "type": "json_schema", + "json_schema": {"name": "Person", "schema": person_schema, "strict": True}, + }, + ) + + content = response.choices[0].message.content + print(f"Response: {content}") + + # Parse and validate the structured output + try: + data = json.loads(content or "{}") + print(f"\nParsed structured output:\n{json.dumps(data, indent=2)}") + + # Verify required fields + assert "name" in data, "Missing required field: name" + assert "age" in data, "Missing required field: age" + assert "email" in data, "Missing required field: email" + print("\n✓ All required fields present") + + except json.JSONDecodeError as e: + print(f"Failed to parse JSON: {e}") + except AssertionError as e: + print(f"Validation error: {e}") + + +def example_json_schema_product(): + """Example 4: Structured output for a product catalog.""" + print("\n" + "=" * 60) + print("Example 4: JSON Schema Format - Product") + print("=" * 60) + + # Define a JSON schema for a product + product_schema = { + "type": "object", + "properties": { + "name": {"type": "string", "description": "Product name"}, + "price": {"type": "number", "description": "Price in USD"}, + "category": { + "type": "string", + "enum": ["electronics", "clothing", "food", "books"], + "description": "Product category", + }, + "in_stock": { + "type": "boolean", + "description": "Whether the product is in stock", + }, + "description": {"type": "string", "description": "Product description"}, + }, + "required": ["name", "price", "category", "in_stock"], + "additionalProperties": False, + } + + response = client.chat.completions.create( + model="granite4:micro-h", + messages=[ + { + "role": "user", + "content": "Generate a product listing for a laptop computer.", + } + ], + response_format={ + "type": "json_schema", + "json_schema": { + "name": "Product", + "schema": product_schema, + "strict": True, + }, + }, + ) + + content = response.choices[0].message.content + print(f"Response: {content}") + + # Parse and display the structured output + try: + data = json.loads(content or "{}") + print(f"\nParsed product data:\n{json.dumps(data, indent=2)}") + + # Verify the category is valid + valid_categories = ["electronics", "clothing", "food", "books"] + if data.get("category") in valid_categories: + print(f"\n✓ Valid category: {data['category']}") + + except json.JSONDecodeError as e: + print(f"Failed to parse JSON: {e}") + + +def main(): + """Run all examples.""" + print("\n" + "=" * 60) + print("RESPONSE_FORMAT EXAMPLES") + print("=" * 60) + print(f"Connecting to: {BASE_URL}") + print("=" * 60) + + try: + # Run all examples + example_text_format() + example_json_object() + example_json_schema_person() + example_json_schema_product() + + print("\n" + "=" * 60) + print("ALL EXAMPLES COMPLETED") + print("=" * 60) + + except Exception as e: + print(f"\nError: {e}") + print("\nMake sure the server is running:") + print( + f" m serve docs/examples/m_serve/m_serve_example_response_format.py --port {PORT}" + ) + + +if __name__ == "__main__": + main() diff --git a/docs/examples/m_serve/m_serve_example_response_format.py b/docs/examples/m_serve/m_serve_example_response_format.py new file mode 100644 index 000000000..4d2bc6b5c --- /dev/null +++ b/docs/examples/m_serve/m_serve_example_response_format.py @@ -0,0 +1,56 @@ +# pytest: ollama, e2e + +"""Example demonstrating response_format with m serve. + +This example shows how to use the response_format parameter to get structured +output from the model. The server supports three format types: +- text: Plain text output (default) +- json_object: Unstructured JSON output +- json_schema: Structured output validated against a JSON schema + +Run the server: + m serve docs/examples/m_serve/m_serve_example_response_format.py + +Test with the client: + python docs/examples/m_serve/client_response_format.py +""" + +from typing import Any + +import mellea +from cli.serve.models import ChatMessage +from mellea.core import ModelOutputThunk +from mellea.stdlib.context import ChatContext + +session = mellea.start_session(ctx=ChatContext()) + + +def serve( + input: list[ChatMessage], + requirements: list[str] | None = None, + model_options: dict[str, Any] | None = None, + format: type | None = None, +) -> ModelOutputThunk: + """Serve function that supports response_format parameter. + + Args: + input: List of chat messages from the client + requirements: Optional list of requirement strings + model_options: Optional model configuration parameters + format: Optional Pydantic model for structured output (from response_format) + + Returns: + ModelOutputThunk with the generated response + """ + message = input[-1].content or "No message provided" + + # When format is provided (from json_schema response_format), + # pass it to instruct() to get structured output + result = session.instruct( + description=message, + requirements=requirements, # type: ignore + model_options=model_options, + format=format, # This enables structured output validation + ) + + return result diff --git a/test/cli/test_serve.py b/test/cli/test_serve.py index 515cc82f2..afeda4ebe 100644 --- a/test/cli/test_serve.py +++ b/test/cli/test_serve.py @@ -535,3 +535,535 @@ async def test_response_format_excluded_from_model_options(self, mock_module): # response_format should NOT be in model_options assert "response_format" not in model_options + + +class TestResponseFormat: + """Tests for response_format parameter handling.""" + + @pytest.mark.asyncio + async def test_json_schema_format_passed_to_serve(self): + """Test that json_schema response_format is converted to Pydantic model and passed to serve.""" + from pydantic import BaseModel + + from cli.serve.models import JsonSchemaFormat, ResponseFormat + + # Create a mock module with serve that accepts format parameter + mock_module = Mock() + mock_module.__name__ = "test_module" + + # Track calls manually + captured_format = None + + def mock_serve(input, requirements=None, model_options=None, format=None): + nonlocal captured_format + captured_format = format + return ModelOutputThunk('{"name": "Alice", "age": 30}') + + # Assign the real function so signature inspection works + mock_module.serve = mock_serve + + # Create a request with json_schema response_format + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Generate a person")], + response_format=ResponseFormat( + type="json_schema", + json_schema=JsonSchemaFormat( + name="Person", + schema={ + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name", "age"], + }, + ), + ), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Verify format was passed + assert captured_format is not None + assert issubclass(captured_format, BaseModel) + assert "name" in captured_format.model_fields + assert "age" in captured_format.model_fields + + # Verify response is successful + assert isinstance(response, ChatCompletion) + assert response.choices[0].message.content == '{"name": "Alice", "age": 30}' + + @pytest.mark.asyncio + async def test_json_object_format_no_schema(self, mock_module): + """Test that json_object response_format doesn't pass a format model.""" + from cli.serve.models import ResponseFormat + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Generate JSON")], + response_format=ResponseFormat(type="json_object"), + ) + + mock_output = ModelOutputThunk('{"result": "success"}') + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Verify serve was called + call_args = mock_module.serve.call_args + assert call_args is not None + + # For json_object, format should be None (no specific schema) + if "format" in call_args.kwargs: + assert call_args.kwargs["format"] is None + + # Verify response is successful + assert isinstance(response, ChatCompletion) + + @pytest.mark.asyncio + async def test_text_format_no_schema(self, mock_module): + """Test that text response_format doesn't pass a format model.""" + from cli.serve.models import ResponseFormat + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Hello")], + response_format=ResponseFormat(type="text"), + ) + + mock_output = ModelOutputThunk("Hello there!") + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Verify serve was called + call_args = mock_module.serve.call_args + assert call_args is not None + + # For text, format should be None + if "format" in call_args.kwargs: + assert call_args.kwargs["format"] is None + + # Verify response is successful + assert isinstance(response, ChatCompletion) + + @pytest.mark.asyncio + async def test_json_schema_missing_schema_field(self, mock_module): + """Test that json_schema without schema field returns error.""" + import json + + from fastapi.responses import JSONResponse + + from cli.serve.models import ResponseFormat + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Generate")], + response_format=ResponseFormat( + type="json_schema", + json_schema=None, # Missing schema + ), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Should return error + assert isinstance(response, JSONResponse) + assert response.status_code == 400 + + body_bytes = response.body + if isinstance(body_bytes, memoryview): + body_bytes = bytes(body_bytes) + error_data = json.loads(body_bytes.decode("utf-8")) + assert "error" in error_data + assert error_data["error"]["type"] == "invalid_request_error" + assert "json_schema" in error_data["error"]["message"].lower() + + @pytest.mark.asyncio + async def test_json_schema_invalid_schema(self, mock_module): + """Test that invalid JSON schema returns error.""" + import json + + from fastapi.responses import JSONResponse + + from cli.serve.models import JsonSchemaFormat, ResponseFormat + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Generate")], + response_format=ResponseFormat( + type="json_schema", + json_schema=JsonSchemaFormat( + name="Invalid", + schema={ + "type": "array", # Not supported (only object) + "items": {"type": "string"}, + }, + ), + ), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Should return error + assert isinstance(response, JSONResponse) + assert response.status_code == 400 + + body_bytes = response.body + if isinstance(body_bytes, memoryview): + body_bytes = bytes(body_bytes) + error_data = json.loads(body_bytes.decode("utf-8")) + assert "error" in error_data + assert error_data["error"]["type"] == "invalid_request_error" + assert "schema" in error_data["error"]["message"].lower() + + @pytest.mark.asyncio + async def test_serve_without_format_parameter(self, mock_module): + """Test that serve functions without format parameter still work.""" + from cli.serve.models import JsonSchemaFormat, ResponseFormat + + # Create a serve function that doesn't accept format + def serve_no_format(input, requirements=None, model_options=None): + return ModelOutputThunk("Response without format") + + mock_module.serve = serve_no_format + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Hello")], + response_format=ResponseFormat( + type="json_schema", + json_schema=JsonSchemaFormat( + name="Test", + schema={ + "type": "object", + "properties": {"result": {"type": "string"}}, + "required": ["result"], + }, + ), + ), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Should succeed even though serve doesn't accept format + assert isinstance(response, ChatCompletion) + assert response.choices[0].message.content == "Response without format" + + @pytest.mark.asyncio + async def test_json_schema_with_optional_fields(self): + """Test that JSON schema with optional fields is handled correctly.""" + from pydantic import BaseModel + + from cli.serve.models import JsonSchemaFormat, ResponseFormat + + # Create a mock module with serve that accepts format parameter + mock_module = Mock() + mock_module.__name__ = "test_module" + + # Track calls manually + captured_format = None + + def mock_serve(input, requirements=None, model_options=None, format=None): + nonlocal captured_format + captured_format = format + return ModelOutputThunk('{"name": "Widget", "price": 9.99}') + + # Assign the real function so signature inspection works + mock_module.serve = mock_serve + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Generate")], + response_format=ResponseFormat( + type="json_schema", + json_schema=JsonSchemaFormat( + name="Product", + schema={ + "type": "object", + "properties": { + "name": {"type": "string"}, + "price": {"type": "number"}, + "description": {"type": "string"}, + }, + "required": ["name", "price"], # description is optional + }, + ), + ), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Verify format model was created correctly + assert captured_format is not None + assert issubclass(captured_format, BaseModel) + assert "name" in captured_format.model_fields + assert "price" in captured_format.model_fields + assert "description" in captured_format.model_fields + + # Verify response is successful + assert isinstance(response, ChatCompletion) + + +class TestResponseFormatStreaming: + """Tests for response_format parameter with streaming enabled.""" + + @pytest.mark.asyncio + async def test_json_schema_format_with_streaming(self): + """Test that json_schema response_format works with stream=True.""" + from cli.serve.models import JsonSchemaFormat, ResponseFormat + + # Create a mock module with serve that accepts format parameter + mock_module = Mock() + mock_module.__name__ = "test_module" + + # Create a mock output that supports streaming + mock_output = ModelOutputThunk('{"name": "Alice", "age": 30}') + mock_output._computed = True # Mark as pre-computed + + def mock_serve(input, requirements=None, model_options=None, format=None): + return mock_output + + mock_module.serve = mock_serve + + # Create a request with json_schema response_format and streaming + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Generate a person")], + stream=True, + response_format=ResponseFormat( + type="json_schema", + json_schema=JsonSchemaFormat( + name="Person", + schema={ + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name", "age"], + }, + ), + ), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Verify it's a streaming response + from fastapi.responses import StreamingResponse + + assert isinstance(response, StreamingResponse) + + # Consume the stream and verify chunks + chunks = [] + async for chunk in response.body_iterator: + chunks.append(chunk) + + # Should have multiple chunks including initial, content, final, and [DONE] + assert len(chunks) > 0 + + # Verify no error chunks (all should start with "data: ") + for chunk in chunks: + chunk_str = chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk + assert chunk_str.startswith("data: ") + + @pytest.mark.asyncio + async def test_json_schema_format_streaming_validation_error(self): + """Test that invalid JSON in streaming response returns error chunk.""" + from cli.serve.models import JsonSchemaFormat, ResponseFormat + + # Create a mock module + mock_module = Mock() + mock_module.__name__ = "test_module" + + # Create output with invalid JSON (missing required field) + mock_output = ModelOutputThunk('{"name": "Alice"}') # Missing 'age' + mock_output._computed = True + + def mock_serve(input, requirements=None, model_options=None, format=None): + return mock_output + + mock_module.serve = mock_serve + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Generate")], + stream=True, + response_format=ResponseFormat( + type="json_schema", + json_schema=JsonSchemaFormat( + name="Person", + schema={ + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name", "age"], + }, + ), + ), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + from fastapi.responses import StreamingResponse + + assert isinstance(response, StreamingResponse) + + # Consume the stream + chunks = [] + async for chunk in response.body_iterator: + chunks.append(chunk) + + # Should contain an error chunk + error_found = False + for chunk in chunks: + chunk_str = chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk + if "error" in chunk_str.lower() and "schema" in chunk_str.lower(): + error_found = True + break + + assert error_found, "Expected validation error in stream" + + @pytest.mark.asyncio + async def test_json_schema_format_streaming_invalid_json(self): + """Test that non-JSON output in streaming response returns error chunk.""" + from cli.serve.models import JsonSchemaFormat, ResponseFormat + + # Create a mock module + mock_module = Mock() + mock_module.__name__ = "test_module" + + # Create output with invalid JSON + mock_output = ModelOutputThunk("This is not JSON") + mock_output._computed = True + + def mock_serve(input, requirements=None, model_options=None, format=None): + return mock_output + + mock_module.serve = mock_serve + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Generate")], + stream=True, + response_format=ResponseFormat( + type="json_schema", + json_schema=JsonSchemaFormat( + name="Person", + schema={ + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + }, + ), + ), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + from fastapi.responses import StreamingResponse + + assert isinstance(response, StreamingResponse) + + # Consume the stream + chunks = [] + async for chunk in response.body_iterator: + chunks.append(chunk) + + # Should contain an error chunk about invalid JSON + error_found = False + for chunk in chunks: + chunk_str = chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk + if "error" in chunk_str.lower() and "json" in chunk_str.lower(): + error_found = True + break + + assert error_found, "Expected JSON parsing error in stream" + + @pytest.mark.asyncio + async def test_json_object_format_with_streaming(self): + """Test that json_object response_format works with stream=True.""" + from cli.serve.models import ResponseFormat + + mock_module = Mock() + mock_module.__name__ = "test_module" + + # Valid JSON output + mock_output = ModelOutputThunk('{"result": "success"}') + mock_output._computed = True + mock_module.serve.return_value = mock_output + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Generate JSON")], + stream=True, + response_format=ResponseFormat(type="json_object"), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + from fastapi.responses import StreamingResponse + + assert isinstance(response, StreamingResponse) + + # Consume the stream + chunks = [] + async for chunk in response.body_iterator: + chunks.append(chunk) + + # Should complete successfully without errors + assert len(chunks) > 0 + # Verify no error chunks + for chunk in chunks: + chunk_str = chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk + assert "error" not in chunk_str.lower() or chunk_str.startswith( + "data: [DONE]" + ) + + @pytest.mark.asyncio + async def test_text_format_with_streaming(self): + """Test that text response_format works with stream=True.""" + from cli.serve.models import ResponseFormat + + mock_module = Mock() + mock_module.__name__ = "test_module" + + mock_output = ModelOutputThunk("Plain text response") + mock_output._computed = True + mock_module.serve.return_value = mock_output + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Hello")], + stream=True, + response_format=ResponseFormat(type="text"), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + from fastapi.responses import StreamingResponse + + assert isinstance(response, StreamingResponse) + + # Consume the stream + chunks = [] + async for chunk in response.body_iterator: + chunks.append(chunk) + + # Should complete successfully + assert len(chunks) > 0