Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/memu/app/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,13 @@ class LLMConfig(BaseModel):
default="openai",
description="Identifier for the LLM provider implementation (used by HTTP client backend).",
)
base_url: str = Field(default="https://api.openai.com/v1")
base_url: str = Field(
default="https://api.openai.com/v1",
description="API base URL for the configured provider."
)
api_key: str = Field(default="OPENAI_API_KEY")
chat_model: str = Field(default="gpt-4o-mini")

client_backend: str = Field(
default="sdk",
description="Which LLM client backend to use: 'httpx' (httpx), 'sdk' (official OpenAI), or 'lazyllm_backend' (for more LLM source like Qwen, Doubao, SIliconflow, etc.)",
Expand Down Expand Up @@ -135,6 +139,14 @@ def set_provider_defaults(self) -> "LLMConfig":
self.api_key = "XAI_API_KEY"
if self.chat_model == "gpt-4o-mini":
self.chat_model = "grok-2-latest"
if self.provider == "novita":
# If values match the OpenAI defaults, switch them to Novita defaults
if self.base_url == "https://api.openai.com/v1":
self.base_url = "https://api.novita.ai/openai"
if self.api_key == "OPENAI_API_KEY":
self.api_key = "NOVITA_API_KEY"
if self.chat_model == "gpt-4o-mini":
self.chat_model = "deepseek/deepseek-r1"
return self


Expand Down
3 changes: 2 additions & 1 deletion src/memu/llm/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from memu.llm.backends.base import LLMBackend
from memu.llm.backends.doubao import DoubaoLLMBackend
from memu.llm.backends.grok import GrokBackend
from memu.llm.backends.novita import NovitaBackend
from memu.llm.backends.openai import OpenAILLMBackend
from memu.llm.backends.openrouter import OpenRouterLLMBackend

__all__ = ["DoubaoLLMBackend", "GrokBackend", "LLMBackend", "OpenAILLMBackend", "OpenRouterLLMBackend"]
__all__ = ["DoubaoLLMBackend", "GrokBackend", "LLMBackend", "NovitaBackend", "OpenAILLMBackend", "OpenRouterLLMBackend"]
10 changes: 10 additions & 0 deletions src/memu/llm/backends/novita.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from __future__ import annotations

from memu.llm.backends.openai import OpenAILLMBackend


class NovitaBackend(OpenAILLMBackend):
"""Backend for Novita LLM API."""

name = "novita"
# Novita uses the same payload structure as OpenAI
3 changes: 3 additions & 0 deletions src/memu/llm/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from memu.llm.backends.base import LLMBackend
from memu.llm.backends.doubao import DoubaoLLMBackend
from memu.llm.backends.grok import GrokBackend
from memu.llm.backends.novita import NovitaBackend
from memu.llm.backends.openai import OpenAILLMBackend
from memu.llm.backends.openrouter import OpenRouterLLMBackend

Expand Down Expand Up @@ -73,6 +74,7 @@ def parse_embedding_response(self, data: dict[str, Any]) -> list[list[float]]:
OpenAILLMBackend.name: OpenAILLMBackend,
DoubaoLLMBackend.name: DoubaoLLMBackend,
GrokBackend.name: GrokBackend,
NovitaBackend.name: NovitaBackend,
OpenRouterLLMBackend.name: OpenRouterLLMBackend,
}

Expand Down Expand Up @@ -291,6 +293,7 @@ def _load_embedding_backend(self, provider: str) -> _EmbeddingBackend:
_OpenAIEmbeddingBackend.name: _OpenAIEmbeddingBackend,
_DoubaoEmbeddingBackend.name: _DoubaoEmbeddingBackend,
"grok": _OpenAIEmbeddingBackend,
"novita": _OpenAIEmbeddingBackend,
_OpenRouterEmbeddingBackend.name: _OpenRouterEmbeddingBackend,
}
factory = backends.get(provider)
Expand Down
56 changes: 56 additions & 0 deletions tests/llm/test_novita_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import unittest
from unittest.mock import patch

from memu.app.settings import LLMConfig
from memu.llm.backends.novita import NovitaBackend
from memu.llm.openai_sdk import OpenAISDKClient


class TestNovitaProvider(unittest.IsolatedAsyncioTestCase):
def test_settings_defaults(self):
"""Test that setting provider='novita' sets the correct defaults."""
config = LLMConfig(provider="novita")
self.assertEqual(config.base_url, "https://api.novita.ai/openai")
self.assertEqual(config.api_key, "NOVITA_API_KEY")
self.assertEqual(config.chat_model, "deepseek/deepseek-r1")

def test_settings_do_not_override_non_openai_defaults(self):
"""Test that provider defaults only apply when values are OpenAI defaults."""
config = LLMConfig(
provider="novita",
base_url="https://custom.novita.endpoint/v1",
api_key="CUSTOM_NOVITA_KEY",
chat_model="custom-model",
)
self.assertEqual(config.base_url, "https://custom.novita.endpoint/v1")
self.assertEqual(config.api_key, "CUSTOM_NOVITA_KEY")
self.assertEqual(config.chat_model, "custom-model")

@patch.dict("os.environ", {"NOVITA_API_KEY": "env-key"}, clear=True)
def test_openai_provider_not_auto_switched_by_env(self):
"""Test that NOVITA_API_KEY env var does not auto-switch provider defaults."""
config = LLMConfig(provider="openai")
self.assertEqual(config.base_url, "https://api.openai.com/v1")
self.assertEqual(config.api_key, "OPENAI_API_KEY")

@patch("memu.llm.openai_sdk.AsyncOpenAI")
async def test_client_initialization_with_novita_config(self, mock_async_openai):
"""Test that OpenAISDKClient initializes with Novita base URL when configured."""
config = LLMConfig(provider="novita")

client = OpenAISDKClient(
base_url=config.base_url,
api_key="fake-key",
chat_model=config.chat_model,
embed_model=config.embed_model,
)

mock_async_openai.assert_called_with(api_key="fake-key", base_url="https://api.novita.ai/openai")
self.assertEqual(client.chat_model, "deepseek/deepseek-r1")

def test_novita_backend_payload_parsing(self):
"""Test that NovitaBackend parses responses correctly (inherited from OpenAI)."""
backend = NovitaBackend()
dummy_response = {"choices": [{"message": {"content": "Novita response content", "role": "assistant"}}]}
result = backend.parse_summary_response(dummy_response)
self.assertEqual(result, "Novita response content")