Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/memu/llm/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,19 @@
from typing import Any


def _extract_content_from_dict(data: dict[str, Any]) -> str:
"""Extract text content from a raw API response dict.

Falls back to ``reasoning_content`` for reasoning models (e.g. MiniMax-M2.7,
DeepSeek-R1) that put their output there instead of ``content``.
"""
msg = data["choices"][0]["message"]
content = msg.get("content")
if not content:
content = msg.get("reasoning_content")
return content or ""


class LLMBackend:
"""Defines how to talk to a specific HTTP LLM provider."""

Expand Down
6 changes: 3 additions & 3 deletions src/memu/llm/backends/doubao.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

from typing import Any, cast
from typing import Any

from memu.llm.backends.base import LLMBackend
from memu.llm.backends.base import LLMBackend, _extract_content_from_dict


class DoubaoLLMBackend(LLMBackend):
Expand All @@ -29,7 +29,7 @@ def build_summary_payload(
return payload

def parse_summary_response(self, data: dict[str, Any]) -> str:
return cast(str, data["choices"][0]["message"]["content"])
return _extract_content_from_dict(data)

def build_vision_payload(
self,
Expand Down
6 changes: 3 additions & 3 deletions src/memu/llm/backends/openai.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

from typing import Any, cast
from typing import Any

from memu.llm.backends.base import LLMBackend
from memu.llm.backends.base import LLMBackend, _extract_content_from_dict


class OpenAILLMBackend(LLMBackend):
Expand All @@ -26,7 +26,7 @@ def build_summary_payload(
}

def parse_summary_response(self, data: dict[str, Any]) -> str:
return cast(str, data["choices"][0]["message"]["content"])
return _extract_content_from_dict(data)

def build_vision_payload(
self,
Expand Down
6 changes: 3 additions & 3 deletions src/memu/llm/backends/openrouter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

from typing import Any, cast
from typing import Any

from memu.llm.backends.base import LLMBackend
from memu.llm.backends.base import LLMBackend, _extract_content_from_dict


class OpenRouterLLMBackend(LLMBackend):
Expand Down Expand Up @@ -30,7 +30,7 @@ def build_summary_payload(

def parse_summary_response(self, data: dict[str, Any]) -> str:
"""Parse OpenRouter response (OpenAI-compatible format)."""
return cast(str, data["choices"][0]["message"]["content"])
return _extract_content_from_dict(data)

def build_vision_payload(
self,
Expand Down
25 changes: 19 additions & 6 deletions src/memu/llm/openai_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@
logger = logging.getLogger(__name__)


def _extract_content(response: ChatCompletion) -> str:
"""Extract text content from a chat completion, with fallback for reasoning models.

Some reasoning models (e.g. MiniMax-M2.7, DeepSeek-R1) return their output in
``reasoning_content`` instead of ``content``. This helper checks both fields.
"""
msg = response.choices[0].message
content = msg.content
if not content:
content = getattr(msg, "reasoning_content", None)
return content or ""


class OpenAISDKClient:
"""OpenAI LLM client that relies on the official Python SDK."""

Expand Down Expand Up @@ -59,9 +72,9 @@ async def chat(
temperature=temperature,
max_tokens=max_tokens,
)
content = response.choices[0].message.content
content = _extract_content(response)
logger.debug("OpenAI chat response: %s", response)
return content or "", response
return content, response

async def summarize(
self,
Expand All @@ -82,9 +95,9 @@ async def summarize(
temperature=1,
max_tokens=max_tokens,
)
content = response.choices[0].message.content
content = _extract_content(response)
logger.debug("OpenAI summarize response: %s", response)
return content or "", response
return content, response

async def vision(
self,
Expand Down Expand Up @@ -148,9 +161,9 @@ async def vision(
temperature=1,
max_tokens=max_tokens,
)
content = response.choices[0].message.content
content = _extract_content(response)
logger.debug("OpenAI vision response: %s", response)
return content or "", response
return content, response

async def embed(self, inputs: list[str]) -> tuple[list[list[float]], CreateEmbeddingResponse | None]:
"""Create text embeddings via the official SDK."""
Expand Down
87 changes: 87 additions & 0 deletions tests/llm/test_extract_content.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Tests for reasoning_content fallback in content extraction helpers."""

from __future__ import annotations

from types import SimpleNamespace
from unittest.mock import MagicMock

from memu.llm.backends.base import _extract_content_from_dict
from memu.llm.openai_sdk import _extract_content


# -- _extract_content (SDK path, ChatCompletion objects) --


def _fake_completion(content=None, reasoning_content=None):
"""Build a minimal ChatCompletion-like object."""
msg = MagicMock()
msg.content = content
# reasoning_content is an extra attr on some providers
if reasoning_content is not None:
msg.reasoning_content = reasoning_content
else:
# simulate the attr not existing at all
del msg.reasoning_content
choice = SimpleNamespace(message=msg)
return SimpleNamespace(choices=[choice])


class TestExtractContent:
def test_normal_content(self):
resp = _fake_completion(content="hello world")
assert _extract_content(resp) == "hello world"

def test_reasoning_content_fallback(self):
resp = _fake_completion(content=None, reasoning_content="thought result")
assert _extract_content(resp) == "thought result"

def test_empty_string_content_falls_back(self):
resp = _fake_completion(content="", reasoning_content="fallback")
assert _extract_content(resp) == "fallback"

def test_both_none_returns_empty(self):
resp = _fake_completion(content=None)
assert _extract_content(resp) == ""

def test_content_preferred_over_reasoning(self):
resp = _fake_completion(content="real answer", reasoning_content="thinking")
assert _extract_content(resp) == "real answer"


# -- _extract_content_from_dict (HTTP path, raw dicts) --


def _fake_dict_response(content=None, reasoning_content=None):
"""Build a minimal raw API response dict."""
msg = {}
if content is not None:
msg["content"] = content
if reasoning_content is not None:
msg["reasoning_content"] = reasoning_content
return {"choices": [{"message": msg}]}


class TestExtractContentFromDict:
def test_normal_content(self):
data = _fake_dict_response(content="hello")
assert _extract_content_from_dict(data) == "hello"

def test_reasoning_content_fallback(self):
data = _fake_dict_response(reasoning_content="thought")
assert _extract_content_from_dict(data) == "thought"

def test_empty_string_content_falls_back(self):
data = _fake_dict_response(content="", reasoning_content="fb")
assert _extract_content_from_dict(data) == "fb"

def test_both_missing_returns_empty(self):
data = _fake_dict_response()
assert _extract_content_from_dict(data) == ""

def test_content_preferred_over_reasoning(self):
data = _fake_dict_response(content="answer", reasoning_content="thinking")
assert _extract_content_from_dict(data) == "answer"

def test_none_content_with_reasoning(self):
data = _fake_dict_response(content=None, reasoning_content="result")
assert _extract_content_from_dict(data) == "result"