diff --git a/src/memu/llm/backends/base.py b/src/memu/llm/backends/base.py index 8f76330a..9bfe43a8 100644 --- a/src/memu/llm/backends/base.py +++ b/src/memu/llm/backends/base.py @@ -3,6 +3,19 @@ from typing import Any +def _extract_content_from_dict(data: dict[str, Any]) -> str: + """Extract text content from a raw API response dict. + + Falls back to ``reasoning_content`` for reasoning models (e.g. MiniMax-M2.7, + DeepSeek-R1) that put their output there instead of ``content``. + """ + msg = data["choices"][0]["message"] + content = msg.get("content") + if not content: + content = msg.get("reasoning_content") + return content or "" + + class LLMBackend: """Defines how to talk to a specific HTTP LLM provider.""" diff --git a/src/memu/llm/backends/doubao.py b/src/memu/llm/backends/doubao.py index 9dd7012a..fcf5b9a6 100644 --- a/src/memu/llm/backends/doubao.py +++ b/src/memu/llm/backends/doubao.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import Any, cast +from typing import Any -from memu.llm.backends.base import LLMBackend +from memu.llm.backends.base import LLMBackend, _extract_content_from_dict class DoubaoLLMBackend(LLMBackend): @@ -29,7 +29,7 @@ def build_summary_payload( return payload def parse_summary_response(self, data: dict[str, Any]) -> str: - return cast(str, data["choices"][0]["message"]["content"]) + return _extract_content_from_dict(data) def build_vision_payload( self, diff --git a/src/memu/llm/backends/openai.py b/src/memu/llm/backends/openai.py index aef24fc6..dd8f98d1 100644 --- a/src/memu/llm/backends/openai.py +++ b/src/memu/llm/backends/openai.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import Any, cast +from typing import Any -from memu.llm.backends.base import LLMBackend +from memu.llm.backends.base import LLMBackend, _extract_content_from_dict class OpenAILLMBackend(LLMBackend): @@ -26,7 +26,7 @@ def build_summary_payload( } def parse_summary_response(self, data: dict[str, Any]) -> str: - return cast(str, data["choices"][0]["message"]["content"]) + return _extract_content_from_dict(data) def build_vision_payload( self, diff --git a/src/memu/llm/backends/openrouter.py b/src/memu/llm/backends/openrouter.py index 1a8cdeef..8ef00bd0 100644 --- a/src/memu/llm/backends/openrouter.py +++ b/src/memu/llm/backends/openrouter.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import Any, cast +from typing import Any -from memu.llm.backends.base import LLMBackend +from memu.llm.backends.base import LLMBackend, _extract_content_from_dict class OpenRouterLLMBackend(LLMBackend): @@ -30,7 +30,7 @@ def build_summary_payload( def parse_summary_response(self, data: dict[str, Any]) -> str: """Parse OpenRouter response (OpenAI-compatible format).""" - return cast(str, data["choices"][0]["message"]["content"]) + return _extract_content_from_dict(data) def build_vision_payload( self, diff --git a/src/memu/llm/openai_sdk.py b/src/memu/llm/openai_sdk.py index 38c6c8bb..8d78d0c5 100644 --- a/src/memu/llm/openai_sdk.py +++ b/src/memu/llm/openai_sdk.py @@ -17,6 +17,19 @@ logger = logging.getLogger(__name__) +def _extract_content(response: ChatCompletion) -> str: + """Extract text content from a chat completion, with fallback for reasoning models. + + Some reasoning models (e.g. MiniMax-M2.7, DeepSeek-R1) return their output in + ``reasoning_content`` instead of ``content``. This helper checks both fields. + """ + msg = response.choices[0].message + content = msg.content + if not content: + content = getattr(msg, "reasoning_content", None) + return content or "" + + class OpenAISDKClient: """OpenAI LLM client that relies on the official Python SDK.""" @@ -59,9 +72,9 @@ async def chat( temperature=temperature, max_tokens=max_tokens, ) - content = response.choices[0].message.content + content = _extract_content(response) logger.debug("OpenAI chat response: %s", response) - return content or "", response + return content, response async def summarize( self, @@ -82,9 +95,9 @@ async def summarize( temperature=1, max_tokens=max_tokens, ) - content = response.choices[0].message.content + content = _extract_content(response) logger.debug("OpenAI summarize response: %s", response) - return content or "", response + return content, response async def vision( self, @@ -148,9 +161,9 @@ async def vision( temperature=1, max_tokens=max_tokens, ) - content = response.choices[0].message.content + content = _extract_content(response) logger.debug("OpenAI vision response: %s", response) - return content or "", response + return content, response async def embed(self, inputs: list[str]) -> tuple[list[list[float]], CreateEmbeddingResponse | None]: """Create text embeddings via the official SDK.""" diff --git a/tests/llm/test_extract_content.py b/tests/llm/test_extract_content.py new file mode 100644 index 00000000..15511970 --- /dev/null +++ b/tests/llm/test_extract_content.py @@ -0,0 +1,87 @@ +"""Tests for reasoning_content fallback in content extraction helpers.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +from memu.llm.backends.base import _extract_content_from_dict +from memu.llm.openai_sdk import _extract_content + + +# -- _extract_content (SDK path, ChatCompletion objects) -- + + +def _fake_completion(content=None, reasoning_content=None): + """Build a minimal ChatCompletion-like object.""" + msg = MagicMock() + msg.content = content + # reasoning_content is an extra attr on some providers + if reasoning_content is not None: + msg.reasoning_content = reasoning_content + else: + # simulate the attr not existing at all + del msg.reasoning_content + choice = SimpleNamespace(message=msg) + return SimpleNamespace(choices=[choice]) + + +class TestExtractContent: + def test_normal_content(self): + resp = _fake_completion(content="hello world") + assert _extract_content(resp) == "hello world" + + def test_reasoning_content_fallback(self): + resp = _fake_completion(content=None, reasoning_content="thought result") + assert _extract_content(resp) == "thought result" + + def test_empty_string_content_falls_back(self): + resp = _fake_completion(content="", reasoning_content="fallback") + assert _extract_content(resp) == "fallback" + + def test_both_none_returns_empty(self): + resp = _fake_completion(content=None) + assert _extract_content(resp) == "" + + def test_content_preferred_over_reasoning(self): + resp = _fake_completion(content="real answer", reasoning_content="thinking") + assert _extract_content(resp) == "real answer" + + +# -- _extract_content_from_dict (HTTP path, raw dicts) -- + + +def _fake_dict_response(content=None, reasoning_content=None): + """Build a minimal raw API response dict.""" + msg = {} + if content is not None: + msg["content"] = content + if reasoning_content is not None: + msg["reasoning_content"] = reasoning_content + return {"choices": [{"message": msg}]} + + +class TestExtractContentFromDict: + def test_normal_content(self): + data = _fake_dict_response(content="hello") + assert _extract_content_from_dict(data) == "hello" + + def test_reasoning_content_fallback(self): + data = _fake_dict_response(reasoning_content="thought") + assert _extract_content_from_dict(data) == "thought" + + def test_empty_string_content_falls_back(self): + data = _fake_dict_response(content="", reasoning_content="fb") + assert _extract_content_from_dict(data) == "fb" + + def test_both_missing_returns_empty(self): + data = _fake_dict_response() + assert _extract_content_from_dict(data) == "" + + def test_content_preferred_over_reasoning(self): + data = _fake_dict_response(content="answer", reasoning_content="thinking") + assert _extract_content_from_dict(data) == "answer" + + def test_none_content_with_reasoning(self): + data = _fake_dict_response(content=None, reasoning_content="result") + assert _extract_content_from_dict(data) == "result"