Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ result = await service.retrieve(
"categories": [...], # Relevant topic areas (auto-prioritized)
"items": [...], # Specific memory facts
"resources": [...], # Original sources for traceability
"graph_nodes": [...], # Graph-enhanced context (if enabled)
"next_step_query": "..." # Predicted follow-up context
}
```
Expand All @@ -487,6 +488,35 @@ result = await service.retrieve(
- `where={"agent_id__in": ["1", "2"]}` - Multi-agent coordination
- Omit `where` for global context awareness

#### Graph-Enhanced Retrieval

MemU can optionally build a **knowledge graph** from stored memories, enabling retrieval that follows semantic relationships between concepts — not just vector similarity.

```python
service = MemoryService(
retrieve_config={
"method": "rag",
"graph": {
"enabled": True, # Enable graph recall alongside vector search
"weight": 0.3, # Score fusion: 70% vector + 30% graph
"max_nodes": 6, # Max graph nodes per query
},
},
# ... other config
)
```

When enabled, the retrieve pipeline runs a **dual-path graph recall**:
1. **Precise path**: Vector/FTS seed nodes → community expansion → BFS walk → Personalized PageRank
2. **Generalized path**: Community representatives → shallow walk → PPR

Results are fused with vector retrieval using configurable weights (`α * vector_score + β * graph_ppr`), giving you both direct semantic matches and structurally related context.

The graph store supports:
- **Personalized PageRank** for query-relevant ranking
- **Label Propagation** for automatic community detection
- **Global PageRank** for baseline node importance

---

## 💡 Proactive Scenarios
Expand Down
86 changes: 85 additions & 1 deletion src/memu/app/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,20 @@ def _build_rag_retrieve_workflow(self) -> list[WorkflowStep]:
capabilities={"vector"},
config={"embed_llm_profile": "embedding"},
),
WorkflowStep(
step_id="recall_graph",
role="recall_graph",
handler=self._rag_recall_graph,
requires={
"needs_retrieval",
"active_query",
"query_vector",
"store",
},
produces={"graph_hits", "graph_recall_result"},
capabilities={"vector"},
config={"embed_llm_profile": "embedding"},
),
WorkflowStep(
step_id="build_context",
role="build_context",
Expand Down Expand Up @@ -423,6 +437,36 @@ async def _rag_recall_resources(self, state: WorkflowState, step_context: Any) -
state["resource_hits"] = cosine_topk(qvec, corpus, k=self.retrieve_config.resource.top_k)
return state

async def _rag_recall_graph(self, state: WorkflowState, step_context: Any) -> WorkflowState:
if not state.get("needs_retrieval") or not self.retrieve_config.graph.enabled:
state["graph_hits"] = []
return state

store = state["store"]
graph_store = getattr(store, "graph_store", None)
if graph_store is None:
state["graph_hits"] = []
return state

query_vec = state.get("query_vector")
if query_vec is None:
embed_client = self._get_step_embedding_client(step_context)
query_vec = (await embed_client.embed([state["active_query"]]))[0]
state["query_vector"] = query_vec

result = graph_store.graph_recall(
state["active_query"],
query_vec=query_vec,
max_nodes=self.retrieve_config.graph.max_nodes,
where=state.get("where"),
)
# Convert to (id, score) tuples for consistency with other hits
state["graph_hits"] = [
(n.id, n.ppr_score) for n in result.nodes
]
state["graph_recall_result"] = result
return state

def _rag_build_context(self, state: WorkflowState, _: Any) -> WorkflowState:
response = {
"needs_retrieval": bool(state.get("needs_retrieval")),
Expand All @@ -432,6 +476,7 @@ def _rag_build_context(self, state: WorkflowState, _: Any) -> WorkflowState:
"categories": [],
"items": [],
"resources": [],
"graph_nodes": [],
}
if state.get("needs_retrieval"):
store = state["store"]
Expand All @@ -443,11 +488,49 @@ def _rag_build_context(self, state: WorkflowState, _: Any) -> WorkflowState:
state.get("category_hits", []),
categories_pool,
)
response["items"] = self._materialize_hits(state.get("item_hits", []), items_pool)

# Score fusion: only deflate item scores when graph is enabled AND returned results
graph_recall_result = state.get("graph_recall_result")
graph_active = (
self.retrieve_config.graph.enabled
and graph_recall_result
and graph_recall_result.nodes
)
item_hits = state.get("item_hits", [])
if graph_active:
graph_weight = self.retrieve_config.graph.weight
vector_weight = 1.0 - graph_weight
response["items"] = [
{**d, "score": d["score"] * vector_weight}
for d in self._materialize_hits(item_hits, items_pool)
]
else:
response["items"] = self._materialize_hits(item_hits, items_pool)

response["resources"] = self._materialize_hits(
state.get("resource_hits", []),
resources_pool,
)

# Graph nodes: materialize from RecallResult
if graph_active:
gw = self.retrieve_config.graph.weight
max_ppr = max((n.ppr_score for n in graph_recall_result.nodes), default=0.0) or 1.0
graph_entries = []
for n in graph_recall_result.nodes:
ppr_norm = n.ppr_score / max_ppr
graph_entries.append({
"id": n.id,
"type": n.type,
"name": n.name,
"description": n.description,
"content": n.content,
"community_id": n.community_id,
"score": ppr_norm * gw,
"ppr_score": n.ppr_score,
})
response["graph_nodes"] = graph_entries

state["response"] = response
return state

Expand Down Expand Up @@ -714,6 +797,7 @@ def _llm_build_context(self, state: WorkflowState, _: Any) -> WorkflowState:
"categories": [],
"items": [],
"resources": [],
"graph_nodes": [],
}
if state.get("needs_retrieval"):
response["categories"] = list(state.get("category_hits") or [])
Expand Down
12 changes: 12 additions & 0 deletions src/memu/app/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,17 @@ class RetrieveItemConfig(BaseModel):
)


class RetrieveGraphConfig(BaseModel):
enabled: bool = Field(default=False, description="Whether to enable graph-enhanced retrieval.")
max_nodes: int = Field(default=6, description="Maximum graph nodes to return per recall.")
weight: float = Field(
default=0.3,
ge=0.0,
le=1.0,
description="Graph score weight (β) in fusion. Vector weight is 1-β.",
)


class RetrieveResourceConfig(BaseModel):
enabled: bool = Field(default=True, description="Whether to enable resource retrieval.")
top_k: int = Field(default=5, description="Total number of resources to retrieve.")
Expand Down Expand Up @@ -195,6 +206,7 @@ class RetrieveConfig(BaseModel):
category: RetrieveCategoryConfig = Field(default=RetrieveCategoryConfig())
item: RetrieveItemConfig = Field(default=RetrieveItemConfig())
resource: RetrieveResourceConfig = Field(default=RetrieveResourceConfig())
graph: RetrieveGraphConfig = Field(default=RetrieveGraphConfig())
sufficiency_check: bool = Field(default=True, description="Whether to check sufficiency after each tier.")
sufficiency_check_prompt: str = Field(default="", description="User prompt for sufficiency check.")
sufficiency_check_llm_profile: str = Field(default="default", description="LLM profile for sufficiency check.")
Expand Down
31 changes: 31 additions & 0 deletions src/memu/database/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,34 @@ class CategoryItem(BaseRecord):
category_id: str


class GraphNode(BaseRecord):
type: str
name: str
description: str = ""
content: str = ""
status: str = "active"
validated_count: int = 1
source_sessions: list[str] = Field(default_factory=list)
community_id: str | None = None
pagerank: float = 0.0
embedding: list[float] | None = None


class GraphEdge(BaseRecord):
from_id: str
to_id: str
type: str
instruction: str = ""
condition: str | None = None
session_id: str | None = None


class GraphCommunity(BaseRecord):
summary: str | None = None
node_count: int = 0
embedding: list[float] | None = None


def merge_scope_model[TBaseRecord: BaseRecord](
user_model: type[BaseModel], core_model: type[TBaseRecord], *, name_suffix: str
) -> type[TBaseRecord]:
Expand Down Expand Up @@ -137,6 +165,9 @@ def build_scoped_models(
__all__ = [
"BaseRecord",
"CategoryItem",
"GraphCommunity",
"GraphEdge",
"GraphNode",
"MemoryCategory",
"MemoryItem",
"MemoryType",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Add graph tables (gm_nodes, gm_edges, gm_communities).

Revision ID: 001_add_graph
Revises:
Create Date: 2026-03-27

Uses IF NOT EXISTS to safely adopt pre-existing tables created outside Alembic.
"""

revision = "001_add_graph"
down_revision = None
branch_labels = None
depends_on = None

from alembic import op
import sqlalchemy as sa
from pgvector.sqlalchemy import Vector


def upgrade() -> None:
# gm_nodes — graph knowledge nodes
op.execute("""
CREATE TABLE IF NOT EXISTS gm_nodes (
id TEXT PRIMARY KEY,
type TEXT NOT NULL,
name TEXT NOT NULL,
description TEXT NOT NULL DEFAULT '',
content TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'active',
validated_count INTEGER DEFAULT 1,
source_sessions TEXT[] DEFAULT '{}',
community_id TEXT,
pagerank REAL DEFAULT 0,
embedding vector,
created_at TIMESTAMPTZ DEFAULT now(),
updated_at TIMESTAMPTZ DEFAULT now()
)
""")

# gm_edges — directed graph edges
op.execute("""
CREATE TABLE IF NOT EXISTS gm_edges (
id TEXT PRIMARY KEY,
from_id TEXT NOT NULL,
to_id TEXT NOT NULL,
type TEXT NOT NULL,
instruction TEXT NOT NULL,
condition TEXT,
session_id TEXT,
created_at TIMESTAMPTZ DEFAULT now()
)
""")

# gm_communities — LPA community aggregates
op.execute("""
CREATE TABLE IF NOT EXISTS gm_communities (
id TEXT PRIMARY KEY,
summary TEXT,
node_count INTEGER DEFAULT 0,
embedding vector,
created_at TIMESTAMPTZ DEFAULT now(),
updated_at TIMESTAMPTZ DEFAULT now()
)
""")

# Scope column for multi-user support (added to pre-existing tables safely)
for table in ("gm_nodes", "gm_edges", "gm_communities"):
op.execute(f"""
DO $$ BEGIN
ALTER TABLE {table} ADD COLUMN IF NOT EXISTS user_id TEXT DEFAULT '';
EXCEPTION WHEN duplicate_column THEN NULL;
END $$
""")

# Indexes (IF NOT EXISTS for safety)
op.execute("CREATE INDEX IF NOT EXISTS ix_gm_nodes_status ON gm_nodes (status)")
op.execute("CREATE INDEX IF NOT EXISTS ix_gm_nodes_community ON gm_nodes (community_id)")
op.execute("CREATE INDEX IF NOT EXISTS ix_gm_edges_from ON gm_edges (from_id)")
op.execute("CREATE INDEX IF NOT EXISTS ix_gm_edges_to ON gm_edges (to_id)")
op.execute("CREATE INDEX IF NOT EXISTS ix_gm_nodes__scope ON gm_nodes (user_id)")
op.execute("CREATE INDEX IF NOT EXISTS ix_gm_edges__scope ON gm_edges (user_id)")
op.execute("CREATE INDEX IF NOT EXISTS ix_gm_communities__scope ON gm_communities (user_id)")


def downgrade() -> None:
op.drop_table("gm_communities")
op.drop_table("gm_edges")
op.drop_table("gm_nodes")
Loading