-
Notifications
You must be signed in to change notification settings - Fork 112
feat: cli OpenAI-compatible API response_format support
#884
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,13 +7,15 @@ | |
| import sys | ||
| import time | ||
| import uuid | ||
| from typing import Any | ||
|
|
||
| try: | ||
| import typer | ||
| import uvicorn | ||
| from fastapi import FastAPI, Request | ||
| from fastapi.exceptions import RequestValidationError | ||
| from fastapi.responses import JSONResponse, StreamingResponse | ||
| from pydantic import BaseModel, create_model | ||
| except ImportError as e: | ||
| raise ImportError( | ||
| "The 'm serve' command requires extra dependencies. " | ||
|
|
@@ -90,6 +92,58 @@ def create_openai_error_response( | |
| ) | ||
|
|
||
|
|
||
| def _json_schema_to_pydantic( | ||
| schema: dict[str, Any], model_name: str = "DynamicModel" | ||
| ) -> type[BaseModel]: | ||
| """Convert a JSON Schema to a Pydantic model dynamically. | ||
|
|
||
| Args: | ||
| schema: JSON Schema definition (must have 'properties' and 'type': 'object'). | ||
| model_name: Name for the generated Pydantic model. | ||
|
|
||
| Returns: | ||
| A dynamically created Pydantic model class. | ||
|
|
||
| Raises: | ||
| ValueError: If the schema is invalid or unsupported. | ||
| """ | ||
| if not isinstance(schema, dict): | ||
| raise ValueError("Schema must be a dictionary") | ||
|
|
||
| if schema.get("type") != "object": | ||
| raise ValueError("Only object-type schemas are supported") | ||
|
|
||
| properties = schema.get("properties", {}) | ||
| required = schema.get("required", []) | ||
|
|
||
| if not properties: | ||
| raise ValueError("Schema must have 'properties' field") | ||
|
|
||
| # Map JSON Schema types to Python types | ||
| type_mapping = { | ||
| "string": str, | ||
| "integer": int, | ||
| "number": float, | ||
| "boolean": bool, | ||
| "array": list, | ||
| "object": dict, | ||
| } | ||
|
|
||
| # Build field definitions for create_model | ||
| field_definitions: dict[str, Any] = {} | ||
| for field_name, field_schema in properties.items(): | ||
| field_type = field_schema.get("type", "string") | ||
| python_type = type_mapping.get(field_type, str) | ||
|
|
||
| # Handle optional fields | ||
| if field_name in required: | ||
| field_definitions[field_name] = (python_type, ...) | ||
| else: | ||
| field_definitions[field_name] = (python_type | None, None) | ||
|
|
||
| return create_model(model_name, **field_definitions) | ||
|
|
||
|
|
||
| def _build_model_options(request: ChatCompletionRequest) -> dict: | ||
| """Build model_options dict from OpenAI-compatible request parameters.""" | ||
| excluded_fields = { | ||
|
|
@@ -108,7 +162,7 @@ def _build_model_options(request: ChatCompletionRequest) -> dict: | |
| "presence_penalty", # Presence penalty - not yet implemented | ||
| "frequency_penalty", # Frequency penalty - not yet implemented | ||
| "logit_bias", # Logit bias - not yet implemented | ||
| "response_format", # Response format (json_object) - not yet implemented | ||
| "response_format", # Response format - handled separately | ||
| "functions", # Legacy function calling - not yet implemented | ||
| "function_call", # Legacy function calling - not yet implemented | ||
| "tools", # Tool calling - not yet implemented | ||
|
|
@@ -154,22 +208,71 @@ async def endpoint(request: ChatCompletionRequest): | |
|
|
||
| model_options = _build_model_options(request) | ||
|
|
||
| # Handle response_format | ||
| format_model: type[BaseModel] | None = None | ||
| if request.response_format is not None: | ||
| if request.response_format.type == "json_schema": | ||
| if request.response_format.json_schema is None: | ||
| return create_openai_error_response( | ||
| status_code=400, | ||
| message="json_schema field is required when response_format.type is 'json_schema'", | ||
| error_type="invalid_request_error", | ||
| param="response_format.json_schema", | ||
| ) | ||
| try: | ||
| format_model = _json_schema_to_pydantic( | ||
| request.response_format.json_schema.schema_, | ||
| request.response_format.json_schema.name, | ||
| ) | ||
| except ValueError as e: | ||
| return create_openai_error_response( | ||
| status_code=400, | ||
| message=f"Invalid JSON schema: {e!s}", | ||
| error_type="invalid_request_error", | ||
| param="response_format.json_schema.schema", | ||
| ) | ||
| elif request.response_format.type == "json_object": | ||
| # For json_object, we don't enforce a specific schema | ||
| # The backend will handle JSON mode if supported | ||
| pass | ||
|
|
||
| # Check if serve function accepts format parameter | ||
| serve_sig = inspect.signature(module.serve) | ||
| accepts_format = "format" in serve_sig.parameters | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cacheable/could be done up front? Here it's done in every request but won't change? |
||
|
|
||
| # Detect if serve is async or sync and handle accordingly | ||
| if inspect.iscoroutinefunction(module.serve): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similar (not identical) code is repeated multiple times - possible opportunity for making common - minor. |
||
| # It's async, await it directly | ||
| output = await module.serve( | ||
| input=request.messages, | ||
| requirements=request.requirements, | ||
| model_options=model_options, | ||
| ) | ||
| if accepts_format: | ||
| output = await module.serve( | ||
| input=request.messages, | ||
| requirements=request.requirements, | ||
| model_options=model_options, | ||
| format=format_model, | ||
| ) | ||
| else: | ||
| output = await module.serve( | ||
| input=request.messages, | ||
| requirements=request.requirements, | ||
| model_options=model_options, | ||
| ) | ||
|
Comment on lines
+246
to
+258
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can these calls be combined? If format defaults to |
||
| else: | ||
| # It's sync, run in thread pool to avoid blocking event loop | ||
| output = await asyncio.to_thread( | ||
| module.serve, | ||
| input=request.messages, | ||
| requirements=request.requirements, | ||
| model_options=model_options, | ||
| ) | ||
| if accepts_format: | ||
| output = await asyncio.to_thread( | ||
| module.serve, | ||
| input=request.messages, | ||
| requirements=request.requirements, | ||
| model_options=model_options, | ||
| format=format_model, | ||
| ) | ||
| else: | ||
| output = await asyncio.to_thread( | ||
| module.serve, | ||
| input=request.messages, | ||
| requirements=request.requirements, | ||
| model_options=model_options, | ||
| ) | ||
|
|
||
| # system_fingerprint represents backend config hash, not model name | ||
| # The model name is already in response.model (line 73) | ||
|
|
@@ -186,6 +289,7 @@ async def endpoint(request: ChatCompletionRequest): | |
| created=created_timestamp, | ||
| stream_options=request.stream_options, | ||
| system_fingerprint=system_fingerprint, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The non-streaming path ( if format_model is not None and output.value is not None:
try:
format_model.model_validate(json.loads(output.value))
except (json.JSONDecodeError, ValidationError) as e:
return create_openai_error_response(
status_code=400,
message=f"Output does not match required schema: {e!s}",
error_type="invalid_response_error",
)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe OpenAI responses can return output that is not valid for a given schema if things like token limits are hit. Do we want to match that behavior? Or should we always error on our side if the format isn't met? |
||
| format_model=format_model, | ||
| ), | ||
| media_type="text/event-stream", | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,8 +29,26 @@ class ToolFunction(BaseModel): | |
| function: FunctionDefinition | ||
|
|
||
|
|
||
| class JsonSchemaFormat(BaseModel): | ||
| """JSON Schema definition for structured output.""" | ||
|
|
||
| name: str | ||
| """Name of the schema.""" | ||
|
|
||
| schema_: dict[str, Any] = Field(alias="schema") | ||
| """JSON Schema definition.""" | ||
|
|
||
| strict: bool | None = None | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not used? See related comment - more is needed to really be strict or at least clarify behaviour? |
||
| """Whether to enforce strict schema validation.""" | ||
|
|
||
| model_config = {"populate_by_name": True} | ||
|
|
||
|
|
||
| class ResponseFormat(BaseModel): | ||
| type: Literal["text", "json_object"] | ||
| type: Literal["text", "json_object", "json_schema"] | ||
|
|
||
| json_schema: JsonSchemaFormat | None = None | ||
| """JSON Schema definition when type is 'json_schema'.""" | ||
|
|
||
|
|
||
| class StreamOptions(BaseModel): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,10 @@ | ||
| """Streaming utilities for OpenAI-compatible server responses.""" | ||
|
|
||
| import json | ||
| from collections.abc import AsyncGenerator | ||
|
|
||
| from pydantic import BaseModel, ValidationError | ||
|
|
||
| from mellea.core.base import ModelOutputThunk | ||
| from mellea.core.utils import MelleaLogger | ||
| from mellea.helpers.openai_compatible_helpers import build_completion_usage | ||
|
|
@@ -23,6 +26,7 @@ async def stream_chat_completion_chunks( | |
| created: int, | ||
| stream_options: StreamOptions | None = None, | ||
| system_fingerprint: str | None = None, | ||
| format_model: type[BaseModel] | None = None, | ||
| ) -> AsyncGenerator[str, None]: | ||
| """Generate OpenAI-compatible SSE chat completion chunks from a model output. | ||
|
|
||
|
|
@@ -36,6 +40,9 @@ async def stream_chat_completion_chunks( | |
| ``include_usage`` field. | ||
| system_fingerprint: Backend configuration fingerprint to include in chunks. | ||
| Defaults to ``None``. | ||
| format_model: Optional Pydantic model for validating structured output. | ||
| When provided, the complete streamed output will be validated against | ||
| this schema before the final chunk is sent. | ||
|
|
||
| Yields: | ||
| Server-sent event payload strings representing OpenAI-compatible chat | ||
|
|
@@ -98,6 +105,45 @@ async def stream_chat_completion_chunks( | |
| ) | ||
| yield f"data: {chunk.model_dump_json()}\n\n" | ||
|
|
||
| # Validate format if format_model is provided | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Validation runs after all content chunks are already sent (lines 68–106), so the error arrives after the client has consumed the data. A few options:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. related to #891 right? |
||
| if format_model is not None: | ||
| if output.value is None: | ||
| error_response = OpenAIErrorResponse( | ||
| error=OpenAIError( | ||
| message="Output value is None, cannot validate format", | ||
| type="invalid_response_error", | ||
| ) | ||
| ) | ||
| yield f"data: {error_response.model_dump_json()}\n\n" | ||
| yield "data: [DONE]\n\n" | ||
| return | ||
|
|
||
| try: | ||
| # Parse the complete output as JSON | ||
| output_json = json.loads(output.value) | ||
| # Validate against the Pydantic model | ||
| format_model.model_validate(output_json) | ||
| except json.JSONDecodeError as e: | ||
| error_response = OpenAIErrorResponse( | ||
| error=OpenAIError( | ||
| message=f"Output is not valid JSON: {e!s}", | ||
| type="invalid_response_error", | ||
| ) | ||
| ) | ||
| yield f"data: {error_response.model_dump_json()}\n\n" | ||
| yield "data: [DONE]\n\n" | ||
| return | ||
| except ValidationError as e: | ||
| error_response = OpenAIErrorResponse( | ||
| error=OpenAIError( | ||
| message=f"Output does not match required schema: {e!s}", | ||
| type="invalid_response_error", | ||
| ) | ||
| ) | ||
| yield f"data: {error_response.model_dump_json()}\n\n" | ||
| yield "data: [DONE]\n\n" | ||
| return | ||
|
|
||
| # Include usage in final chunk only if explicitly requested via stream_options | ||
| # Per OpenAI spec: usage is only included when stream_options.include_usage=True | ||
| include_usage = stream_options is not None and stream_options.include_usage | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this handles type, but will not handle enum, additionalProperties, nested types, array, $ref, allOf, anyOf
Suggest clarifying caveats in comments? or figuring out if any more validation is viable