diff --git a/src/memu/app/settings.py b/src/memu/app/settings.py index adcb4f16..1a270cf1 100644 --- a/src/memu/app/settings.py +++ b/src/memu/app/settings.py @@ -104,9 +104,13 @@ class LLMConfig(BaseModel): default="openai", description="Identifier for the LLM provider implementation (used by HTTP client backend).", ) - base_url: str = Field(default="https://api.openai.com/v1") + base_url: str = Field( + default="https://api.openai.com/v1", + description="API base URL for the configured provider." + ) api_key: str = Field(default="OPENAI_API_KEY") chat_model: str = Field(default="gpt-4o-mini") + client_backend: str = Field( default="sdk", description="Which LLM client backend to use: 'httpx' (httpx), 'sdk' (official OpenAI), or 'lazyllm_backend' (for more LLM source like Qwen, Doubao, SIliconflow, etc.)", @@ -135,6 +139,14 @@ def set_provider_defaults(self) -> "LLMConfig": self.api_key = "XAI_API_KEY" if self.chat_model == "gpt-4o-mini": self.chat_model = "grok-2-latest" + if self.provider == "novita": + # If values match the OpenAI defaults, switch them to Novita defaults + if self.base_url == "https://api.openai.com/v1": + self.base_url = "https://api.novita.ai/openai" + if self.api_key == "OPENAI_API_KEY": + self.api_key = "NOVITA_API_KEY" + if self.chat_model == "gpt-4o-mini": + self.chat_model = "deepseek/deepseek-r1" return self diff --git a/src/memu/llm/backends/__init__.py b/src/memu/llm/backends/__init__.py index 5350e7b2..f33e2ed2 100644 --- a/src/memu/llm/backends/__init__.py +++ b/src/memu/llm/backends/__init__.py @@ -1,7 +1,8 @@ from memu.llm.backends.base import LLMBackend from memu.llm.backends.doubao import DoubaoLLMBackend from memu.llm.backends.grok import GrokBackend +from memu.llm.backends.novita import NovitaBackend from memu.llm.backends.openai import OpenAILLMBackend from memu.llm.backends.openrouter import OpenRouterLLMBackend -__all__ = ["DoubaoLLMBackend", "GrokBackend", "LLMBackend", "OpenAILLMBackend", "OpenRouterLLMBackend"] +__all__ = ["DoubaoLLMBackend", "GrokBackend", "LLMBackend", "NovitaBackend", "OpenAILLMBackend", "OpenRouterLLMBackend"] diff --git a/src/memu/llm/backends/novita.py b/src/memu/llm/backends/novita.py new file mode 100644 index 00000000..2f8fce46 --- /dev/null +++ b/src/memu/llm/backends/novita.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from memu.llm.backends.openai import OpenAILLMBackend + + +class NovitaBackend(OpenAILLMBackend): + """Backend for Novita LLM API.""" + + name = "novita" + # Novita uses the same payload structure as OpenAI diff --git a/src/memu/llm/http_client.py b/src/memu/llm/http_client.py index ba84b05b..1056cd28 100644 --- a/src/memu/llm/http_client.py +++ b/src/memu/llm/http_client.py @@ -12,6 +12,7 @@ from memu.llm.backends.base import LLMBackend from memu.llm.backends.doubao import DoubaoLLMBackend from memu.llm.backends.grok import GrokBackend +from memu.llm.backends.novita import NovitaBackend from memu.llm.backends.openai import OpenAILLMBackend from memu.llm.backends.openrouter import OpenRouterLLMBackend @@ -73,6 +74,7 @@ def parse_embedding_response(self, data: dict[str, Any]) -> list[list[float]]: OpenAILLMBackend.name: OpenAILLMBackend, DoubaoLLMBackend.name: DoubaoLLMBackend, GrokBackend.name: GrokBackend, + NovitaBackend.name: NovitaBackend, OpenRouterLLMBackend.name: OpenRouterLLMBackend, } @@ -291,6 +293,7 @@ def _load_embedding_backend(self, provider: str) -> _EmbeddingBackend: _OpenAIEmbeddingBackend.name: _OpenAIEmbeddingBackend, _DoubaoEmbeddingBackend.name: _DoubaoEmbeddingBackend, "grok": _OpenAIEmbeddingBackend, + "novita": _OpenAIEmbeddingBackend, _OpenRouterEmbeddingBackend.name: _OpenRouterEmbeddingBackend, } factory = backends.get(provider) diff --git a/tests/llm/test_novita_provider.py b/tests/llm/test_novita_provider.py new file mode 100644 index 00000000..54677e93 --- /dev/null +++ b/tests/llm/test_novita_provider.py @@ -0,0 +1,56 @@ +import unittest +from unittest.mock import patch + +from memu.app.settings import LLMConfig +from memu.llm.backends.novita import NovitaBackend +from memu.llm.openai_sdk import OpenAISDKClient + + +class TestNovitaProvider(unittest.IsolatedAsyncioTestCase): + def test_settings_defaults(self): + """Test that setting provider='novita' sets the correct defaults.""" + config = LLMConfig(provider="novita") + self.assertEqual(config.base_url, "https://api.novita.ai/openai") + self.assertEqual(config.api_key, "NOVITA_API_KEY") + self.assertEqual(config.chat_model, "deepseek/deepseek-r1") + + def test_settings_do_not_override_non_openai_defaults(self): + """Test that provider defaults only apply when values are OpenAI defaults.""" + config = LLMConfig( + provider="novita", + base_url="https://custom.novita.endpoint/v1", + api_key="CUSTOM_NOVITA_KEY", + chat_model="custom-model", + ) + self.assertEqual(config.base_url, "https://custom.novita.endpoint/v1") + self.assertEqual(config.api_key, "CUSTOM_NOVITA_KEY") + self.assertEqual(config.chat_model, "custom-model") + + @patch.dict("os.environ", {"NOVITA_API_KEY": "env-key"}, clear=True) + def test_openai_provider_not_auto_switched_by_env(self): + """Test that NOVITA_API_KEY env var does not auto-switch provider defaults.""" + config = LLMConfig(provider="openai") + self.assertEqual(config.base_url, "https://api.openai.com/v1") + self.assertEqual(config.api_key, "OPENAI_API_KEY") + + @patch("memu.llm.openai_sdk.AsyncOpenAI") + async def test_client_initialization_with_novita_config(self, mock_async_openai): + """Test that OpenAISDKClient initializes with Novita base URL when configured.""" + config = LLMConfig(provider="novita") + + client = OpenAISDKClient( + base_url=config.base_url, + api_key="fake-key", + chat_model=config.chat_model, + embed_model=config.embed_model, + ) + + mock_async_openai.assert_called_with(api_key="fake-key", base_url="https://api.novita.ai/openai") + self.assertEqual(client.chat_model, "deepseek/deepseek-r1") + + def test_novita_backend_payload_parsing(self): + """Test that NovitaBackend parses responses correctly (inherited from OpenAI).""" + backend = NovitaBackend() + dummy_response = {"choices": [{"message": {"content": "Novita response content", "role": "assistant"}}]} + result = backend.parse_summary_response(dummy_response) + self.assertEqual(result, "Novita response content")