Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/any_llm/providers/platform/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,12 @@ async def _acompletion(
**kwargs: Any,
) -> ChatCompletion | AsyncIterator[ChatCompletionChunk]:
client_name = kwargs.pop("client_name", None)
if self.client_name is None:
self.client_name = client_name
if client_name is not None:
msg = (
"Passing client_name at request time is not supported for PlatformProvider. "
"Set client_name when creating the provider (for example, AnyLLM.create(..., client_name=...))."
)
raise ValueError(msg)

start_time = time.perf_counter()

Expand Down
63 changes: 42 additions & 21 deletions tests/unit/providers/test_platform_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,14 +1163,14 @@ async def test_usage_event_includes_version_header(
@pytest.mark.asyncio
@patch("any_llm_platform_client.AnyLLMPlatformClient.get_decrypted_provider_key")
@patch("any_llm.providers.platform.platform.post_completion_usage_event")
async def test_acompletion_handles_client_name_in_kwargs(
async def test_acompletion_rejects_client_name_in_kwargs_when_not_initialized(
mock_post_usage: AsyncMock,
mock_get_decrypted_provider_key: Mock,
any_llm_key: str,
mock_decrypted_provider_key: DecryptedProviderKey,
mock_completion: ChatCompletion,
) -> None:
"""Test that _acompletion correctly handles client_name passed in kwargs."""
"""Test that _acompletion rejects request-time client_name when provider was not initialized with it."""
mock_get_decrypted_provider_key.return_value = mock_decrypted_provider_key

provider_instance = PlatformProvider(api_key=any_llm_key)
Expand All @@ -1186,29 +1186,24 @@ async def test_acompletion_handles_client_name_in_kwargs(

client_name = "test-client-from-kwargs"

# Call _acompletion with client_name in kwargs
await provider_instance._acompletion(params, client_name=client_name)
with pytest.raises(ValueError, match="Passing client_name at request time is not supported"):
await provider_instance._acompletion(params, client_name=client_name)

# Verify self.client_name was updated
assert provider_instance.client_name == client_name

# Verify post_completion_usage_event was called with the correct client_name
mock_post_usage.assert_called_once()
call_args = mock_post_usage.call_args
assert call_args.kwargs["client_name"] == client_name
provider_instance.provider._acompletion.assert_not_called()
mock_post_usage.assert_not_called()


@pytest.mark.asyncio
@patch("any_llm_platform_client.AnyLLMPlatformClient.get_decrypted_provider_key")
@patch("any_llm.providers.platform.platform.post_completion_usage_event")
async def test_acompletion_does_not_overwrite_existing_client_name(
async def test_acompletion_rejects_changing_existing_client_name(
mock_post_usage: AsyncMock,
mock_get_decrypted_provider_key: Mock,
any_llm_key: str,
mock_decrypted_provider_key: DecryptedProviderKey,
mock_completion: ChatCompletion,
) -> None:
"""Test that _acompletion does not overwrite an existing client_name if one is already set."""
"""Test that _acompletion rejects a request-time client_name that differs from configured client_name."""
mock_get_decrypted_provider_key.return_value = mock_decrypted_provider_key

initial_client_name = "initial-client"
Expand All @@ -1225,13 +1220,39 @@ async def test_acompletion_does_not_overwrite_existing_client_name(

new_client_name = "new-client-from-kwargs"

# Call _acompletion with a new client_name in kwargs
await provider_instance._acompletion(params, client_name=new_client_name)
with pytest.raises(ValueError, match="Passing client_name at request time is not supported"):
await provider_instance._acompletion(params, client_name=new_client_name)

# Verify self.client_name was NOT updated
assert provider_instance.client_name == initial_client_name
provider_instance.provider._acompletion.assert_not_called()
mock_post_usage.assert_not_called()

# Verify post_completion_usage_event was called with the INITIAL client_name
mock_post_usage.assert_called_once()
call_args = mock_post_usage.call_args
assert call_args.kwargs["client_name"] == initial_client_name

@pytest.mark.asyncio
@patch("any_llm_platform_client.AnyLLMPlatformClient.get_decrypted_provider_key")
@patch("any_llm.providers.platform.platform.post_completion_usage_event")
async def test_acompletion_rejects_same_client_name_in_kwargs(
mock_post_usage: AsyncMock,
mock_get_decrypted_provider_key: Mock,
any_llm_key: str,
mock_decrypted_provider_key: DecryptedProviderKey,
mock_completion: ChatCompletion,
) -> None:
"""Test that _acompletion rejects request-time client_name even when it matches configured value."""
mock_get_decrypted_provider_key.return_value = mock_decrypted_provider_key

client_name = "configured-client"
provider_instance = PlatformProvider(api_key=any_llm_key, client_name=client_name)
provider_instance.provider = OpenaiProvider
provider_instance.provider._acompletion = AsyncMock(return_value=mock_completion) # type: ignore[method-assign]

params = CompletionParams(
model_id="gpt-4",
messages=[{"role": "user", "content": "Hello"}],
stream=False,
)

with pytest.raises(ValueError, match="Passing client_name at request time is not supported"):
await provider_instance._acompletion(params, client_name=client_name)

provider_instance.provider._acompletion.assert_not_called()
mock_post_usage.assert_not_called()