Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
7 changes: 6 additions & 1 deletion AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,13 @@ that make changes to the repository. Not even `git add`**

When moving, copying or deleting files, use the git commands: `git mv`, `git cp`, `git rm`

When I ask to update AGENTS.md (even if maybe) extract a general rule from what I said
before and update AGENTS.md (unless it's already in there -- maybe reformulate since
it apparently didn't work). Also, when it looks like I state a general rule, add it to
AGENTS.md. In all cases show what you added to AGENTS.md.

- Don't use '!' on the command line, it's some bash magic (even inside single quotes)
- Activate `.venv`: `make venv; source .venv/bin/activate` (run this only once)
- When running 'make' commands, do not use the venv (the Makefile uses 'uv run')
- To get API keys in ad-hoc code, call `load_dotenv()`
- Use `pytest test` to run tests in test/
- Use `pyright` to check type annotations in src/, tools/, tests/, examples/
Expand Down
15 changes: 13 additions & 2 deletions src/typeagent/aitools/model_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,18 @@ def create_chat_model(
if _needs_azure_fallback(provider):
from pydantic_ai.models.openai import OpenAIChatModel

from .utils import parse_azure_endpoint_parts

if os.getenv("OPENAI_MODEL"):
print(
f"OPENAI_MODEL={os.getenv('OPENAI_MODEL')!r} ignored; "
f"Azure deployment is determined by AZURE_OPENAI_ENDPOINT"
)
model = OpenAIChatModel(model_name, provider=_make_azure_provider())
_, _, deployment_name = parse_azure_endpoint_parts()
model = OpenAIChatModel(
deployment_name or model_name,
provider=_make_azure_provider(),
)
else:
model = infer_model(model_spec)
return PydanticAIChatModel(model)
Expand Down Expand Up @@ -283,6 +289,7 @@ def create_embedding_model(
from pydantic_ai.embeddings.openai import OpenAIEmbeddingModel

from .embeddings import model_to_envvar
from .utils import parse_azure_endpoint_parts

# Look up model-specific Azure endpoint, falling back to the generic one.
suggested_envvar = model_to_envvar.get(model_name)
Expand All @@ -296,7 +303,11 @@ def create_embedding_model(
api_key_envvar = "AZURE_OPENAI_API_KEY"

azure_provider = _make_azure_provider(endpoint_envvar, api_key_envvar)
embedding_model = OpenAIEmbeddingModel(model_name, provider=azure_provider)
_, _, deployment_name = parse_azure_endpoint_parts(endpoint_envvar)
embedding_model = OpenAIEmbeddingModel(
deployment_name or model_name,
provider=azure_provider,
)
embedder = _PydanticAIEmbedder(embedding_model)
else:
embedder = _PydanticAIEmbedder(model_spec)
Expand Down
45 changes: 41 additions & 4 deletions src/typeagent/aitools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,21 @@ def parse_azure_endpoint(
Raises:
RuntimeError: If endpoint is not found or doesn't contain api-version.
"""
endpoint, version, _ = parse_azure_endpoint_parts(endpoint_envvar)
return endpoint, version


def parse_azure_endpoint_parts(
endpoint_envvar: str = "AZURE_OPENAI_ENDPOINT",
) -> tuple[str, str, str | None]:
"""Parse Azure OpenAI endpoint, version, and optional deployment name.

Returns:
Tuple of (endpoint_url, api_version, deployment_name).

The deployment name is extracted from endpoints of the form
``.../openai/deployments/<deployment>/...`` and is ``None`` otherwise.
"""
azure_endpoint = os.getenv(endpoint_envvar)
if not azure_endpoint:
raise RuntimeError(f"Environment variable {endpoint_envvar} not found")
Expand All @@ -200,12 +215,22 @@ def parse_azure_endpoint(
f"{endpoint_envvar}={azure_endpoint} doesn't contain valid api-version field"
)

clean_endpoint = azure_endpoint.split("?", 1)[0]
deployment_match = re.search(
r"/openai/deployments/([^/?]+)(?:/.*)?$",
clean_endpoint,
)
deployment_name = deployment_match.group(1) if deployment_match else None

# Strip query string and /openai... path — AsyncAzureOpenAI expects a
# clean base URL and builds the deployment path internally.
clean_endpoint = azure_endpoint.split("?", 1)[0]
clean_endpoint = re.sub(r"/openai(/deployments/.*)?$", "", clean_endpoint)
clean_endpoint = re.sub(
r"/openai(?:/deployments/[^/?]+(?:/.*)?)?$",
"",
clean_endpoint,
)

return clean_endpoint, m.group(1)
return clean_endpoint, m.group(1), deployment_name


def get_azure_api_key(azure_api_key: str) -> str:
Expand Down Expand Up @@ -272,6 +297,15 @@ def create_async_openai_client(
)


def resolve_azure_model_name(
model_name: str,
endpoint_envvar: str = "AZURE_OPENAI_ENDPOINT",
) -> str:
"""Resolve an Azure deployment name from an endpoint, if present."""
_, _, deployment_name = parse_azure_endpoint_parts(endpoint_envvar)
return deployment_name or model_name


# The true return type is pydantic_ai.Agent[T], but that's an optional dependency.
def make_agent[T](cls: type[T]):
Comment thread
bmerkle marked this conversation as resolved.
"""Create Pydantic AI agent using hardcoded preferences."""
Expand All @@ -291,7 +325,10 @@ def make_agent[T](cls: type[T]):
Wrapper = ToolOutput

print(f"## Using Azure with {Wrapper.__name__} ##")
model = OpenAIChatModel("gpt-4o", provider=azure_provider)
model = OpenAIChatModel(
resolve_azure_model_name("gpt-4o"),
provider=azure_provider,
)

else:
raise RuntimeError(
Expand Down
71 changes: 68 additions & 3 deletions tests/test_mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import json
import os
import sys
from typing import Any
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import AsyncMock

import pytest
Expand All @@ -24,7 +25,7 @@
from openai.types.chat import ChatCompletionMessageParam
import typechat

from typeagent.aitools.utils import create_async_openai_client
from typeagent.aitools.utils import create_async_openai_client, resolve_azure_model_name
from typeagent.mcp.server import MCPTypeChatModel, QuestionResponse

from conftest import EPISODE_53_INDEX
Expand Down Expand Up @@ -76,8 +77,12 @@ async def sampling_callback(
messages.insert(0, {"role": "system", "content": params.systemPrompt})

# Call OpenAI
model_name = "gpt-4o"
if os.getenv("AZURE_OPENAI_API_KEY") and not os.getenv("OPENAI_API_KEY"):
model_name = resolve_azure_model_name(model_name)

response = await client.chat.completions.create(
model="gpt-4o",
model=model_name,
messages=messages,
max_tokens=params.maxTokens,
temperature=params.temperature if params.temperature is not None else 1.0,
Expand Down Expand Up @@ -343,3 +348,63 @@ def test_answer_type_coverage(self) -> None:
assert answered.type == "Answered"
no_answer = AnswerResponse(type="NoAnswer", why_no_answer="dunno")
assert no_answer.type == "NoAnswer"


@pytest.mark.asyncio
async def test_sampling_callback_uses_azure_deployment_name(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Azure-only sampling should send the resolved deployment name."""
create = AsyncMock(
return_value=SimpleNamespace(
choices=[
SimpleNamespace(
message=SimpleNamespace(content="response"),
)
],
model="gpt-4o-2",
)
)
fake_client = SimpleNamespace(
chat=SimpleNamespace(
completions=SimpleNamespace(
create=create,
)
)
)

monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-key")
monkeypatch.setattr(
sys.modules[__name__],
"create_async_openai_client",
lambda: fake_client,
)
monkeypatch.setattr(
sys.modules[__name__],
"resolve_azure_model_name",
lambda model_name: f"{model_name}-2",
)

params = CreateMessageRequestParams(
messages=[
SamplingMessage(
role="user",
content=TextContent(type="text", text="hello"),
)
],
maxTokens=32,
)

result = await sampling_callback(
cast(RequestContext[ClientSessionType, Any, Any], None),
params,
)

create.assert_awaited_once_with(
model="gpt-4o-2",
messages=[{"role": "user", "content": "hello"}],
max_tokens=32,
temperature=1.0,
)
assert result.model == "gpt-4o-2"
82 changes: 82 additions & 0 deletions tests/test_model_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
from pydantic_ai.models import Model
import typechat

from typeagent.aitools import model_adapters
from typeagent.aitools.embeddings import CachingEmbeddingModel, NormalizedEmbedding
from typeagent.aitools.model_adapters import (
configure_models,
create_chat_model,
create_embedding_model,
PydanticAIChatModel,
PydanticAIEmbedder,
)
Expand Down Expand Up @@ -189,3 +191,83 @@ def test_configure_models_returns_correct_types(
chat, embedder = configure_models("openai:gpt-4o", "openai:text-embedding-3-small")
assert isinstance(chat, PydanticAIChatModel)
assert isinstance(embedder, CachingEmbeddingModel)


def test_create_embedding_model_uses_azure_deployment_name(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Azure embedding endpoints contribute the deployment name."""
captured: dict[str, object] = {}
provider = object()

class FakeOpenAIEmbeddingModel:
def __init__(self, model_name: str, provider: object) -> None:
captured["azure_model_name"] = model_name
captured["provider"] = provider

monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-key")
monkeypatch.setenv(
"AZURE_OPENAI_ENDPOINT_EMBEDDING",
"https://myhost.openai.azure.com/openai/deployments/ada-002/embeddings?api-version=2025-01-01-preview",
)
monkeypatch.setattr(
model_adapters,
"_make_azure_provider",
lambda endpoint_envvar, api_key_envvar: provider,
)
monkeypatch.setattr(
"pydantic_ai.embeddings.openai.OpenAIEmbeddingModel",
FakeOpenAIEmbeddingModel,
)
monkeypatch.setattr(
model_adapters,
"_PydanticAIEmbedder",
lambda embedding_model: embedding_model,
)

embedder = create_embedding_model()

assert isinstance(embedder, CachingEmbeddingModel)
assert captured["azure_model_name"] == "ada-002"
assert captured["provider"] is provider
assert embedder.model_name == "text-embedding-ada-002"


def test_create_chat_model_uses_azure_deployment_name(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Azure chat endpoints contribute the deployment name."""
captured: dict[str, object] = {}
provider = object()

class FakeOpenAIChatModel:
def __init__(self, model_name: str, provider: object) -> None:
captured["azure_model_name"] = model_name
captured["provider"] = provider

async def request(self, *args: object, **kwargs: object) -> ModelResponse:
raise AssertionError("request() should not be called in this test")

monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.delenv("OPENAI_MODEL", raising=False)
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-key")
monkeypatch.setenv(
"AZURE_OPENAI_ENDPOINT",
"https://myhost.openai.azure.com/openai/deployments/gpt-4o-2/chat/completions?api-version=2025-01-01-preview",
)
monkeypatch.setattr(
model_adapters,
"_make_azure_provider",
lambda endpoint_envvar="AZURE_OPENAI_ENDPOINT", api_key_envvar="AZURE_OPENAI_API_KEY": provider,
)
monkeypatch.setattr(
"pydantic_ai.models.openai.OpenAIChatModel",
FakeOpenAIChatModel,
)

chat_model = create_chat_model()

assert isinstance(chat_model, PydanticAIChatModel)
assert captured["azure_model_name"] == "gpt-4o-2"
assert captured["provider"] is provider
13 changes: 13 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,19 @@ def test_query_string_stripped_with_path(
assert "?" not in endpoint
assert version == "2025-01-01-preview"

def test_deployment_name_extracted(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Deployment name is extracted from deployment-style endpoints."""
monkeypatch.setenv(
"TEST_ENDPOINT",
"https://myhost.openai.azure.com/openai/deployments/ada-002/embeddings?api-version=2025-01-01-preview",
)
endpoint, version, deployment = utils.parse_azure_endpoint_parts(
"TEST_ENDPOINT"
)
assert endpoint == "https://myhost.openai.azure.com"
assert version == "2025-01-01-preview"
assert deployment == "ada-002"

def test_query_string_stripped_multiple_params(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
WEBVTT

NOTE Intro titles, announcer, beach 'It's' man, end credits and final score.

00:00:00.000 --> 00:00:05.000
[action] A ragged man struggles out of the sea onto a beach, collapses, and announces.

00:00:05.000 --> 00:00:07.000
<v It's Man>
It's…

00:00:07.000 --> 00:00:10.000
<v Voice Over>
Monty Python's Flying Circus.

00:00:10.000 --> 00:00:20.000
[action] Titles begin with the words 'Monty Python's Flying Circus'. Various bizarre things happen. Titles end on an ordinary grey-suited announcer by a desk, smiling confidently.

00:00:20.000 --> 00:00:22.000
<v Announcer>
Good evening.

00:00:22.000 --> 00:00:30.000
[action] Announcer sits; a loud squeal like a pig being sat upon. Cut to a blackboard with rows of pigs drawn. A man crosses one pig off in chalk.

00:29:30.000 --> 00:29:40.000
[action] The seashore again. The 'It's' man lies on the beach. A stick from off-screen prods him. Exhausted, he rises and staggers back into the sea.

00:29:40.000 --> 00:29:50.000
[caption] "WHITHER CANADA" WAS CONCEIVED WRITTEN AND PERFORMED BY… (CREDITS)

00:29:50.000 --> 00:29:58.000
<v Announcer>
And here is the final score: Pigs 9 - British Bipeds 4. The Pigs go on to meet Vicki Carr in the final.

Loading
Loading