Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import sys
import time
import uuid
from copy import deepcopy
from datetime import datetime, timedelta, timezone
from queue import Empty, Queue
from threading import Event, Lock, Thread
Expand Down Expand Up @@ -167,9 +168,13 @@ def _parse_viewer_allowed_origins() -> Any:
SEARCH_WEIGHT_RECENCY,
SEARCH_WEIGHT_TAG,
SEARCH_WEIGHT_VECTOR,
SERVICE_MODE,
SERVICE_PROFILE,
SERVICE_TIER,
SYNC_AUTO_REPAIR,
SYNC_CHECK_INTERVAL_SECONDS,
TYPE_ALIASES,
UPGRADE_URL,
VECTOR_SIZE,
normalize_memory_type,
)
Expand Down Expand Up @@ -242,6 +247,72 @@ def _parse_viewer_allowed_origins() -> Any:

state = ServiceState()

_LOCKED_WRITE_PREFIXES = ("/admin/exports",)


def get_service_profile() -> Dict[str, Any]:
profile = getattr(state, "service_profile", None)
if isinstance(profile, dict):
return profile
return deepcopy(SERVICE_PROFILE)


def get_service_mode() -> str:
return str(getattr(state, "service_mode", SERVICE_MODE) or SERVICE_MODE)


def get_service_tier() -> str:
return str(getattr(state, "service_tier", SERVICE_TIER) or SERVICE_TIER)


def _service_capabilities() -> Dict[str, Any]:
capabilities = get_service_profile().get("capabilities", {})
return capabilities if isinstance(capabilities, dict) else {}


def _service_writes_enabled() -> bool:
return bool(_service_capabilities().get("writes_enabled", True))


def _build_service_locked_response() -> tuple[Any, int]:
tier = get_service_tier()
mode = get_service_mode()
is_trial_expired = tier == "trial" and mode == "archived"

if is_trial_expired:
error = "trial_expired"
message = (
"Your 30-day trial has ended. Your memories are safe. "
f"Subscribe at {UPGRADE_URL} to unlock them."
)
elif mode == "archived":
error = "service_locked"
message = "Service is archived and does not accept write operations"
elif mode == "read_only":
error = "service_locked"
message = "Service is in read-only mode and does not accept write operations"
else:
error = "service_locked"
message = "Service is locked and does not accept write operations"

body = {
"status": "error",
"code": 423,
"error": error,
"message": message,
"service_mode": mode,
"service_tier": tier,
"capabilities": _service_capabilities(),
}
if is_trial_expired:
body["reason"] = "trial_expired"
body["upgrade_url"] = UPGRADE_URL

return (
jsonify(body),
423,
)


def _extract_api_token() -> Optional[str]:
return _extract_api_token_helper(request, API_TOKEN)
Expand All @@ -251,6 +322,16 @@ def get_openai_client() -> Optional[OpenAI]:
return state.openai_client


def get_embedding_provider_name() -> Optional[str]:
provider = getattr(state, "embedding_provider", None)
if provider is None:
return None
try:
return provider.provider_name()
except Exception:
return None


def _require_admin_token() -> None:
_require_admin_token_helper(
request_obj=request,
Expand All @@ -275,6 +356,27 @@ def require_api_token() -> None:
)


@app.before_request
def enforce_service_mode() -> Any:
if request.method == "OPTIONS":
return None

if request.path.startswith("/viewer"):
return None

if _service_writes_enabled():
return None

if request.method in {"GET", "HEAD"}:
return None

if any(request.path.startswith(prefix) for prefix in _LOCKED_WRITE_PREFIXES):
if _service_capabilities().get("archive_export_enabled", False):
return None

return _build_service_locked_response()


