diff --git a/README.md b/README.md index 2dedb745..af012712 100644 --- a/README.md +++ b/README.md @@ -478,6 +478,7 @@ result = await service.retrieve( "categories": [...], # Relevant topic areas (auto-prioritized) "items": [...], # Specific memory facts "resources": [...], # Original sources for traceability + "graph_nodes": [...], # Graph-enhanced context (if enabled) "next_step_query": "..." # Predicted follow-up context } ``` @@ -487,6 +488,35 @@ result = await service.retrieve( - `where={"agent_id__in": ["1", "2"]}` - Multi-agent coordination - Omit `where` for global context awareness +#### Graph-Enhanced Retrieval + +MemU can optionally build a **knowledge graph** from stored memories, enabling retrieval that follows semantic relationships between concepts — not just vector similarity. + +```python +service = MemoryService( + retrieve_config={ + "method": "rag", + "graph": { + "enabled": True, # Enable graph recall alongside vector search + "weight": 0.3, # Score fusion: 70% vector + 30% graph + "max_nodes": 6, # Max graph nodes per query + }, + }, + # ... other config +) +``` + +When enabled, the retrieve pipeline runs a **dual-path graph recall**: +1. **Precise path**: Vector/FTS seed nodes → community expansion → BFS walk → Personalized PageRank +2. **Generalized path**: Community representatives → shallow walk → PPR + +Results are fused with vector retrieval using configurable weights (`α * vector_score + β * graph_ppr`), giving you both direct semantic matches and structurally related context. + +The graph store supports: +- **Personalized PageRank** for query-relevant ranking +- **Label Propagation** for automatic community detection +- **Global PageRank** for baseline node importance + --- ## 💡 Proactive Scenarios diff --git a/src/memu/app/admission.py b/src/memu/app/admission.py new file mode 100644 index 00000000..97db04f3 --- /dev/null +++ b/src/memu/app/admission.py @@ -0,0 +1,125 @@ +"""Memory Admission Gate — filter low-quality content before memorization. + +Inspired by A-MAC (arXiv:2603.04549): score content at write-time, +reject below threshold. Pure heuristics, no LLM calls. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from memu.app.settings import MemorizeAdmissionConfig + + +class AdmissionRejectedError(Exception): + """Raised when content fails the admission gate.""" + + def __init__(self, result: AdmissionResult) -> None: + self.result = result + super().__init__(result.reason) + +# Built-in noise patterns (always applied when gate is enabled) +_BUILTIN_NOISE_PATTERNS: list[re.Pattern[str]] = [ + re.compile(r""), + re.compile(r"^\s*EXIT:\s*\d+", re.MULTILINE), + # Bare shell prompt lines, e.g. "$ ls -la" or "> git status" + re.compile(r"^\s*[$>]\s+\S+", re.MULTILINE), + # Pure JSON blob with no natural language around it + re.compile(r"^\s*[\[{][\s\S]*[\]}]\s*$"), +] + + +@dataclass(frozen=True, slots=True) +class AdmissionResult: + allowed: bool + reason: str + score: float + + +class AdmissionGate: + """Stateless, cheap content gate.""" + + def __init__(self, config: MemorizeAdmissionConfig) -> None: + self._enabled = config.enabled + self._min_length = config.min_length + self._threshold = config.threshold + # Compile user-supplied patterns once + self._extra_patterns = [re.compile(p) for p in (config.noise_patterns or [])] + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def check(self, content: str) -> AdmissionResult: + """Return admission decision for *content*. + + When the gate is disabled every input is allowed (score=1.0). + """ + if not self._enabled: + return AdmissionResult(allowed=True, reason="gate_disabled", score=1.0) + + stripped = content.strip() + + # --- length filter --- + if len(stripped) < self._min_length: + return AdmissionResult( + allowed=False, + reason=f"too_short (len={len(stripped)}<{self._min_length})", + score=0.0, + ) + + # --- noise pattern filter --- + for pat in _BUILTIN_NOISE_PATTERNS: + if pat.search(stripped): + return AdmissionResult( + allowed=False, + reason=f"noise_pattern ({pat.pattern!r})", + score=0.1, + ) + + for pat in self._extra_patterns: + if pat.search(stripped): + return AdmissionResult( + allowed=False, + reason=f"custom_noise_pattern ({pat.pattern!r})", + score=0.1, + ) + + # --- basic quality score (simple heuristic) --- + score = self._score(stripped) + if score < self._threshold: + return AdmissionResult( + allowed=False, + reason=f"low_score ({score:.2f}<{self._threshold})", + score=score, + ) + + return AdmissionResult(allowed=True, reason="pass", score=score) + + # ------------------------------------------------------------------ + # Internals + # ------------------------------------------------------------------ + + @staticmethod + def _score(text: str) -> float: + """Cheap 0-1 quality heuristic. + + Rewards: longer text, presence of spaces (sentence-like), mixed case. + """ + length = len(text) + # length component: ramp 0→1 over 30-300 chars + len_score = min(length / 300.0, 1.0) + + # space ratio — natural language has ~15-20% spaces + space_ratio = text.count(" ") / max(length, 1) + space_score = min(space_ratio / 0.15, 1.0) + + # mixed case (not ALL CAPS or all lower) + has_upper = any(c.isupper() for c in text) + has_lower = any(c.islower() for c in text) + case_score = 1.0 if (has_upper and has_lower) else 0.5 + + return round(0.4 * len_score + 0.35 * space_score + 0.25 * case_score, 4) diff --git a/src/memu/app/memorize.py b/src/memu/app/memorize.py index 0f2a06fc..de53e047 100644 --- a/src/memu/app/memorize.py +++ b/src/memu/app/memorize.py @@ -12,6 +12,7 @@ import defusedxml.ElementTree as ET from pydantic import BaseModel +from memu.app.admission import AdmissionGate, AdmissionRejectedError from memu.app.settings import CategoryConfig, CustomPrompt from memu.database.models import CategoryItem, MemoryCategory, MemoryItem, MemoryType, Resource from memu.prompts.category_summary import ( @@ -87,7 +88,16 @@ async def memorize( "user": user_scope, } - result = await self._run_workflow("memorize", state) + try: + result = await self._run_workflow("memorize", state) + except AdmissionRejectedError as exc: + return { + "memories": [], + "relations": [], + "admission_rejected": True, + "admission_reason": exc.result.reason, + "admission_score": exc.result.score, + } response = cast(dict[str, Any] | None, result.get("response")) if response is None: msg = "Memorize workflow failed to produce a response" @@ -104,6 +114,14 @@ def _build_memorize_workflow(self) -> list[WorkflowStep]: produces={"local_path", "raw_text"}, capabilities={"io"}, ), + WorkflowStep( + step_id="admission_check", + role="filter", + handler=self._memorize_admission_check, + requires={"raw_text"}, + produces=set(), + capabilities=set(), + ), WorkflowStep( step_id="preprocess_multimodal", role="preprocess", @@ -183,6 +201,16 @@ async def _memorize_ingest_resource(self, state: WorkflowState, step_context: An state.update({"local_path": local_path, "raw_text": raw_text}) return state + async def _memorize_admission_check(self, state: WorkflowState, step_context: Any) -> WorkflowState: + """Early exit when content fails admission gate.""" + gate = AdmissionGate(self.memorize_config.admission) + raw_text = state.get("raw_text") or "" + result = gate.check(raw_text) + if not result.allowed: + logger.info("Admission gate rejected content: %s (score=%.2f)", result.reason, result.score) + raise AdmissionRejectedError(result) + return state + async def _memorize_preprocess_multimodal(self, state: WorkflowState, step_context: Any) -> WorkflowState: llm_client = self._get_step_llm_client(step_context) preprocessed = await self._preprocess_resource_url( diff --git a/src/memu/app/retrieve.py b/src/memu/app/retrieve.py index a7cbff5c..463d9aa7 100644 --- a/src/memu/app/retrieve.py +++ b/src/memu/app/retrieve.py @@ -198,6 +198,20 @@ def _build_rag_retrieve_workflow(self) -> list[WorkflowStep]: capabilities={"vector"}, config={"embed_llm_profile": "embedding"}, ), + WorkflowStep( + step_id="recall_graph", + role="recall_graph", + handler=self._rag_recall_graph, + requires={ + "needs_retrieval", + "active_query", + "query_vector", + "store", + }, + produces={"graph_hits", "graph_recall_result"}, + capabilities={"vector"}, + config={"embed_llm_profile": "embedding"}, + ), WorkflowStep( step_id="build_context", role="build_context", @@ -423,6 +437,36 @@ async def _rag_recall_resources(self, state: WorkflowState, step_context: Any) - state["resource_hits"] = cosine_topk(qvec, corpus, k=self.retrieve_config.resource.top_k) return state + async def _rag_recall_graph(self, state: WorkflowState, step_context: Any) -> WorkflowState: + if not state.get("needs_retrieval") or not self.retrieve_config.graph.enabled: + state["graph_hits"] = [] + return state + + store = state["store"] + graph_store = getattr(store, "graph_store", None) + if graph_store is None: + state["graph_hits"] = [] + return state + + query_vec = state.get("query_vector") + if query_vec is None: + embed_client = self._get_step_embedding_client(step_context) + query_vec = (await embed_client.embed([state["active_query"]]))[0] + state["query_vector"] = query_vec + + result = graph_store.graph_recall( + state["active_query"], + query_vec=query_vec, + max_nodes=self.retrieve_config.graph.max_nodes, + where=state.get("where"), + ) + # Convert to (id, score) tuples for consistency with other hits + state["graph_hits"] = [ + (n.id, n.ppr_score) for n in result.nodes + ] + state["graph_recall_result"] = result + return state + def _rag_build_context(self, state: WorkflowState, _: Any) -> WorkflowState: response = { "needs_retrieval": bool(state.get("needs_retrieval")), @@ -432,6 +476,7 @@ def _rag_build_context(self, state: WorkflowState, _: Any) -> WorkflowState: "categories": [], "items": [], "resources": [], + "graph_nodes": [], } if state.get("needs_retrieval"): store = state["store"] @@ -443,11 +488,49 @@ def _rag_build_context(self, state: WorkflowState, _: Any) -> WorkflowState: state.get("category_hits", []), categories_pool, ) - response["items"] = self._materialize_hits(state.get("item_hits", []), items_pool) + + # Score fusion: only deflate item scores when graph is enabled AND returned results + graph_recall_result = state.get("graph_recall_result") + graph_active = ( + self.retrieve_config.graph.enabled + and graph_recall_result + and graph_recall_result.nodes + ) + item_hits = state.get("item_hits", []) + if graph_active: + graph_weight = self.retrieve_config.graph.weight + vector_weight = 1.0 - graph_weight + response["items"] = [ + {**d, "score": d["score"] * vector_weight} + for d in self._materialize_hits(item_hits, items_pool) + ] + else: + response["items"] = self._materialize_hits(item_hits, items_pool) + response["resources"] = self._materialize_hits( state.get("resource_hits", []), resources_pool, ) + + # Graph nodes: materialize from RecallResult + if graph_active: + gw = self.retrieve_config.graph.weight + max_ppr = max((n.ppr_score for n in graph_recall_result.nodes), default=0.0) or 1.0 + graph_entries = [] + for n in graph_recall_result.nodes: + ppr_norm = n.ppr_score / max_ppr + graph_entries.append({ + "id": n.id, + "type": n.type, + "name": n.name, + "description": n.description, + "content": n.content, + "community_id": n.community_id, + "score": ppr_norm * gw, + "ppr_score": n.ppr_score, + }) + response["graph_nodes"] = graph_entries + state["response"] = response return state @@ -714,6 +797,7 @@ def _llm_build_context(self, state: WorkflowState, _: Any) -> WorkflowState: "categories": [], "items": [], "resources": [], + "graph_nodes": [], } if state.get("needs_retrieval"): response["categories"] = list(state.get("category_hits") or []) diff --git a/src/memu/app/settings.py b/src/memu/app/settings.py index adcb4f16..bf7e4c58 100644 --- a/src/memu/app/settings.py +++ b/src/memu/app/settings.py @@ -167,6 +167,17 @@ class RetrieveItemConfig(BaseModel): ) +class RetrieveGraphConfig(BaseModel): + enabled: bool = Field(default=False, description="Whether to enable graph-enhanced retrieval.") + max_nodes: int = Field(default=6, description="Maximum graph nodes to return per recall.") + weight: float = Field( + default=0.3, + ge=0.0, + le=1.0, + description="Graph score weight (β) in fusion. Vector weight is 1-β.", + ) + + class RetrieveResourceConfig(BaseModel): enabled: bool = Field(default=True, description="Whether to enable resource retrieval.") top_k: int = Field(default=5, description="Total number of resources to retrieve.") @@ -195,12 +206,24 @@ class RetrieveConfig(BaseModel): category: RetrieveCategoryConfig = Field(default=RetrieveCategoryConfig()) item: RetrieveItemConfig = Field(default=RetrieveItemConfig()) resource: RetrieveResourceConfig = Field(default=RetrieveResourceConfig()) + graph: RetrieveGraphConfig = Field(default=RetrieveGraphConfig()) sufficiency_check: bool = Field(default=True, description="Whether to check sufficiency after each tier.") sufficiency_check_prompt: str = Field(default="", description="User prompt for sufficiency check.") sufficiency_check_llm_profile: str = Field(default="default", description="LLM profile for sufficiency check.") llm_ranking_llm_profile: str = Field(default="default", description="LLM profile for LLM ranking.") +class MemorizeAdmissionConfig(BaseModel): + """Configuration for the memory admission gate (write-time quality filter).""" + + enabled: bool = Field(default=False, description="Enable admission gate. When False, all content is accepted.") + min_length: int = Field(default=30, description="Reject content shorter than this (after stripping).") + threshold: float = Field(default=0.3, description="Minimum quality score (0-1) to admit content.") + noise_patterns: list[str] = Field( + default_factory=list, description="Additional regex patterns to reject as noise." + ) + + class MemorizeConfig(BaseModel): category_assign_threshold: float = Field(default=0.25) multimodal_preprocess_prompts: dict[str, str | CustomPrompt] = Field( @@ -240,6 +263,10 @@ class MemorizeConfig(BaseModel): default=False, description="Enable reinforcement tracking for memory items.", ) + admission: MemorizeAdmissionConfig = Field( + default_factory=MemorizeAdmissionConfig, + description="Write-time admission gate configuration.", + ) class PatchConfig(BaseModel): diff --git a/src/memu/database/inmemory/repo.py b/src/memu/database/inmemory/repo.py index 44275f9f..9a1c2005 100644 --- a/src/memu/database/inmemory/repo.py +++ b/src/memu/database/inmemory/repo.py @@ -18,6 +18,8 @@ class InMemoryStore(Database): + graph_store: Any | None = None + def __init__( self, *, diff --git a/src/memu/database/interfaces.py b/src/memu/database/interfaces.py index 4acd0312..094179cb 100644 --- a/src/memu/database/interfaces.py +++ b/src/memu/database/interfaces.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Protocol, runtime_checkable +from typing import Any, Protocol, runtime_checkable from memu.database.models import CategoryItem as CategoryItemRecord from memu.database.models import MemoryCategory as MemoryCategoryRecord @@ -17,6 +17,7 @@ class Database(Protocol): memory_category_repo: MemoryCategoryRepo memory_item_repo: MemoryItemRepo category_item_repo: CategoryItemRepo + graph_store: Any | None resources: dict[str, ResourceRecord] items: dict[str, MemoryItemRecord] diff --git a/src/memu/database/models.py b/src/memu/database/models.py index 0124b784..4c395dd6 100644 --- a/src/memu/database/models.py +++ b/src/memu/database/models.py @@ -105,6 +105,34 @@ class CategoryItem(BaseRecord): category_id: str +class GraphNode(BaseRecord): + type: str + name: str + description: str = "" + content: str = "" + status: str = "active" + validated_count: int = 1 + source_sessions: list[str] = Field(default_factory=list) + community_id: str | None = None + pagerank: float = 0.0 + embedding: list[float] | None = None + + +class GraphEdge(BaseRecord): + from_id: str + to_id: str + type: str + instruction: str = "" + condition: str | None = None + session_id: str | None = None + + +class GraphCommunity(BaseRecord): + summary: str | None = None + node_count: int = 0 + embedding: list[float] | None = None + + def merge_scope_model[TBaseRecord: BaseRecord]( user_model: type[BaseModel], core_model: type[TBaseRecord], *, name_suffix: str ) -> type[TBaseRecord]: @@ -137,6 +165,9 @@ def build_scoped_models( __all__ = [ "BaseRecord", "CategoryItem", + "GraphCommunity", + "GraphEdge", + "GraphNode", "MemoryCategory", "MemoryItem", "MemoryType", diff --git a/src/memu/database/postgres/migrations/versions/001_add_graph_tables.py b/src/memu/database/postgres/migrations/versions/001_add_graph_tables.py new file mode 100644 index 00000000..7b07cf6f --- /dev/null +++ b/src/memu/database/postgres/migrations/versions/001_add_graph_tables.py @@ -0,0 +1,88 @@ +"""Add graph tables (gm_nodes, gm_edges, gm_communities). + +Revision ID: 001_add_graph +Revises: +Create Date: 2026-03-27 + +Uses IF NOT EXISTS to safely adopt pre-existing tables created outside Alembic. +""" + +revision = "001_add_graph" +down_revision = None +branch_labels = None +depends_on = None + +from alembic import op +import sqlalchemy as sa +from pgvector.sqlalchemy import Vector + + +def upgrade() -> None: + # gm_nodes — graph knowledge nodes + op.execute(""" + CREATE TABLE IF NOT EXISTS gm_nodes ( + id TEXT PRIMARY KEY, + type TEXT NOT NULL, + name TEXT NOT NULL, + description TEXT NOT NULL DEFAULT '', + content TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'active', + validated_count INTEGER DEFAULT 1, + source_sessions TEXT[] DEFAULT '{}', + community_id TEXT, + pagerank REAL DEFAULT 0, + embedding vector, + created_at TIMESTAMPTZ DEFAULT now(), + updated_at TIMESTAMPTZ DEFAULT now() + ) + """) + + # gm_edges — directed graph edges + op.execute(""" + CREATE TABLE IF NOT EXISTS gm_edges ( + id TEXT PRIMARY KEY, + from_id TEXT NOT NULL, + to_id TEXT NOT NULL, + type TEXT NOT NULL, + instruction TEXT NOT NULL, + condition TEXT, + session_id TEXT, + created_at TIMESTAMPTZ DEFAULT now() + ) + """) + + # gm_communities — LPA community aggregates + op.execute(""" + CREATE TABLE IF NOT EXISTS gm_communities ( + id TEXT PRIMARY KEY, + summary TEXT, + node_count INTEGER DEFAULT 0, + embedding vector, + created_at TIMESTAMPTZ DEFAULT now(), + updated_at TIMESTAMPTZ DEFAULT now() + ) + """) + + # Scope column for multi-user support (added to pre-existing tables safely) + for table in ("gm_nodes", "gm_edges", "gm_communities"): + op.execute(f""" + DO $$ BEGIN + ALTER TABLE {table} ADD COLUMN IF NOT EXISTS user_id TEXT DEFAULT ''; + EXCEPTION WHEN duplicate_column THEN NULL; + END $$ + """) + + # Indexes (IF NOT EXISTS for safety) + op.execute("CREATE INDEX IF NOT EXISTS ix_gm_nodes_status ON gm_nodes (status)") + op.execute("CREATE INDEX IF NOT EXISTS ix_gm_nodes_community ON gm_nodes (community_id)") + op.execute("CREATE INDEX IF NOT EXISTS ix_gm_edges_from ON gm_edges (from_id)") + op.execute("CREATE INDEX IF NOT EXISTS ix_gm_edges_to ON gm_edges (to_id)") + op.execute("CREATE INDEX IF NOT EXISTS ix_gm_nodes__scope ON gm_nodes (user_id)") + op.execute("CREATE INDEX IF NOT EXISTS ix_gm_edges__scope ON gm_edges (user_id)") + op.execute("CREATE INDEX IF NOT EXISTS ix_gm_communities__scope ON gm_communities (user_id)") + + +def downgrade() -> None: + op.drop_table("gm_communities") + op.drop_table("gm_edges") + op.drop_table("gm_nodes") diff --git a/src/memu/database/postgres/migrations/versions/002_add_relation_category.py b/src/memu/database/postgres/migrations/versions/002_add_relation_category.py new file mode 100644 index 00000000..c074a0c4 --- /dev/null +++ b/src/memu/database/postgres/migrations/versions/002_add_relation_category.py @@ -0,0 +1,33 @@ +"""Add relation_category column to gm_edges. + +Revision ID: 002_relation_category +Revises: 001_add_graph +Create Date: 2026-03-28 + +Supports disentangled relation graphs (MAGMA 2026 pattern): +edges classified as semantic/temporal/causal/entity/synthesis. +""" + +revision = "002_relation_category" +down_revision = "001_add_graph" +branch_labels = None +depends_on = None + +from alembic import op +import sqlalchemy as sa + + +def upgrade() -> None: + op.execute(""" + ALTER TABLE gm_edges + ADD COLUMN IF NOT EXISTS relation_category TEXT NOT NULL DEFAULT 'semantic' + """) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_gm_edges_relation_category " + "ON gm_edges (relation_category)" + ) + + +def downgrade() -> None: + op.drop_index("ix_gm_edges_relation_category", table_name="gm_edges") + op.drop_column("gm_edges", "relation_category") diff --git a/src/memu/database/postgres/models.py b/src/memu/database/postgres/models.py index e83797a2..aa0cdd0e 100644 --- a/src/memu/database/postgres/models.py +++ b/src/memu/database/postgres/models.py @@ -13,11 +13,23 @@ raise ImportError(msg) from exc from pydantic import BaseModel +import sqlalchemy as sa from sqlalchemy import ForeignKey, MetaData, String, Text from sqlalchemy.dialects.postgresql import JSONB from sqlmodel import Column, DateTime, Field, Index, SQLModel, func -from memu.database.models import CategoryItem, MemoryCategory, MemoryItem, MemoryType, Resource +from sqlalchemy.dialects.postgresql import ARRAY as PgArray + +from memu.database.models import ( + CategoryItem, + GraphCommunity, + GraphEdge, + GraphNode, + MemoryCategory, + MemoryItem, + MemoryType, + Resource, +) class TZDateTime(DateTime): @@ -74,6 +86,37 @@ class CategoryItemModel(BaseModelMixin, CategoryItem): __table_args__ = (Index("idx_category_items_unique", "item_id", "category_id", unique=True),) +class GraphNodeModel(BaseModelMixin, GraphNode): + type: str = Field(sa_column=Column(String, nullable=False)) + name: str = Field(sa_column=Column(String, nullable=False)) + description: str = Field(default="", sa_column=Column(Text, nullable=False, server_default="")) + content: str = Field(default="", sa_column=Column(Text, nullable=False)) + status: str = Field(default="active", sa_column=Column(String, nullable=False, server_default="active")) + validated_count: int | None = Field(default=1, sa_column=Column(sa.Integer, nullable=True, server_default="1")) + source_sessions: list[str] = Field( + default_factory=list, + sa_column=Column(PgArray(Text), nullable=True, server_default="{}"), + ) + community_id: str | None = Field(default=None, sa_column=Column(String, nullable=True)) + pagerank: float | None = Field(default=0.0, sa_column=Column(sa.Float, nullable=True, server_default="0")) + embedding: list[float] | None = Field(default=None, sa_column=Column(Vector(), nullable=True)) + + +class GraphEdgeModel(BaseModelMixin, GraphEdge): + from_id: str = Field(sa_column=Column(String, nullable=False)) + to_id: str = Field(sa_column=Column(String, nullable=False)) + type: str = Field(sa_column=Column(String, nullable=False)) + instruction: str = Field(default="", sa_column=Column(Text, nullable=False)) + condition: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) + session_id: str | None = Field(default=None, sa_column=Column(String, nullable=True)) + + +class GraphCommunityModel(BaseModelMixin, GraphCommunity): + summary: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) + node_count: int | None = Field(default=0, sa_column=Column(sa.Integer, nullable=True, server_default="0")) + embedding: list[float] | None = Field(default=None, sa_column=Column(Vector(), nullable=True)) + + def _normalize_table_args(table_args: Any) -> tuple[list[Any], dict[str, Any]]: if table_args is None: return [], {} @@ -172,6 +215,9 @@ def build_scoped_models( __all__ = [ "BaseModelMixin", "CategoryItemModel", + "GraphCommunityModel", + "GraphEdgeModel", + "GraphNodeModel", "MemoryCategoryModel", "MemoryItemModel", "ResourceModel", diff --git a/src/memu/database/postgres/postgres.py b/src/memu/database/postgres/postgres.py index d1ff7b05..a1ab818b 100644 --- a/src/memu/database/postgres/postgres.py +++ b/src/memu/database/postgres/postgres.py @@ -9,6 +9,7 @@ from memu.database.models import CategoryItem, MemoryCategory, MemoryItem, Resource from memu.database.postgres.migration import DDLMode, run_migrations from memu.database.postgres.repositories.category_item_repo import PostgresCategoryItemRepo +from memu.database.postgres.repositories.graph_store import PostgresGraphStore from memu.database.postgres.repositories.memory_category_repo import PostgresMemoryCategoryRepo from memu.database.postgres.repositories.memory_item_repo import PostgresMemoryItemRepo from memu.database.postgres.repositories.resource_repo import PostgresResourceRepo @@ -25,6 +26,7 @@ class PostgresStore(Database): memory_category_repo: MemoryCategoryRepo memory_item_repo: MemoryItemRepo category_item_repo: CategoryItemRepo + graph_store: Any | None resources: dict[str, Resource] items: dict[str, MemoryItem] categories: dict[str, MemoryCategory] @@ -90,6 +92,13 @@ def __init__( sessions=self._sessions, scope_fields=self._scope_fields, ) + self.graph_store = PostgresGraphStore( + state=self._state, + sqla_models=self._sqla_models, + sessions=self._sessions, + scope_fields=self._scope_fields, + use_vector=self._use_vector_type, + ) self.resources = self._state.resources self.items = self._state.items diff --git a/src/memu/database/postgres/repositories/__init__.py b/src/memu/database/postgres/repositories/__init__.py index 648623e5..389c9329 100644 --- a/src/memu/database/postgres/repositories/__init__.py +++ b/src/memu/database/postgres/repositories/__init__.py @@ -1,10 +1,12 @@ from memu.database.postgres.repositories.category_item_repo import PostgresCategoryItemRepo +from memu.database.postgres.repositories.graph_store import PostgresGraphStore from memu.database.postgres.repositories.memory_category_repo import PostgresMemoryCategoryRepo from memu.database.postgres.repositories.memory_item_repo import PostgresMemoryItemRepo from memu.database.postgres.repositories.resource_repo import PostgresResourceRepo __all__ = [ "PostgresCategoryItemRepo", + "PostgresGraphStore", "PostgresMemoryCategoryRepo", "PostgresMemoryItemRepo", "PostgresResourceRepo", diff --git a/src/memu/database/postgres/repositories/graph_store.py b/src/memu/database/postgres/repositories/graph_store.py new file mode 100644 index 00000000..5dee0607 --- /dev/null +++ b/src/memu/database/postgres/repositories/graph_store.py @@ -0,0 +1,840 @@ +"""Graph-enhanced memory storage and retrieval. + +Provides GraphStore repository for managing knowledge graph nodes, edges, +and communities, plus dual-path graph recall (precise + generalized) with +Personalized PageRank scoring. +""" + +from __future__ import annotations + +import random +from collections import defaultdict +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any + +from memu.database.postgres.repositories.base import PostgresRepoBase +from memu.database.postgres.session import SessionManager +from memu.database.state import DatabaseState + + +@dataclass +class RecallNode: + id: str + name: str + type: str + description: str + content: str + community_id: str | None + pagerank: float + ppr_score: float + + +@dataclass +class RecallEdge: + from_name: str + to_name: str + type: str + instruction: str + + +@dataclass +class RecallResult: + nodes: list[RecallNode] + edges: list[RecallEdge] + path: str # "precise" | "generalized" | "merged" | "empty" + + +class PostgresGraphStore(PostgresRepoBase): + """Repository for graph nodes, edges, and communities. + + Handles CRUD + dual-path graph recall with PPR scoring. + """ + + def __init__( + self, + *, + state: DatabaseState, + sqla_models: Any, + sessions: SessionManager, + scope_fields: list[str], + use_vector: bool = True, + ) -> None: + super().__init__( + state=state, + sqla_models=sqla_models, + sessions=sessions, + scope_fields=scope_fields, + use_vector=use_vector, + ) + + # ── Node CRUD ────────────────────────────────────────────────── + + def get_node(self, node_id: str) -> Any | None: + from sqlmodel import select + + model = self._sqla_models.GraphNode + with self._sessions.session() as session: + return session.scalar(select(model).where(model.id == node_id)) + + def list_nodes(self, where: Mapping[str, Any] | None = None) -> list[Any]: + from sqlmodel import select + + model = self._sqla_models.GraphNode + filters = self._build_filters(model, where) + with self._sessions.session() as session: + return list(session.scalars(select(model).where(*filters)).all()) + + def create_node(self, **kwargs: Any) -> Any: + model = self._sqla_models.GraphNode + now = self._now() + obj = model(created_at=now, updated_at=now, **kwargs) + with self._sessions.session() as session: + session.add(obj) + session.commit() + session.refresh(obj) + return obj + + def update_node(self, node_id: str, **kwargs: Any) -> Any: + from sqlmodel import select + + model = self._sqla_models.GraphNode + with self._sessions.session() as session: + obj = session.scalar(select(model).where(model.id == node_id)) + if obj is None: + msg = f"GraphNode {node_id} not found" + raise KeyError(msg) + for k, v in kwargs.items(): + setattr(obj, k, v) + obj.updated_at = self._now() + session.add(obj) + session.commit() + session.refresh(obj) + return obj + + def delete_node(self, node_id: str) -> None: + from sqlmodel import delete + + node_model = self._sqla_models.GraphNode + edge_model = self._sqla_models.GraphEdge + with self._sessions.session() as session: + # Cascade: remove all edges touching this node + session.exec( + delete(edge_model).where( + (edge_model.from_id == node_id) | (edge_model.to_id == node_id) + ) + ) + session.exec(delete(node_model).where(node_model.id == node_id)) + session.commit() + + # ── Edge CRUD ────────────────────────────────────────────────── + + def create_edge(self, **kwargs: Any) -> Any: + model = self._sqla_models.GraphEdge + now = self._now() + obj = model(created_at=now, **kwargs) + with self._sessions.session() as session: + session.add(obj) + session.commit() + session.refresh(obj) + return obj + + def list_edges(self, where: Mapping[str, Any] | None = None) -> list[Any]: + from sqlmodel import select + + model = self._sqla_models.GraphEdge + filters = self._build_filters(model, where) + with self._sessions.session() as session: + return list(session.scalars(select(model).where(*filters)).all()) + + def delete_edge(self, edge_id: str) -> None: + from sqlmodel import delete + + model = self._sqla_models.GraphEdge + with self._sessions.session() as session: + session.exec(delete(model).where(model.id == edge_id)) + session.commit() + + # ── Community CRUD ───────────────────────────────────────────── + + def create_community(self, **kwargs: Any) -> Any: + model = self._sqla_models.GraphCommunity + now = self._now() + obj = model(created_at=now, updated_at=now, **kwargs) + with self._sessions.session() as session: + session.add(obj) + session.commit() + session.refresh(obj) + return obj + + def list_communities(self, where: Mapping[str, Any] | None = None) -> list[Any]: + from sqlmodel import select + + model = self._sqla_models.GraphCommunity + filters = self._build_filters(model, where) + with self._sessions.session() as session: + return list(session.scalars(select(model).where(*filters)).all()) + + def clear_communities(self) -> None: + from sqlmodel import delete + + model = self._sqla_models.GraphCommunity + with self._sessions.session() as session: + session.exec(delete(model)) + session.commit() + + # ── Graph loading (for PPR) ──────────────────────────────────── + + def load_graph( + self, where: Mapping[str, Any] | None = None + ) -> tuple[set[str], dict[str, set[str]]]: + """Load active node IDs and undirected adjacency from DB.""" + from sqlmodel import select + + node_model = self._sqla_models.GraphNode + edge_model = self._sqla_models.GraphEdge + scope_filters = self._build_filters(node_model, where) + + with self._sessions.session() as session: + node_ids = set( + session.scalars( + select(node_model.id).where( + node_model.status == "active", *scope_filters + ) + ).all() + ) + + adj: dict[str, set[str]] = defaultdict(set) + edge_scope_filters = self._build_filters(edge_model, where) + edges = session.execute( + select(edge_model.from_id, edge_model.to_id).where(*edge_scope_filters) + ).all() + for from_id, to_id in edges: + if from_id in node_ids and to_id in node_ids: + adj[from_id].add(to_id) + adj[to_id].add(from_id) + + for nid in node_ids: + if nid not in adj: + adj[nid] = set() + + return node_ids, adj + + # ── Seed search ──────────────────────────────────────────────── + + def vector_seed_search( + self, + query_vec: list[float], + limit: int = 6, + min_score: float = 0.35, + where: Mapping[str, Any] | None = None, + ) -> list[tuple[str, float]]: + """Find seed nodes by pgvector cosine similarity.""" + if not self._use_vector: + return [] + + node_model = self._sqla_models.GraphNode + distance = node_model.embedding.cosine_distance(query_vec) + score_col = (1 - distance).label("score") + scope_filters = self._build_filters(node_model, where) + + from sqlmodel import select + + stmt = ( + select(node_model.id, score_col) + .where( + node_model.status == "active", + node_model.embedding.isnot(None), + (1 - distance) >= min_score, + *scope_filters, + ) + .order_by(distance) + .limit(limit) + ) + + with self._sessions.session() as session: + rows = session.execute(stmt).all() + return [(rid, float(score)) for rid, score in rows] + + def fts_seed_search( + self, query: str, limit: int = 6, where: Mapping[str, Any] | None = None + ) -> list[str]: + """Fallback: full-text search on node name/description/content.""" + from sqlalchemy import func, text + from sqlmodel import select + + node_model = self._sqla_models.GraphNode + scope_filters = self._build_filters(node_model, where) + tsvec = func.to_tsvector( + "simple", + func.coalesce(node_model.name, "") + + " " + + func.coalesce(node_model.description, "") + + " " + + func.coalesce(node_model.content, ""), + ) + tsq = func.plainto_tsquery(text("'simple'"), query) + + stmt = ( + select(node_model.id) + .where(node_model.status == "active", tsvec.op("@@")(tsq), *scope_filters) + .order_by(node_model.pagerank.desc()) + .limit(limit) + ) + + with self._sessions.session() as session: + return list(session.scalars(stmt).all()) + + def get_community_peers( + self, node_id: str, limit: int = 2, where: Mapping[str, Any] | None = None + ) -> list[str]: + """Get peers in the same community, ordered by validated_count.""" + from sqlmodel import select + + node_model = self._sqla_models.GraphNode + scope_filters = self._build_filters(node_model, where) + + # First get this node's community_id + with self._sessions.session() as session: + node = session.scalar(select(node_model).where(node_model.id == node_id)) + if not node or not node.community_id: + return [] + + stmt = ( + select(node_model.id) + .where( + node_model.community_id == node.community_id, + node_model.id != node_id, + node_model.status == "active", + *scope_filters, + ) + .order_by(node_model.validated_count.desc(), node_model.updated_at.desc()) + .limit(limit) + ) + return list(session.scalars(stmt).all()) + + # ── Graph walk ───────────────────────────────────────────────── + + def graph_walk( + self, start_ids: set[str], depth: int = 2, where: Mapping[str, Any] | None = None + ) -> set[str]: + """BFS graph walk up to `depth` hops, undirected.""" + from sqlalchemy import or_ + from sqlmodel import select + + edge_model = self._sqla_models.GraphEdge + node_model = self._sqla_models.GraphNode + scope_filters = self._build_filters(node_model, where) + + visited = set(start_ids) + frontier = set(start_ids) + + with self._sessions.session() as session: + for _ in range(depth): + if not frontier: + break + frontier_list = list(frontier) + edge_scope_filters = self._build_filters(edge_model, where) + stmt = select(edge_model.from_id, edge_model.to_id).where( + or_( + edge_model.from_id.in_(frontier_list), + edge_model.to_id.in_(frontier_list), + ), + *edge_scope_filters, + ) + rows = session.execute(stmt).all() + neighbors: set[str] = set() + for from_id, to_id in rows: + if from_id in frontier: + neighbors.add(to_id) + if to_id in frontier: + neighbors.add(from_id) + new_nodes = neighbors - visited + visited |= new_nodes + frontier = new_nodes + + # Filter to active nodes only, respecting scope + active = set( + session.scalars( + select(node_model.id).where( + node_model.id.in_(list(visited)), + node_model.status == "active", + *scope_filters, + ) + ).all() + ) + + return active + + # ── Node/Edge loading for recall results ─────────────────────── + + def load_recall_nodes( + self, node_ids: set[str], where: Mapping[str, Any] | None = None + ) -> dict[str, RecallNode]: + """Load full node data as RecallNode dataclasses.""" + if not node_ids: + return {} + + from sqlmodel import select + + node_model = self._sqla_models.GraphNode + scope_filters = self._build_filters(node_model, where) + with self._sessions.session() as session: + rows = session.scalars( + select(node_model).where(node_model.id.in_(list(node_ids)), *scope_filters) + ).all() + return { + r.id: RecallNode( + id=r.id, + name=r.name, + type=r.type, + description=r.description or "", + content=r.content or "", + community_id=r.community_id, + pagerank=r.pagerank or 0.0, + ppr_score=0.0, + ) + for r in rows + } + + def load_recall_edges( + self, node_ids: set[str], where: Mapping[str, Any] | None = None + ) -> list[RecallEdge]: + """Load edges where both endpoints are in node_ids, with scope filtering.""" + if not node_ids: + return [] + + from sqlalchemy import alias + from sqlmodel import select + + edge_model = self._sqla_models.GraphEdge + node_model = self._sqla_models.GraphNode + node_list = list(node_ids) + + # Join node table twice to resolve from/to names + to_node = alias(node_model.__table__, name="to_node") + from_node = alias(node_model.__table__, name="from_node") + + with self._sessions.session() as session: + stmt = ( + select( + from_node.c.name.label("from_name"), + to_node.c.name.label("to_name"), + edge_model.type, + edge_model.instruction, + ) + .join(from_node, edge_model.from_id == from_node.c.id) + .join(to_node, edge_model.to_id == to_node.c.id) + .where( + edge_model.from_id.in_(node_list), + edge_model.to_id.in_(node_list), + *self._build_filters(edge_model, where), + ) + ) + + rows = session.execute(stmt).all() + return [ + RecallEdge( + from_name=r.from_name, + to_name=r.to_name, + type=r.type, + instruction=r.instruction or "", + ) + for r in rows + ] + + # ── PPR algorithm ────────────────────────────────────────────── + + @staticmethod + def personalized_pagerank( + node_ids: set[str], + adj: dict[str, set[str]], + seed_ids: list[str], + candidate_ids: set[str] | None = None, + damping: float = 0.85, + iterations: int = 20, + ) -> dict[str, float]: + """Personalized PageRank from seed nodes.""" + valid_seeds = [s for s in seed_ids if s in node_ids] + if not valid_seeds: + return {} + + teleport_weight = 1.0 / len(valid_seeds) + seed_set = set(valid_seeds) + + rank = {nid: (teleport_weight if nid in seed_set else 0.0) for nid in node_ids} + + for _ in range(iterations): + new_rank = { + nid: ((1 - damping) * teleport_weight if nid in seed_set else 0.0) + for nid in node_ids + } + + for nid in node_ids: + neighbors = adj[nid] + if not neighbors: + continue + contrib = rank[nid] / len(neighbors) + for nb in neighbors: + new_rank[nb] = new_rank.get(nb, 0.0) + damping * contrib + + dangling_sum = sum(rank[nid] for nid in node_ids if not adj[nid]) + if dangling_sum > 0: + dangling_contrib = damping * dangling_sum * teleport_weight + for sid in valid_seeds: + new_rank[sid] += dangling_contrib + + rank = new_rank + + if candidate_ids is not None: + return {nid: rank.get(nid, 0.0) for nid in candidate_ids} + return rank + + @staticmethod + def global_pagerank( + node_ids: set[str], + adj: dict[str, set[str]], + damping: float = 0.85, + iterations: int = 20, + ) -> dict[str, float]: + """Global PageRank — uniform teleport.""" + n = len(node_ids) + if n == 0: + return {} + + teleport_base = (1 - damping) / n + rank = dict.fromkeys(node_ids, 1.0 / n) + + for _ in range(iterations): + new_rank = dict.fromkeys(node_ids, teleport_base) + + for nid in node_ids: + neighbors = adj[nid] + if not neighbors: + continue + contrib = rank[nid] / len(neighbors) + for nb in neighbors: + new_rank[nb] += damping * contrib + + dangling_sum = sum(rank[nid] for nid in node_ids if not adj[nid]) + if dangling_sum > 0: + dangling_contrib = damping * dangling_sum / n + for nid in node_ids: + new_rank[nid] += dangling_contrib + + rank = new_rank + + return rank + + @staticmethod + def label_propagation( + node_ids: set[str], + adj: dict[str, set[str]], + max_iter: int = 50, + seed: int | None = None, + ) -> dict[str, str]: + """Label Propagation Algorithm for community detection.""" + rng = random.Random(seed) # noqa: S311 — LPA shuffle is not security-sensitive + nodes = list(node_ids) + label = {nid: nid for nid in nodes} + + for _ in range(max_iter): + changed = False + rng.shuffle(nodes) + + for nid in nodes: + neighbors = adj.get(nid, set()) + if not neighbors: + continue + + freq: dict[str, int] = defaultdict(int) + for nb in neighbors: + freq[label[nb]] += 1 + + max_freq = max(freq.values()) + candidates = [lbl for lbl, f in freq.items() if f == max_freq] + best_label = min(candidates) + + if label[nid] != best_label: + label[nid] = best_label + changed = True + + if not changed: + break + + # Renumber by descending size + communities: dict[str, list[str]] = defaultdict(list) + for nid, lab in label.items(): + communities[lab].append(nid) + + sorted_communities = sorted(communities.items(), key=lambda x: -len(x[1])) + rename = {old_label: f"c-{rank + 1}" for rank, (old_label, _) in enumerate(sorted_communities)} + + return {nid: rename[label[nid]] for nid in nodes} + + # ── Maintenance (PageRank + LPA) ────────────────────────────── + + def run_maintenance(self) -> dict[str, int]: + """Run global PageRank + LPA community detection. Returns stats.""" + node_ids, adj = self.load_graph() + if not node_ids: + return {"nodes": 0, "communities": 0} + + # Global PageRank + pr_scores = self.global_pagerank(node_ids, adj) + self.write_pagerank(pr_scores) + + # Community detection + labels = self.label_propagation(node_ids, adj, seed=42) + self.write_communities(labels) + + return {"nodes": len(pr_scores), "communities": len(set(labels.values()))} + + def write_pagerank(self, scores: dict[str, float]) -> None: + """Write global PageRank scores to graph nodes.""" + from sqlmodel import select + + node_model = self._sqla_models.GraphNode + with self._sessions.session() as session: + for nid, score in scores.items(): + node = session.scalar(select(node_model).where(node_model.id == nid)) + if node: + node.pagerank = score + session.add(node) + session.commit() + + def write_communities(self, labels: dict[str, str]) -> None: + """Write community labels to nodes and rebuild community table.""" + from sqlmodel import delete, select + + node_model = self._sqla_models.GraphNode + community_model = self._sqla_models.GraphCommunity + + with self._sessions.session() as session: + # Update node community_id + for nid, cid in labels.items(): + node = session.scalar(select(node_model).where(node_model.id == nid)) + if node: + node.community_id = cid + session.add(node) + + # Rebuild communities + session.exec(delete(community_model)) + + community_members: dict[str, list[str]] = defaultdict(list) + for nid, cid in labels.items(): + community_members[cid].append(nid) + + # Extract scope fields from an existing node to propagate to communities + scope_kwargs: dict[str, Any] = {} + if labels: + sample_nid = next(iter(labels)) + sample_node = session.scalar( + select(node_model).where(node_model.id == sample_nid) + ) + if sample_node: + for field in self._scope_fields: + val = getattr(sample_node, field, None) + if val is not None and hasattr(community_model, field): + scope_kwargs[field] = val + + now = self._now() + for cid, members in community_members.items(): + obj = community_model( + id=cid, node_count=len(members), + created_at=now, updated_at=now, + **scope_kwargs, + ) + session.add(obj) + + session.commit() + + # ── Dual-path graph recall ───────────────────────────────────── + + def recall_precise( + self, + query: str, + query_vec: list[float] | None, + node_ids: set[str], + adj: dict[str, set[str]], + max_nodes: int = 6, + where: Mapping[str, Any] | None = None, + ) -> RecallResult: + """Precise path: vector/FTS seed → community expansion → walk → PPR.""" + seeds: list[tuple[str, float]] = [] + if query_vec: + seeds = self.vector_seed_search(query_vec, limit=max_nodes // 2, where=where) + + seed_ids = [s[0] for s in seeds] + if len(seed_ids) < 2: + fts_ids = self.fts_seed_search(query, limit=max_nodes, where=where) + seed_id_set = set(seed_ids) + for fid in fts_ids: + if fid not in seed_id_set: + seed_ids.append(fid) + + if not seed_ids: + return RecallResult(nodes=[], edges=[], path="precise") + + # Community expansion + expanded = set(seed_ids) + for sid in seed_ids: + peers = self.get_community_peers(sid, limit=2, where=where) + expanded.update(peers) + + # Graph walk + walked = self.graph_walk(expanded, depth=2, where=where) + + # PPR ranking + ppr = self.personalized_pagerank(node_ids, adj, seed_ids, candidate_ids=walked) + + nodes_data = self.load_recall_nodes(walked, where=where) + for nid, node in nodes_data.items(): + node.ppr_score = ppr.get(nid, 0.0) + + sorted_nodes = sorted( + nodes_data.values(), + key=lambda n: (-n.ppr_score, -n.pagerank), + )[:max_nodes] + + result_ids = {n.id for n in sorted_nodes} + edges = self.load_recall_edges(result_ids, where=where) + + return RecallResult(nodes=sorted_nodes, edges=edges, path="precise") + + def recall_generalized( + self, + query: str, + query_vec: list[float] | None, + node_ids: set[str], + adj: dict[str, set[str]], + max_nodes: int = 6, + where: Mapping[str, Any] | None = None, + ) -> RecallResult: + """Generalized path: community representatives → shallow walk → PPR.""" + from sqlmodel import select + + node_model = self._sqla_models.GraphNode + scope_filters = self._build_filters(node_model, where) + + with self._sessions.session() as session: + # Pick top representative per community + stmt = ( + select(node_model.id, node_model.community_id) + .where( + node_model.status == "active", + node_model.community_id.isnot(None), + *scope_filters, + ) + .order_by( + node_model.community_id, + node_model.validated_count.desc(), + node_model.updated_at.desc(), + ) + ) + rows = session.execute(stmt).all() + + # Deduplicate: first per community wins (since ordered by validated_count desc) + seen_communities: set[str] = set() + rep_ids: list[str] = [] + for nid, cid in rows: + if cid not in seen_communities: + seen_communities.add(cid) + rep_ids.append(nid) + + if not rep_ids: + return RecallResult(nodes=[], edges=[], path="generalized") + + # Rank community representatives by query relevance (P1 #1 fix) + if query_vec and self._use_vector: + rep_scores = self.vector_seed_search( + query_vec, limit=max_nodes, where=where + ) + rep_score_map = dict(rep_scores) + # Sort reps by cosine similarity; reps not in results get score 0 + rep_ids = sorted( + rep_ids, + key=lambda nid: rep_score_map.get(nid, 0.0), + reverse=True, + ) + seed_ids = rep_ids[:max_nodes] + + # Shallow walk + walked = self.graph_walk(set(seed_ids), depth=1, where=where) + + # PPR ranking + ppr = self.personalized_pagerank(node_ids, adj, seed_ids, candidate_ids=walked) + + nodes_data = self.load_recall_nodes(walked, where=where) + for nid, node in nodes_data.items(): + node.ppr_score = ppr.get(nid, 0.0) + + sorted_nodes = sorted( + nodes_data.values(), + key=lambda n: (-n.ppr_score, -n.pagerank), + )[:max_nodes] + + result_ids = {n.id for n in sorted_nodes} + edges = self.load_recall_edges(result_ids, where=where) + + return RecallResult(nodes=sorted_nodes, edges=edges, path="generalized") + + @staticmethod + def merge_results( + precise: RecallResult, + generalized: RecallResult, + max_nodes: int = 0, + ) -> RecallResult: + """Merge: precise wins on dedup, generalized fills gaps, cap to max_nodes.""" + seen_ids: set[str] = set() + merged_nodes: list[RecallNode] = [] + + for n in precise.nodes: + if n.id not in seen_ids: + merged_nodes.append(n) + seen_ids.add(n.id) + + for n in generalized.nodes: + if n.id not in seen_ids: + merged_nodes.append(n) + seen_ids.add(n.id) + + if max_nodes > 0: + merged_nodes = merged_nodes[:max_nodes] + + edge_set: set[tuple[str, str, str]] = set() + merged_edges: list[RecallEdge] = [] + for e in precise.edges + generalized.edges: + key = (e.from_name, e.to_name, e.type) + if key not in edge_set: + merged_edges.append(e) + edge_set.add(key) + + return RecallResult(nodes=merged_nodes, edges=merged_edges, path="merged") + + def graph_recall( + self, + query: str, + query_vec: list[float] | None = None, + max_nodes: int = 6, + where: Mapping[str, Any] | None = None, + ) -> RecallResult: + """Full dual-path graph recall.""" + node_ids, adj = self.load_graph(where=where) + if not node_ids: + return RecallResult(nodes=[], edges=[], path="empty") + + precise = self.recall_precise(query, query_vec, node_ids, adj, max_nodes, where=where) + generalized = self.recall_generalized( + query, query_vec, node_ids, adj, max_nodes, where=where + ) + + return self.merge_results(precise, generalized, max_nodes=max_nodes) + + +__all__ = [ + "PostgresGraphStore", + "RecallEdge", + "RecallNode", + "RecallResult", +] diff --git a/src/memu/database/postgres/schema.py b/src/memu/database/postgres/schema.py index ac6e8b52..00a53a64 100644 --- a/src/memu/database/postgres/schema.py +++ b/src/memu/database/postgres/schema.py @@ -25,6 +25,9 @@ from memu.database.postgres.models import ( CategoryItemModel, + GraphCommunityModel, + GraphEdgeModel, + GraphNodeModel, MemoryCategoryModel, MemoryItemModel, ResourceModel, @@ -39,6 +42,9 @@ class SQLAModels: MemoryCategory: type[Any] MemoryItem: type[Any] CategoryItem: type[Any] + GraphNode: type[Any] | None = None + GraphEdge: type[Any] | None = None + GraphCommunity: type[Any] | None = None _MODEL_CACHE: dict[type[Any], SQLAModels] = {} @@ -85,6 +91,24 @@ def get_sqlalchemy_models(*, scope_model: type[BaseModel] | None = None) -> SQLA tablename="category_items", metadata=metadata_obj, ) + graph_node_model = build_table_model( + scope, + GraphNodeModel, + tablename="gm_nodes", + metadata=metadata_obj, + ) + graph_edge_model = build_table_model( + scope, + GraphEdgeModel, + tablename="gm_edges", + metadata=metadata_obj, + ) + graph_community_model = build_table_model( + scope, + GraphCommunityModel, + tablename="gm_communities", + metadata=metadata_obj, + ) class Base(SQLModel): __abstract__ = True @@ -96,6 +120,9 @@ class Base(SQLModel): MemoryCategory=memory_category_model, MemoryItem=memory_item_model, CategoryItem=category_item_model, + GraphNode=graph_node_model, + GraphEdge=graph_edge_model, + GraphCommunity=graph_community_model, ) _MODEL_CACHE[cache_key] = models return models diff --git a/src/memu/database/sqlite/sqlite.py b/src/memu/database/sqlite/sqlite.py index 2083dd99..42930015 100644 --- a/src/memu/database/sqlite/sqlite.py +++ b/src/memu/database/sqlite/sqlite.py @@ -44,6 +44,7 @@ class SQLiteStore(Database): memory_category_repo: MemoryCategoryRepo memory_item_repo: MemoryItemRepo category_item_repo: CategoryItemRepo + graph_store: Any | None = None resources: dict[str, Resource] items: dict[str, MemoryItem] categories: dict[str, MemoryCategory] diff --git a/tests/test_admission.py b/tests/test_admission.py new file mode 100644 index 00000000..79c70b5d --- /dev/null +++ b/tests/test_admission.py @@ -0,0 +1,134 @@ +"""Tests for the Memory Admission Gate.""" + +from __future__ import annotations + +import pytest + +from memu.app.admission import AdmissionGate, AdmissionResult +from memu.app.settings import MemorizeAdmissionConfig + + +def _gate(*, enabled: bool = True, min_length: int = 30, threshold: float = 0.3, noise_patterns: list[str] | None = None) -> AdmissionGate: + return AdmissionGate( + MemorizeAdmissionConfig( + enabled=enabled, + min_length=min_length, + threshold=threshold, + noise_patterns=noise_patterns or [], + ) + ) + + +# ------------------------------------------------------------------ +# Gate disabled → everything passes +# ------------------------------------------------------------------ + +class TestGateDisabled: + def test_short_string_passes_when_disabled(self): + r = _gate(enabled=False).check("hi") + assert r.allowed is True + assert r.score == 1.0 + assert r.reason == "gate_disabled" + + def test_noise_passes_when_disabled(self): + r = _gate(enabled=False).check(" stuff") + assert r.allowed is True + + +# ------------------------------------------------------------------ +# Min-length filter +# ------------------------------------------------------------------ + +class TestMinLength: + def test_too_short_rejected(self): + r = _gate(min_length=30).check("short") + assert r.allowed is False + assert "too_short" in r.reason + assert r.score == 0.0 + + def test_exactly_at_min_length(self): + text = "A" * 30 + r = _gate(min_length=30).check(text) + # Not rejected by length (may still be rejected by score) + assert "too_short" not in r.reason + + def test_whitespace_only_rejected(self): + r = _gate(min_length=5).check(" ") + assert r.allowed is False + assert "too_short" in r.reason + + +# ------------------------------------------------------------------ +# Built-in noise patterns +# ------------------------------------------------------------------ + +class TestNoisePatterns: + def test_local_command_caveat(self): + r = _gate().check("This is a message with in it and some more text here.") + assert r.allowed is False + assert "noise_pattern" in r.reason + + def test_exit_code(self): + r = _gate().check("EXIT: 0\nSome other output that makes it long enough to pass length check.") + assert r.allowed is False + assert "noise_pattern" in r.reason + + def test_shell_prompt(self): + r = _gate().check("$ git status\nOn branch main nothing to commit working tree clean enough text") + assert r.allowed is False + assert "noise_pattern" in r.reason + + def test_pure_json(self): + r = _gate().check('{"key": "value", "another": 123, "nested": {"a": true}}') + assert r.allowed is False + assert "noise_pattern" in r.reason + + def test_json_with_natural_language_not_rejected_by_json_pattern(self): + # Text that starts with [ but has natural language around it shouldn't match the pure-JSON pattern + text = "Here is some context about the user. They prefer dark mode and use vim." + r = _gate().check(text) + # Should not be rejected by noise patterns + assert "noise_pattern" not in r.reason + + +# ------------------------------------------------------------------ +# Custom noise patterns +# ------------------------------------------------------------------ + +class TestCustomPatterns: + def test_custom_pattern_rejects(self): + r = _gate(noise_patterns=[r"IGNORE_THIS"]).check( + "Some text that contains IGNORE_THIS marker and is long enough." + ) + assert r.allowed is False + assert "custom_noise_pattern" in r.reason + + +# ------------------------------------------------------------------ +# Quality score / threshold +# ------------------------------------------------------------------ + +class TestScoreThreshold: + def test_natural_language_passes(self): + text = "The user prefers dark mode and uses Vim as their primary editor. They are experienced with Python." + r = _gate(threshold=0.3).check(text) + assert r.allowed is True + assert r.score >= 0.3 + + def test_low_quality_rejected(self): + # All caps, no spaces, short-ish → low score + text = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" # 31 chars, no spaces, no mixed case + r = _gate(threshold=0.5).check(text) + assert r.allowed is False + assert "low_score" in r.reason + + +# ------------------------------------------------------------------ +# AdmissionResult dataclass +# ------------------------------------------------------------------ + +class TestAdmissionResult: + def test_is_frozen(self): + r = AdmissionResult(allowed=True, reason="pass", score=0.8) + with pytest.raises(AttributeError): + r.allowed = False # type: ignore[misc] diff --git a/tests/test_graph_store.py b/tests/test_graph_store.py new file mode 100644 index 00000000..4ca074b7 --- /dev/null +++ b/tests/test_graph_store.py @@ -0,0 +1,372 @@ +""" +Tests for graph-enhanced memory: GraphStore algorithms, score fusion, and degradation paths. + +Tests are pure-Python (no DB) except where noted — they test the static algorithm methods +directly and mock the DB layer for integration paths. +""" + +from __future__ import annotations + +from collections import defaultdict + +import pytest + +from memu.database.postgres.repositories.graph_store import ( + PostgresGraphStore, + RecallEdge, + RecallNode, + RecallResult, +) + + +# ── PPR algorithm tests ────────────────────────────────────────── + + +class TestPersonalizedPageRank: + """Test the static PPR implementation.""" + + def test_single_seed_no_edges(self): + """Single seed with no edges: all mass stays on seed.""" + node_ids = {"a", "b", "c"} + adj: dict[str, set[str]] = {"a": set(), "b": set(), "c": set()} + result = PostgresGraphStore.personalized_pagerank(node_ids, adj, ["a"]) + assert result["a"] > 0 + # Non-seeds should have zero or near-zero score + assert result["b"] < 0.01 + assert result["c"] < 0.01 + + def test_two_nodes_linked(self): + """Two connected nodes: seed propagates to neighbor.""" + node_ids = {"a", "b"} + adj: dict[str, set[str]] = {"a": {"b"}, "b": {"a"}} + result = PostgresGraphStore.personalized_pagerank(node_ids, adj, ["a"]) + assert result["a"] > result["b"] + assert result["b"] > 0 + + def test_multiple_seeds(self): + """Multiple seeds share teleport mass.""" + node_ids = {"a", "b", "c"} + adj: dict[str, set[str]] = {"a": {"b"}, "b": {"a", "c"}, "c": {"b"}} + result = PostgresGraphStore.personalized_pagerank(node_ids, adj, ["a", "c"]) + # Both seeds should have significant mass + assert result["a"] > 0.1 + assert result["c"] > 0.1 + + def test_empty_graph(self): + """Empty graph returns empty dict.""" + result = PostgresGraphStore.personalized_pagerank(set(), {}, ["a"]) + assert result == {} + + def test_invalid_seeds_ignored(self): + """Seeds not in node_ids are filtered out.""" + node_ids = {"a", "b"} + adj: dict[str, set[str]] = {"a": {"b"}, "b": {"a"}} + result = PostgresGraphStore.personalized_pagerank(node_ids, adj, ["x", "y"]) + assert result == {} + + def test_candidate_filtering(self): + """Only candidate_ids appear in results.""" + node_ids = {"a", "b", "c"} + adj: dict[str, set[str]] = {"a": {"b"}, "b": {"a", "c"}, "c": {"b"}} + result = PostgresGraphStore.personalized_pagerank( + node_ids, adj, ["a"], candidate_ids={"b"} + ) + assert "b" in result + assert "a" not in result + assert "c" not in result + + def test_scores_sum_to_approximately_one(self): + """PPR scores should approximately sum to 1.0.""" + node_ids = {"a", "b", "c", "d"} + adj: dict[str, set[str]] = { + "a": {"b", "c"}, + "b": {"a", "d"}, + "c": {"a"}, + "d": {"b"}, + } + result = PostgresGraphStore.personalized_pagerank(node_ids, adj, ["a"]) + total = sum(result.values()) + assert abs(total - 1.0) < 0.05 + + def test_damping_factor(self): + """Lower damping = more teleport = non-seed gets less mass.""" + node_ids = {"a", "b", "c", "d"} + adj: dict[str, set[str]] = { + "a": {"b"}, "b": {"a", "c"}, "c": {"b", "d"}, "d": {"c"}, + } + high_damp = PostgresGraphStore.personalized_pagerank( + node_ids, adj, ["a"], damping=0.95 + ) + low_damp = PostgresGraphStore.personalized_pagerank( + node_ids, adj, ["a"], damping=0.5 + ) + # Lower damping → distant node (d) gets less mass + assert low_damp["d"] < high_damp["d"] + + +# ── Global PageRank tests ───────────────────────────────────────── + + +class TestGlobalPageRank: + """Test global (uniform teleport) PageRank.""" + + def test_empty_graph(self): + result = PostgresGraphStore.global_pagerank(set(), {}) + assert result == {} + + def test_uniform_for_symmetric_graph(self): + """Symmetric graph → approximately uniform scores.""" + node_ids = {"a", "b", "c"} + adj: dict[str, set[str]] = {"a": {"b", "c"}, "b": {"a", "c"}, "c": {"a", "b"}} + result = PostgresGraphStore.global_pagerank(node_ids, adj) + scores = list(result.values()) + assert max(scores) - min(scores) < 0.05 + + def test_hub_node_gets_higher_score(self): + """Node with more connections gets higher PageRank.""" + node_ids = {"hub", "a", "b", "c"} + adj: dict[str, set[str]] = { + "hub": {"a", "b", "c"}, + "a": {"hub"}, + "b": {"hub"}, + "c": {"hub"}, + } + result = PostgresGraphStore.global_pagerank(node_ids, adj) + assert result["hub"] > result["a"] + + +# ── LPA community detection tests ───────────────────────────────── + + +class TestLabelPropagation: + """Test Label Propagation Algorithm.""" + + def test_disconnected_components(self): + """Two disconnected cliques → two communities.""" + node_ids = {"a", "b", "c", "x", "y", "z"} + adj: dict[str, set[str]] = { + "a": {"b", "c"}, + "b": {"a", "c"}, + "c": {"a", "b"}, + "x": {"y", "z"}, + "y": {"x", "z"}, + "z": {"x", "y"}, + } + labels = PostgresGraphStore.label_propagation(node_ids, adj, seed=42) + # Same clique → same community + assert labels["a"] == labels["b"] == labels["c"] + assert labels["x"] == labels["y"] == labels["z"] + # Different cliques → different communities + assert labels["a"] != labels["x"] + + def test_single_node(self): + """Single isolated node gets its own community.""" + labels = PostgresGraphStore.label_propagation({"a"}, {"a": set()}, seed=42) + assert "a" in labels + + def test_deterministic_with_seed(self): + """Same seed → same result.""" + node_ids = {"a", "b", "c", "d"} + adj: dict[str, set[str]] = { + "a": {"b"}, + "b": {"a", "c"}, + "c": {"b", "d"}, + "d": {"c"}, + } + r1 = PostgresGraphStore.label_propagation(node_ids, adj, seed=123) + r2 = PostgresGraphStore.label_propagation(node_ids, adj, seed=123) + assert r1 == r2 + + def test_community_ids_format(self): + """Communities are named c-1, c-2, ... sorted by size desc.""" + node_ids = {"a", "b", "c", "x"} + adj: dict[str, set[str]] = { + "a": {"b", "c"}, + "b": {"a", "c"}, + "c": {"a", "b"}, + "x": set(), + } + labels = PostgresGraphStore.label_propagation(node_ids, adj, seed=42) + community_ids = set(labels.values()) + assert all(c.startswith("c-") for c in community_ids) + # The larger group (a,b,c) should be c-1 + assert labels["a"] == "c-1" + + +# ── Merge results tests ─────────────────────────────────────────── + + +class TestMergeResults: + """Test dual-path merge logic.""" + + def _make_node(self, id: str, ppr: float = 0.5) -> RecallNode: + return RecallNode( + id=id, name=id, type="TEST", description="", content="", + community_id=None, pagerank=0.0, ppr_score=ppr, + ) + + def _make_edge(self, f: str, t: str) -> RecallEdge: + return RecallEdge(from_name=f, to_name=t, type="TEST", instruction="") + + def test_precise_wins_on_dedup(self): + """Precise path nodes take priority over generalized.""" + precise = RecallResult( + nodes=[self._make_node("a", 0.9)], + edges=[], + path="precise", + ) + generalized = RecallResult( + nodes=[self._make_node("a", 0.5), self._make_node("b", 0.3)], + edges=[], + path="generalized", + ) + merged = PostgresGraphStore.merge_results(precise, generalized) + assert len(merged.nodes) == 2 + # Node "a" should have precise score (0.9), not generalized (0.5) + a_node = [n for n in merged.nodes if n.id == "a"][0] + assert a_node.ppr_score == 0.9 + + def test_empty_merge(self): + """Merging two empty results.""" + empty = RecallResult(nodes=[], edges=[], path="precise") + merged = PostgresGraphStore.merge_results(empty, empty) + assert merged.nodes == [] + assert merged.edges == [] + + def test_edge_dedup(self): + """Duplicate edges are deduplicated.""" + e = self._make_edge("a", "b") + r1 = RecallResult(nodes=[], edges=[e], path="precise") + r2 = RecallResult(nodes=[], edges=[e], path="generalized") + merged = PostgresGraphStore.merge_results(r1, r2) + assert len(merged.edges) == 1 + + def test_path_is_merged(self): + """Merged result has path='merged'.""" + r1 = RecallResult(nodes=[], edges=[], path="precise") + r2 = RecallResult(nodes=[], edges=[], path="generalized") + merged = PostgresGraphStore.merge_results(r1, r2) + assert merged.path == "merged" + + +# ── Score fusion tests ──────────────────────────────────────────── + + +class TestScoreFusion: + """Test the score fusion logic used in _rag_build_context.""" + + def test_vector_weight_applied(self): + """Vector scores are scaled by (1 - graph_weight).""" + graph_weight = 0.3 + vector_weight = 1.0 - graph_weight + original_score = 0.8 + fused = original_score * vector_weight + assert abs(fused - 0.56) < 0.01 + + def test_graph_weight_applied(self): + """Graph PPR scores are normalized and scaled by graph_weight.""" + graph_weight = 0.3 + ppr_scores = [0.5, 0.3, 0.1] + max_ppr = max(ppr_scores) + fused = [(ppr / max_ppr) * graph_weight for ppr in ppr_scores] + assert abs(fused[0] - 0.3) < 0.01 # 0.5/0.5 * 0.3 + assert abs(fused[1] - 0.18) < 0.01 # 0.3/0.5 * 0.3 + + def test_zero_graph_weight_no_fusion(self): + """With graph_weight=0, vector scores are unchanged.""" + graph_weight = 0.0 + vector_weight = 1.0 - graph_weight + original = 0.75 + assert original * vector_weight == original + + +# ── RetrieveGraphConfig tests ───────────────────────────────────── + + +class TestRetrieveGraphConfig: + """Test graph config defaults and validation.""" + + def test_defaults(self): + from memu.app.settings import RetrieveGraphConfig + + cfg = RetrieveGraphConfig() + assert cfg.enabled is False + assert cfg.max_nodes == 6 + assert cfg.weight == 0.3 + + def test_custom_values(self): + from memu.app.settings import RetrieveGraphConfig + + cfg = RetrieveGraphConfig(enabled=True, max_nodes=10, weight=0.5) + assert cfg.enabled is True + assert cfg.max_nodes == 10 + assert cfg.weight == 0.5 + + def test_retrieve_config_has_graph(self): + from memu.app.settings import RetrieveConfig + + cfg = RetrieveConfig() + assert hasattr(cfg, "graph") + assert cfg.graph.enabled is False + + def test_retrieve_config_graph_from_dict(self): + from memu.app.settings import RetrieveConfig + + cfg = RetrieveConfig(graph={"enabled": True, "weight": 0.4}) + assert cfg.graph.enabled is True + assert cfg.graph.weight == 0.4 + + +# ── Domain model tests ──────────────────────────────────────────── + + +class TestGraphDomainModels: + """Test base domain models.""" + + def test_graph_node_defaults(self): + from memu.database.models import GraphNode + + node = GraphNode(type="SKILL", name="test", content="body") + assert node.status == "active" + assert node.validated_count == 1 + assert node.source_sessions == [] + assert node.pagerank == 0.0 + assert node.embedding is None + + def test_graph_edge_defaults(self): + from memu.database.models import GraphEdge + + edge = GraphEdge(from_id="a", to_id="b", type="USES") + assert edge.instruction == "" + assert edge.condition is None + + def test_graph_community_defaults(self): + from memu.database.models import GraphCommunity + + c = GraphCommunity() + assert c.node_count == 0 + assert c.summary is None + + +# ── ORM model registration tests ────────────────────────────────── + + +class TestGraphORMModels: + """Test that graph ORM models register correctly in schema.""" + + def test_sqla_models_have_graph_fields_and_table(self): + from pydantic import BaseModel, Field + + class GraphTestScope(BaseModel): + user_id: str = Field(default="") + + from memu.database.postgres.schema import get_sqlalchemy_models + + models = get_sqlalchemy_models(scope_model=GraphTestScope) + assert models.GraphNode is not None + assert models.GraphEdge is not None + assert models.GraphCommunity is not None + # Table names + assert models.GraphNode.__tablename__ == "gm_nodes" + assert models.GraphEdge.__tablename__ == "gm_edges" + assert models.GraphCommunity.__tablename__ == "gm_communities" diff --git a/uv.lock b/uv.lock index 76e7b0c6..24d4c353 100644 --- a/uv.lock +++ b/uv.lock @@ -929,7 +929,7 @@ wheels = [ [[package]] name = "memu-py" -version = "1.5.0" +version = "1.5.1" source = { editable = "." } dependencies = [ { name = "alembic" },