diff --git a/src/any_llm/providers/platform/platform.py b/src/any_llm/providers/platform/platform.py index 58be0f4f5..caac2d0eb 100644 --- a/src/any_llm/providers/platform/platform.py +++ b/src/any_llm/providers/platform/platform.py @@ -104,8 +104,12 @@ async def _acompletion( **kwargs: Any, ) -> ChatCompletion | AsyncIterator[ChatCompletionChunk]: client_name = kwargs.pop("client_name", None) - if self.client_name is None: - self.client_name = client_name + if client_name is not None: + msg = ( + "Passing client_name at request time is not supported for PlatformProvider. " + "Set client_name when creating the provider (for example, AnyLLM.create(..., client_name=...))." + ) + raise ValueError(msg) start_time = time.perf_counter() diff --git a/tests/unit/providers/test_platform_provider.py b/tests/unit/providers/test_platform_provider.py index f5322aff3..5d65ddb61 100644 --- a/tests/unit/providers/test_platform_provider.py +++ b/tests/unit/providers/test_platform_provider.py @@ -1163,14 +1163,14 @@ async def test_usage_event_includes_version_header( @pytest.mark.asyncio @patch("any_llm_platform_client.AnyLLMPlatformClient.get_decrypted_provider_key") @patch("any_llm.providers.platform.platform.post_completion_usage_event") -async def test_acompletion_handles_client_name_in_kwargs( +async def test_acompletion_rejects_client_name_in_kwargs_when_not_initialized( mock_post_usage: AsyncMock, mock_get_decrypted_provider_key: Mock, any_llm_key: str, mock_decrypted_provider_key: DecryptedProviderKey, mock_completion: ChatCompletion, ) -> None: - """Test that _acompletion correctly handles client_name passed in kwargs.""" + """Test that _acompletion rejects request-time client_name when provider was not initialized with it.""" mock_get_decrypted_provider_key.return_value = mock_decrypted_provider_key provider_instance = PlatformProvider(api_key=any_llm_key) @@ -1186,29 +1186,24 @@ async def test_acompletion_handles_client_name_in_kwargs( client_name = "test-client-from-kwargs" - # Call _acompletion with client_name in kwargs - await provider_instance._acompletion(params, client_name=client_name) + with pytest.raises(ValueError, match="Passing client_name at request time is not supported"): + await provider_instance._acompletion(params, client_name=client_name) - # Verify self.client_name was updated - assert provider_instance.client_name == client_name - - # Verify post_completion_usage_event was called with the correct client_name - mock_post_usage.assert_called_once() - call_args = mock_post_usage.call_args - assert call_args.kwargs["client_name"] == client_name + provider_instance.provider._acompletion.assert_not_called() + mock_post_usage.assert_not_called() @pytest.mark.asyncio @patch("any_llm_platform_client.AnyLLMPlatformClient.get_decrypted_provider_key") @patch("any_llm.providers.platform.platform.post_completion_usage_event") -async def test_acompletion_does_not_overwrite_existing_client_name( +async def test_acompletion_rejects_changing_existing_client_name( mock_post_usage: AsyncMock, mock_get_decrypted_provider_key: Mock, any_llm_key: str, mock_decrypted_provider_key: DecryptedProviderKey, mock_completion: ChatCompletion, ) -> None: - """Test that _acompletion does not overwrite an existing client_name if one is already set.""" + """Test that _acompletion rejects a request-time client_name that differs from configured client_name.""" mock_get_decrypted_provider_key.return_value = mock_decrypted_provider_key initial_client_name = "initial-client" @@ -1225,13 +1220,39 @@ async def test_acompletion_does_not_overwrite_existing_client_name( new_client_name = "new-client-from-kwargs" - # Call _acompletion with a new client_name in kwargs - await provider_instance._acompletion(params, client_name=new_client_name) + with pytest.raises(ValueError, match="Passing client_name at request time is not supported"): + await provider_instance._acompletion(params, client_name=new_client_name) - # Verify self.client_name was NOT updated - assert provider_instance.client_name == initial_client_name + provider_instance.provider._acompletion.assert_not_called() + mock_post_usage.assert_not_called() - # Verify post_completion_usage_event was called with the INITIAL client_name - mock_post_usage.assert_called_once() - call_args = mock_post_usage.call_args - assert call_args.kwargs["client_name"] == initial_client_name + +@pytest.mark.asyncio +@patch("any_llm_platform_client.AnyLLMPlatformClient.get_decrypted_provider_key") +@patch("any_llm.providers.platform.platform.post_completion_usage_event") +async def test_acompletion_rejects_same_client_name_in_kwargs( + mock_post_usage: AsyncMock, + mock_get_decrypted_provider_key: Mock, + any_llm_key: str, + mock_decrypted_provider_key: DecryptedProviderKey, + mock_completion: ChatCompletion, +) -> None: + """Test that _acompletion rejects request-time client_name even when it matches configured value.""" + mock_get_decrypted_provider_key.return_value = mock_decrypted_provider_key + + client_name = "configured-client" + provider_instance = PlatformProvider(api_key=any_llm_key, client_name=client_name) + provider_instance.provider = OpenaiProvider + provider_instance.provider._acompletion = AsyncMock(return_value=mock_completion) # type: ignore[method-assign] + + params = CompletionParams( + model_id="gpt-4", + messages=[{"role": "user", "content": "Hello"}], + stream=False, + ) + + with pytest.raises(ValueError, match="Passing client_name at request time is not supported"): + await provider_instance._acompletion(params, client_name=client_name) + + provider_instance.provider._acompletion.assert_not_called() + mock_post_usage.assert_not_called()