Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ result = await service.retrieve(
"categories": [...], # Relevant topic areas (auto-prioritized)
"items": [...], # Specific memory facts
"resources": [...], # Original sources for traceability
"graph_nodes": [...], # Graph-enhanced context (if enabled)
"next_step_query": "..." # Predicted follow-up context
}
```
Expand All @@ -487,6 +488,35 @@ result = await service.retrieve(
- `where={"agent_id__in": ["1", "2"]}` - Multi-agent coordination
- Omit `where` for global context awareness

#### Graph-Enhanced Retrieval

MemU can optionally build a **knowledge graph** from stored memories, enabling retrieval that follows semantic relationships between concepts — not just vector similarity.

```python
service = MemoryService(
retrieve_config={
"method": "rag",
"graph": {
"enabled": True, # Enable graph recall alongside vector search
"weight": 0.3, # Score fusion: 70% vector + 30% graph
"max_nodes": 6, # Max graph nodes per query
},
},
# ... other config
)
```

When enabled, the retrieve pipeline runs a **dual-path graph recall**:
1. **Precise path**: Vector/FTS seed nodes → community expansion → BFS walk → Personalized PageRank
2. **Generalized path**: Community representatives → shallow walk → PPR

Results are fused with vector retrieval using configurable weights (`α * vector_score + β * graph_ppr`), giving you both direct semantic matches and structurally related context.

The graph store supports:
- **Personalized PageRank** for query-relevant ranking
- **Label Propagation** for automatic community detection
- **Global PageRank** for baseline node importance

---

## 💡 Proactive Scenarios
Expand Down
125 changes: 125 additions & 0 deletions src/memu/app/admission.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""Memory Admission Gate — filter low-quality content before memorization.

Inspired by A-MAC (arXiv:2603.04549): score content at write-time,
reject below threshold. Pure heuristics, no LLM calls.
"""

from __future__ import annotations

import re
from dataclasses import dataclass
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from memu.app.settings import MemorizeAdmissionConfig


class AdmissionRejectedError(Exception):
"""Raised when content fails the admission gate."""

def __init__(self, result: AdmissionResult) -> None:
self.result = result
super().__init__(result.reason)

# Built-in noise patterns (always applied when gate is enabled)
_BUILTIN_NOISE_PATTERNS: list[re.Pattern[str]] = [
re.compile(r"<local-command-caveat>"),
re.compile(r"^\s*EXIT:\s*\d+", re.MULTILINE),
# Bare shell prompt lines, e.g. "$ ls -la" or "> git status"
re.compile(r"^\s*[$>]\s+\S+", re.MULTILINE),
# Pure JSON blob with no natural language around it
re.compile(r"^\s*[\[{][\s\S]*[\]}]\s*$"),
]


@dataclass(frozen=True, slots=True)
class AdmissionResult:
allowed: bool
reason: str
score: float


class AdmissionGate:
"""Stateless, cheap content gate."""

def __init__(self, config: MemorizeAdmissionConfig) -> None:
self._enabled = config.enabled
self._min_length = config.min_length
self._threshold = config.threshold
# Compile user-supplied patterns once
self._extra_patterns = [re.compile(p) for p in (config.noise_patterns or [])]

# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------

def check(self, content: str) -> AdmissionResult:
"""Return admission decision for *content*.

When the gate is disabled every input is allowed (score=1.0).
"""
if not self._enabled:
return AdmissionResult(allowed=True, reason="gate_disabled", score=1.0)

stripped = content.strip()

# --- length filter ---
if len(stripped) < self._min_length:
return AdmissionResult(
allowed=False,
reason=f"too_short (len={len(stripped)}<{self._min_length})",
score=0.0,
)

# --- noise pattern filter ---
for pat in _BUILTIN_NOISE_PATTERNS:
if pat.search(stripped):
return AdmissionResult(
allowed=False,
reason=f"noise_pattern ({pat.pattern!r})",
score=0.1,
)

for pat in self._extra_patterns:
if pat.search(stripped):
return AdmissionResult(
allowed=False,
reason=f"custom_noise_pattern ({pat.pattern!r})",
score=0.1,
)

# --- basic quality score (simple heuristic) ---
score = self._score(stripped)
if score < self._threshold:
return AdmissionResult(
allowed=False,
reason=f"low_score ({score:.2f}<{self._threshold})",
score=score,
)

return AdmissionResult(allowed=True, reason="pass", score=score)

# ------------------------------------------------------------------
# Internals
# ------------------------------------------------------------------

@staticmethod
def _score(text: str) -> float:
"""Cheap 0-1 quality heuristic.

Rewards: longer text, presence of spaces (sentence-like), mixed case.
"""
length = len(text)
# length component: ramp 0→1 over 30-300 chars
len_score = min(length / 300.0, 1.0)

# space ratio — natural language has ~15-20% spaces
space_ratio = text.count(" ") / max(length, 1)
space_score = min(space_ratio / 0.15, 1.0)

# mixed case (not ALL CAPS or all lower)
has_upper = any(c.isupper() for c in text)
has_lower = any(c.islower() for c in text)
case_score = 1.0 if (has_upper and has_lower) else 0.5

return round(0.4 * len_score + 0.35 * space_score + 0.25 * case_score, 4)
30 changes: 29 additions & 1 deletion src/memu/app/memorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import defusedxml.ElementTree as ET
from pydantic import BaseModel

from memu.app.admission import AdmissionGate, AdmissionRejectedError
from memu.app.settings import CategoryConfig, CustomPrompt
from memu.database.models import CategoryItem, MemoryCategory, MemoryItem, MemoryType, Resource
from memu.prompts.category_summary import (
Expand Down Expand Up @@ -87,7 +88,16 @@ async def memorize(
"user": user_scope,
}

result = await self._run_workflow("memorize", state)
try:
result = await self._run_workflow("memorize", state)
except AdmissionRejectedError as exc:
return {
"memories": [],
"relations": [],
"admission_rejected": True,
"admission_reason": exc.result.reason,
"admission_score": exc.result.score,
}
response = cast(dict[str, Any] | None, result.get("response"))
if response is None:
msg = "Memorize workflow failed to produce a response"
Expand All @@ -104,6 +114,14 @@ def _build_memorize_workflow(self) -> list[WorkflowStep]:
produces={"local_path", "raw_text"},
capabilities={"io"},
),
WorkflowStep(
step_id="admission_check",
role="filter",
handler=self._memorize_admission_check,
requires={"raw_text"},
produces=set(),
capabilities=set(),
),
WorkflowStep(
step_id="preprocess_multimodal",
role="preprocess",
Expand Down Expand Up @@ -183,6 +201,16 @@ async def _memorize_ingest_resource(self, state: WorkflowState, step_context: An
state.update({"local_path": local_path, "raw_text": raw_text})
return state

async def _memorize_admission_check(self, state: WorkflowState, step_context: Any) -> WorkflowState:
"""Early exit when content fails admission gate."""
gate = AdmissionGate(self.memorize_config.admission)
raw_text = state.get("raw_text") or ""
result = gate.check(raw_text)
if not result.allowed:
logger.info("Admission gate rejected content: %s (score=%.2f)", result.reason, result.score)
raise AdmissionRejectedError(result)
return state

async def _memorize_preprocess_multimodal(self, state: WorkflowState, step_context: Any) -> WorkflowState:
llm_client = self._get_step_llm_client(step_context)
preprocessed = await self._preprocess_resource_url(
Expand Down
86 changes: 85 additions & 1 deletion src/memu/app/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,20 @@ def _build_rag_retrieve_workflow(self) -> list[WorkflowStep]:
capabilities={"vector"},
config={"embed_llm_profile": "embedding"},
),
WorkflowStep(
step_id="recall_graph",
role="recall_graph",
handler=self._rag_recall_graph,
requires={
"needs_retrieval",
"active_query",
"query_vector",
"store",
},
produces={"graph_hits", "graph_recall_result"},
capabilities={"vector"},
config={"embed_llm_profile": "embedding"},
),
WorkflowStep(
step_id="build_context",
role="build_context",
Expand Down Expand Up @@ -423,6 +437,36 @@ async def _rag_recall_resources(self, state: WorkflowState, step_context: Any) -
state["resource_hits"] = cosine_topk(qvec, corpus, k=self.retrieve_config.resource.top_k)
return state

async def _rag_recall_graph(self, state: WorkflowState, step_context: Any) -> WorkflowState:
if not state.get("needs_retrieval") or not self.retrieve_config.graph.enabled:
state["graph_hits"] = []
return state

store = state["store"]
graph_store = getattr(store, "graph_store", None)
if graph_store is None:
state["graph_hits"] = []
return state

query_vec = state.get("query_vector")
if query_vec is None:
embed_client = self._get_step_embedding_client(step_context)
query_vec = (await embed_client.embed([state["active_query"]]))[0]
state["query_vector"] = query_vec

result = graph_store.graph_recall(
state["active_query"],
query_vec=query_vec,
max_nodes=self.retrieve_config.graph.max_nodes,
where=state.get("where"),
)
# Convert to (id, score) tuples for consistency with other hits
state["graph_hits"] = [
(n.id, n.ppr_score) for n in result.nodes
]
state["graph_recall_result"] = result
return state

def _rag_build_context(self, state: WorkflowState, _: Any) -> WorkflowState:
response = {
"needs_retrieval": bool(state.get("needs_retrieval")),
Expand All @@ -432,6 +476,7 @@ def _rag_build_context(self, state: WorkflowState, _: Any) -> WorkflowState:
"categories": [],
"items": [],
"resources": [],
"graph_nodes": [],
}
if state.get("needs_retrieval"):
store = state["store"]
Expand All @@ -443,11 +488,49 @@ def _rag_build_context(self, state: WorkflowState, _: Any) -> WorkflowState:
state.get("category_hits", []),
categories_pool,
)
response["items"] = self._materialize_hits(state.get("item_hits", []), items_pool)

# Score fusion: only deflate item scores when graph is enabled AND returned results
graph_recall_result = state.get("graph_recall_result")
graph_active = (
self.retrieve_config.graph.enabled
and graph_recall_result
and graph_recall_result.nodes
)
item_hits = state.get("item_hits", [])
if graph_active:
graph_weight = self.retrieve_config.graph.weight
vector_weight = 1.0 - graph_weight
response["items"] = [
{**d, "score": d["score"] * vector_weight}
for d in self._materialize_hits(item_hits, items_pool)
]
else:
response["items"] = self._materialize_hits(item_hits, items_pool)

response["resources"] = self._materialize_hits(
state.get("resource_hits", []),
resources_pool,
)

# Graph nodes: materialize from RecallResult
if graph_active:
gw = self.retrieve_config.graph.weight
max_ppr = max((n.ppr_score for n in graph_recall_result.nodes), default=0.0) or 1.0
graph_entries = []
for n in graph_recall_result.nodes:
ppr_norm = n.ppr_score / max_ppr
graph_entries.append({
"id": n.id,
"type": n.type,
"name": n.name,
"description": n.description,
"content": n.content,
"community_id": n.community_id,
"score": ppr_norm * gw,
"ppr_score": n.ppr_score,
})
response["graph_nodes"] = graph_entries

state["response"] = response
return state

Expand Down Expand Up @@ -714,6 +797,7 @@ def _llm_build_context(self, state: WorkflowState, _: Any) -> WorkflowState:
"categories": [],
"items": [],
"resources": [],
"graph_nodes": [],
}
if state.get("needs_retrieval"):
response["categories"] = list(state.get("category_hits") or [])
Expand Down
Loading