diff --git a/tests/test_model_adapters.py b/tests/test_model_adapters.py index a6f0770d..eaf36f88 100644 --- a/tests/test_model_adapters.py +++ b/tests/test_model_adapters.py @@ -206,6 +206,7 @@ def __init__(self, model_name: str, provider: object) -> None: captured["provider"] = provider monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.delenv("OPENAI_EMBEDDING_MODEL", raising=False) monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-key") monkeypatch.setenv( "AZURE_OPENAI_ENDPOINT_EMBEDDING", @@ -217,13 +218,10 @@ def __init__(self, model_name: str, provider: object) -> None: lambda endpoint_envvar, api_key_envvar: provider, ) monkeypatch.setattr( - "pydantic_ai.embeddings.openai.OpenAIEmbeddingModel", - FakeOpenAIEmbeddingModel, + "pydantic_ai.embeddings.openai.OpenAIEmbeddingModel", FakeOpenAIEmbeddingModel ) monkeypatch.setattr( - model_adapters, - "_PydanticAIEmbedder", - lambda embedding_model: embedding_model, + model_adapters, "_PydanticAIEmbedder", lambda embedding_model: embedding_model ) embedder = create_embedding_model() @@ -262,8 +260,7 @@ async def request(self, *args: object, **kwargs: object) -> ModelResponse: lambda endpoint_envvar="AZURE_OPENAI_ENDPOINT", api_key_envvar="AZURE_OPENAI_API_KEY": provider, ) monkeypatch.setattr( - "pydantic_ai.models.openai.OpenAIChatModel", - FakeOpenAIChatModel, + "pydantic_ai.models.openai.OpenAIChatModel", FakeOpenAIChatModel ) chat_model = create_chat_model() diff --git a/tests/test_utils.py b/tests/test_utils.py index cd3336f7..ad526129 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -11,14 +11,13 @@ import pydantic.dataclasses import typechat -import typeagent.aitools.utils as utils +from typeagent.aitools import utils def test_timelog(): buf = StringIO() - with redirect_stderr(buf): - with utils.timelog("test block"): - pass + with redirect_stderr(buf), utils.timelog("test block"): + pass out = buf.getvalue() assert "test block..." in out @@ -136,8 +135,7 @@ def test_query_string_stripped_from_endpoint( ) -> None: """Returned endpoint should not contain query string parameters.""" monkeypatch.setenv( - "TEST_ENDPOINT", - "https://myhost.openai.azure.com?api-version=2024-06-01", + "TEST_ENDPOINT", "https://myhost.openai.azure.com?api-version=2024-06-01" ) endpoint, version = utils.parse_azure_endpoint("TEST_ENDPOINT") assert endpoint == "https://myhost.openai.azure.com" @@ -205,8 +203,41 @@ def test_apim_prefix_preserved(self, monkeypatch: pytest.MonkeyPatch) -> None: def test_no_api_version_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: """RuntimeError when the endpoint has no api-version field.""" monkeypatch.setenv( - "TEST_ENDPOINT", - "https://myhost.openai.azure.com/openai/deployments/gpt-4", + "TEST_ENDPOINT", "https://myhost.openai.azure.com/openai/deployments/gpt-4" ) with pytest.raises(RuntimeError, match="doesn't contain valid api-version"): utils.parse_azure_endpoint("TEST_ENDPOINT") + + +class TestResolveAzureModelName: + """Tests for resolve_azure_model_name.""" + + def test_returns_deployment_name_from_endpoint( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Deployment name in the endpoint takes precedence over the fallback.""" + monkeypatch.setenv( + "TEST_ENDPOINT", + "https://myhost.openai.azure.com/openai/deployments/gpt-4o-custom?api-version=2025-01-01-preview", + ) + result = utils.resolve_azure_model_name("gpt-4o", "TEST_ENDPOINT") + assert result == "gpt-4o-custom" + + def test_falls_back_to_model_name(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Falls back to the provided model_name when no deployment in the endpoint.""" + monkeypatch.setenv( + "TEST_ENDPOINT", "https://myhost.openai.azure.com?api-version=2024-06-01" + ) + result = utils.resolve_azure_model_name("gpt-4o", "TEST_ENDPOINT") + assert result == "gpt-4o" + + def test_uses_default_endpoint_envvar( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Uses AZURE_OPENAI_ENDPOINT by default.""" + monkeypatch.setenv( + "AZURE_OPENAI_ENDPOINT", + "https://myhost.openai.azure.com/openai/deployments/my-deploy?api-version=2024-06-01", + ) + result = utils.resolve_azure_model_name("gpt-4o") + assert result == "my-deploy"