Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions src/memu/app/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,6 @@ async def _rag_recall_items(self, state: WorkflowState, step_context: Any) -> Wo

store = state["store"]
where_filters = state.get("where") or {}
items_pool = store.memory_item_repo.list_items(where_filters)
qvec = state.get("query_vector")
if qvec is None:
embed_client = self._get_step_embedding_client(step_context)
Expand All @@ -363,6 +362,13 @@ async def _rag_recall_items(self, state: WorkflowState, step_context: Any) -> Wo
ranking=self.retrieve_config.item.ranking,
recency_decay_days=self.retrieve_config.item.recency_decay_days,
)
# Build mini pool from hit IDs only — avoids full table scan of all items
hit_ids = [_id for _id, _ in state["item_hits"]]
items_pool: dict[str, Any] = {}
for _id in hit_ids:
item = store.memory_item_repo.get_item(_id)
if item is not None:
items_pool[_id] = item
state["item_pool"] = items_pool
return state

Expand Down Expand Up @@ -437,7 +443,7 @@ def _rag_build_context(self, state: WorkflowState, _: Any) -> WorkflowState:
store = state["store"]
where_filters = state.get("where") or {}
categories_pool = state.get("category_pool") or store.memory_category_repo.list_categories(where_filters)
items_pool = state.get("item_pool") or store.memory_item_repo.list_items(where_filters)
items_pool = state["item_pool"] if "item_pool" in state else store.memory_item_repo.list_items(where_filters)
resources_pool = state.get("resource_pool") or store.resource_repo.list_resources(where_filters)
response["categories"] = self._materialize_hits(
state.get("category_hits", []),
Expand Down Expand Up @@ -737,7 +743,17 @@ async def _rank_categories_by_summary(
return [], {}
summary_texts = [summary for _, summary in entries]
client = embed_client or self._get_llm_client()
summary_embeddings = await client.embed(summary_texts)

# Cache category summary embeddings — summaries rarely change
if not hasattr(self, "_cat_embed_cache"):
self._cat_embed_cache: dict[tuple[tuple[str, str], ...], list[list[float]]] = {}
cache_key = tuple(sorted(entries))
if cache_key in self._cat_embed_cache:
summary_embeddings = self._cat_embed_cache[cache_key]
else:
summary_embeddings = await client.embed(summary_texts)
self._cat_embed_cache[cache_key] = summary_embeddings

corpus = [(cid, emb) for (cid, _), emb in zip(entries, summary_embeddings, strict=True)]
hits = cosine_topk(query_vec, corpus, k=top_k)
summary_lookup = dict(entries)
Expand Down