diff --git a/cli/serve/app.py b/cli/serve/app.py index 583b28c01..83676fedd 100644 --- a/cli/serve/app.py +++ b/cli/serve/app.py @@ -3,10 +3,12 @@ import asyncio import importlib.util import inspect +import json import os import sys import time import uuid +from typing import Literal try: import typer @@ -26,10 +28,12 @@ from .models import ( ChatCompletion, ChatCompletionMessage, + ChatCompletionMessageToolCall, ChatCompletionRequest, Choice, OpenAIError, OpenAIErrorResponse, + ToolCallFunction, ) from .streaming import stream_chat_completion_chunks @@ -111,14 +115,14 @@ def _build_model_options(request: ChatCompletionRequest) -> dict: "response_format", # Response format (json_object) - not yet implemented "functions", # Legacy function calling - not yet implemented "function_call", # Legacy function calling - not yet implemented - "tools", # Tool calling - not yet implemented - "tool_choice", # Tool choice - not yet implemented } openai_to_model_option = { "temperature": ModelOption.TEMPERATURE, "max_tokens": ModelOption.MAX_NEW_TOKENS, "seed": ModelOption.SEED, "stream": ModelOption.STREAM, + "tools": ModelOption.TOOLS, + "tool_choice": ModelOption.TOOL_CHOICE, } # Get all non-None fields @@ -171,6 +175,36 @@ async def endpoint(request: ChatCompletionRequest): model_options=model_options, ) + # Extract tool calls from the ModelOutputThunk if available + tool_calls = None + finish_reason: Literal[ + "stop", "length", "content_filter", "tool_calls", "function_call" + ] = "stop" + if ( + hasattr(output, "tool_calls") + and output.tool_calls is not None + and isinstance(output.tool_calls, dict) + and output.tool_calls # Check dict is not empty + ): + tool_calls = [] + for model_tool_call in output.tool_calls.values(): + # Generate a unique ID for this tool call + tool_call_id = f"call_{uuid.uuid4().hex[:24]}" + + # Serialize the arguments to JSON string + args_json = json.dumps(model_tool_call.args) + + tool_calls.append( + ChatCompletionMessageToolCall( + id=tool_call_id, + type="function", + function=ToolCallFunction( + name=model_tool_call.name, arguments=args_json + ), + ) + ) + finish_reason = "tool_calls" + # system_fingerprint represents backend config hash, not model name # The model name is already in response.model (line 73) # Leave as None since we don't track backend config fingerprints yet @@ -198,9 +232,11 @@ async def endpoint(request: ChatCompletionRequest): Choice( index=0, message=ChatCompletionMessage( - content=output.value, role="assistant" + content=output.value, + role="assistant", + tool_calls=tool_calls, ), - finish_reason="stop", + finish_reason=finish_reason, ) ], object="chat.completion", # type: ignore diff --git a/cli/serve/models.py b/cli/serve/models.py index 7e738730e..ba0bd8cca 100644 --- a/cli/serve/models.py +++ b/cli/serve/models.py @@ -80,6 +80,29 @@ class ChatCompletionRequest(BaseModel): extra: dict[str, Any] = Field(default_factory=dict) +class ToolCallFunction(BaseModel): + """Function details for a tool call.""" + + name: str + """The name of the function to call.""" + + arguments: str + """The arguments to call the function with, as a JSON string.""" + + +class ChatCompletionMessageToolCall(BaseModel): + """A tool call generated by the model.""" + + id: str + """The ID of the tool call.""" + + type: Literal["function"] + """The type of the tool. Currently, only 'function' is supported.""" + + function: ToolCallFunction + """The function that the model called.""" + + # Taking this from OpenAI types https://github.com/openai/openai-python/blob/main/src/openai/types/chat/chat_completion.py, class ChatCompletionMessage(BaseModel): content: str | None = None @@ -91,6 +114,9 @@ class ChatCompletionMessage(BaseModel): role: Literal["assistant"] """The role of the author of this message.""" + tool_calls: list[ChatCompletionMessageToolCall] | None = None + """The tool calls generated by the model, such as function calls.""" + class Choice(BaseModel): index: int diff --git a/docs/examples/m_serve/client_tool_calling.py b/docs/examples/m_serve/client_tool_calling.py new file mode 100644 index 000000000..d68e5d238 --- /dev/null +++ b/docs/examples/m_serve/client_tool_calling.py @@ -0,0 +1,291 @@ +"""Client example for testing tool calling with m serve. + +This script demonstrates how to interact with an m serve server +that supports tool calling using the OpenAI-compatible API. + +Usage: + 1. Start the server: + uv run m serve docs/examples/m_serve/m_serve_example_tool_calling.py + + 2. Run this client: + uv run python docs/examples/m_serve/client_tool_calling.py +""" + +import json + +import requests + +# Server configuration +BASE_URL = "http://localhost:8080" +ENDPOINT = f"{BASE_URL}/v1/chat/completions" + +# Define tools in OpenAI format +tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "parameters": { + "RootModel": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name, e.g. San Francisco", + }, + "units": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Temperature units", + }, + }, + "required": ["location"], + } + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_stock_price", + "description": "Get the current stock price for a given ticker symbol", + "parameters": { + "RootModel": { + "type": "object", + "properties": { + "symbol": { + "type": "string", + "description": "The stock ticker symbol, e.g. AAPL, GOOGL", + } + }, + "required": ["symbol"], + } + }, + }, + }, +] + + +def make_request( + messages: list[dict], tools: list[dict] | None = None, tool_name: str | None = None +) -> dict: + """Make a request to the m serve API. + + Args: + messages: List of message dictionaries + tools: Optional list of tool definitions + tool_name: Optional tool name to request explicitly + + Returns: + Response dictionary from the API + """ + payload = { + "model": "gpt-3.5-turbo", # Model name (not used by m serve) + "messages": messages, + "temperature": 0.7, + } + + if tools: + payload["tools"] = tools + if tool_name is not None: + # m serve forwards tool_choice to compatible backends, but the + # downstream provider/model may ignore it or treat it as a weak + # preference rather than a guarantee. Use an explicit function + # selection in this client so the example demonstrates the API + # contract even when the model would otherwise decline to call tools. + payload["tool_choice"] = { + "type": "function", + "function": {"name": tool_name}, + } + else: + payload["tool_choice"] = "auto" + + response = requests.post(ENDPOINT, json=payload, timeout=30) + + if response.status_code >= 400: + try: + error_payload = response.json() + except ValueError: + error_payload = {"error": {"message": response.text}} + + error_message = error_payload.get("error", {}).get("message", response.text) + raise requests.HTTPError( + f"{response.status_code} Server Error: {error_message}", response=response + ) + + return response.json() + + +def _run_local_tool(tool_name: str, args: dict) -> str: + """Simulate local execution of the example tools.""" + if tool_name == "get_weather": + units = args.get("units") or "celsius" + unit_suffix = "C" if units == "celsius" else "F" + return f"The weather in {args['location']} is sunny and 22°{unit_suffix}" + if tool_name == "get_stock_price": + mock_prices = { + "AAPL": "$175.43", + "GOOGL": "$142.87", + "MSFT": "$378.91", + "TSLA": "$242.15", + } + symbol = args["symbol"].upper() + return f"The current price of {symbol} is {mock_prices.get(symbol, '$100.00')}" + return "Tool result" + + +def main(): + """Run example tool calling interactions.""" + print("=" * 60) + print("Tool Calling Example with m serve") + print("=" * 60) + + # Example 1: Request that should trigger weather tool + print("\n1. Weather Query") + print("-" * 60) + messages = [{"role": "user", "content": "What's the weather like in Tokyo?"}] + + print(f"User: {messages[0]['content']}") + response = make_request(messages, tools=tools, tool_name="get_weather") + + choice = response["choices"][0] + print(f"\nFinish Reason: {choice['finish_reason']}") + + if choice.get("message", {}).get("tool_calls"): + print("\nTool Calls:") + for tool_call in choice["message"]["tool_calls"]: + func = tool_call["function"] + args = json.loads(func["arguments"]) + print(f" - {func['name']}({json.dumps(args)})") + elif choice.get("message", {}).get("content"): + print(f"Assistant: {choice['message']['content']}") + else: + print("Assistant returned no content and no tool calls.") + + # Example 2: Request that should trigger stock price tool + print("\n\n2. Stock Price Query") + print("-" * 60) + messages = [{"role": "user", "content": "What's the current stock price of AAPL?"}] + + print(f"User: {messages[0]['content']}") + response = make_request(messages, tools=tools, tool_name="get_stock_price") + + choice = response["choices"][0] + print(f"\nFinish Reason: {choice['finish_reason']}") + + if choice.get("message", {}).get("tool_calls"): + print("\nTool Calls:") + for tool_call in choice["message"]["tool_calls"]: + func = tool_call["function"] + args = json.loads(func["arguments"]) + print(f" - {func['name']}({json.dumps(args)})") + elif choice.get("message", {}).get("content"): + print(f"Assistant: {choice['message']['content']}") + else: + print("Assistant returned no content and no tool calls.") + + # Example 3: Request without tools (normal chat) + print("\n\n3. Normal Chat (No Tools)") + print("-" * 60) + messages = [{"role": "user", "content": "Hello! How are you?"}] + + print(f"User: {messages[0]['content']}") + response = make_request(messages, tools=None) + + choice = response["choices"][0] + print(f"\nFinish Reason: {choice['finish_reason']}") + print(f"Assistant: {choice['message']['content']}") + + # Example 4: Multi-turn conversation with tool use + print("\n\n4. Multi-turn Conversation") + print("-" * 60) + messages = [{"role": "user", "content": "What's the weather in Paris?"}] + + print(f"User: {messages[0]['content']}") + response = make_request(messages, tools=tools, tool_name="get_weather") + + choice = response["choices"][0] + assistant_message = choice["message"] + + if assistant_message.get("tool_calls"): + print("\nAssistant requested tool calls:") + + # Add assistant message once before processing tool calls + messages.append( + { + "role": "assistant", + "content": assistant_message.get("content"), + "tool_calls": assistant_message["tool_calls"], + } + ) + + tool_results: list[str] = [] + + # Process each tool call and add tool responses + for tool_call in assistant_message["tool_calls"]: + func = tool_call["function"] + args = json.loads(func["arguments"]) + print(f" - {func['name']}({json.dumps(args)})") + + tool_result = _run_local_tool(func["name"], args) + tool_results.append(tool_result) + print(f" Result: {tool_result}") + + # Add tool response to conversation + messages.append( + { + "role": "tool", + "tool_call_id": tool_call["id"], + "content": tool_result, + } + ) + + # Get final response after tool execution. + # Ask for a concise answer that explicitly uses the tool result so the + # example output includes the actual weather/price instead of only a + # conversational acknowledgement. + messages.append( + { + "role": "user", + "content": ( + f"Original question: {messages[0]['content']}\n" + f"Tool result: {'; '.join(tool_results)}\n" + "Answer the original question directly using only that tool " + "result. Do not mention unrelated topics or other tools." + ), + } + ) + print("\nGetting final response after tool execution...") + response = make_request(messages, tools=None) + choice = response["choices"][0] + if choice.get("message", {}).get("content"): + print(f"Assistant: {choice['message']['content']}") + else: + print("Assistant returned no content after tool execution.") + elif assistant_message.get("content"): + print(f"Assistant: {assistant_message['content']}") + else: + print("Assistant returned no content and no tool calls.") + + print("\n" + "=" * 60) + print("Examples completed!") + print("=" * 60) + + +if __name__ == "__main__": + try: + main() + except requests.exceptions.ConnectionError: + print("Error: Could not connect to server.") + print("Make sure the server is running:") + print(" uv run m serve docs/examples/m_serve/m_serve_example_tool_calling.py") + except requests.exceptions.HTTPError as e: + print(f"Error: {e}") + if e.response is not None: + try: + print("Server response:", json.dumps(e.response.json(), indent=2)) + except ValueError: + print("Server response:", e.response.text) + except Exception as e: + print(f"Error: {e}") diff --git a/docs/examples/m_serve/m_serve_example_tool_calling.py b/docs/examples/m_serve/m_serve_example_tool_calling.py new file mode 100644 index 000000000..839c91b1b --- /dev/null +++ b/docs/examples/m_serve/m_serve_example_tool_calling.py @@ -0,0 +1,271 @@ +# pytest: ollama, e2e + +"""Example demonstrating tool calling with m serve. + +This file supports two distinct usage patterns: + +1. Running it directly with ``uv run python ...`` performs a local smoke test + using native Mellea tool calling. +2. Serving it with ``m serve`` exposes an OpenAI-compatible endpoint that + accepts OpenAI-style tool definitions in the request. + +The direct ``__main__`` smoke test is intentionally separate from the +OpenAI-compatible server flow because local ``session.instruct(...)`` calls +should use ``MelleaTool`` objects directly. +""" + +import os +from typing import Any + +from cli.serve.models import ChatMessage +from mellea.backends import ModelOption +from mellea.backends.model_ids import IBM_GRANITE_4_HYBRID_MICRO +from mellea.backends.openai import OpenAIBackend +from mellea.backends.tools import MelleaTool +from mellea.core import ModelOutputThunk, Requirement +from mellea.core.base import AbstractMelleaTool +from mellea.formatters import TemplateFormatter +from mellea.stdlib.context import ChatContext +from mellea.stdlib.session import MelleaSession + +_ollama_host = os.environ.get("OLLAMA_HOST", "localhost:11434") +if not _ollama_host.startswith(("http://", "https://")): + _ollama_host = f"http://{_ollama_host}" + +backend = OpenAIBackend( + model_id=IBM_GRANITE_4_HYBRID_MICRO.ollama_name, # type: ignore[arg-type] + formatter=TemplateFormatter(model_id=IBM_GRANITE_4_HYBRID_MICRO.hf_model_name), # type: ignore[arg-type] + base_url=f"{_ollama_host}/v1", + api_key="ollama", +) +session = MelleaSession(backend, ctx=ChatContext()) + + +class GetWeatherTool(AbstractMelleaTool): + """Tool for getting weather information.""" + + name = "get_weather" + + def run(self, location: str, units: str | None = "celsius") -> str: + """Get the current weather for a location. + + Args: + location: The city name + units: Temperature units (celsius or fahrenheit) + + Returns: + Weather information as a string + """ + # Models sometimes emit optional arguments explicitly as null/None. + resolved_units = units or "celsius" + # In a real implementation, this would call a weather API + return f"The weather in {location} is sunny and 22°{resolved_units[0].upper()}" + + @property + def as_json_tool(self) -> dict[str, Any]: + """Return JSON schema for this tool.""" + return { + "type": "function", + "function": { + "name": self.name, + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name, e.g. San Francisco", + }, + "units": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Temperature units", + }, + }, + "required": ["location"], + }, + }, + } + + +class GetStockPriceTool(AbstractMelleaTool): + """Tool for getting stock price information.""" + + name = "get_stock_price" + + def run(self, symbol: str) -> str: + """Get the current stock price for a symbol. + + Args: + symbol: The stock ticker symbol (e.g., AAPL, GOOGL) + + Returns: + Stock price information as a string + """ + # In a real implementation, this would call a stock market API + mock_prices = { + "AAPL": "$175.43", + "GOOGL": "$142.87", + "MSFT": "$378.91", + "TSLA": "$242.15", + } + price = mock_prices.get(symbol.upper(), "$100.00") + return f"The current price of {symbol.upper()} is {price}" + + @property + def as_json_tool(self) -> dict[str, Any]: + """Return JSON schema for this tool.""" + return { + "type": "function", + "function": { + "name": self.name, + "description": "Get the current stock price for a given ticker symbol", + "parameters": { + "type": "object", + "properties": { + "symbol": { + "type": "string", + "description": "The stock ticker symbol, e.g. AAPL, GOOGL", + } + }, + "required": ["symbol"], + }, + }, + } + + +# Create tool instances for server-side lookup +weather_tool_impl = GetWeatherTool() +stock_price_tool_impl = GetStockPriceTool() + +# Native MelleaTool wrappers are only needed for the direct ``__main__`` path. +# The backend helper used by local ``session.instruct(..., ModelOption.TOOLS=[...])`` +# expects ``MelleaTool`` instances in a list, while the server path below uses the +# class-based implementations via the ``TOOLS`` lookup. +weather_tool = MelleaTool( + name=weather_tool_impl.name, + tool_call=weather_tool_impl.run, + as_json_tool=weather_tool_impl.as_json_tool, +) +stock_price_tool = MelleaTool( + name=stock_price_tool_impl.name, + tool_call=stock_price_tool_impl.run, + as_json_tool=stock_price_tool_impl.as_json_tool, +) + +# Map tool names to server-side tool implementations for easy lookup +TOOLS = { + weather_tool_impl.name: weather_tool_impl, + stock_price_tool_impl.name: stock_price_tool_impl, +} + + +def _extract_mellea_tools_from_model_options( + model_options: dict | None, +) -> dict[str, AbstractMelleaTool]: + """Normalize example tool inputs to native tool instances. + + This example supports only two shapes: + - OpenAI-style JSON tool definitions from the server path + - native tool objects from the direct ``__main__`` path + """ + if model_options is None or ModelOption.TOOLS not in model_options: + return {} + + provided_tools = model_options[ModelOption.TOOLS] + tools: dict[str, AbstractMelleaTool] = {} + + for tool_def in provided_tools: + if isinstance(tool_def, AbstractMelleaTool): + tools[tool_def.name] = tool_def + else: + tool_name = tool_def["function"]["name"] + if tool_name in TOOLS: + tools[tool_name] = TOOLS[tool_name] + + return tools + + +def serve( + input: list[ChatMessage], + requirements: list[str] | None = None, + model_options: None | dict = None, +) -> ModelOutputThunk: + """Serve function that handles tool calling. + + This function demonstrates how to use tools with m serve. The tools + are passed via model_options using ModelOption.TOOLS, and tool_choice + can be specified using ModelOption.TOOL_CHOICE. Mellea forwards that + setting to compatible backends, but the downstream provider/model may + still ignore it or treat it as a weak preference. + + Args: + input: List of chat messages + requirements: Optional list of requirement strings + model_options: Model options including ModelOption.TOOLS and ModelOption.TOOL_CHOICE + + Returns: + ModelOutputThunk with potential tool calls + """ + requirements = requirements if requirements else [] + message = input[-1].content + + # Extract tools from model_options if provided + tools = _extract_mellea_tools_from_model_options(model_options) + + # Build model options with tools. + # If the caller explicitly selected a single function via tool_choice, + # narrow the advertised tool set to that one tool so the backend/model + # is not asked to choose among unrelated tools. + final_model_options = dict(model_options or {}) + selected_tool_name: str | None = None + if tools: + selected_tools = tools + if model_options is not None and ModelOption.TOOL_CHOICE in model_options: + tool_choice = model_options[ModelOption.TOOL_CHOICE] + if isinstance(tool_choice, dict): + selected_tool_name = tool_choice.get("function", {}).get("name") + if selected_tool_name in tools: + selected_tools = {selected_tool_name: tools[selected_tool_name]} + final_model_options[ModelOption.TOOLS] = selected_tools + + # Keep the serve path deterministic for the client example by retrying only + # at the request level. Enforcing uses_tool(...) inside session.instruct() + # caused noisy server-side failures when the model ignored the tool request + # on a particular sample. + result = session.instruct( + description=message, # type: ignore + requirements=[Requirement(req) for req in requirements], # type: ignore + model_options=final_model_options, + tool_calls=True, + strategy=None, + ) + + return result + + +if __name__ == "__main__": + response = session.instruct( + "What's the weather in Boston?", + model_options={ + ModelOption.TOOLS: [weather_tool], + # This direct path now uses the OpenAI backend against Ollama's + # OpenAI-compatible endpoint, so TOOL_CHOICE is forwarded by + # Mellea. Ollama and/or the selected model may still ignore it + # or not enforce it strictly in practice. + ModelOption.TOOL_CHOICE: "auto", + ModelOption.MAX_NEW_TOKENS: 1000, + }, + strategy=None, + tool_calls=True, + ) + + print(f"Response: {response.value}") + print( + "Tool calls requested:", + None if response.tool_calls is None else list(response.tool_calls.keys()), + ) + + if response.tool_calls and weather_tool.name in response.tool_calls: + tool_result = response.tool_calls[weather_tool.name].call_func() + print(f"Tool result: {tool_result}") diff --git a/mellea/backends/model_options.py b/mellea/backends/model_options.py index f71ddfb53..682c77bc7 100644 --- a/mellea/backends/model_options.py +++ b/mellea/backends/model_options.py @@ -22,6 +22,7 @@ class ModelOption: Attributes: TOOLS (str): Sentinel key for a list or dict of tools to expose for tool calling. + TOOL_CHOICE (str): Key for tool choice strategy (passed through to the backend). MAX_NEW_TOKENS (str): Sentinel key for the maximum number of new tokens to generate. SYSTEM_PROMPT (str): Sentinel key for the system prompt string. TEMPERATURE (str): Key for the sampling temperature (passed through to the backend). @@ -34,6 +35,9 @@ class ModelOption: TOOLS = "@@@tools@@@" """Must be a list[Callable] or a dict[str, Callable] where str is the name of the function.""" + TOOL_CHOICE = "tool_choice" + """Controls which tool the model should use. Can be "none", "auto", or a specific tool name.""" + MAX_NEW_TOKENS = "@@@max_new_tokens@@@" SYSTEM_PROMPT = "@@@system_prompt@@@" TEMPERATURE = "temperature" diff --git a/test/cli/test_serve.py b/test/cli/test_serve.py index 515cc82f2..688c01b21 100644 --- a/test/cli/test_serve.py +++ b/test/cli/test_serve.py @@ -455,18 +455,19 @@ async def test_unsupported_params_excluded_from_model_options(self, mock_module) assert "logit_bias" not in model_options @pytest.mark.asyncio - async def test_tool_params_excluded_from_model_options(self, mock_module): - """Test that tool-related parameters are excluded from model_options.""" + async def test_tool_params_passed_to_model_options(self, mock_module): + """Test that tool-related parameters are passed to model_options.""" from cli.serve.models import ( FunctionDefinition, FunctionParameters, ToolFunction, ) + from mellea.backends.model_options import ModelOption request = ChatCompletionRequest( model="test-model", messages=[ChatMessage(role="user", content="Hello")], - # Tool-related parameters that should be excluded + # Tool-related parameters tools=[ ToolFunction( type="function", @@ -502,9 +503,12 @@ async def test_tool_params_excluded_from_model_options(self, mock_module): assert call_args is not None model_options = call_args.kwargs["model_options"] - # Tool-related parameters should NOT be in model_options - assert "tools" not in model_options - assert "tool_choice" not in model_options + # Tools should be passed with ModelOption.TOOLS key + assert ModelOption.TOOLS in model_options + # tool_choice should be passed through using ModelOption.TOOL_CHOICE + assert ModelOption.TOOL_CHOICE in model_options + assert model_options[ModelOption.TOOL_CHOICE] == "auto" + # Legacy function calling parameters should still be excluded assert "functions" not in model_options assert "function_call" not in model_options diff --git a/test/cli/test_serve_tool_calling.py b/test/cli/test_serve_tool_calling.py new file mode 100644 index 000000000..d1e76cd09 --- /dev/null +++ b/test/cli/test_serve_tool_calling.py @@ -0,0 +1,313 @@ +"""Tests for tool calling support in m serve OpenAI-compatible API server.""" + +import json +from typing import Any +from unittest.mock import Mock + +import pytest + +from cli.serve.app import make_chat_endpoint +from cli.serve.models import ( + ChatCompletion, + ChatCompletionRequest, + ChatMessage, + FunctionDefinition, + FunctionParameters, + ToolFunction, +) +from mellea.backends import ModelOption +from mellea.core.base import AbstractMelleaTool, ModelOutputThunk, ModelToolCall + + +class MockTool(AbstractMelleaTool): + """Mock tool for testing.""" + + name = "get_weather" + + def run(self, location: str) -> str: + """Mock run method.""" + return f"Weather in {location} is sunny" + + @property + def as_json_tool(self) -> dict[str, Any]: + """Return JSON schema for this tool.""" + return { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, + }, + } + + +@pytest.fixture +def mock_module(): + """Create a mock module with a serve function.""" + module = Mock() + module.__name__ = "test_module" + return module + + +@pytest.fixture +def sample_tool_request(): + """Create a sample ChatCompletionRequest with tools.""" + return ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="What's the weather in Paris?")], + tools=[ + ToolFunction( + type="function", + function=FunctionDefinition( + name="get_weather", + description="Get the current weather in a location", + parameters=FunctionParameters( + RootModel={ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name", + } + }, + "required": ["location"], + } + ), + ), + ) + ], + tool_choice="auto", + ) + + +class TestToolCalling: + """Tests for tool calling functionality.""" + + @pytest.mark.asyncio + async def test_tool_calls_in_response(self, mock_module, sample_tool_request): + """Test that tool calls are properly formatted in the response.""" + # Setup mock output with tool calls + mock_output = ModelOutputThunk("I'll check the weather for you.") + mock_tool = MockTool() + mock_output.tool_calls = { + "get_weather": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "Paris"} + ) + } + mock_module.serve.return_value = mock_output + + # Create endpoint and call it + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(sample_tool_request) + + # Verify response structure + assert isinstance(response, ChatCompletion) + assert response.choices[0].finish_reason == "tool_calls" + assert response.choices[0].message.tool_calls is not None + assert len(response.choices[0].message.tool_calls) == 1 + + # Verify tool call details + tool_call = response.choices[0].message.tool_calls[0] + assert tool_call.type == "function" + assert tool_call.function.name == "get_weather" + + # Parse and verify arguments + args = json.loads(tool_call.function.arguments) + assert args == {"location": "Paris"} + + # Verify tool call ID format + assert tool_call.id.startswith("call_") + assert len(tool_call.id) > len("call_") + + @pytest.mark.asyncio + async def test_multiple_tool_calls(self, mock_module, sample_tool_request): + """Test handling multiple tool calls in a single response.""" + mock_output = ModelOutputThunk("I'll check multiple locations.") + mock_tool = MockTool() + mock_output.tool_calls = { + "get_weather_paris": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "Paris"} + ), + "get_weather_london": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "London"} + ), + } + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(sample_tool_request) + + # Verify multiple tool calls + assert response.choices[0].finish_reason == "tool_calls" + assert len(response.choices[0].message.tool_calls) == 2 + + # Verify each tool call has unique ID + ids = [tc.id for tc in response.choices[0].message.tool_calls] + assert len(ids) == len(set(ids)), "Tool call IDs should be unique" + + @pytest.mark.asyncio + async def test_no_tool_calls_finish_reason_stop( + self, mock_module, sample_tool_request + ): + """Test that finish_reason is 'stop' when no tool calls are made.""" + mock_output = ModelOutputThunk("The weather is sunny.") + # No tool_calls set + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(sample_tool_request) + + assert response.choices[0].finish_reason == "stop" + assert response.choices[0].message.tool_calls is None + + @pytest.mark.asyncio + async def test_empty_tool_calls_dict_finish_reason_stop( + self, mock_module, sample_tool_request + ): + """Test that finish_reason is 'stop' when tool_calls is an empty dict. + + Regression test for bug where empty tool_calls dict {} produces + finish_reason='tool_calls' with an empty array instead of + finish_reason='stop' with tool_calls=None. + """ + mock_output = ModelOutputThunk("Hello! How can I help?") + # Set tool_calls to empty dict (the bug case) + mock_output.tool_calls = {} + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(sample_tool_request) + + # Should behave like no tool calls at all + assert response.choices[0].finish_reason == "stop" + assert response.choices[0].message.tool_calls is None + + @pytest.mark.asyncio + async def test_tools_passed_to_model_options( + self, mock_module, sample_tool_request + ): + """Test that tools are passed to serve function in model_options.""" + from mellea.backends.model_options import ModelOption + + mock_output = ModelOutputThunk("Test response") + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + await endpoint(sample_tool_request) + + # Verify serve was called with tools in model_options + call_args = mock_module.serve.call_args + assert call_args is not None + model_options = call_args.kwargs["model_options"] + + # Tools should be in model_options with the ModelOption.TOOLS key + assert ModelOption.TOOLS in model_options + assert model_options[ModelOption.TOOLS] is not None + + @pytest.mark.asyncio + async def test_tool_choice_passed_to_model_options( + self, mock_module, sample_tool_request + ): + """Test that tool_choice is passed to serve function in model_options.""" + mock_output = ModelOutputThunk("Test response") + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + await endpoint(sample_tool_request) + + # Verify serve was called with tool_choice in model_options + call_args = mock_module.serve.call_args + assert call_args is not None + model_options = call_args.kwargs["model_options"] + + # tool_choice should be passed through using ModelOption.TOOL_CHOICE + assert ModelOption.TOOL_CHOICE in model_options + assert model_options[ModelOption.TOOL_CHOICE] == "auto" + + @pytest.mark.asyncio + async def test_tool_calls_with_complex_arguments( + self, mock_module, sample_tool_request + ): + """Test tool calls with complex nested arguments.""" + mock_output = ModelOutputThunk("Processing complex request.") + mock_tool = MockTool() + mock_output.tool_calls = { + "complex_tool": ModelToolCall( + name="complex_function", + func=mock_tool, + args={ + "location": "Paris", + "options": { + "units": "celsius", + "include_forecast": True, + "days": 5, + }, + "tags": ["weather", "forecast"], + }, + ) + } + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(sample_tool_request) + + # Verify complex arguments are properly serialized + tool_call = response.choices[0].message.tool_calls[0] + args = json.loads(tool_call.function.arguments) + + assert args["location"] == "Paris" + assert args["options"]["units"] == "celsius" + assert args["options"]["include_forecast"] is True + assert args["options"]["days"] == 5 + assert args["tags"] == ["weather", "forecast"] + + @pytest.mark.asyncio + async def test_tool_calls_with_usage_info(self, mock_module, sample_tool_request): + """Test that usage info is included alongside tool calls.""" + mock_output = ModelOutputThunk("Calling tool.") + mock_tool = MockTool() + mock_output.tool_calls = { + "get_weather": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "Paris"} + ) + } + mock_output.usage = { + "prompt_tokens": 50, + "completion_tokens": 20, + "total_tokens": 70, + } + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(sample_tool_request) + + # Verify both tool calls and usage are present + assert response.choices[0].finish_reason == "tool_calls" + assert response.choices[0].message.tool_calls is not None + assert response.usage is not None + assert response.usage.total_tokens == 70 + + @pytest.mark.asyncio + async def test_request_without_tools(self, mock_module): + """Test that requests without tools still work normally.""" + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Hello")], + # No tools specified + ) + + mock_output = ModelOutputThunk("Hello! How can I help?") + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Should work normally without tool-related fields + assert isinstance(response, ChatCompletion) + assert response.choices[0].finish_reason == "stop" + assert response.choices[0].message.tool_calls is None + assert response.choices[0].message.content == "Hello! How can I help?" diff --git a/test/core/test_component_typing.py b/test/core/test_component_typing.py index bbc3de9ef..5003fbe34 100644 --- a/test/core/test_component_typing.py +++ b/test/core/test_component_typing.py @@ -78,16 +78,16 @@ def session(backend) -> MelleaSession: def test_mot_init_typing(): mot = ModelOutputThunk[float](value="1") - assert hasattr(mot, "__orig_class__"), ( - "mots are generics and should have this field" + assert "__orig_class__" in mot.__dict__, ( + "mots are generics and should have this field in instance dict" ) assert get_args(mot.__orig_class__)[0] is float, ( # type: ignore f"expected float, got {get_args(mot.__orig_class__)[0]} as mot type" # type: ignore ) # type: ignore unknown_mot = ModelOutputThunk(value="2") - assert not hasattr(unknown_mot, "__orig_class__"), ( - "unknown mots / mots with no type defined at instantiate don't have this attribute" + assert "__orig_class__" not in unknown_mot.__dict__, ( + "unknown mots / mots with no type defined at instantiate don't have this attribute in instance dict" )