diff --git a/src/memu/database/sqlite/models.py b/src/memu/database/sqlite/models.py index 6cdaed49..fae26ead 100644 --- a/src/memu/database/sqlite/models.py +++ b/src/memu/database/sqlite/models.py @@ -11,6 +11,7 @@ import pendulum from pydantic import BaseModel from sqlalchemy import JSON, MetaData, String, Text +from sqlalchemy.types import TypeDecorator from sqlmodel import Column, DateTime, Field, Index, SQLModel, func from memu.database.models import CategoryItem, MemoryCategory, MemoryItem, MemoryType, Resource @@ -18,6 +19,31 @@ logger = logging.getLogger(__name__) +class JSONEncodedList(TypeDecorator): + """Store a list of floats as a JSON-encoded TEXT column. + + SQLite has no native vector type, so embeddings are serialized to JSON + strings for storage and deserialized back to ``list[float]`` on read. + """ + + impl = Text + cache_ok = True + + def process_bind_param(self, value: list[float] | None, dialect: Any) -> str | None: + if value is not None: + return json.dumps(value) + return None + + def process_result_value(self, value: str | None, dialect: Any) -> list[float] | None: + if value is not None: + try: + return [float(x) for x in json.loads(value)] + except (json.JSONDecodeError, TypeError, ValueError): + logger.warning("Failed to decode embedding JSON from SQLite") + return None + return None + + class TZDateTime(DateTime): """DateTime type with timezone support.""" @@ -52,27 +78,7 @@ class SQLiteResourceModel(SQLiteBaseModelMixin, Resource): modality: str = Field(sa_column=Column(String, nullable=False)) local_path: str = Field(sa_column=Column(String, nullable=False)) caption: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) - # Store embedding as JSON string since SQLite doesn't have native vector type - embedding_json: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) - - @property - def embedding(self) -> list[float] | None: - """Parse embedding from JSON string.""" - if self.embedding_json is None: - return None - try: - return list(json.loads(self.embedding_json)) - except (json.JSONDecodeError, TypeError) as e: - logger.warning("Failed to parse resource embedding JSON: %s", e) - return None - - @embedding.setter - def embedding(self, value: list[float] | None) -> None: - """Serialize embedding to JSON string.""" - if value is None: - self.embedding_json = None - else: - self.embedding_json = json.dumps(value) + embedding: list[float] | None = Field(default=None, sa_column=Column(JSONEncodedList(), nullable=True)) class SQLiteMemoryItemModel(SQLiteBaseModelMixin, MemoryItem): @@ -81,59 +87,19 @@ class SQLiteMemoryItemModel(SQLiteBaseModelMixin, MemoryItem): resource_id: str | None = Field(sa_column=Column(String, nullable=True)) memory_type: MemoryType = Field(sa_column=Column(String, nullable=False)) summary: str = Field(sa_column=Column(Text, nullable=False)) - # Store embedding as JSON string since SQLite doesn't have native vector type - embedding_json: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) + embedding: list[float] | None = Field(default=None, sa_column=Column(JSONEncodedList(), nullable=True)) happened_at: datetime | None = Field(default=None, sa_column=Column(DateTime, nullable=True)) extra: dict[str, Any] = Field(default={}, sa_column=Column(JSON, nullable=True)) - @property - def embedding(self) -> list[float] | None: - """Parse embedding from JSON string.""" - if self.embedding_json is None: - return None - try: - return list(json.loads(self.embedding_json)) - except (json.JSONDecodeError, TypeError) as e: - logger.warning("Failed to parse memory item embedding JSON: %s", e) - return None - - @embedding.setter - def embedding(self, value: list[float] | None) -> None: - """Serialize embedding to JSON string.""" - if value is None: - self.embedding_json = None - else: - self.embedding_json = json.dumps(value) - class SQLiteMemoryCategoryModel(SQLiteBaseModelMixin, MemoryCategory): """SQLite memory category model.""" name: str = Field(sa_column=Column(String, nullable=False, index=True)) description: str = Field(sa_column=Column(Text, nullable=False)) - # Store embedding as JSON string since SQLite doesn't have native vector type - embedding_json: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) + embedding: list[float] | None = Field(default=None, sa_column=Column(JSONEncodedList(), nullable=True)) summary: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) - @property - def embedding(self) -> list[float] | None: - """Parse embedding from JSON string.""" - if self.embedding_json is None: - return None - try: - return list(json.loads(self.embedding_json)) - except (json.JSONDecodeError, TypeError) as e: - logger.warning("Failed to parse category embedding JSON: %s", e) - return None - - @embedding.setter - def embedding(self, value: list[float] | None) -> None: - """Serialize embedding to JSON string.""" - if value is None: - self.embedding_json = None - else: - self.embedding_json = json.dumps(value) - class SQLiteCategoryItemModel(SQLiteBaseModelMixin, CategoryItem): """SQLite category-item relation model.""" diff --git a/src/memu/database/sqlite/repositories/memory_category_repo.py b/src/memu/database/sqlite/repositories/memory_category_repo.py index a2cd2c46..4c4c1674 100644 --- a/src/memu/database/sqlite/repositories/memory_category_repo.py +++ b/src/memu/database/sqlite/repositories/memory_category_repo.py @@ -70,7 +70,7 @@ def list_categories(self, where: Mapping[str, Any] | None = None) -> dict[str, M id=row.id, name=row.name, description=row.description, - embedding=self._normalize_embedding(row.embedding_json), + embedding=row.embedding, summary=row.summary, created_at=row.created_at, updated_at=row.updated_at, @@ -104,7 +104,7 @@ def clear_categories(self, where: Mapping[str, Any] | None = None) -> dict[str, id=row.id, name=row.name, description=row.description, - embedding=self._normalize_embedding(row.embedding_json), + embedding=row.embedding, summary=row.summary, created_at=row.created_at, updated_at=row.updated_at, @@ -156,7 +156,7 @@ def get_or_create_category( id=existing.id, name=existing.name, description=existing.description, - embedding=self._normalize_embedding(existing.embedding_json), + embedding=existing.embedding, summary=existing.summary, created_at=existing.created_at, updated_at=existing.updated_at, @@ -170,7 +170,7 @@ def get_or_create_category( row = self._memory_category_model( name=name, description=description, - embedding_json=self._prepare_embedding(embedding), + embedding=embedding, summary=None, created_at=now, updated_at=now, @@ -230,7 +230,7 @@ def update_category( if description is not None: row.description = description if embedding is not None: - row.embedding_json = self._prepare_embedding(embedding) + row.embedding = embedding if summary is not None: row.summary = summary row.updated_at = self._now() @@ -243,7 +243,7 @@ def update_category( id=row.id, name=row.name, description=row.description, - embedding=self._normalize_embedding(row.embedding_json), + embedding=row.embedding, summary=row.summary, created_at=row.created_at, updated_at=row.updated_at, diff --git a/src/memu/database/sqlite/repositories/memory_item_repo.py b/src/memu/database/sqlite/repositories/memory_item_repo.py index 0bff124e..4172d551 100644 --- a/src/memu/database/sqlite/repositories/memory_item_repo.py +++ b/src/memu/database/sqlite/repositories/memory_item_repo.py @@ -75,7 +75,7 @@ def get_item(self, item_id: str) -> MemoryItem | None: resource_id=row.resource_id, memory_type=row.memory_type, summary=row.summary, - embedding=self._normalize_embedding(row.embedding_json), + embedding=row.embedding, created_at=row.created_at, updated_at=row.updated_at, **self._scope_kwargs_from(row), @@ -106,7 +106,7 @@ def list_items(self, where: Mapping[str, Any] | None = None) -> dict[str, Memory resource_id=row.resource_id, memory_type=row.memory_type, summary=row.summary, - embedding=self._normalize_embedding(row.embedding_json), + embedding=row.embedding, created_at=row.created_at, updated_at=row.updated_at, **self._scope_kwargs_from(row), @@ -151,7 +151,7 @@ def list_items_by_ref_ids( resource_id=row.resource_id, memory_type=row.memory_type, summary=row.summary, - embedding=self._normalize_embedding(row.embedding_json), + embedding=row.embedding, created_at=row.created_at, updated_at=row.updated_at, **self._scope_kwargs_from(row), @@ -185,7 +185,7 @@ def clear_items(self, where: Mapping[str, Any] | None = None) -> dict[str, Memor resource_id=row.resource_id, memory_type=row.memory_type, summary=row.summary, - embedding=self._normalize_embedding(row.embedding_json), + embedding=row.embedding, created_at=row.created_at, updated_at=row.updated_at, **self._scope_kwargs_from(row), @@ -257,7 +257,7 @@ def create_item( resource_id=resource_id, memory_type=memory_type, summary=summary, - embedding_json=self._prepare_embedding(embedding), + embedding=embedding, extra=extra if extra else {}, created_at=now, updated_at=now, @@ -338,7 +338,7 @@ def create_item_reinforce( resource_id=existing.resource_id, memory_type=existing.memory_type, summary=existing.summary, - embedding=self._normalize_embedding(existing.embedding_json), + embedding=existing.embedding, created_at=existing.created_at, updated_at=existing.updated_at, extra=existing.extra, @@ -360,7 +360,7 @@ def create_item_reinforce( resource_id=resource_id, memory_type=memory_type, summary=summary, - embedding_json=self._prepare_embedding(embedding), + embedding=embedding, extra=item_extra, created_at=now, updated_at=now, @@ -424,7 +424,7 @@ def update_item( if summary is not None: row.summary = summary if embedding is not None: - row.embedding_json = self._prepare_embedding(embedding) + row.embedding = embedding # Merge extra and tool_record into existing extra dict current_extra = row.extra or {} @@ -449,7 +449,7 @@ def update_item( resource_id=row.resource_id, memory_type=row.memory_type, summary=row.summary, - embedding=self._normalize_embedding(row.embedding_json), + embedding=row.embedding, extra=row.extra, created_at=row.created_at, updated_at=row.updated_at, diff --git a/src/memu/database/sqlite/repositories/resource_repo.py b/src/memu/database/sqlite/repositories/resource_repo.py index 7777eefd..ebeb9cbb 100644 --- a/src/memu/database/sqlite/repositories/resource_repo.py +++ b/src/memu/database/sqlite/repositories/resource_repo.py @@ -76,7 +76,7 @@ def list_resources(self, where: Mapping[str, Any] | None = None) -> dict[str, Re modality=row.modality, local_path=row.local_path, caption=row.caption, - embedding=self._normalize_embedding(row.embedding_json), + embedding=row.embedding, created_at=row.created_at, updated_at=row.updated_at, **self._scope_kwargs_from(row), @@ -111,7 +111,7 @@ def clear_resources(self, where: Mapping[str, Any] | None = None) -> dict[str, R modality=row.modality, local_path=row.local_path, caption=row.caption, - embedding=self._normalize_embedding(row.embedding_json), + embedding=row.embedding, created_at=row.created_at, updated_at=row.updated_at, **self._scope_kwargs_from(row), @@ -163,7 +163,7 @@ def create_resource( modality=modality, local_path=local_path, caption=caption, - embedding_json=self._prepare_embedding(embedding), + embedding=embedding, created_at=now, updated_at=now, **user_data,