Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class _LLMOptions:
seed: NotGivenOr[int]
safety_settings: NotGivenOr[list[types.SafetySettingOrDict]]
service_tier: NotGivenOr[types.ServiceTier]
cached_content: NotGivenOr[str]


BLOCKED_REASONS = [
Expand Down Expand Up @@ -119,6 +120,7 @@ def __init__(
seed: NotGivenOr[int] = NOT_GIVEN,
safety_settings: NotGivenOr[list[types.SafetySettingOrDict]] = NOT_GIVEN,
service_tier: NotGivenOr[types.ServiceTier] = NOT_GIVEN,
cached_content: NotGivenOr[str] = NOT_GIVEN,
credentials: google.auth.credentials.Credentials | None = None,
) -> None:
"""
Expand Down Expand Up @@ -151,6 +153,7 @@ def __init__(
seed (int, optional): Random seed for reproducible generation. Defaults to None.
safety_settings (list[SafetySettingOrDict], optional): Safety settings for content filtering. Defaults to None.
service_tier (types.ServiceTier, optional): The service tier for the request (e.g. types.ServiceTier.PRIORITY). Defaults to None.
cached_content (str, optional): Resource name of an explicit context cache to attach to every request from this LLM instance, e.g. ``"cachedContents/abc123"`` for the Gemini API or ``"projects/<project>/locations/<location>/cachedContents/abc123"`` for VertexAI. The cache must already exist — create it via ``client.caches.create(...)`` and pass the returned ``name``. Gemini rejects ``generateContent`` requests that combine ``cached_content`` with ``system_instruction``, ``tools``, or ``tool_config``, so when this option is set the plugin bakes those fields out of every outgoing request; the cache resource itself must contain whichever of them the model needs (typically the system prompt and the tool schemas). Useful for long-lived static prefixes where implicit caching is unreliable. See https://ai.google.dev/gemini-api/docs/caching for details and minimum prefix-token requirements. Defaults to None.
""" # noqa: E501
super().__init__()
gcp_project = project if is_given(project) else os.environ.get("GOOGLE_CLOUD_PROJECT")
Expand Down Expand Up @@ -224,6 +227,7 @@ def __init__(
seed=seed,
safety_settings=safety_settings,
service_tier=service_tier,
cached_content=cached_content,
)
self._client = Client(
api_key=gemini_api_key,
Expand Down Expand Up @@ -391,6 +395,9 @@ def chat(
if is_given(self._opts.service_tier):
extra["service_tier"] = self._opts.service_tier

if is_given(self._opts.cached_content):
extra["cached_content"] = self._opts.cached_content

return LLMStream(
self,
client=self._client,
Expand Down Expand Up @@ -437,8 +444,19 @@ async def _run(self) -> None:
turns = [types.Content.model_validate(turn) for turn in turns_dict]
tool_context = llm.ToolContext(self._tools)
tools_config = create_tools_config(tool_context, _only_single_type=True)
if tools_config:
# Gemini's API rejects `generateContent` requests that pass
# `cached_content` together with `system_instruction`, `tools`,
# or `tool_config` — those fields must live INSIDE the
# CachedContent resource, not on the request. The application
# bakes them into the cache via `client.caches.create(...)`;
# here we just suppress the duplicates on the outgoing request
# whenever a cache is attached.
using_cache = "cached_content" in self._extra_kwargs
if tools_config and not using_cache:
self._extra_kwargs["tools"] = tools_config
elif using_cache:
self._extra_kwargs.pop("tools", None)
self._extra_kwargs.pop("tool_config", None)
http_options = self._llm._opts.http_options or types.HttpOptions(
timeout=int(self._conn_options.timeout * 1000)
)
Expand All @@ -447,9 +465,13 @@ async def _run(self) -> None:
http_options.headers["x-goog-api-client"] = f"livekit-agents/{__version__}"
config = types.GenerateContentConfig(
system_instruction=(
[types.Part(text=content) for content in extra_data.system_messages]
if extra_data.system_messages
else None
None
if using_cache
else (
[types.Part(text=content) for content in extra_data.system_messages]
if extra_data.system_messages
else None
)
),
http_options=http_options,
**self._extra_kwargs,
Expand Down
179 changes: 177 additions & 2 deletions tests/test_plugin_google_llm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from google.genai import types

from livekit.plugins.google.llm import LLMStream
from livekit.agents.llm import ChatContext, function_tool
from livekit.plugins.google.llm import LLM, LLMStream


@pytest.fixture
Expand Down Expand Up @@ -63,3 +64,177 @@ def test_empty_text_part_returns_none(self, llm_stream: LLMStream):
chunk = llm_stream._parse_part("test-id", part)

assert chunk is None


class TestCachedContentOption:
"""Verify the ``cached_content`` constructor option propagates from
``LLM.__init__`` through ``_LLMOptions`` and into the keyword
arguments handed to ``GenerateContentConfig`` for every request.

The propagation tests are ``async def`` because ``LLM.chat()`` builds
an ``LLMStream`` whose constructor schedules a metrics-monitoring
task on the running event loop. A sync test would raise
``RuntimeError: no running event loop`` before reaching the
assertion.
"""

@pytest.mark.asyncio
async def test_cached_content_propagates_to_extra_kwargs(self) -> None:
llm = LLM(model="gemini-2.5-flash", api_key="test", cached_content="cachedContents/abc123")

stream = llm.chat(chat_ctx=ChatContext.empty())
try:
assert stream._extra_kwargs.get("cached_content") == "cachedContents/abc123"
finally:
await stream.aclose()

@pytest.mark.asyncio
async def test_cached_content_omitted_when_not_set(self) -> None:
"""Backward compatibility: callers that don't pass
``cached_content`` must produce a request config without the
field, so existing behaviour is unchanged."""
llm = LLM(model="gemini-2.5-flash", api_key="test")

stream = llm.chat(chat_ctx=ChatContext.empty())
try:
assert "cached_content" not in stream._extra_kwargs
finally:
await stream.aclose()

def test_cached_content_stored_on_opts(self) -> None:
llm = LLM(
model="gemini-2.5-flash",
api_key="test",
cached_content="projects/p/locations/us-central1/cachedContents/xyz",
)

assert llm._opts.cached_content == "projects/p/locations/us-central1/cachedContents/xyz"


class TestCachedContentRequestSuppression:
"""Gemini's API rejects ``generateContent`` requests that pass
``cached_content`` together with ``system_instruction``, ``tools``,
or ``tool_config`` — those fields belong inside the CachedContent
resource. The plugin therefore strips them off the outgoing request
whenever a cache is attached. These tests run the LLMStream against
a stubbed ``generate_content_stream`` and assert on the
``GenerateContentConfig`` it received.
"""

@staticmethod
async def _single_response_async_iter():
"""Emit one minimal-but-valid GenerateContentResponse so the
retry layer in livekit.agents.LLM doesn't treat the stream as
empty and re-issue the request three more times."""
yield types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(
role="model",
parts=[types.Part(text="ok")],
),
finish_reason=types.FinishReason.STOP,
)
],
)

@classmethod
def _patched_stream_capture(cls) -> tuple[AsyncMock, dict]:
captured: dict = {}

async def fake_stream(**kwargs):
captured["model"] = kwargs.get("model")
captured["contents"] = kwargs.get("contents")
captured["config"] = kwargs.get("config")
return cls._single_response_async_iter()

return AsyncMock(side_effect=fake_stream), captured

@pytest.mark.asyncio
async def test_request_omits_system_instruction_when_cached_content_set(self) -> None:
"""With a cache attached, the outgoing request must carry
``system_instruction=None`` — the system prompt lives in the
cache resource and re-sending it would make Gemini return 400."""
llm = LLM(
model="gemini-2.5-flash",
api_key="test",
cached_content="cachedContents/abc123",
)

chat_ctx = ChatContext.empty()
chat_ctx.add_message(role="system", content="system prompt that lives in cache")
chat_ctx.add_message(role="user", content="hi")

fake, captured = self._patched_stream_capture()
with patch.object(llm._client.aio.models, "generate_content_stream", fake):
stream = llm.chat(chat_ctx=chat_ctx)
try:
async for _ in stream:
pass
finally:
await stream.aclose()

config = captured["config"]
assert config.system_instruction is None
assert config.cached_content == "cachedContents/abc123"

@pytest.mark.asyncio
async def test_request_omits_tools_when_cached_content_set(self) -> None:
"""With a cache attached, the outgoing request must NOT include
``tools`` even if the LLMStream was constructed with function
tools — the tool schemas belong inside the cache resource."""

@function_tool
async def example_tool(query: str) -> str:
"""Look something up."""
return query

llm = LLM(
model="gemini-2.5-flash",
api_key="test",
cached_content="cachedContents/abc123",
)

fake, captured = self._patched_stream_capture()
with patch.object(llm._client.aio.models, "generate_content_stream", fake):
stream = llm.chat(chat_ctx=ChatContext.empty(), tools=[example_tool])
try:
async for _ in stream:
pass
finally:
await stream.aclose()

config = captured["config"]
assert config.tools is None
assert config.tool_config is None
assert config.cached_content == "cachedContents/abc123"

@pytest.mark.asyncio
async def test_request_includes_system_instruction_and_tools_when_no_cache(self) -> None:
"""Backward compatibility: without ``cached_content``, the
request still carries ``system_instruction`` and ``tools`` as
before. Suppression is gated strictly on the cache being set."""

@function_tool
async def example_tool(query: str) -> str:
"""Look something up."""
return query

llm = LLM(model="gemini-2.5-flash", api_key="test")

chat_ctx = ChatContext.empty()
chat_ctx.add_message(role="system", content="system prompt sent on every request")
chat_ctx.add_message(role="user", content="hi")

fake, captured = self._patched_stream_capture()
with patch.object(llm._client.aio.models, "generate_content_stream", fake):
stream = llm.chat(chat_ctx=chat_ctx, tools=[example_tool])
try:
async for _ in stream:
pass
finally:
await stream.aclose()

config = captured["config"]
assert config.system_instruction is not None
assert config.tools is not None and len(config.tools) >= 1
Loading