Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 29 additions & 63 deletions src/memu/database/sqlite/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,39 @@
import pendulum
from pydantic import BaseModel
from sqlalchemy import JSON, MetaData, String, Text
from sqlalchemy.types import TypeDecorator
from sqlmodel import Column, DateTime, Field, Index, SQLModel, func

from memu.database.models import CategoryItem, MemoryCategory, MemoryItem, MemoryType, Resource

logger = logging.getLogger(__name__)


class JSONEncodedList(TypeDecorator):
"""Store a list of floats as a JSON-encoded TEXT column.

SQLite has no native vector type, so embeddings are serialized to JSON
strings for storage and deserialized back to ``list[float]`` on read.
"""

impl = Text
cache_ok = True

def process_bind_param(self, value: list[float] | None, dialect: Any) -> str | None:
if value is not None:
return json.dumps(value)
return None

def process_result_value(self, value: str | None, dialect: Any) -> list[float] | None:
if value is not None:
try:
return [float(x) for x in json.loads(value)]
except (json.JSONDecodeError, TypeError, ValueError):
logger.warning("Failed to decode embedding JSON from SQLite")
return None
return None


class TZDateTime(DateTime):
"""DateTime type with timezone support."""

Expand Down Expand Up @@ -52,27 +78,7 @@ class SQLiteResourceModel(SQLiteBaseModelMixin, Resource):
modality: str = Field(sa_column=Column(String, nullable=False))
local_path: str = Field(sa_column=Column(String, nullable=False))
caption: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
# Store embedding as JSON string since SQLite doesn't have native vector type
embedding_json: str | None = Field(default=None, sa_column=Column(Text, nullable=True))

@property
def embedding(self) -> list[float] | None:
"""Parse embedding from JSON string."""
if self.embedding_json is None:
return None
try:
return list(json.loads(self.embedding_json))
except (json.JSONDecodeError, TypeError) as e:
logger.warning("Failed to parse resource embedding JSON: %s", e)
return None

@embedding.setter
def embedding(self, value: list[float] | None) -> None:
"""Serialize embedding to JSON string."""
if value is None:
self.embedding_json = None
else:
self.embedding_json = json.dumps(value)
embedding: list[float] | None = Field(default=None, sa_column=Column(JSONEncodedList(), nullable=True))


class SQLiteMemoryItemModel(SQLiteBaseModelMixin, MemoryItem):
Expand All @@ -81,59 +87,19 @@ class SQLiteMemoryItemModel(SQLiteBaseModelMixin, MemoryItem):
resource_id: str | None = Field(sa_column=Column(String, nullable=True))
memory_type: MemoryType = Field(sa_column=Column(String, nullable=False))
summary: str = Field(sa_column=Column(Text, nullable=False))
# Store embedding as JSON string since SQLite doesn't have native vector type
embedding_json: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
embedding: list[float] | None = Field(default=None, sa_column=Column(JSONEncodedList(), nullable=True))
happened_at: datetime | None = Field(default=None, sa_column=Column(DateTime, nullable=True))
extra: dict[str, Any] = Field(default={}, sa_column=Column(JSON, nullable=True))

@property
def embedding(self) -> list[float] | None:
"""Parse embedding from JSON string."""
if self.embedding_json is None:
return None
try:
return list(json.loads(self.embedding_json))
except (json.JSONDecodeError, TypeError) as e:
logger.warning("Failed to parse memory item embedding JSON: %s", e)
return None

@embedding.setter
def embedding(self, value: list[float] | None) -> None:
"""Serialize embedding to JSON string."""
if value is None:
self.embedding_json = None
else:
self.embedding_json = json.dumps(value)


class SQLiteMemoryCategoryModel(SQLiteBaseModelMixin, MemoryCategory):
"""SQLite memory category model."""

name: str = Field(sa_column=Column(String, nullable=False, index=True))
description: str = Field(sa_column=Column(Text, nullable=False))
# Store embedding as JSON string since SQLite doesn't have native vector type
embedding_json: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
embedding: list[float] | None = Field(default=None, sa_column=Column(JSONEncodedList(), nullable=True))
summary: str | None = Field(default=None, sa_column=Column(Text, nullable=True))

@property
def embedding(self) -> list[float] | None:
"""Parse embedding from JSON string."""
if self.embedding_json is None:
return None
try:
return list(json.loads(self.embedding_json))
except (json.JSONDecodeError, TypeError) as e:
logger.warning("Failed to parse category embedding JSON: %s", e)
return None

@embedding.setter
def embedding(self, value: list[float] | None) -> None:
"""Serialize embedding to JSON string."""
if value is None:
self.embedding_json = None
else:
self.embedding_json = json.dumps(value)


