diff --git a/src/memu/app/settings.py b/src/memu/app/settings.py index adcb4f16..a370a2b7 100644 --- a/src/memu/app/settings.py +++ b/src/memu/app/settings.py @@ -135,6 +135,15 @@ def set_provider_defaults(self) -> "LLMConfig": self.api_key = "XAI_API_KEY" if self.chat_model == "gpt-4o-mini": self.chat_model = "grok-2-latest" + elif self.provider == "gemini": + if self.base_url == "https://api.openai.com/v1": + self.base_url = "https://generativelanguage.googleapis.com/v1beta/openai" + if self.api_key == "OPENAI_API_KEY": + self.api_key = "GEMINI_API_KEY" + if self.chat_model == "gpt-4o-mini": + self.chat_model = "gemini-2.0-flash" + if self.embed_model == "text-embedding-3-small": + self.embed_model = "gemini-embedding-001" return self diff --git a/src/memu/llm/backends/__init__.py b/src/memu/llm/backends/__init__.py index 5350e7b2..b3f869ff 100644 --- a/src/memu/llm/backends/__init__.py +++ b/src/memu/llm/backends/__init__.py @@ -1,7 +1,8 @@ from memu.llm.backends.base import LLMBackend from memu.llm.backends.doubao import DoubaoLLMBackend +from memu.llm.backends.gemini import GeminiLLMBackend from memu.llm.backends.grok import GrokBackend from memu.llm.backends.openai import OpenAILLMBackend from memu.llm.backends.openrouter import OpenRouterLLMBackend -__all__ = ["DoubaoLLMBackend", "GrokBackend", "LLMBackend", "OpenAILLMBackend", "OpenRouterLLMBackend"] +__all__ = ["DoubaoLLMBackend", "GeminiLLMBackend", "GrokBackend", "LLMBackend", "OpenAILLMBackend", "OpenRouterLLMBackend"] diff --git a/src/memu/llm/backends/gemini.py b/src/memu/llm/backends/gemini.py new file mode 100644 index 00000000..125c7afa --- /dev/null +++ b/src/memu/llm/backends/gemini.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from memu.llm.backends.openai import OpenAILLMBackend + + +class GeminiLLMBackend(OpenAILLMBackend): + """Backend for Google Gemini via its OpenAI-compatible API endpoint.""" + + name = "gemini" + # Gemini's OpenAI-compatible chat endpoint is the same as OpenAI's + summary_endpoint = "/chat/completions" diff --git a/src/memu/llm/http_client.py b/src/memu/llm/http_client.py index ba84b05b..6d9be74a 100644 --- a/src/memu/llm/http_client.py +++ b/src/memu/llm/http_client.py @@ -11,6 +11,7 @@ from memu.llm.backends.base import LLMBackend from memu.llm.backends.doubao import DoubaoLLMBackend +from memu.llm.backends.gemini import GeminiLLMBackend from memu.llm.backends.grok import GrokBackend from memu.llm.backends.openai import OpenAILLMBackend from memu.llm.backends.openrouter import OpenRouterLLMBackend @@ -72,6 +73,7 @@ def parse_embedding_response(self, data: dict[str, Any]) -> list[list[float]]: LLM_BACKENDS: dict[str, Callable[[], LLMBackend]] = { OpenAILLMBackend.name: OpenAILLMBackend, DoubaoLLMBackend.name: DoubaoLLMBackend, + GeminiLLMBackend.name: GeminiLLMBackend, GrokBackend.name: GrokBackend, OpenRouterLLMBackend.name: OpenRouterLLMBackend, } @@ -291,6 +293,7 @@ def _load_embedding_backend(self, provider: str) -> _EmbeddingBackend: _OpenAIEmbeddingBackend.name: _OpenAIEmbeddingBackend, _DoubaoEmbeddingBackend.name: _DoubaoEmbeddingBackend, "grok": _OpenAIEmbeddingBackend, + "gemini": _OpenAIEmbeddingBackend, _OpenRouterEmbeddingBackend.name: _OpenRouterEmbeddingBackend, } factory = backends.get(provider) diff --git a/tests/llm/test_gemini_provider.py b/tests/llm/test_gemini_provider.py new file mode 100644 index 00000000..9d7990dc --- /dev/null +++ b/tests/llm/test_gemini_provider.py @@ -0,0 +1,133 @@ +import os +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +from memu.app.settings import LLMConfig +from memu.llm.backends.gemini import GeminiLLMBackend +from memu.llm.http_client import HTTPLLMClient, LLM_BACKENDS, _OpenAIEmbeddingBackend + + +class TestGeminiSettings(unittest.TestCase): + def test_settings_defaults(self): + """provider='gemini' should set Gemini-specific defaults.""" + config = LLMConfig(provider="gemini") + self.assertEqual(config.base_url, "https://generativelanguage.googleapis.com/v1beta/openai") + self.assertEqual(config.api_key, "GEMINI_API_KEY") + self.assertEqual(config.chat_model, "gemini-2.0-flash") + self.assertEqual(config.embed_model, "gemini-embedding-001") + + def test_explicit_values_not_overridden(self): + """Explicit values should not be replaced by defaults.""" + config = LLMConfig( + provider="gemini", + chat_model="gemini-2.5-flash", + embed_model="gemini-embedding-001", + api_key="my-real-key", + ) + self.assertEqual(config.chat_model, "gemini-2.5-flash") + self.assertEqual(config.embed_model, "gemini-embedding-001") + self.assertEqual(config.api_key, "my-real-key") + + +class TestGeminiBackend(unittest.TestCase): + def setUp(self): + self.backend = GeminiLLMBackend() + + def test_backend_registered(self): + """GeminiLLMBackend must be in the LLM_BACKENDS registry.""" + self.assertIn("gemini", LLM_BACKENDS) + self.assertIsInstance(LLM_BACKENDS["gemini"](), GeminiLLMBackend) + + def test_summary_endpoint(self): + self.assertEqual(self.backend.summary_endpoint, "/chat/completions") + + def test_build_summary_payload(self): + payload = self.backend.build_summary_payload( + text="Hello world", + system_prompt="Be concise.", + chat_model="gemini-2.0-flash", + max_tokens=100, + ) + self.assertEqual(payload["model"], "gemini-2.0-flash") + self.assertEqual(payload["messages"][0]["role"], "system") + self.assertEqual(payload["messages"][1]["content"], "Hello world") + self.assertEqual(payload["max_tokens"], 100) + + def test_parse_summary_response(self): + dummy = {"choices": [{"message": {"content": "Gemini reply", "role": "assistant"}}]} + result = self.backend.parse_summary_response(dummy) + self.assertEqual(result, "Gemini reply") + + def test_build_vision_payload(self): + payload = self.backend.build_vision_payload( + prompt="Describe this image", + base64_image="abc123", + mime_type="image/png", + system_prompt=None, + chat_model="gemini-2.0-flash", + max_tokens=None, + ) + self.assertEqual(payload["model"], "gemini-2.0-flash") + content = payload["messages"][0]["content"] + image_part = next(p for p in content if p["type"] == "image_url") + self.assertIn("data:image/png;base64,abc123", image_part["image_url"]["url"]) + + +class TestGeminiHTTPClient(unittest.TestCase): + def test_client_loads_gemini_backend(self): + """HTTPLLMClient with provider='gemini' should load GeminiLLMBackend.""" + client = HTTPLLMClient( + base_url="https://generativelanguage.googleapis.com/v1beta/openai", + api_key="fake-key", + chat_model="gemini-2.0-flash", + provider="gemini", + embed_model="gemini-embedding-001", + ) + self.assertIsInstance(client.backend, GeminiLLMBackend) + self.assertIsInstance(client.embedding_backend, _OpenAIEmbeddingBackend) + self.assertEqual(client.embed_model, "gemini-embedding-001") + + def test_embedding_endpoint(self): + client = HTTPLLMClient( + base_url="https://generativelanguage.googleapis.com/v1beta/openai", + api_key="fake-key", + chat_model="gemini-2.0-flash", + provider="gemini", + embed_model="gemini-embedding-001", + ) + self.assertEqual(client.embedding_endpoint, "embeddings") + + +class TestGeminiLiveAPI(unittest.IsolatedAsyncioTestCase): + """Live tests — skipped if GEMINI_API_KEY is not set.""" + + def setUp(self): + self.api_key = os.getenv("GEMINI_API_KEY") + if not self.api_key: + self.skipTest("GEMINI_API_KEY not set") + self.client = HTTPLLMClient( + base_url="https://generativelanguage.googleapis.com/v1beta/openai", + api_key=self.api_key, + chat_model="gemini-2.5-flash", + provider="gemini", + embed_model="gemini-embedding-001", + ) + + async def test_chat(self): + response, _ = await self.client.chat("Say hello in one word.") + self.assertIsInstance(response, str) + self.assertGreater(len(response), 0) + + async def test_summarize(self): + response, _ = await self.client.summarize("The sky is blue and the grass is green.") + self.assertIsInstance(response, str) + self.assertGreater(len(response), 0) + + async def test_embed(self): + vectors, _ = await self.client.embed(["Hello world", "Gemini embeddings"]) + self.assertEqual(len(vectors), 2) + self.assertEqual(len(vectors[0]), 3072) # gemini-embedding-001 dimension + + +if __name__ == "__main__": + unittest.main()