_service_runtime = create_service_runtime(
get_state_fn=lambda: state,
logger=logger,
Expand Down
197 changes: 196 additions & 1 deletion automem/api/admin.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations

import gzip
import json
import os
from pathlib import Path
from typing import Any, Callable, Dict, List, Set
from uuid import uuid4

from flask import Blueprint, abort, jsonify, request
from flask import Blueprint, abort, jsonify, request, send_file


def _parse_metadata(raw: Any) -> Dict[str, Any]:
Expand Down Expand Up @@ -58,6 +62,108 @@ def _get_all_qdrant_ids(qdrant_client: Any, collection_name: str) -> Set[str]:
return all_ids


def _export_dir() -> Path:
directory = Path(os.getenv("AUTOMEM_EXPORT_DIR", "./exports")).expanduser().resolve()
directory.mkdir(parents=True, exist_ok=True)
return directory


def _export_bundle_path(export_id: str) -> Path:
return _export_dir() / f"{export_id}.json.gz"


def _export_manifest_path(export_id: str) -> Path:
return _export_dir() / f"{export_id}.manifest.json"


def _load_export_manifest(export_id: str) -> Dict[str, Any]:
path = _export_manifest_path(export_id)
if not path.exists():
abort(404, description="Export not found")
return json.loads(path.read_text(encoding="utf-8"))


def _export_graph_snapshot(graph: Any, graph_name: str) -> Dict[str, Any]:
nodes_result = graph.query(
"""
MATCH (n)
RETURN labels(n) AS labels,
properties(n) AS props
"""
)
relationships_result = graph.query(
"""
MATCH (a)-[r]->(b)
RETURN properties(a).id AS source_id,
type(r) AS rel_type,
properties(b).id AS target_id,
properties(r) AS props
"""
)

nodes = []
for row in getattr(nodes_result, "result_set", []) or []:
labels = row[0] if len(row) > 0 else []
props = row[1] if len(row) > 1 else {}
nodes.append({"labels": labels or [], "properties": props or {}})

relationships = []
for row in getattr(relationships_result, "result_set", []) or []:
relationships.append(
{
"source_id": row[0] if len(row) > 0 else None,
"type": row[1] if len(row) > 1 else None,
"target_id": row[2] if len(row) > 2 else None,
"properties": row[3] if len(row) > 3 else {},
}
)

return {
"graph_name": graph_name,
"nodes": nodes,
"relationships": relationships,
"stats": {
"node_count": len(nodes),
"relationship_count": len(relationships),
},
}


def _export_qdrant_snapshot(qdrant_client: Any, collection_name: str) -> Dict[str, Any] | None:
if qdrant_client is None:
return None

points = []
offset = None
while True:
batch, next_offset = qdrant_client.scroll(
collection_name=collection_name,
limit=100,
offset=offset,
with_payload=True,
with_vectors=True,
)
for point in batch:
points.append(
{
"id": point.id,
"vector": getattr(point, "vector", None),
"payload": getattr(point, "payload", None),
}
)
if next_offset is None:
break
offset = next_offset

return {
"collection_name": collection_name,
"points": points,
"stats": {
"points_count": len(points),
},
}


def create_admin_blueprint_full(
require_admin_token: Callable[[], None],
init_openai: Callable[[], None],
Expand All @@ -70,6 +176,10 @@ def create_admin_blueprint_full(
embedding_model: str,
utc_now: Callable[[], str],
logger: Any,
graph_name: str,
get_service_profile: Callable[[], Dict[str, Any]],
get_service_mode: Callable[[], str],
get_service_tier: Callable[[], str],
) -> Blueprint:
bp = Blueprint("admin", __name__)

Expand Down Expand Up @@ -404,4 +514,89 @@ def sync_missing() -> Any:
response["failed_ids_truncated"] = True
return jsonify(response)

@bp.route("/admin/exports", methods=["POST"])
def create_export() -> Any:
require_admin_token()

payload = request.get_json(silent=True) or {}
include_vectors = bool(payload.get("include_vectors", True))
reason = str(payload.get("reason") or "").strip() or None
export_id = str(payload.get("export_id") or uuid4().hex)

graph = get_memory_graph()
if graph is None:
abort(503, description="FalkorDB is unavailable")

qdrant_client = get_qdrant_client() if include_vectors else None
graph_snapshot = _export_graph_snapshot(graph, graph_name)
qdrant_snapshot = _export_qdrant_snapshot(qdrant_client, collection_name)

bundle = {
"export_id": export_id,
"created_at": utc_now(),
"reason": reason,
"service": {
"tier": get_service_tier(),
"mode": get_service_mode(),
"profile": get_service_profile(),
},
"graph": graph_snapshot,
"qdrant": qdrant_snapshot,
}

bundle_path = _export_bundle_path(export_id)
with gzip.open(bundle_path, "wt", encoding="utf-8") as handle:
json.dump(bundle, handle, indent=2, default=str)

manifest = {
"status": "complete",
"export_id": export_id,
"created_at": bundle["created_at"],
"reason": reason,
"service": bundle["service"],
"graph": graph_snapshot["stats"],
"qdrant": qdrant_snapshot["stats"] if qdrant_snapshot else None,
"include_vectors": qdrant_snapshot is not None,
"bundle": {
"filename": bundle_path.name,
"bytes": bundle_path.stat().st_size,
},
"download_url": f"{request.host_url.rstrip('/')}/admin/exports/{export_id}/download",
"status_url": f"{request.host_url.rstrip('/')}/admin/exports/{export_id}",
}

manifest_path = _export_manifest_path(export_id)
manifest_path.write_text(json.dumps(manifest, indent=2, default=str), encoding="utf-8")

logger.info(
"Created service export",
extra={
"export_id": export_id,
"tier": get_service_tier(),
"mode": get_service_mode(),
"include_vectors": bool(qdrant_snapshot is not None),
},
)
return jsonify(manifest), 201

@bp.route("/admin/exports/<export_id>", methods=["GET"])
def export_status(export_id: str) -> Any:
require_admin_token()
return jsonify(_load_export_manifest(export_id))

@bp.route("/admin/exports/<export_id>/download", methods=["GET"])
def download_export(export_id: str) -> Any:
require_admin_token()

bundle_path = _export_bundle_path(export_id)
if not bundle_path.exists():
abort(404, description="Export bundle not found")

return send_file(
bundle_path,
mimetype="application/gzip",
as_attachment=True,
download_name=bundle_path.name,
)

return bp
14 changes: 14 additions & 0 deletions automem/api/health.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,20 @@ def health() -> Any:
"processed": state.enrichment_stats.successes,
"failed": state.enrichment_stats.failures,
},
"service": {
"tier": getattr(state, "service_tier", "pro"),
"mode": getattr(state, "service_mode", "active"),
"writes_enabled": bool(
getattr(state, "service_profile", {})
.get("capabilities", {})
.get("writes_enabled", True)
),
"self_service_export_enabled": bool(
getattr(state, "service_profile", {})
.get("capabilities", {})
.get("self_service_export_enabled", False)
),
},
"timestamp": utc_now(),
"graph": graph_name,
}
Expand Down
Loading
Loading