Skip to content
Open
44 changes: 40 additions & 4 deletions cli/serve/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import asyncio
import importlib.util
import inspect
import json
import os
import sys
import time
import uuid
from typing import Literal

try:
import typer
Expand All @@ -26,10 +28,12 @@
from .models import (
ChatCompletion,
ChatCompletionMessage,
ChatCompletionMessageToolCall,
ChatCompletionRequest,
Choice,
OpenAIError,
OpenAIErrorResponse,
ToolCallFunction,
)
from .streaming import stream_chat_completion_chunks

Expand Down Expand Up @@ -111,14 +115,14 @@ def _build_model_options(request: ChatCompletionRequest) -> dict:
"response_format", # Response format (json_object) - not yet implemented
"functions", # Legacy function calling - not yet implemented
"function_call", # Legacy function calling - not yet implemented
"tools", # Tool calling - not yet implemented
"tool_choice", # Tool choice - not yet implemented
}
openai_to_model_option = {
"temperature": ModelOption.TEMPERATURE,
"max_tokens": ModelOption.MAX_NEW_TOKENS,
"seed": ModelOption.SEED,
"stream": ModelOption.STREAM,
"tools": ModelOption.TOOLS,
"tool_choice": ModelOption.TOOL_CHOICE,
}

# Get all non-None fields
Expand Down Expand Up @@ -171,6 +175,36 @@ async def endpoint(request: ChatCompletionRequest):
model_options=model_options,
)

# Extract tool calls from the ModelOutputThunk if available
tool_calls = None
finish_reason: Literal[
"stop", "length", "content_filter", "tool_calls", "function_call"
] = "stop"
if (
hasattr(output, "tool_calls")
and output.tool_calls is not None
and isinstance(output.tool_calls, dict)
and output.tool_calls # Check dict is not empty
):
tool_calls = []
for model_tool_call in output.tool_calls.values():
# Generate a unique ID for this tool call
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"

# Serialize the arguments to JSON string
args_json = json.dumps(model_tool_call.args)

tool_calls.append(
ChatCompletionMessageToolCall(
id=tool_call_id,
type="function",
function=ToolCallFunction(
name=model_tool_call.name, arguments=args_json
),
)
)
finish_reason = "tool_calls"
Comment thread
markstur marked this conversation as resolved.

# system_fingerprint represents backend config hash, not model name
# The model name is already in response.model (line 73)
# Leave as None since we don't track backend config fingerprints yet
Expand Down Expand Up @@ -198,9 +232,11 @@ async def endpoint(request: ChatCompletionRequest):
Choice(
index=0,
message=ChatCompletionMessage(
content=output.value, role="assistant"
content=output.value,
role="assistant",
tool_calls=tool_calls,
),
finish_reason="stop",
finish_reason=finish_reason,
)
],
object="chat.completion", # type: ignore
Expand Down
26 changes: 26 additions & 0 deletions cli/serve/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,29 @@ class ChatCompletionRequest(BaseModel):
extra: dict[str, Any] = Field(default_factory=dict)


class ToolCallFunction(BaseModel):
"""Function details for a tool call."""

name: str
"""The name of the function to call."""

arguments: str
"""The arguments to call the function with, as a JSON string."""


class ChatCompletionMessageToolCall(BaseModel):
"""A tool call generated by the model."""

id: str
"""The ID of the tool call."""

type: Literal["function"]
"""The type of the tool. Currently, only 'function' is supported."""

function: ToolCallFunction
"""The function that the model called."""


# Taking this from OpenAI types https://github.com/openai/openai-python/blob/main/src/openai/types/chat/chat_completion.py,
class ChatCompletionMessage(BaseModel):
content: str | None = None
Expand All @@ -91,6 +114,9 @@ class ChatCompletionMessage(BaseModel):
role: Literal["assistant"]
"""The role of the author of this message."""

tool_calls: list[ChatCompletionMessageToolCall] | None = None
"""The tool calls generated by the model, such as function calls."""


class Choice(BaseModel):
index: int
Expand Down
Loading
Loading