class SQLiteCategoryItemModel(SQLiteBaseModelMixin, CategoryItem):
"""SQLite category-item relation model."""
Expand Down
12 changes: 6 additions & 6 deletions src/memu/database/sqlite/repositories/memory_category_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def list_categories(self, where: Mapping[str, Any] | None = None) -> dict[str, M
id=row.id,
name=row.name,
description=row.description,
embedding=self._normalize_embedding(row.embedding_json),
embedding=row.embedding,
summary=row.summary,
created_at=row.created_at,
updated_at=row.updated_at,
Expand Down Expand Up @@ -104,7 +104,7 @@ def clear_categories(self, where: Mapping[str, Any] | None = None) -> dict[str,
id=row.id,
name=row.name,
description=row.description,
embedding=self._normalize_embedding(row.embedding_json),
embedding=row.embedding,
summary=row.summary,
created_at=row.created_at,
updated_at=row.updated_at,
Expand Down Expand Up @@ -156,7 +156,7 @@ def get_or_create_category(
id=existing.id,
name=existing.name,
description=existing.description,
embedding=self._normalize_embedding(existing.embedding_json),
embedding=existing.embedding,
summary=existing.summary,
created_at=existing.created_at,
updated_at=existing.updated_at,
Expand All @@ -170,7 +170,7 @@ def get_or_create_category(
row = self._memory_category_model(
name=name,
description=description,
embedding_json=self._prepare_embedding(embedding),
embedding=embedding,
summary=None,
created_at=now,
updated_at=now,
Expand Down Expand Up @@ -230,7 +230,7 @@ def update_category(
if description is not None:
row.description = description
if embedding is not None:
row.embedding_json = self._prepare_embedding(embedding)
row.embedding = embedding
if summary is not None:
row.summary = summary
row.updated_at = self._now()
Expand All @@ -243,7 +243,7 @@ def update_category(
id=row.id,
name=row.name,
description=row.description,
embedding=self._normalize_embedding(row.embedding_json),
embedding=row.embedding,
summary=row.summary,
created_at=row.created_at,
updated_at=row.updated_at,
Expand Down
18 changes: 9 additions & 9 deletions src/memu/database/sqlite/repositories/memory_item_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def get_item(self, item_id: str) -> MemoryItem | None:
resource_id=row.resource_id,
memory_type=row.memory_type,
summary=row.summary,
embedding=self._normalize_embedding(row.embedding_json),
embedding=row.embedding,
created_at=row.created_at,
updated_at=row.updated_at,
**self._scope_kwargs_from(row),
Expand Down Expand Up @@ -106,7 +106,7 @@ def list_items(self, where: Mapping[str, Any] | None = None) -> dict[str, Memory
resource_id=row.resource_id,
memory_type=row.memory_type,
summary=row.summary,
embedding=self._normalize_embedding(row.embedding_json),
embedding=row.embedding,
created_at=row.created_at,
updated_at=row.updated_at,
**self._scope_kwargs_from(row),
Expand Down Expand Up @@ -151,7 +151,7 @@ def list_items_by_ref_ids(
resource_id=row.resource_id,
memory_type=row.memory_type,
summary=row.summary,
embedding=self._normalize_embedding(row.embedding_json),
embedding=row.embedding,
created_at=row.created_at,
updated_at=row.updated_at,
**self._scope_kwargs_from(row),
Expand Down Expand Up @@ -185,7 +185,7 @@ def clear_items(self, where: Mapping[str, Any] | None = None) -> dict[str, Memor
resource_id=row.resource_id,
memory_type=row.memory_type,
summary=row.summary,
embedding=self._normalize_embedding(row.embedding_json),
embedding=row.embedding,
created_at=row.created_at,
updated_at=row.updated_at,
**self._scope_kwargs_from(row),
Expand Down Expand Up @@ -257,7 +257,7 @@ def create_item(
resource_id=resource_id,
memory_type=memory_type,
summary=summary,
embedding_json=self._prepare_embedding(embedding),
embedding=embedding,
extra=extra if extra else {},
created_at=now,
updated_at=now,
Expand Down Expand Up @@ -338,7 +338,7 @@ def create_item_reinforce(
resource_id=existing.resource_id,
memory_type=existing.memory_type,
summary=existing.summary,
embedding=self._normalize_embedding(existing.embedding_json),
embedding=existing.embedding,
created_at=existing.created_at,
updated_at=existing.updated_at,
extra=existing.extra,
Expand All @@ -360,7 +360,7 @@ def create_item_reinforce(
resource_id=resource_id,
memory_type=memory_type,
summary=summary,
embedding_json=self._prepare_embedding(embedding),
embedding=embedding,
extra=item_extra,
created_at=now,
updated_at=now,
Expand Down Expand Up @@ -424,7 +424,7 @@ def update_item(
if summary is not None:
row.summary = summary
if embedding is not None:
row.embedding_json = self._prepare_embedding(embedding)
row.embedding = embedding

# Merge extra and tool_record into existing extra dict
current_extra = row.extra or {}
Expand All @@ -449,7 +449,7 @@ def update_item(
resource_id=row.resource_id,
memory_type=row.memory_type,
summary=row.summary,
embedding=self._normalize_embedding(row.embedding_json),
embedding=row.embedding,
extra=row.extra,
created_at=row.created_at,
updated_at=row.updated_at,
Expand Down
6 changes: 3 additions & 3 deletions src/memu/database/sqlite/repositories/resource_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def list_resources(self, where: Mapping[str, Any] | None = None) -> dict[str, Re
modality=row.modality,
local_path=row.local_path,
caption=row.caption,
embedding=self._normalize_embedding(row.embedding_json),
embedding=row.embedding,
created_at=row.created_at,
updated_at=row.updated_at,
**self._scope_kwargs_from(row),
Expand Down Expand Up @@ -111,7 +111,7 @@ def clear_resources(self, where: Mapping[str, Any] | None = None) -> dict[str, R
modality=row.modality,
local_path=row.local_path,
caption=row.caption,
embedding=self._normalize_embedding(row.embedding_json),
embedding=row.embedding,
created_at=row.created_at,
updated_at=row.updated_at,
**self._scope_kwargs_from(row),
Expand Down Expand Up @@ -163,7 +163,7 @@ def create_resource(
modality=modality,
local_path=local_path,
caption=caption,
embedding_json=self._prepare_embedding(embedding),
embedding=embedding,
created_at=now,
updated_at=now,
**user_data,
Expand Down