diff --git a/src/memu/app/service.py b/src/memu/app/service.py index 4e2dea04..0002ed59 100644 --- a/src/memu/app/service.py +++ b/src/memu/app/service.py @@ -107,6 +107,7 @@ def _init_llm_client(self, config: LLMConfig | None = None) -> Any: chat_model=cfg.chat_model, embed_model=cfg.embed_model, embed_batch_size=cfg.embed_batch_size, + proxy=cfg.proxy, ) elif backend == "httpx": return HTTPLLMClient( @@ -116,6 +117,7 @@ def _init_llm_client(self, config: LLMConfig | None = None) -> Any: provider=cfg.provider, endpoint_overrides=cfg.endpoint_overrides, embed_model=cfg.embed_model, + proxy=cfg.proxy, ) elif backend == "lazyllm_backend": from memu.llm.lazyllm_client import LazyLLMClient @@ -129,6 +131,7 @@ def _init_llm_client(self, config: LLMConfig | None = None) -> Any: embed_model=cfg.embed_model, vlm_model=cfg.lazyllm_source.vlm_model, stt_model=cfg.lazyllm_source.stt_model, + proxy=cfg.proxy, ) else: msg = f"Unknown llm_client_backend '{cfg.client_backend}'" diff --git a/src/memu/app/settings.py b/src/memu/app/settings.py index adcb4f16..ff32d4f9 100644 --- a/src/memu/app/settings.py +++ b/src/memu/app/settings.py @@ -124,6 +124,10 @@ class LLMConfig(BaseModel): default=1, description="Maximum batch size for embedding API calls (used by SDK client backends).", ) + proxy: str | None = Field( + default=None, + description="HTTP proxy URL for LLM requests (e.g., 'http://proxy.example.com:8080').", + ) @model_validator(mode="after") def set_provider_defaults(self) -> "LLMConfig": diff --git a/src/memu/embedding/backends/doubao.py b/src/memu/embedding/backends/doubao.py index ae8a4ae7..c9405927 100644 --- a/src/memu/embedding/backends/doubao.py +++ b/src/memu/embedding/backends/doubao.py @@ -37,7 +37,7 @@ class DoubaoEmbeddingBackend(EmbeddingBackend): def build_embedding_payload(self, *, inputs: list[str], embed_model: str) -> dict[str, Any]: """Build payload for standard text embeddings.""" - return {"model": embed_model, "input": inputs, "encoding_format": "float"} + return {"model": embed_model, "inputs": inputs, "encoding_format": "float"} def parse_embedding_response(self, data: dict[str, Any]) -> list[list[float]]: """Parse embedding response.""" @@ -64,7 +64,7 @@ def build_multimodal_embedding_payload( return { "model": embed_model, "encoding_format": encoding_format, - "input": [inp.to_dict() for inp in inputs], + "inputs": [inp.to_dict() for inp in inputs], } def parse_multimodal_embedding_response(self, data: dict[str, Any]) -> list[list[float]]: diff --git a/src/memu/llm/http_client.py b/src/memu/llm/http_client.py index ba84b05b..b601c158 100644 --- a/src/memu/llm/http_client.py +++ b/src/memu/llm/http_client.py @@ -90,6 +90,7 @@ def __init__( endpoint_overrides: dict[str, str] | None = None, timeout: int = 60, embed_model: str | None = None, + proxy: str | None = None, ): # Ensure base_url ends with "/" so httpx doesn't discard the path # component when joining with endpoint paths. @@ -101,7 +102,9 @@ def __init__( self.backend = self._load_backend(self.provider) self.embedding_backend = self._load_embedding_backend(self.provider) overrides = endpoint_overrides or {} - raw_summary_ep = overrides.get("chat") or overrides.get("summary") or self.backend.summary_endpoint + raw_summary_ep = ( + overrides.get("chat") or overrides.get("summary") or self.backend.summary_endpoint + ) raw_embedding_ep = ( overrides.get("embeddings") or overrides.get("embedding") @@ -114,7 +117,7 @@ def __init__( self.embedding_endpoint = raw_embedding_ep.lstrip("/") self.timeout = timeout self.embed_model = embed_model or chat_model - self.proxy = _load_proxy() + self.proxy = proxy or _load_proxy() async def chat( self, diff --git a/src/memu/llm/lazyllm_client.py b/src/memu/llm/lazyllm_client.py index 8446b6a5..c94bf750 100644 --- a/src/memu/llm/lazyllm_client.py +++ b/src/memu/llm/lazyllm_client.py @@ -22,6 +22,7 @@ def __init__( vlm_model: str | None = None, embed_model: str | None = None, stt_model: str | None = None, + proxy: str | None = None, ): self.llm_source = llm_source or self.DEFAULT_SOURCE self.vlm_source = vlm_source or self.DEFAULT_SOURCE @@ -31,6 +32,13 @@ def __init__( self.vlm_model = vlm_model self.embed_model = embed_model self.stt_model = stt_model + self.proxy = proxy + + # Set proxy for LazyLLM if provided + if proxy: + import os + os.environ["HTTP_PROXY"] = proxy + os.environ["HTTPS_PROXY"] = proxy async def _call_async(self, client: Any, *args: Any, **kwargs: Any) -> Any: """ diff --git a/src/memu/llm/openai_sdk.py b/src/memu/llm/openai_sdk.py index 38c6c8bb..8ad08b83 100644 --- a/src/memu/llm/openai_sdk.py +++ b/src/memu/llm/openai_sdk.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import Any, Literal, cast +import httpx from openai import AsyncOpenAI from openai.types import CreateEmbeddingResponse from openai.types.chat import ( @@ -28,13 +29,24 @@ def __init__( chat_model: str, embed_model: str, embed_batch_size: int = 1, + proxy: str | None = None, ): self.base_url = base_url.rstrip("/") self.api_key = api_key or "" self.chat_model = chat_model self.embed_model = embed_model self.embed_batch_size = embed_batch_size - self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url) + + # Create httpx client with proxy if provided + http_client = None + if proxy: + http_client = httpx.AsyncClient(proxy=proxy) + + self.client = AsyncOpenAI( + api_key=self.api_key, + base_url=self.base_url, + http_client=http_client + ) async def chat( self,