Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions tests/test_model_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def __init__(self, model_name: str, provider: object) -> None:
captured["provider"] = provider

monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.delenv("OPENAI_EMBEDDING_MODEL", raising=False)
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-key")
monkeypatch.setenv(
"AZURE_OPENAI_ENDPOINT_EMBEDDING",
Expand All @@ -217,13 +218,10 @@ def __init__(self, model_name: str, provider: object) -> None:
lambda endpoint_envvar, api_key_envvar: provider,
)
monkeypatch.setattr(
"pydantic_ai.embeddings.openai.OpenAIEmbeddingModel",
FakeOpenAIEmbeddingModel,
"pydantic_ai.embeddings.openai.OpenAIEmbeddingModel", FakeOpenAIEmbeddingModel
)
monkeypatch.setattr(
model_adapters,
"_PydanticAIEmbedder",
lambda embedding_model: embedding_model,
model_adapters, "_PydanticAIEmbedder", lambda embedding_model: embedding_model
)

embedder = create_embedding_model()
Expand Down Expand Up @@ -262,8 +260,7 @@ async def request(self, *args: object, **kwargs: object) -> ModelResponse:
lambda endpoint_envvar="AZURE_OPENAI_ENDPOINT", api_key_envvar="AZURE_OPENAI_API_KEY": provider,
)
monkeypatch.setattr(
"pydantic_ai.models.openai.OpenAIChatModel",
FakeOpenAIChatModel,
"pydantic_ai.models.openai.OpenAIChatModel", FakeOpenAIChatModel
)

chat_model = create_chat_model()
Expand Down
47 changes: 39 additions & 8 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@
import pydantic.dataclasses
import typechat

import typeagent.aitools.utils as utils
from typeagent.aitools import utils


def test_timelog():
buf = StringIO()
with redirect_stderr(buf):
with utils.timelog("test block"):
pass
with redirect_stderr(buf), utils.timelog("test block"):
pass
out = buf.getvalue()
assert "test block..." in out

Expand Down Expand Up @@ -136,8 +135,7 @@ def test_query_string_stripped_from_endpoint(
) -> None:
"""Returned endpoint should not contain query string parameters."""
monkeypatch.setenv(
"TEST_ENDPOINT",
"https://myhost.openai.azure.com?api-version=2024-06-01",
"TEST_ENDPOINT", "https://myhost.openai.azure.com?api-version=2024-06-01"
)
endpoint, version = utils.parse_azure_endpoint("TEST_ENDPOINT")
assert endpoint == "https://myhost.openai.azure.com"
Expand Down Expand Up @@ -205,8 +203,41 @@ def test_apim_prefix_preserved(self, monkeypatch: pytest.MonkeyPatch) -> None:
def test_no_api_version_raises(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""RuntimeError when the endpoint has no api-version field."""
monkeypatch.setenv(
"TEST_ENDPOINT",
"https://myhost.openai.azure.com/openai/deployments/gpt-4",
"TEST_ENDPOINT", "https://myhost.openai.azure.com/openai/deployments/gpt-4"
)
with pytest.raises(RuntimeError, match="doesn't contain valid api-version"):
utils.parse_azure_endpoint("TEST_ENDPOINT")


class TestResolveAzureModelName:
"""Tests for resolve_azure_model_name."""

def test_returns_deployment_name_from_endpoint(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Deployment name in the endpoint takes precedence over the fallback."""
monkeypatch.setenv(
"TEST_ENDPOINT",
"https://myhost.openai.azure.com/openai/deployments/gpt-4o-custom?api-version=2025-01-01-preview",
)
result = utils.resolve_azure_model_name("gpt-4o", "TEST_ENDPOINT")
assert result == "gpt-4o-custom"

def test_falls_back_to_model_name(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Falls back to the provided model_name when no deployment in the endpoint."""
monkeypatch.setenv(
"TEST_ENDPOINT", "https://myhost.openai.azure.com?api-version=2024-06-01"
)
result = utils.resolve_azure_model_name("gpt-4o", "TEST_ENDPOINT")
assert result == "gpt-4o"

def test_uses_default_endpoint_envvar(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Uses AZURE_OPENAI_ENDPOINT by default."""
monkeypatch.setenv(
"AZURE_OPENAI_ENDPOINT",
"https://myhost.openai.azure.com/openai/deployments/my-deploy?api-version=2024-06-01",
)
result = utils.resolve_azure_model_name("gpt-4o")
assert result == "my-deploy"
Loading