diff --git a/app.py b/app.py index 0e14a50..0927809 100644 --- a/app.py +++ b/app.py @@ -267,6 +267,9 @@ def require_api_token() -> None: if request.path.startswith("/viewer"): return + if request.path == "/backup": + return + _require_api_token_helper( request_obj=request, api_token=API_TOKEN, diff --git a/automem/api/backup.py b/automem/api/backup.py new file mode 100644 index 0000000..007fd5e --- /dev/null +++ b/automem/api/backup.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import hashlib +import json +import time +from typing import Any, Callable + +from flask import Blueprint, Response, abort, request + +from automem.backup import ( + InvalidBackupInclude, + backup_timestamp, + parse_backup_include, + stream_backup_tar_gz, +) + + +def _admin_key_fingerprint() -> str | None: + token = ( + request.headers.get("X-Admin-Token") + or request.headers.get("X-Admin-Api-Key") + or request.args.get("admin_token") + ) + if not token: + return None + return hashlib.sha256(token.encode("utf-8")).hexdigest()[:12] + + +def create_backup_blueprint( + require_admin_token: Callable[[], None], + get_memory_graph: Callable[[], Any], + get_qdrant_client: Callable[[], Any], + graph_name: str, + collection_name: str, + logger: Any, +) -> Blueprint: + bp = Blueprint("backup", __name__) + + @bp.route("/backup", methods=["GET"]) + def backup() -> Response: + require_admin_token() + + try: + includes = parse_backup_include(request.args.get("include")) + except InvalidBackupInclude as exc: + abort(400, description=str(exc)) + + graph = get_memory_graph() if "falkordb" in includes else None + if "falkordb" in includes and graph is None: + abort(503, description="FalkorDB is unavailable") + + qdrant_client = get_qdrant_client() if "qdrant" in includes else None + if "qdrant" in includes and qdrant_client is None: + abort(503, description="Qdrant is unavailable") + + started = time.perf_counter() + timestamp = backup_timestamp() + audit_base = { + "event": "backup.request", + "key_fingerprint": _admin_key_fingerprint(), + "include": list(includes), + "timestamp": timestamp, + } + + def audit_complete(stats: dict[str, Any]) -> None: + artifacts = stats.get("artifacts") or {} + falkordb_stats = artifacts.get("falkordb") or {} + qdrant_stats = artifacts.get("qdrant") or {} + audit = { + **audit_base, + "status": stats.get("status"), + "byte_count": stats.get("bytes", 0), + "duration_ms": round((time.perf_counter() - started) * 1000, 2), + "node_count": falkordb_stats.get("node_count"), + "relationship_count": falkordb_stats.get("relationship_count"), + "point_count": qdrant_stats.get("points_count"), + } + if stats.get("error"): + audit["error"] = stats["error"] + logger.info("backup.request %s", json.dumps(audit, sort_keys=True)) + + stream = stream_backup_tar_gz( + includes=includes, + timestamp=timestamp, + graph=graph, + graph_name=graph_name, + qdrant_client=qdrant_client, + collection_name=collection_name, + logger=logger, + on_complete=audit_complete, + ) + + return Response( + stream, + mimetype="application/gzip", + headers={ + "Content-Disposition": f'attachment; filename="automem-backup-{timestamp}.tar.gz"', + "Cache-Control": "no-store", + }, + ) + + return bp diff --git a/automem/api/runtime_bootstrap.py b/automem/api/runtime_bootstrap.py index d99495c..88d929e 100644 --- a/automem/api/runtime_bootstrap.py +++ b/automem/api/runtime_bootstrap.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Optional from automem.api.admin import create_admin_blueprint_full +from automem.api.backup import create_backup_blueprint from automem.api.consolidation import create_consolidation_blueprint_full from automem.api.enrichment import create_enrichment_blueprint from automem.api.graph import create_graph_blueprint @@ -147,6 +148,15 @@ def register_blueprints( logger, ) + backup_bp = create_backup_blueprint( + require_admin_token_fn, + get_memory_graph_fn, + get_qdrant_client_fn, + graph_name, + collection_name, + logger, + ) + consolidation_bp = create_consolidation_blueprint_full( get_memory_graph_fn, get_qdrant_client_fn, @@ -176,6 +186,7 @@ def register_blueprints( app.register_blueprint(enrichment_bp) app.register_blueprint(memory_bp) app.register_blueprint(admin_bp) + app.register_blueprint(backup_bp) app.register_blueprint(recall_bp) app.register_blueprint(consolidation_bp) app.register_blueprint(graph_bp) diff --git a/automem/backup.py b/automem/backup.py new file mode 100644 index 0000000..2cc40b6 --- /dev/null +++ b/automem/backup.py @@ -0,0 +1,425 @@ +from __future__ import annotations + +import gzip +import io +import json +import queue +import tarfile +import threading +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Callable, Iterable, Iterator, Optional + +VALID_BACKUP_INCLUDES = ("falkordb", "qdrant") + + +class BackupError(RuntimeError): + """Raised when backup creation fails.""" + + +class InvalidBackupInclude(ValueError): + """Raised when the backup include query parameter is invalid.""" + + +@dataclass(frozen=True) +class BackupArtifact: + service: str + member_name: str + data: bytes + stats: dict[str, Any] + + +@dataclass(frozen=True) +class BackupFile: + service: str + path: Path + stats: dict[str, Any] + + +def backup_timestamp() -> str: + return datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") + + +def parse_backup_include(raw_include: Optional[str]) -> tuple[str, ...]: + """Parse a comma-separated include list, defaulting to both stores when absent.""" + if raw_include is None: + return VALID_BACKUP_INCLUDES + + parts = [part.strip().lower() for part in raw_include.split(",")] + includes = tuple(include for include in VALID_BACKUP_INCLUDES if include in parts) + invalid = [part for part in parts if part and part not in VALID_BACKUP_INCLUDES] + + if invalid or not includes or any(not part for part in parts): + valid = ",".join(VALID_BACKUP_INCLUDES) + raise InvalidBackupInclude(f"include must be a comma-separated subset of: {valid}") + + return includes + + +def _gzip_json(data: dict[str, Any]) -> bytes: + buffer = io.BytesIO() + with gzip.GzipFile(fileobj=buffer, mode="wb") as gz: + with io.TextIOWrapper(gz, encoding="utf-8") as writer: + json.dump(data, writer, indent=2, default=str) + return buffer.getvalue() + + +def _query_rows(result: Any) -> list[Any]: + return list(getattr(result, "result_set", []) or []) + + +def export_falkordb_artifact( + *, + graph: Any, + graph_name: str, + timestamp: str, + batch_size: int = 10000, + logger: Any = None, +) -> BackupArtifact: + """Export FalkorDB graph data as a compressed JSON backup artifact.""" + if graph is None: + raise BackupError("FalkorDB is unavailable") + + nodes: list[dict[str, Any]] = [] + offset = 0 + + while True: + rows = _query_rows( + graph.query( + f""" + MATCH (n) + RETURN + id(n) as id, + labels(n) as labels, + properties(n) as props + SKIP {offset} LIMIT {batch_size} + """ + ) + ) + if not rows: + break + + for row in rows: + nodes.append({"id": row[0], "labels": row[1], "properties": row[2]}) + + if logger: + logger.info( + "Exported FalkorDB node batch: %d nodes (total: %d)", + len(rows), + len(nodes), + ) + if len(rows) < batch_size: + break + offset += batch_size + + relationships: list[dict[str, Any]] = [] + offset = 0 + + while True: + rows = _query_rows( + graph.query( + f""" + MATCH (a)-[r]->(b) + RETURN + id(a) as source_id, + type(r) as rel_type, + id(b) as target_id, + properties(r) as props + SKIP {offset} LIMIT {batch_size} + """ + ) + ) + if not rows: + break + + for row in rows: + relationships.append( + { + "source_id": row[0], + "type": row[1], + "target_id": row[2], + "properties": row[3], + } + ) + + if logger: + logger.info( + "Exported FalkorDB relationship batch: %d relationships (total: %d)", + len(rows), + len(relationships), + ) + if len(rows) < batch_size: + break + offset += batch_size + + stats = { + "node_count": len(nodes), + "relationship_count": len(relationships), + } + backup_data = { + "timestamp": timestamp, + "graph_name": graph_name, + "nodes": nodes, + "relationships": relationships, + "stats": stats, + } + return BackupArtifact( + service="falkordb", + member_name=f"falkordb/falkordb_{timestamp}.json.gz", + data=_gzip_json(backup_data), + stats=stats, + ) + + +def _vector_size_from_collection_info(collection_info: Any) -> Any: + vectors = getattr( + getattr(getattr(collection_info, "config", None), "params", None), + "vectors", + None, + ) + if isinstance(vectors, dict): + first = next(iter(vectors.values()), None) + return getattr(first, "size", None) + return getattr(vectors, "size", None) + + +def export_qdrant_artifact( + *, + client: Any, + collection_name: str, + timestamp: str, + batch_size: int = 100, + logger: Any = None, +) -> BackupArtifact: + """Export Qdrant collection data as a compressed JSON backup artifact.""" + if client is None: + raise BackupError("Qdrant is unavailable") + + all_points: list[dict[str, Any]] = [] + offset = None + + while True: + points, next_offset = client.scroll( + collection_name=collection_name, + limit=batch_size, + offset=offset, + with_payload=True, + with_vectors=True, + ) + + for point in points: + all_points.append( + { + "id": point.id, + "vector": point.vector, + "payload": point.payload, + } + ) + + if logger and points: + logger.info( + "Exported Qdrant point batch: %d points (total: %d)", + len(points), + len(all_points), + ) + + if next_offset is None: + break + offset = next_offset + + collection_info = client.get_collection(collection_name) + stats = { + "points_count": len(all_points), + "vector_size": _vector_size_from_collection_info(collection_info), + } + backup_data = { + "timestamp": timestamp, + "collection_name": collection_name, + "points": all_points, + "stats": stats, + } + return BackupArtifact( + service="qdrant", + member_name=f"qdrant/qdrant_{timestamp}.json.gz", + data=_gzip_json(backup_data), + stats=stats, + ) + + +def create_backup_artifacts( + *, + includes: Iterable[str], + timestamp: str, + graph: Any = None, + graph_name: str = "memories", + qdrant_client: Any = None, + collection_name: str = "memories", + logger: Any = None, +) -> list[BackupArtifact]: + artifacts: list[BackupArtifact] = [] + include_set = set(includes) + + if "falkordb" in include_set: + artifacts.append( + export_falkordb_artifact( + graph=graph, + graph_name=graph_name, + timestamp=timestamp, + logger=logger, + ) + ) + if "qdrant" in include_set: + artifacts.append( + export_qdrant_artifact( + client=qdrant_client, + collection_name=collection_name, + timestamp=timestamp, + logger=logger, + ) + ) + + return artifacts + + +def write_backup_artifact(backup_dir: Path, artifact: BackupArtifact) -> BackupFile: + path = backup_dir / artifact.member_name + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(artifact.data) + return BackupFile(service=artifact.service, path=path, stats=artifact.stats) + + +def write_falkordb_backup_file( + *, + backup_dir: Path, + graph: Any, + graph_name: str, + timestamp: str, + logger: Any = None, +) -> BackupFile: + return write_backup_artifact( + backup_dir, + export_falkordb_artifact( + graph=graph, + graph_name=graph_name, + timestamp=timestamp, + logger=logger, + ), + ) + + +def write_qdrant_backup_file( + *, + backup_dir: Path, + qdrant_client: Any, + collection_name: str, + timestamp: str, + logger: Any = None, +) -> BackupFile: + return write_backup_artifact( + backup_dir, + export_qdrant_artifact( + client=qdrant_client, + collection_name=collection_name, + timestamp=timestamp, + logger=logger, + ), + ) + + +def cleanup_old_backup_files(*, backup_dir: Path, keep: int, logger: Any = None) -> None: + for backup_type in VALID_BACKUP_INCLUDES: + backup_path = backup_dir / backup_type + backup_files = sorted( + backup_path.glob("*.json.gz"), + key=lambda p: p.stat().st_mtime, + reverse=True, + ) + for old_file in backup_files[keep:]: + if logger: + logger.info("Removing old %s backup: %s", backup_type, old_file.name) + old_file.unlink() + + if logger: + kept = min(len(backup_files), keep) + removed = max(0, len(backup_files) - keep) + logger.info("%s backup cleanup: kept %d, removed %d", backup_type, kept, removed) + + +class _QueueWriter: + def __init__(self, output_queue: "queue.Queue[Any]") -> None: + self.output_queue = output_queue + self.bytes_written = 0 + + def write(self, data: bytes) -> int: + if data: + chunk = bytes(data) + self.bytes_written += len(chunk) + self.output_queue.put(chunk) + return len(data) + + def flush(self) -> None: + return None + + +def _add_artifact_to_tar(tar: tarfile.TarFile, artifact: BackupArtifact) -> None: + info = tarfile.TarInfo(artifact.member_name) + info.size = len(artifact.data) + info.mtime = int(datetime.now(timezone.utc).timestamp()) + tar.addfile(info, io.BytesIO(artifact.data)) + + +def stream_backup_tar_gz( + *, + includes: Iterable[str], + timestamp: str, + graph: Any = None, + graph_name: str = "memories", + qdrant_client: Any = None, + collection_name: str = "memories", + logger: Any = None, + on_complete: Optional[Callable[[dict[str, Any]], None]] = None, +) -> Iterator[bytes]: + """Stream a tar.gz archive containing restore-compatible backup files.""" + output_queue: "queue.Queue[Any]" = queue.Queue() + + def worker() -> None: + writer = _QueueWriter(output_queue) + stats: dict[str, Any] = { + "status": "complete", + "bytes": 0, + "artifacts": {}, + } + try: + with tarfile.open(fileobj=writer, mode="w|gz") as tar: + artifacts = create_backup_artifacts( + includes=includes, + timestamp=timestamp, + graph=graph, + graph_name=graph_name, + qdrant_client=qdrant_client, + collection_name=collection_name, + logger=logger, + ) + for artifact in artifacts: + _add_artifact_to_tar(tar, artifact) + stats["artifacts"][artifact.service] = artifact.stats + except Exception as exc: # pragma: no cover - exercised by Flask streaming internals + stats["status"] = "failed" + stats["error"] = str(exc) + output_queue.put(exc) + finally: + stats["bytes"] = writer.bytes_written + if on_complete: + on_complete(stats) + output_queue.put(None) + + threading.Thread(target=worker, daemon=True).start() + + while True: + item = output_queue.get() + if item is None: + break + if isinstance(item, Exception): + raise item + yield item diff --git a/docs/API.md b/docs/API.md index ab9a130..6976fff 100644 --- a/docs/API.md +++ b/docs/API.md @@ -1,6 +1,6 @@ # AutoMem API Reference -This document lists the primary API endpoints and examples. All JSON responses include `status` and primary payload fields for LLM-friendliness. +This document lists the primary API endpoints and examples. JSON responses include `status` and primary payload fields for LLM-friendliness unless an endpoint explicitly returns a binary payload. Authentication @@ -9,6 +9,7 @@ Authentication - `X-API-Key: ` header - `?api_key=` query parameter - Admin endpoints additionally require `X-Admin-Token: ` header. +- `GET /backup` is admin-only and does not accept the regular API token by itself. Health @@ -109,6 +110,13 @@ Enrichment Admin +- GET `/backup` + - Headers: requires `X-Admin-Token: ` or `X-Admin-Api-Key: `. + - Query: optional `include=falkordb,qdrant`; defaults to both. Use `include=falkordb` or `include=qdrant` for a partial export. + - Response: binary `application/gzip` attachment named `automem-backup-.tar.gz`. + - Archive contents are restore-compatible: `falkordb/falkordb_.json.gz` and/or `qdrant/qdrant_.json.gz`. + - Example: `curl -H "X-Admin-Token: $ADMIN_API_TOKEN" "$AUTOMEM_API_URL/backup" -o snapshot.tar.gz` + - POST `/admin/reembed` - Headers: requires both API and Admin tokens. - Body: `{ "batch_size": 32, "limit": 100, "force": false }` diff --git a/docs/MONITORING_AND_BACKUPS.md b/docs/MONITORING_AND_BACKUPS.md index 18c3cdb..f317d20 100644 --- a/docs/MONITORING_AND_BACKUPS.md +++ b/docs/MONITORING_AND_BACKUPS.md @@ -146,6 +146,20 @@ python scripts/health_monitor.py \ For portable backups that cover both databases, use the `backup_automem.py` script: +#### API Backup Export + +When AutoMem itself has internal-network access to FalkorDB and Qdrant, operators can pull a full portable backup through the API without opening database proxies: + +```bash +curl -H "X-Admin-Token: $ADMIN_API_TOKEN" \ + "$AUTOMEM_API_URL/backup" \ + -o snapshot.tar.gz + +python scripts/restore_from_backup.py --backup-dir snapshot.tar.gz --force +``` + +`GET /backup` requires the admin token because it exports the full corpus. Add `?include=falkordb` or `?include=qdrant` to export only one store. + #### Local Backups (Development) The `backup_automem.py` script exports both FalkorDB and Qdrant to compressed JSON files: diff --git a/scripts/backup_automem.py b/scripts/backup_automem.py index d79b524..82ecfc9 100755 --- a/scripts/backup_automem.py +++ b/scripts/backup_automem.py @@ -16,7 +16,6 @@ """ import argparse -import gzip import json import logging import os @@ -29,6 +28,12 @@ from falkordb import FalkorDB from qdrant_client import QdrantClient +from automem.backup import ( + cleanup_old_backup_files, + write_falkordb_backup_file, + write_qdrant_backup_file, +) + # Load environment load_dotenv() load_dotenv(Path.home() / ".config" / "automem" / ".env") @@ -53,7 +58,7 @@ class AutoMemBackup: - """Handles backup and restoration of AutoMem data.""" + """Handles backup of AutoMem data.""" def __init__(self, backup_dir: Path, s3_bucket: Optional[str] = None): self.backup_dir = backup_dir @@ -78,102 +83,21 @@ def backup_falkordb(self) -> Path: ) graph = db.select_graph(FALKORDB_GRAPH) - # Export all nodes (using LIMIT to handle large graphs in batches) - # Note: FalkorDB has a default result limit, so we need to paginate - nodes = [] - batch_size = 10000 - offset = 0 - - while True: - nodes_result = graph.query( - f""" - MATCH (n) - RETURN - id(n) as id, - labels(n) as labels, - properties(n) as props - SKIP {offset} LIMIT {batch_size} - """ - ) - - if not nodes_result.result_set: - break - - batch_count = 0 - for row in nodes_result.result_set: - nodes.append({"id": row[0], "labels": row[1], "properties": row[2]}) - batch_count += 1 - - logger.info(f" Exported batch: {batch_count} nodes (total: {len(nodes)})") - - if batch_count < batch_size: - break # Last batch - - offset += batch_size - - # Export all relationships (using LIMIT to handle large graphs in batches) - # Note: FalkorDB has a default result limit, so we need to paginate - relationships = [] - batch_size = 10000 - offset = 0 - - while True: - rels_result = graph.query( - f""" - MATCH (a)-[r]->(b) - RETURN - id(a) as source_id, - type(r) as rel_type, - id(b) as target_id, - properties(r) as props - SKIP {offset} LIMIT {batch_size} - """ - ) - - if not rels_result.result_set: - break - - batch_count = 0 - for row in rels_result.result_set: - relationships.append( - { - "source_id": row[0], - "type": row[1], - "target_id": row[2], - "properties": row[3], - } - ) - batch_count += 1 - - logger.info( - f" Exported batch: {batch_count} relationships (total: {len(relationships)})" - ) - - if batch_count < batch_size: - break # Last batch - - offset += batch_size - - # Create backup data - backup_data = { - "timestamp": self.timestamp, - "graph_name": FALKORDB_GRAPH, - "nodes": nodes, - "relationships": relationships, - "stats": { - "node_count": len(nodes), - "relationship_count": len(relationships), - }, - } - - # Write to compressed file - backup_file = self.backup_dir / "falkordb" / f"falkordb_{self.timestamp}.json.gz" - with gzip.open(backup_file, "wt", encoding="utf-8") as f: - json.dump(backup_data, f, indent=2, default=str) - + backup = write_falkordb_backup_file( + backup_dir=self.backup_dir, + graph=graph, + graph_name=FALKORDB_GRAPH, + timestamp=self.timestamp, + logger=logger, + ) + backup_file = backup.path size_mb = backup_file.stat().st_size / 1024 / 1024 logger.info(f"โœ… FalkorDB backup saved: {backup_file.name} ({size_mb:.2f} MB)") - logger.info(f" Nodes: {len(nodes)}, Relationships: {len(relationships)}") + logger.info( + " Nodes: %d, Relationships: %d", + backup.stats["node_count"], + backup.stats["relationship_count"], + ) return backup_file @@ -188,55 +112,17 @@ def backup_qdrant(self) -> Path: try: client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY) - # Fetch all points - all_points = [] - offset = None - batch_size = 100 - - while True: - result = client.scroll( - collection_name=QDRANT_COLLECTION, - limit=batch_size, - offset=offset, - with_payload=True, - with_vectors=True, - ) - - points, next_offset = result - - for point in points: - all_points.append( - { - "id": point.id, - "vector": point.vector, - "payload": point.payload, - } - ) - - if next_offset is None: - break - offset = next_offset - - # Create backup data - collection_info = client.get_collection(QDRANT_COLLECTION) - backup_data = { - "timestamp": self.timestamp, - "collection_name": QDRANT_COLLECTION, - "points": all_points, - "stats": { - "points_count": len(all_points), - "vector_size": collection_info.config.params.vectors.size, - }, - } - - # Write to compressed file - backup_file = self.backup_dir / "qdrant" / f"qdrant_{self.timestamp}.json.gz" - with gzip.open(backup_file, "wt", encoding="utf-8") as f: - json.dump(backup_data, f, indent=2, default=str) - + backup = write_qdrant_backup_file( + backup_dir=self.backup_dir, + qdrant_client=client, + collection_name=QDRANT_COLLECTION, + timestamp=self.timestamp, + logger=logger, + ) + backup_file = backup.path size_mb = backup_file.stat().st_size / 1024 / 1024 logger.info(f"โœ… Qdrant backup saved: {backup_file.name} ({size_mb:.2f} MB)") - logger.info(f" Points: {len(all_points)}") + logger.info(" Points: %d", backup.stats["points_count"]) return backup_file @@ -268,25 +154,7 @@ def upload_to_s3(self, file_path: Path): def cleanup_old_backups(self, keep: int = 7): """Remove old backup files, keeping only the most recent N.""" logger.info(f"๐Ÿงน Cleaning up old backups (keeping last {keep})...") - - for backup_type in ["falkordb", "qdrant"]: - backup_path = self.backup_dir / backup_type - - # Get all backup files sorted by modification time - backup_files = sorted( - backup_path.glob("*.json.gz"), - key=lambda p: p.stat().st_mtime, - reverse=True, - ) - - # Remove old files - for old_file in backup_files[keep:]: - logger.info(f" ๐Ÿ—‘๏ธ Removing old backup: {old_file.name}") - old_file.unlink() - - kept = min(len(backup_files), keep) - removed = max(0, len(backup_files) - keep) - logger.info(f" โœ… {backup_type}: kept {kept}, removed {removed}") + cleanup_old_backup_files(backup_dir=self.backup_dir, keep=keep, logger=logger) def run_backup(self, cleanup: bool = False, keep: int = 7) -> Dict[str, Any]: """Run full backup process.""" diff --git a/scripts/restore_from_backup.py b/scripts/restore_from_backup.py index d9937df..41971a3 100755 --- a/scripts/restore_from_backup.py +++ b/scripts/restore_from_backup.py @@ -10,6 +10,9 @@ # Restore from specific backup python scripts/restore_from_backup.py --backup-timestamp 20251019_085625 + # Restore from downloaded API backup tarball + python scripts/restore_from_backup.py --backup-dir snapshot.tar.gz + # Dry run (show what would be restored) python scripts/restore_from_backup.py --dry-run @@ -23,9 +26,12 @@ import logging import os import sys +import tarfile +import tempfile +from contextlib import contextmanager from datetime import datetime, timezone from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, Iterator from dotenv import load_dotenv from falkordb import FalkorDB @@ -55,6 +61,36 @@ QDRANT_COLLECTION = os.getenv("QDRANT_COLLECTION", "memories") +def _is_tar_gz(path: Path) -> bool: + name = path.name.lower() + return name.endswith(".tar.gz") or name.endswith(".tgz") + + +def _safe_extract_tar_gz(archive_path: Path, target_dir: Path) -> None: + target_root = target_dir.resolve() + with tarfile.open(archive_path, "r:gz") as archive: + for member in archive.getmembers(): + if member.issym() or member.islnk(): + raise ValueError(f"Refusing to extract link from backup archive: {member.name}") + destination = (target_root / member.name).resolve() + if not destination.is_relative_to(target_root): + raise ValueError(f"Refusing to extract unsafe backup path: {member.name}") + archive.extractall(target_root, filter="data") + + +@contextmanager +def resolve_backup_dir(backup_path: Path) -> Iterator[Path]: + """Resolve a backup directory or extract a downloaded tar.gz backup to a temp dir.""" + if backup_path.is_file() and _is_tar_gz(backup_path): + with tempfile.TemporaryDirectory(prefix="automem-restore-") as temp_dir: + extracted = Path(temp_dir) + _safe_extract_tar_gz(backup_path, extracted) + yield extracted + return + + yield backup_path + + class AutoMemRestore: """Handles restoration of AutoMem data from backups.""" @@ -461,6 +497,9 @@ def main(): # Restore from specific timestamp python scripts/restore_from_backup.py --backup-timestamp 20251019_085625 + # Restore from downloaded API backup tarball + python scripts/restore_from_backup.py --backup-dir snapshot.tar.gz + # Dry run (preview only) python scripts/restore_from_backup.py --dry-run @@ -478,7 +517,7 @@ def main(): "--backup-dir", type=str, default=str(BACKUP_DIR), - help="Directory containing backup files (default: ./backups)", + help="Directory containing backup files or downloaded .tar.gz (default: ./backups)", ) parser.add_argument( "--backup-timestamp", @@ -510,19 +549,19 @@ def main(): args = parser.parse_args() - restore = AutoMemRestore( - backup_dir=Path(args.backup_dir), - dry_run=args.dry_run, - force=args.force, - merge=args.merge, - ) - try: - results = restore.run_restore( - timestamp=args.backup_timestamp, - falkordb_only=args.falkordb_only, - qdrant_only=args.qdrant_only, - ) + with resolve_backup_dir(Path(args.backup_dir)) as backup_dir: + restore = AutoMemRestore( + backup_dir=backup_dir, + dry_run=args.dry_run, + force=args.force, + merge=args.merge, + ) + results = restore.run_restore( + timestamp=args.backup_timestamp, + falkordb_only=args.falkordb_only, + qdrant_only=args.qdrant_only, + ) print(json.dumps(results, indent=2)) sys.exit(0) except Exception as e: diff --git a/tests/contracts/test_routes_contract.py b/tests/contracts/test_routes_contract.py index db6e956..a4f7003 100644 --- a/tests/contracts/test_routes_contract.py +++ b/tests/contracts/test_routes_contract.py @@ -18,6 +18,7 @@ ("GET", "/memories//related"), ("POST", "/admin/reembed"), ("POST", "/admin/sync"), + ("GET", "/backup"), ("POST", "/consolidate"), ("GET", "/consolidate/status"), ("GET", "/enrichment/status"), diff --git a/tests/support/fake_graph.py b/tests/support/fake_graph.py index 2b82422..a8fc168 100644 --- a/tests/support/fake_graph.py +++ b/tests/support/fake_graph.py @@ -25,6 +25,13 @@ def _returns_whole_memory_node(query: str) -> bool: return re.search(r"\bRETURN\s+m\b(?![\w.])", query) is not None +def _skip_limit(query: str) -> tuple[int, int | None]: + match = re.search(r"\bSKIP\s+(\d+)\s+LIMIT\s+(\d+)", query) + if not match: + return 0, None + return int(match.group(1)), int(match.group(2)) + + class FakeGraph: """Shared fake FalkorDB graph used across unit tests. @@ -171,6 +178,43 @@ def query(self, query: str, params: Dict[str, Any] | None = None, **kwargs: Any) ) return FakeResult([]) + # Full graph backup export + if "MATCH (n)" in query and "id(n) as id" in query and "properties(n) as props" in query: + memory_items = sorted(self.memories.items(), key=lambda item: item[0]) + rows = [ + [index, ["Memory"], dict(memory)] + for index, (_memory_id, memory) in enumerate(memory_items) + ] + skip, limit = _skip_limit(query) + return FakeResult(rows[skip : None if limit is None else skip + limit]) + + if ( + "MATCH (a)-[r]->(b)" in query + and "id(a) as source_id" in query + and "properties(r) as props" in query + ): + memory_ids = [memory_id for memory_id, _memory in sorted(self.memories.items())] + backup_ids = {memory_id: index for index, memory_id in enumerate(memory_ids)} + rows = [] + for rel in self.relationships: + source_id = str(rel.get("id1") or "") + target_id = str(rel.get("id2") or "") + if source_id not in backup_ids or target_id not in backup_ids: + continue + props = { + key: value for key, value in rel.items() if key not in {"id1", "id2", "type"} + } + rows.append( + [ + backup_ids[source_id], + str(rel.get("type") or "RELATES_TO"), + backup_ids[target_id], + props, + ] + ) + skip, limit = _skip_limit(query) + return FakeResult(rows[skip : None if limit is None else skip + limit]) + # Memory create/upsert if "MERGE (m:Memory {id:" in query or "CREATE (m:Memory {id:" in query: memory_id = str(params["id"]) @@ -355,7 +399,13 @@ def query(self, query: str, params: Dict[str, Any] | None = None, **kwargs: Any) if isinstance(tag, str) and tag.strip() ] if any(tag in {"system", "memory-recall"} for tag in tags): - rows.append([memory.get("id"), memory.get("content"), memory.get("tags", [])]) + rows.append( + [ + memory.get("id"), + memory.get("content"), + memory.get("tags", []), + ] + ) return FakeResult(rows[:5]) # Association creation @@ -446,17 +496,20 @@ def _importance(memory: Dict[str, Any]) -> float: results.sort(key=lambda memory: (_timestamp_key(memory), -_importance(memory))) elif "ORDER BY m.timestamp DESC" in query: results.sort( - key=lambda memory: (_timestamp_key(memory), _importance(memory)), reverse=True + key=lambda memory: (_timestamp_key(memory), _importance(memory)), + reverse=True, ) elif "ORDER BY coalesce(m.updated_at, m.timestamp) ASC" in query: results.sort(key=lambda memory: (_timestamp_key(memory), -_importance(memory))) elif "ORDER BY coalesce(m.updated_at, m.timestamp) DESC" in query: results.sort( - key=lambda memory: (_timestamp_key(memory), _importance(memory)), reverse=True + key=lambda memory: (_timestamp_key(memory), _importance(memory)), + reverse=True, ) else: results.sort( - key=lambda memory: (_importance(memory), _timestamp_key(memory)), reverse=True + key=lambda memory: (_importance(memory), _timestamp_key(memory)), + reverse=True, ) limit_param = params.get("limit") diff --git a/tests/test_backup_endpoint.py b/tests/test_backup_endpoint.py new file mode 100644 index 0000000..ac61e99 --- /dev/null +++ b/tests/test_backup_endpoint.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +import gzip +import io +import json +import tarfile +from pathlib import Path +from types import SimpleNamespace +from typing import Any + +import pytest + +import app +from scripts.restore_from_backup import AutoMemRestore, resolve_backup_dir +from tests.support.fake_graph import FakeGraph + + +class FakeQdrantClient: + def __init__(self, vector_size: int = 3) -> None: + self.points: dict[str, dict[str, Any]] = {} + self.vector_size = vector_size + + def scroll( + self, + collection_name: str, + limit: int = 100, + offset: int | None = None, + with_payload: bool = True, + with_vectors: bool = False, + **_kwargs: Any, + ) -> tuple[list[Any], int | None]: + del collection_name + ids = sorted(self.points) + start = int(offset or 0) + selected = ids[start : start + limit] + points = [] + for point_id in selected: + point = self.points[point_id] + points.append( + SimpleNamespace( + id=point_id, + vector=point["vector"] if with_vectors else None, + payload=point["payload"] if with_payload else None, + ) + ) + next_offset = start + len(selected) + if next_offset >= len(ids): + return points, None + return points, next_offset + + def get_collection(self, collection_name: str) -> Any: + del collection_name + return SimpleNamespace( + points_count=len(self.points), + config=SimpleNamespace( + params=SimpleNamespace(vectors=SimpleNamespace(size=self.vector_size)) + ), + ) + + +@pytest.fixture +def backup_state(monkeypatch: pytest.MonkeyPatch) -> Any: + state = app.ServiceState() + state.memory_graph = FakeGraph() + state.qdrant = FakeQdrantClient() + + monkeypatch.setattr(app, "state", state) + monkeypatch.setattr(app, "init_falkordb", lambda: None) + monkeypatch.setattr(app, "init_qdrant", lambda: None) + monkeypatch.setattr(app, "init_openai", lambda: None) + monkeypatch.setattr(app, "API_TOKEN", "test-token") + monkeypatch.setattr(app, "ADMIN_TOKEN", "test-admin-token") + + original_testing = app.app.config.get("TESTING") + app.app.config["TESTING"] = True + yield state + app.app.config["TESTING"] = original_testing + + +@pytest.fixture +def client(backup_state: Any) -> Any: + del backup_state + with app.app.test_client() as test_client: + yield test_client + + +def _admin_headers() -> dict[str, str]: + return {"X-Admin-Token": "test-admin-token"} + + +def _archive_members(raw: bytes) -> dict[str, bytes]: + with tarfile.open(fileobj=io.BytesIO(raw), mode="r:gz") as archive: + return { + member.name: archive.extractfile(member).read() + for member in archive.getmembers() + if member.isfile() + } + + +def _read_backup_json(members: dict[str, bytes], prefix: str) -> dict[str, Any]: + matching = [name for name in members if name.startswith(prefix)] + assert len(matching) == 1 + return json.loads(gzip.decompress(members[matching[0]]).decode("utf-8")) + + +def test_backup_requires_admin_token_not_api_token(client: Any) -> None: + response = client.get("/backup", headers={"Authorization": "Bearer test-token"}) + + assert response.status_code == 401 + + +def test_backup_missing_admin_token_returns_401(client: Any) -> None: + response = client.get("/backup") + + assert response.status_code == 401 + + +def test_backup_unconfigured_admin_token_returns_403( + client: Any, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setattr(app, "ADMIN_TOKEN", None) + + response = client.get("/backup", headers=_admin_headers()) + + assert response.status_code == 403 + + +def test_backup_happy_path_tar_contains_falkordb_and_qdrant(client: Any, backup_state: Any) -> None: + memory_id = "10000000-0000-0000-0000-000000000001" + backup_state.memory_graph.memories[memory_id] = { + "id": memory_id, + "content": "Backup me", + "tags": ["backup"], + "importance": 0.8, + "timestamp": "2026-01-01T00:00:00+00:00", + "type": "Context", + } + backup_state.qdrant.points[memory_id] = { + "vector": [0.1, 0.2, 0.3], + "payload": {"content": "Backup me", "tags": ["backup"]}, + } + + response = client.get("/backup", headers=_admin_headers()) + + assert response.status_code == 200 + assert response.mimetype == "application/gzip" + assert "automem-backup-" in response.headers["Content-Disposition"] + + members = _archive_members(response.get_data()) + assert any(name.startswith("falkordb/falkordb_") for name in members) + assert any(name.startswith("qdrant/qdrant_") for name in members) + + falkordb = _read_backup_json(members, "falkordb/falkordb_") + qdrant = _read_backup_json(members, "qdrant/qdrant_") + assert falkordb["stats"]["node_count"] == 1 + assert falkordb["nodes"][0]["properties"]["content"] == "Backup me" + assert qdrant["stats"]["points_count"] == 1 + assert qdrant["points"][0]["vector"] == [0.1, 0.2, 0.3] + + +def test_backup_include_falkordb_only_does_not_require_qdrant( + client: Any, backup_state: Any +) -> None: + backup_state.qdrant = None + + response = client.get("/backup?include=falkordb", headers=_admin_headers()) + + assert response.status_code == 200 + members = _archive_members(response.get_data()) + assert any(name.startswith("falkordb/falkordb_") for name in members) + assert not any(name.startswith("qdrant/qdrant_") for name in members) + + +@pytest.mark.parametrize("include", ["", "unknown", "falkordb,,qdrant"]) +def test_backup_invalid_include_returns_400(client: Any, include: str) -> None: + response = client.get(f"/backup?include={include}", headers=_admin_headers()) + + assert response.status_code == 400 + + +def test_backup_empty_corpus(client: Any) -> None: + response = client.get("/backup", headers=_admin_headers()) + + assert response.status_code == 200 + members = _archive_members(response.get_data()) + falkordb = _read_backup_json(members, "falkordb/falkordb_") + qdrant = _read_backup_json(members, "qdrant/qdrant_") + assert falkordb["stats"]["node_count"] == 0 + assert falkordb["stats"]["relationship_count"] == 0 + assert qdrant["stats"]["points_count"] == 0 + + +def test_restore_accepts_downloaded_tar_gz_backup(tmp_path: Path) -> None: + falkordb_payload = { + "timestamp": "20260101_000000", + "graph_name": "memories", + "nodes": [], + "relationships": [], + "stats": {"node_count": 0, "relationship_count": 0}, + } + compressed = gzip.compress(json.dumps(falkordb_payload).encode("utf-8")) + archive_path = tmp_path / "snapshot.tar.gz" + + with tarfile.open(archive_path, mode="w:gz") as archive: + info = tarfile.TarInfo("falkordb/falkordb_20260101_000000.json.gz") + info.size = len(compressed) + archive.addfile(info, io.BytesIO(compressed)) + + with resolve_backup_dir(archive_path) as backup_dir: + restore = AutoMemRestore(backup_dir=backup_dir, dry_run=True, force=True) + result = restore.run_restore(falkordb_only=True) + + assert result["falkordb"]["nodes"] == 0 + assert result["falkordb"]["relationships"] == 0