diff --git a/app.py b/app.py index 0e14a50..c3b3c63 100644 --- a/app.py +++ b/app.py @@ -15,6 +15,7 @@ import sys import time import uuid +from copy import deepcopy from datetime import datetime, timedelta, timezone from queue import Empty, Queue from threading import Event, Lock, Thread @@ -167,9 +168,13 @@ def _parse_viewer_allowed_origins() -> Any: SEARCH_WEIGHT_RECENCY, SEARCH_WEIGHT_TAG, SEARCH_WEIGHT_VECTOR, + SERVICE_MODE, + SERVICE_PROFILE, + SERVICE_TIER, SYNC_AUTO_REPAIR, SYNC_CHECK_INTERVAL_SECONDS, TYPE_ALIASES, + UPGRADE_URL, VECTOR_SIZE, normalize_memory_type, ) @@ -242,6 +247,72 @@ def _parse_viewer_allowed_origins() -> Any: state = ServiceState() +_LOCKED_WRITE_PREFIXES = ("/admin/exports",) + + +def get_service_profile() -> Dict[str, Any]: + profile = getattr(state, "service_profile", None) + if isinstance(profile, dict): + return profile + return deepcopy(SERVICE_PROFILE) + + +def get_service_mode() -> str: + return str(getattr(state, "service_mode", SERVICE_MODE) or SERVICE_MODE) + + +def get_service_tier() -> str: + return str(getattr(state, "service_tier", SERVICE_TIER) or SERVICE_TIER) + + +def _service_capabilities() -> Dict[str, Any]: + capabilities = get_service_profile().get("capabilities", {}) + return capabilities if isinstance(capabilities, dict) else {} + + +def _service_writes_enabled() -> bool: + return bool(_service_capabilities().get("writes_enabled", True)) + + +def _build_service_locked_response() -> tuple[Any, int]: + tier = get_service_tier() + mode = get_service_mode() + is_trial_expired = tier == "trial" and mode == "archived" + + if is_trial_expired: + error = "trial_expired" + message = ( + "Your 30-day trial has ended. Your memories are safe. " + f"Subscribe at {UPGRADE_URL} to unlock them." + ) + elif mode == "archived": + error = "service_locked" + message = "Service is archived and does not accept write operations" + elif mode == "read_only": + error = "service_locked" + message = "Service is in read-only mode and does not accept write operations" + else: + error = "service_locked" + message = "Service is locked and does not accept write operations" + + body = { + "status": "error", + "code": 423, + "error": error, + "message": message, + "service_mode": mode, + "service_tier": tier, + "capabilities": _service_capabilities(), + } + if is_trial_expired: + body["reason"] = "trial_expired" + body["upgrade_url"] = UPGRADE_URL + + return ( + jsonify(body), + 423, + ) + def _extract_api_token() -> Optional[str]: return _extract_api_token_helper(request, API_TOKEN) @@ -251,6 +322,16 @@ def get_openai_client() -> Optional[OpenAI]: return state.openai_client +def get_embedding_provider_name() -> Optional[str]: + provider = getattr(state, "embedding_provider", None) + if provider is None: + return None + try: + return provider.provider_name() + except Exception: + return None + + def _require_admin_token() -> None: _require_admin_token_helper( request_obj=request, @@ -275,6 +356,27 @@ def require_api_token() -> None: ) +@app.before_request +def enforce_service_mode() -> Any: + if request.method == "OPTIONS": + return None + + if request.path.startswith("/viewer"): + return None + + if _service_writes_enabled(): + return None + + if request.method in {"GET", "HEAD"}: + return None + + if any(request.path.startswith(prefix) for prefix in _LOCKED_WRITE_PREFIXES): + if _service_capabilities().get("archive_export_enabled", False): + return None + + return _build_service_locked_response() + + _service_runtime = create_service_runtime( get_state_fn=lambda: state, logger=logger, diff --git a/automem/api/admin.py b/automem/api/admin.py index 932ece5..acd850e 100644 --- a/automem/api/admin.py +++ b/automem/api/admin.py @@ -1,9 +1,13 @@ from __future__ import annotations +import gzip import json +import os +from pathlib import Path from typing import Any, Callable, Dict, List, Set +from uuid import uuid4 -from flask import Blueprint, abort, jsonify, request +from flask import Blueprint, abort, jsonify, request, send_file def _parse_metadata(raw: Any) -> Dict[str, Any]: @@ -58,6 +62,108 @@ def _get_all_qdrant_ids(qdrant_client: Any, collection_name: str) -> Set[str]: return all_ids +def _export_dir() -> Path: + directory = Path(os.getenv("AUTOMEM_EXPORT_DIR", "./exports")).expanduser().resolve() + directory.mkdir(parents=True, exist_ok=True) + return directory + + +def _export_bundle_path(export_id: str) -> Path: + return _export_dir() / f"{export_id}.json.gz" + + +def _export_manifest_path(export_id: str) -> Path: + return _export_dir() / f"{export_id}.manifest.json" + + +def _load_export_manifest(export_id: str) -> Dict[str, Any]: + path = _export_manifest_path(export_id) + if not path.exists(): + abort(404, description="Export not found") + return json.loads(path.read_text(encoding="utf-8")) + + +def _export_graph_snapshot(graph: Any, graph_name: str) -> Dict[str, Any]: + nodes_result = graph.query( + """ + MATCH (n) + RETURN labels(n) AS labels, + properties(n) AS props + """ + ) + relationships_result = graph.query( + """ + MATCH (a)-[r]->(b) + RETURN properties(a).id AS source_id, + type(r) AS rel_type, + properties(b).id AS target_id, + properties(r) AS props + """ + ) + + nodes = [] + for row in getattr(nodes_result, "result_set", []) or []: + labels = row[0] if len(row) > 0 else [] + props = row[1] if len(row) > 1 else {} + nodes.append({"labels": labels or [], "properties": props or {}}) + + relationships = [] + for row in getattr(relationships_result, "result_set", []) or []: + relationships.append( + { + "source_id": row[0] if len(row) > 0 else None, + "type": row[1] if len(row) > 1 else None, + "target_id": row[2] if len(row) > 2 else None, + "properties": row[3] if len(row) > 3 else {}, + } + ) + + return { + "graph_name": graph_name, + "nodes": nodes, + "relationships": relationships, + "stats": { + "node_count": len(nodes), + "relationship_count": len(relationships), + }, + } + + +def _export_qdrant_snapshot(qdrant_client: Any, collection_name: str) -> Dict[str, Any] | None: + if qdrant_client is None: + return None + + points = [] + offset = None + while True: + batch, next_offset = qdrant_client.scroll( + collection_name=collection_name, + limit=100, + offset=offset, + with_payload=True, + with_vectors=True, + ) + for point in batch: + points.append( + { + "id": point.id, + "vector": getattr(point, "vector", None), + "payload": getattr(point, "payload", None), + } + ) + if next_offset is None: + break + offset = next_offset + + return { + "collection_name": collection_name, + "points": points, + "stats": { + "points_count": len(points), + }, + } + + def create_admin_blueprint_full( require_admin_token: Callable[[], None], init_openai: Callable[[], None], @@ -70,6 +176,10 @@ def create_admin_blueprint_full( embedding_model: str, utc_now: Callable[[], str], logger: Any, + graph_name: str, + get_service_profile: Callable[[], Dict[str, Any]], + get_service_mode: Callable[[], str], + get_service_tier: Callable[[], str], ) -> Blueprint: bp = Blueprint("admin", __name__) @@ -404,4 +514,89 @@ def sync_missing() -> Any: response["failed_ids_truncated"] = True return jsonify(response) + @bp.route("/admin/exports", methods=["POST"]) + def create_export() -> Any: + require_admin_token() + + payload = request.get_json(silent=True) or {} + include_vectors = bool(payload.get("include_vectors", True)) + reason = str(payload.get("reason") or "").strip() or None + export_id = str(payload.get("export_id") or uuid4().hex) + + graph = get_memory_graph() + if graph is None: + abort(503, description="FalkorDB is unavailable") + + qdrant_client = get_qdrant_client() if include_vectors else None + graph_snapshot = _export_graph_snapshot(graph, graph_name) + qdrant_snapshot = _export_qdrant_snapshot(qdrant_client, collection_name) + + bundle = { + "export_id": export_id, + "created_at": utc_now(), + "reason": reason, + "service": { + "tier": get_service_tier(), + "mode": get_service_mode(), + "profile": get_service_profile(), + }, + "graph": graph_snapshot, + "qdrant": qdrant_snapshot, + } + + bundle_path = _export_bundle_path(export_id) + with gzip.open(bundle_path, "wt", encoding="utf-8") as handle: + json.dump(bundle, handle, indent=2, default=str) + + manifest = { + "status": "complete", + "export_id": export_id, + "created_at": bundle["created_at"], + "reason": reason, + "service": bundle["service"], + "graph": graph_snapshot["stats"], + "qdrant": qdrant_snapshot["stats"] if qdrant_snapshot else None, + "include_vectors": qdrant_snapshot is not None, + "bundle": { + "filename": bundle_path.name, + "bytes": bundle_path.stat().st_size, + }, + "download_url": f"{request.host_url.rstrip('/')}/admin/exports/{export_id}/download", + "status_url": f"{request.host_url.rstrip('/')}/admin/exports/{export_id}", + } + + manifest_path = _export_manifest_path(export_id) + manifest_path.write_text(json.dumps(manifest, indent=2, default=str), encoding="utf-8") + + logger.info( + "Created service export", + extra={ + "export_id": export_id, + "tier": get_service_tier(), + "mode": get_service_mode(), + "include_vectors": bool(qdrant_snapshot is not None), + }, + ) + return jsonify(manifest), 201 + + @bp.route("/admin/exports/", methods=["GET"]) + def export_status(export_id: str) -> Any: + require_admin_token() + return jsonify(_load_export_manifest(export_id)) + + @bp.route("/admin/exports//download", methods=["GET"]) + def download_export(export_id: str) -> Any: + require_admin_token() + + bundle_path = _export_bundle_path(export_id) + if not bundle_path.exists(): + abort(404, description="Export bundle not found") + + return send_file( + bundle_path, + mimetype="application/gzip", + as_attachment=True, + download_name=bundle_path.name, + ) + return bp diff --git a/automem/api/health.py b/automem/api/health.py index cf98634..1b86182 100644 --- a/automem/api/health.py +++ b/automem/api/health.py @@ -103,6 +103,20 @@ def health() -> Any: "processed": state.enrichment_stats.successes, "failed": state.enrichment_stats.failures, }, + "service": { + "tier": getattr(state, "service_tier", "pro"), + "mode": getattr(state, "service_mode", "active"), + "writes_enabled": bool( + getattr(state, "service_profile", {}) + .get("capabilities", {}) + .get("writes_enabled", True) + ), + "self_service_export_enabled": bool( + getattr(state, "service_profile", {}) + .get("capabilities", {}) + .get("self_service_export_enabled", False) + ), + }, "timestamp": utc_now(), "graph": graph_name, } diff --git a/automem/api/memory.py b/automem/api/memory.py index 78c7d79..89e178d 100644 --- a/automem/api/memory.py +++ b/automem/api/memory.py @@ -87,10 +87,29 @@ def create_memory_blueprint_full( ) -> Blueprint: bp = Blueprint("memory", __name__) - @bp.route("/memory", methods=["POST"]) - def store() -> Any: + def _merge_tags(raw_tags: Any, forced_tags: Optional[List[str]] = None) -> List[str]: + tags = normalize_tags(raw_tags) + if not forced_tags: + return tags + + seen = {str(tag).strip().lower() for tag in tags if isinstance(tag, str) and tag.strip()} + for forced_tag in forced_tags: + normalized = str(forced_tag).strip() + if not normalized: + continue + lowered = normalized.lower() + if lowered in seen: + continue + tags.append(normalized) + seen.add(lowered) + return tags + + def _store_memory_payload( + payload: Dict[str, Any], + *, + forced_tags: Optional[List[str]] = None, + ) -> Dict[str, Any]: query_start = time.perf_counter() - payload = request.get_json(silent=True) if not isinstance(payload, dict): abort(400, description="JSON body is required") @@ -139,7 +158,7 @@ def store() -> Any: len(content), ) - tags = normalize_tags(payload.get("tags")) + tags = _merge_tags(payload.get("tags"), forced_tags) tags_lower = [t.strip().lower() for t in tags if isinstance(t, str) and t.strip()] tag_prefixes = compute_tag_prefixes(tags_lower) importance = coerce_importance(payload.get("importance")) @@ -360,6 +379,80 @@ def store() -> Any: "enrichment_queued": bool(state.enrichment_queue), }, ) + return response + + def _associate_memories( + *, + memory1_id: str, + memory2_id: str, + relation_type: str, + strength: Any, + payload: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + relation_type = (relation_type or "RELATES_TO").upper() + strength_value = coerce_importance(strength) + + if not memory1_id or not memory2_id: + abort(400, description="'memory1_id' and 'memory2_id' are required") + _validate_memory_id(memory1_id) + _validate_memory_id(memory2_id) + if memory1_id == memory2_id: + abort(400, description="Cannot associate a memory with itself") + if relation_type not in set(authorable_relations): + abort( + 400, + description=f"Relation type must be one of {sorted(authorable_relations)}", + ) + + graph = get_memory_graph() + if graph is None: + abort(503, description="FalkorDB is unavailable") + + timestamp = utc_now() + relationship_props = {"strength": strength_value, "updated_at": timestamp} + relation_config = relation_types.get(relation_type, {}) + source_payload = payload or {} + if "properties" in relation_config: + for prop in relation_config["properties"]: + if prop in source_payload: + relationship_props[prop] = source_payload[prop] + + set_clauses = [f"r.{key} = ${key}" for key in relationship_props] + set_clause = ", ".join(set_clauses) + + try: + result = graph.query( + f""" + MATCH (m1:Memory {{id: $id1}}) + MATCH (m2:Memory {{id: $id2}}) + MERGE (m1)-[r:{relation_type}]->(m2) + SET {set_clause} + RETURN r + """, + {"id1": memory1_id, "id2": memory2_id, **relationship_props}, + ) + except Exception: + logger.exception("Failed to create association") + abort(500, description="Failed to create association") + + if not getattr(result, "result_set", None): + abort(404, description="One or both memories do not exist") + + response = { + "status": "success", + "message": f"Association created between {memory1_id} and {memory2_id}", + "relation_type": relation_type, + "strength": strength_value, + } + for prop in relation_config.get("properties", []): + if prop in relationship_props: + response[prop] = relationship_props[prop] + return response + + @bp.route("/memory", methods=["POST"]) + def store() -> Any: + payload = request.get_json(silent=True) + response = _store_memory_payload(payload) return jsonify(response), 201 @bp.route("/memory/", methods=["GET"]) @@ -619,68 +712,189 @@ def associate() -> Any: payload = request.get_json(silent=True) if not isinstance(payload, dict): abort(400, description="JSON body is required") + response = _associate_memories( + memory1_id=(payload.get("memory1_id") or "").strip(), + memory2_id=(payload.get("memory2_id") or "").strip(), + relation_type=(payload.get("type") or "RELATES_TO").upper(), + strength=payload.get("strength", 0.5), + payload=payload, + ) + return jsonify(response), 201 - memory1_id = (payload.get("memory1_id") or "").strip() - memory2_id = (payload.get("memory2_id") or "").strip() - relation_type = (payload.get("type") or "RELATES_TO").upper() - strength = coerce_importance(payload.get("strength", 0.5)) + @bp.route("/api/v1/preseed", methods=["POST"]) + def preseed() -> Any: + internal_header = (request.headers.get("X-AutoMem-Internal") or "").strip().lower() + if internal_header != "preseed": + abort(403, description="Internal preseed authorization required") - if not memory1_id or not memory2_id: - abort(400, description="'memory1_id' and 'memory2_id' are required") - _validate_memory_id(memory1_id) - _validate_memory_id(memory2_id) - if memory1_id == memory2_id: - abort(400, description="Cannot associate a memory with itself") - if relation_type not in set(authorable_relations): - abort( - 400, - description=f"Relation type must be one of {sorted(authorable_relations)}", + payload = request.get_json(silent=True) + if not isinstance(payload, dict): + abort(400, description="JSON body with 'memories' array required") + + memories_input = payload.get("memories") + if not isinstance(memories_input, list) or not memories_input: + abort(400, description="'memories' must be a non-empty array") + + associations_input = payload.get("associations", []) + if not isinstance(associations_input, list): + abort(400, description="'associations' must be an array when provided") + + normalized_associations: List[Dict[str, Any]] = [] + for index, association in enumerate(associations_input): + if not isinstance(association, dict): + abort(400, description=f"Association at index {index} must be an object") + + try: + from_index = int(association.get("from_index")) + to_index = int(association.get("to_index")) + except (TypeError, ValueError): + abort( + 400, + description=( + f"Association at index {index} must include integer " + "'from_index' and 'to_index' fields" + ), + ) + + if from_index < 0 or to_index < 0: + abort(400, description=f"Association at index {index} cannot use negative indexes") + if from_index >= len(memories_input) or to_index >= len(memories_input): + abort( + 400, + description=( + f"Association at index {index} references an out-of-range memory index" + ), + ) + if from_index == to_index: + abort( + 400, + description=f"Association at index {index} cannot reference the same memory", + ) + + relation_type = ( + association.get("relationship") or association.get("type") or "RELATES_TO" ) + if str(relation_type).upper() not in set(authorable_relations): + abort( + 400, + description=( + f"Association at index {index} must use one of " + f"{sorted(authorable_relations)}" + ), + ) - graph = get_memory_graph() - if graph is None: - abort(503, description="FalkorDB is unavailable") + normalized_associations.append( + { + **association, + "from_index": from_index, + "to_index": to_index, + "type": str(relation_type).upper(), + } + ) - timestamp = utc_now() + created_responses: List[Dict[str, Any]] = [] + for index, memory_payload in enumerate(memories_input): + if not isinstance(memory_payload, dict): + abort(400, description=f"Memory at index {index} must be an object") + created_responses.append( + _store_memory_payload(memory_payload, forced_tags=["onboarding"]) + ) - relationship_props = {"strength": strength, "updated_at": timestamp} - relation_config = relation_types.get(relation_type, {}) - if "properties" in relation_config: - for prop in relation_config["properties"]: - if prop in payload: - relationship_props[prop] = payload[prop] + memory_ids = [response["memory_id"] for response in created_responses] - set_clauses = [f"r.{key} = ${key}" for key in relationship_props] - set_clause = ", ".join(set_clauses) + for association in normalized_associations: + _associate_memories( + memory1_id=memory_ids[int(association["from_index"])], + memory2_id=memory_ids[int(association["to_index"])], + relation_type=str(association["type"]), + strength=association.get("strength", 0.5), + payload=association, + ) + + return ( + jsonify( + { + "status": "success", + "memories_created": len(memory_ids), + "associations_created": len(normalized_associations), + "memory_ids": memory_ids, + } + ), + 201, + ) + + @bp.route("/api/v1/stats", methods=["GET"]) + def api_stats() -> Any: + graph = get_memory_graph() + if graph is None: + abort(503, description="FalkorDB is unavailable") try: - result = graph.query( - f""" - MATCH (m1:Memory {{id: $id1}}) - MATCH (m2:Memory {{id: $id2}}) - MERGE (m1)-[r:{relation_type}]->(m2) - SET {set_clause} - RETURN r - """, - {"id1": memory1_id, "id2": memory2_id, **relationship_props}, + memory_count_result = graph.query("MATCH (m:Memory) RETURN count(m) as count") + association_count_result = graph.query("MATCH ()-[r]->() RETURN count(r)") + type_result = graph.query( + """ + MATCH (m:Memory) + WHERE m.type IS NOT NULL + RETURN m.type, COUNT(m) as count + ORDER BY count DESC + """ + ) + tag_result = graph.query( + """ + MATCH (m:Memory) + UNWIND coalesce(m.tags, []) AS tag + RETURN toLower(tag) AS tag, count(*) AS count + ORDER BY count DESC, tag ASC + LIMIT 10 + """ + ) + last_activity_result = graph.query( + """ + MATCH (m:Memory) + RETURN max(coalesce(m.updated_at, m.last_accessed, m.timestamp)) AS last_activity + """ ) except Exception: - logger.exception("Failed to create association") - abort(500, description="Failed to create association") + logger.exception("Failed to build dashboard stats") + abort(500, description="Failed to build stats") - if not getattr(result, "result_set", None): - abort(404, description="One or both memories do not exist") - - response = { - "status": "success", - "message": f"Association created between {memory1_id} and {memory2_id}", - "relation_type": relation_type, - "strength": strength, + memories_stored = ( + int(memory_count_result.result_set[0][0]) if memory_count_result.result_set else 0 + ) + associations = ( + int(association_count_result.result_set[0][0]) + if association_count_result.result_set + else 0 + ) + memory_types = { + str(row[0]): int(row[1]) + for row in getattr(type_result, "result_set", []) or [] + if row and row[0] is not None } - for prop in relation_config.get("properties", []): - if prop in relationship_props: - response[prop] = relationship_props[prop] - return jsonify(response), 201 + top_tags = [ + {"tag": str(row[0]), "count": int(row[1])} + for row in getattr(tag_result, "result_set", []) or [] + if row and row[0] + ] + last_activity = None + if getattr(last_activity_result, "result_set", None): + raw_last_activity = last_activity_result.result_set[0][0] + last_activity = str(raw_last_activity) if raw_last_activity is not None else None + + graph_density = round((associations / memories_stored), 4) if memories_stored else 0.0 + + return jsonify( + { + "status": "success", + "memories_stored": memories_stored, + "associations": associations, + "memory_types": memory_types, + "top_tags": top_tags, + "last_activity": last_activity, + "graph_density": graph_density, + } + ) @bp.route("/memory/batch", methods=["POST"]) def store_batch() -> Any: diff --git a/automem/api/runtime_bootstrap.py b/automem/api/runtime_bootstrap.py index d99495c..8b1c46e 100644 --- a/automem/api/runtime_bootstrap.py +++ b/automem/api/runtime_bootstrap.py @@ -9,6 +9,7 @@ from automem.api.health import create_health_blueprint from automem.api.memory import create_memory_blueprint_full from automem.api.recall import create_recall_blueprint +from automem.api.service_profile import create_service_blueprint from automem.api.stream import create_stream_blueprint from automem.api.viewer import create_viewer_blueprint, is_viewer_enabled @@ -59,6 +60,10 @@ def register_blueprints( init_openai_fn: Callable[[], None], effective_vector_size_fn: Callable[[], int], embedding_model: str, + service_profile_fn: Callable[[], dict[str, Any]], + service_mode_fn: Callable[[], str], + service_tier_fn: Callable[[], str], + embedding_provider_name_fn: Callable[[], str | None], build_consolidator_from_config_fn: Callable[[Any, Any], Any], persist_consolidation_run_fn: Callable[[Any, dict[str, Any]], None], build_scheduler_from_graph_fn: Callable[[Any], Any], @@ -76,6 +81,19 @@ def register_blueprints( utc_now_fn, ) + service_bp = create_service_blueprint( + get_memory_graph=get_memory_graph_fn, + get_qdrant_client=get_qdrant_client_fn, + graph_name=graph_name, + collection_name=collection_name, + embedding_model=embedding_model, + utc_now=utc_now_fn, + get_service_profile=service_profile_fn, + get_service_mode=service_mode_fn, + get_service_tier=service_tier_fn, + get_embedding_provider_name=embedding_provider_name_fn, + ) + enrichment_bp = create_enrichment_blueprint( require_admin_token_fn, state, @@ -145,6 +163,10 @@ def register_blueprints( embedding_model, utc_now_fn, logger, + graph_name, + service_profile_fn, + service_mode_fn, + service_tier_fn, ) consolidation_bp = create_consolidation_blueprint_full( @@ -173,6 +195,7 @@ def register_blueprints( ) app.register_blueprint(health_bp) + app.register_blueprint(service_bp) app.register_blueprint(enrichment_bp) app.register_blueprint(memory_bp) app.register_blueprint(admin_bp) diff --git a/automem/api/service_profile.py b/automem/api/service_profile.py new file mode 100644 index 0000000..e81ecfb --- /dev/null +++ b/automem/api/service_profile.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from copy import deepcopy +from typing import Any, Callable, Dict + +from flask import Blueprint, jsonify + + +def create_service_blueprint( + *, + get_memory_graph: Callable[[], Any], + get_qdrant_client: Callable[[], Any], + graph_name: str, + collection_name: str, + embedding_model: str, + utc_now: Callable[[], str], + get_service_profile: Callable[[], Dict[str, Any]], + get_service_mode: Callable[[], str], + get_service_tier: Callable[[], str], + get_embedding_provider_name: Callable[[], str | None], +) -> Blueprint: + bp = Blueprint("service", __name__) + + @bp.route("/service/profile", methods=["GET"]) + def service_profile() -> Any: + graph_available = get_memory_graph() is not None + qdrant_available = get_qdrant_client() is not None + profile = deepcopy(get_service_profile()) + capabilities = deepcopy(profile.get("capabilities", {})) + profile["capabilities"] = capabilities + + return jsonify( + { + "status": "success", + "service": { + "tier": get_service_tier(), + "mode": get_service_mode(), + "profile": profile, + "graph": { + "name": graph_name, + "available": graph_available, + }, + "vector_store": { + "collection": collection_name, + "available": qdrant_available, + "expected": bool(capabilities.get("qdrant_expected", True)), + }, + "embedding": { + "provider": get_embedding_provider_name(), + "model": embedding_model, + "tier": profile.get("embedding_tier"), + }, + "consolidation": { + "tier": profile.get("consolidation_tier"), + }, + "timestamp": utc_now(), + }, + } + ) + + return bp diff --git a/automem/config.py b/automem/config.py index cbeb508..a697a17 100644 --- a/automem/config.py +++ b/automem/config.py @@ -1,8 +1,9 @@ from __future__ import annotations import os +from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Iterable +from typing import Any, Dict, Iterable, Mapping from dotenv import load_dotenv @@ -10,6 +11,121 @@ load_dotenv() load_dotenv(Path.home() / ".config" / "automem" / ".env") + +def _env_bool(env: Mapping[str, str], key: str, default: bool) -> bool: + raw = env.get(key) + if raw is None: + return default + return raw.strip().lower() not in {"0", "false", "no", "off"} + + +@dataclass(frozen=True) +class RuntimeProfile: + tier: str + mode: str + embedding_tier: str + consolidation_tier: str + qdrant_expected: bool + writes_enabled: bool + admin_mutations_enabled: bool + archive_export_enabled: bool + self_service_export_enabled: bool + + def to_dict(self) -> Dict[str, Any]: + return { + "tier": self.tier, + "mode": self.mode, + "embedding_tier": self.embedding_tier, + "consolidation_tier": self.consolidation_tier, + "capabilities": { + "qdrant_expected": self.qdrant_expected, + "writes_enabled": self.writes_enabled, + "admin_mutations_enabled": self.admin_mutations_enabled, + "archive_export_enabled": self.archive_export_enabled, + "self_service_export_enabled": self.self_service_export_enabled, + }, + } + + +def resolve_runtime_profile(env: Mapping[str, str] | None = None) -> RuntimeProfile: + env_map = dict(env or os.environ) + tier_defaults: Dict[str, Dict[str, Any]] = { + "trial": { + "embedding_tier": "managed", + "consolidation_tier": "standard", + "qdrant_expected": True, + "self_service_export_enabled": False, + }, + "pro": { + "embedding_tier": "managed", + "consolidation_tier": "standard", + "qdrant_expected": True, + "self_service_export_enabled": False, + }, + "ultimate": { + "embedding_tier": "premium", + "consolidation_tier": "full", + "qdrant_expected": True, + "self_service_export_enabled": True, + }, + "archived": { + "embedding_tier": "disabled", + "consolidation_tier": "disabled", + "qdrant_expected": True, + "self_service_export_enabled": False, + }, + } + + tier = (env_map.get("AUTOMEM_SERVICE_TIER", "pro") or "pro").strip().lower() + if tier not in tier_defaults: + tier = "pro" + + default_mode = "archived" if tier == "archived" else "active" + mode = (env_map.get("AUTOMEM_SERVICE_MODE", default_mode) or default_mode).strip().lower() + if mode not in {"active", "read_only", "archived"}: + mode = default_mode + + defaults = tier_defaults[tier] + writes_enabled = mode == "active" + admin_mutations_enabled = mode == "active" + + return RuntimeProfile( + tier=tier, + mode=mode, + embedding_tier=( + env_map.get("AUTOMEM_EMBEDDING_TIER", defaults["embedding_tier"]) + or defaults["embedding_tier"] + ) + .strip() + .lower(), + consolidation_tier=( + env_map.get("AUTOMEM_CONSOLIDATION_TIER", defaults["consolidation_tier"]) + or defaults["consolidation_tier"] + ) + .strip() + .lower(), + qdrant_expected=_env_bool( + env_map, "AUTOMEM_QDRANT_EXPECTED", bool(defaults["qdrant_expected"]) + ), + writes_enabled=_env_bool(env_map, "AUTOMEM_WRITES_ENABLED", writes_enabled), + admin_mutations_enabled=_env_bool( + env_map, "AUTOMEM_ADMIN_MUTATIONS_ENABLED", admin_mutations_enabled + ), + archive_export_enabled=_env_bool(env_map, "AUTOMEM_ARCHIVE_EXPORT_ENABLED", True), + self_service_export_enabled=_env_bool( + env_map, + "AUTOMEM_SELF_SERVICE_EXPORT_ENABLED", + bool(defaults["self_service_export_enabled"]), + ), + ) + + +RUNTIME_PROFILE = resolve_runtime_profile() +SERVICE_TIER = RUNTIME_PROFILE.tier +SERVICE_MODE = RUNTIME_PROFILE.mode +SERVICE_PROFILE = RUNTIME_PROFILE.to_dict() +UPGRADE_URL = os.getenv("AUTOMEM_UPGRADE_URL", "https://automem.ai/subscribe").strip() + # Qdrant / FalkorDB configuration COLLECTION_NAME = os.getenv("QDRANT_COLLECTION", "memories") VECTOR_SIZE = int(os.getenv("VECTOR_SIZE") or os.getenv("QDRANT_VECTOR_SIZE", "1024")) diff --git a/automem/runtime_wiring.py b/automem/runtime_wiring.py index 672b7d5..e49ad70 100644 --- a/automem/runtime_wiring.py +++ b/automem/runtime_wiring.py @@ -70,6 +70,10 @@ def wire_recall_and_blueprints( init_openai_fn=module.init_openai, effective_vector_size_fn=lambda: module.state.effective_vector_size, embedding_model=module.EMBEDDING_MODEL, + service_profile_fn=module.get_service_profile, + service_mode_fn=module.get_service_mode, + service_tier_fn=module.get_service_tier, + embedding_provider_name_fn=module.get_embedding_provider_name, build_consolidator_from_config_fn=module._build_consolidator_from_config, persist_consolidation_run_fn=module._persist_consolidation_run, build_scheduler_from_graph_fn=module._build_scheduler_from_graph, diff --git a/automem/service_state.py b/automem/service_state.py index ddd1702..f11c45f 100644 --- a/automem/service_state.py +++ b/automem/service_state.py @@ -1,5 +1,6 @@ from __future__ import annotations +from copy import deepcopy from dataclasses import dataclass, field from queue import Queue from threading import Event, Lock, Thread @@ -8,7 +9,7 @@ from falkordb import FalkorDB from qdrant_client import QdrantClient -from automem.config import VECTOR_SIZE +from automem.config import SERVICE_MODE, SERVICE_PROFILE, SERVICE_TIER, VECTOR_SIZE from automem.embedding.provider import EmbeddingProvider from automem.utils.time import utc_now @@ -82,3 +83,7 @@ class ServiceState: sync_last_result: Optional[Dict[str, Any]] = None # Effective vector size (auto-detected from existing collection or config default) effective_vector_size: int = VECTOR_SIZE + service_tier: str = SERVICE_TIER + service_mode: str = SERVICE_MODE + service_profile: Dict[str, Any] = field(default_factory=lambda: deepcopy(SERVICE_PROFILE)) + export_registry: Dict[str, Dict[str, Any]] = field(default_factory=dict) diff --git a/docs/API.md b/docs/API.md index a7f2988..3cce8f3 100644 --- a/docs/API.md +++ b/docs/API.md @@ -21,6 +21,12 @@ Memory - Body: `{ "content": "...", "tags": ["tag"], "importance": 0.7, "metadata": {} }` - Response: `{ "status": "success", "memory_id": "...", ... }` +- POST `/api/v1/preseed` + - Headers: standard API token plus `X-AutoMem-Internal: preseed`. + - Body: `{ "memories": [...], "associations": [{ "from_index": 0, "to_index": 1, "relationship": "RELATES_TO", "strength": 0.7 }] }` + - Response: `{ "status": "success", "memories_created": N, "associations_created": M, "memory_ids": [...] }` + - All memories created through this endpoint are automatically tagged with `onboarding`. + - GET `/memory/{id}` - Response: `{ "status": "success", "memory": { ... } }` - Errors: `404` if memory is missing, `500` on query failure, `503` if graph database is unavailable. @@ -111,6 +117,32 @@ Admin - Body: `{ "batch_size": 32, "limit": 100, "force": false }` - Response: `{ "status": "complete", "processed": N, "failed": K }` +- POST `/admin/exports` + - Headers: requires both API and Admin tokens. + - Body (optional): `{ "export_id": "trial-expiry-export", "reason": "trial_expired", "include_vectors": true }` + - Response: export manifest with `export_id`, `status_url`, and `download_url`. + - This endpoint remains available in `read_only` and `archived` service modes when archive export support is enabled. + +- GET `/admin/exports/{export_id}` + - Headers: requires both API and Admin tokens. + - Response: persisted export manifest/status for polling. + +- GET `/admin/exports/{export_id}/download` + - Headers: requires both API and Admin tokens. + - Response: gzip-compressed JSON export bundle containing service metadata plus FalkorDB and optional Qdrant snapshots. + +Service + +- GET `/service/profile` + - Auth: standard API token required. + - Response: `{ "status": "success", "service": { "tier": "...", "mode": "...", "profile": {...}, "vector_store": {...}, "embedding": {...} } }` + - Intended for control-plane provisioning checks, dashboard status, and hosted tier introspection. + +- GET `/api/v1/stats` + - Auth: standard API token required. + - Response: `{ "status": "success", "memories_stored": N, "associations": M, "memory_types": {...}, "top_tags": [...], "last_activity": "...", "graph_density": 0.5 }` + - Intended for lightweight dashboard and CLI summaries without the full `/analyze` payload. + Consolidation - POST `/consolidate` @@ -126,3 +158,5 @@ Notes - Exclusion filtering (`exclude_tags`) removes any memory containing ANY of the excluded tags, supporting both exact and prefix matching. - Time filtering accepts ISO timestamps (`start`, `end`) or a natural expression via `time_query`. - Context hints boost matching preferences (e.g., Python coding style) and guarantee at least one anchor memory when applicable; responses echo what was applied via `context_priority`. +- When the hosted service mode is `read_only` or `archived`, mutating endpoints return HTTP `423` and include `service_mode` / `service_tier` in the JSON error payload. +- Archived trial pods return `error: "trial_expired"` and include `reason` plus `upgrade_url` for upgrade prompts while still allowing read access. diff --git a/docs/ENVIRONMENT_VARIABLES.md b/docs/ENVIRONMENT_VARIABLES.md index d0ffffe..c6d30e6 100644 --- a/docs/ENVIRONMENT_VARIABLES.md +++ b/docs/ENVIRONMENT_VARIABLES.md @@ -47,6 +47,29 @@ curl -X POST \ https://automem.up.railway.app/enrichment/reprocess ``` +### Managed Service Profiles + +These settings are intended for hosted or reseller deployments where a separate control plane needs to understand pod capabilities and lifecycle state. + +| Variable | Description | Default | Example | +|----------|-------------|---------|---------| +| `AUTOMEM_SERVICE_TIER` | Hosted tier profile | `pro` | `trial`, `pro`, `ultimate`, `archived` | +| `AUTOMEM_SERVICE_MODE` | Mutability mode | tier-dependent | `active`, `read_only`, `archived` | +| `AUTOMEM_EMBEDDING_TIER` | Capability label exposed via `/service/profile` | profile default | `managed`, `premium`, `disabled` | +| `AUTOMEM_CONSOLIDATION_TIER` | Consolidation capability label exposed via `/service/profile` | profile default | `standard`, `full`, `disabled` | +| `AUTOMEM_QDRANT_EXPECTED` | Whether the hosted profile expects vector search | profile default | `true`, `false` | +| `AUTOMEM_WRITES_ENABLED` | Override write access independent of profile default | profile default | `false` | +| `AUTOMEM_ADMIN_MUTATIONS_ENABLED` | Override mutating admin capability flag | profile default | `false` | +| `AUTOMEM_ARCHIVE_EXPORT_ENABLED` | Allow `/admin/exports` while locked | `true` | `false` | +| `AUTOMEM_SELF_SERVICE_EXPORT_ENABLED` | Advertise user-facing export support in `/service/profile` | profile default | `true` | +| `AUTOMEM_UPGRADE_URL` | Upgrade URL returned for archived trial pods | `https://automem.ai/subscribe` | `https://automem.ai/subscribe` | + +Defaults: + +- `trial` and `pro`: managed embeddings, standard consolidation, writes enabled +- `ultimate`: premium embeddings, full consolidation, self-service export enabled +- `archived`: writes disabled by default, archive export still enabled for control-plane workflows + ### Embedding Providers AutoMem supports five embedding backends with automatic fallback. diff --git a/mcp-sse-server/server.js b/mcp-sse-server/server.js index 7f453f9..61de355 100644 --- a/mcp-sse-server/server.js +++ b/mcp-sse-server/server.js @@ -67,8 +67,8 @@ function isRetryableFetchError(error) { return name === 'AbortError' || name === 'TimeoutError' || error instanceof TypeError; } -class UpstreamRequestError extends Error { - constructor(message, { status, requestId, kind, retryable = false, endpoint, cause } = {}) { +export class UpstreamRequestError extends Error { + constructor(message, { status, requestId, kind, retryable = false, endpoint, cause, details } = {}) { super(message); this.name = 'UpstreamRequestError'; this.status = status; @@ -77,12 +77,24 @@ class UpstreamRequestError extends Error { this.retryable = retryable; this.endpoint = endpoint; this.cause = cause; + this.details = details; } } -function formatToolError(error, requestId) { +export function formatToolError(error, requestId) { const suffix = requestId ? ` (request_id: ${requestId})` : ''; if (error instanceof UpstreamRequestError) { + const serviceMode = error?.details?.service_mode; + const serviceLocked = error?.status === 423 || error?.details?.error === 'service_locked'; + if (serviceLocked) { + if (serviceMode === 'archived') { + return `AutoMem is archived. Recall may remain available, but memory writes are disabled.${suffix}`; + } + if (serviceMode === 'read_only') { + return `AutoMem is in read-only mode. Recall is still available, but memory writes are disabled.${suffix}`; + } + return `AutoMem is currently locked for write operations.${suffix}`; + } if (error.kind === 'timeout') { return `AutoMem request timed out. The service may be slow or restarting.${suffix}`; } @@ -152,6 +164,7 @@ async function fetchWithRetry(url, { method, headers, body, requestId, timeoutMs kind: 'http', retryable, endpoint: url, + details: data, }); } catch (error) { if (error instanceof UpstreamRequestError) { diff --git a/mcp-sse-server/test/server.test.js b/mcp-sse-server/test/server.test.js index 227c31e..a2131be 100644 --- a/mcp-sse-server/test/server.test.js +++ b/mcp-sse-server/test/server.test.js @@ -1,6 +1,12 @@ import test from "node:test"; import assert from "node:assert/strict"; -import { AutoMemClient, createApp, formatRecallAsItems } from "../server.js"; +import { + AutoMemClient, + UpstreamRequestError, + createApp, + formatRecallAsItems, + formatToolError, +} from "../server.js"; async function withServer(app, fn) { const server = await new Promise((resolve) => { @@ -156,6 +162,24 @@ test("AutoMemClient._request retries transient upstream errors", async () => { } }); +test("formatToolError explains archived and read-only lock states clearly", () => { + const archived = new UpstreamRequestError("Service is archived", { + status: 423, + kind: "http", + details: { error: "service_locked", service_mode: "archived" }, + }); + assert.match(formatToolError(archived, "req-archived"), /archived/i); + assert.match(formatToolError(archived, "req-archived"), /writes are disabled/i); + + const readOnly = new UpstreamRequestError("Service is read only", { + status: 423, + kind: "http", + details: { error: "service_locked", service_mode: "read_only" }, + }); + assert.match(formatToolError(readOnly, "req-readonly"), /read-only mode/i); + assert.match(formatToolError(readOnly, "req-readonly"), /recall is still available/i); +}); + // ============================================================================= // Streamable HTTP Transport Tests (MCP 2025-03-26) // ============================================================================= diff --git a/tests/contracts/test_routes_contract.py b/tests/contracts/test_routes_contract.py index 2bd15c4..2340b00 100644 --- a/tests/contracts/test_routes_contract.py +++ b/tests/contracts/test_routes_contract.py @@ -6,6 +6,7 @@ ("GET", "/health"), ("POST", "/memory"), ("POST", "/memory/batch"), + ("POST", "/api/v1/preseed"), ("GET", "/memory/"), ("PATCH", "/memory/"), ("DELETE", "/memory/"), @@ -17,6 +18,9 @@ ("GET", "/memories//related"), ("POST", "/admin/reembed"), ("POST", "/admin/sync"), + ("POST", "/admin/exports"), + ("GET", "/admin/exports/"), + ("GET", "/admin/exports//download"), ("POST", "/consolidate"), ("GET", "/consolidate/status"), ("GET", "/enrichment/status"), @@ -24,12 +28,14 @@ ("GET", "/graph/snapshot"), ("GET", "/graph/neighbors/"), ("GET", "/graph/stats"), + ("GET", "/api/v1/stats"), ("GET", "/graph/types"), ("GET", "/graph/relations"), ("GET", "/viewer/"), ("GET", "/viewer/"), ("GET", "/stream"), ("GET", "/stream/status"), + ("GET", "/service/profile"), } diff --git a/tests/support/fake_graph.py b/tests/support/fake_graph.py index b4cb8a2..8df3f02 100644 --- a/tests/support/fake_graph.py +++ b/tests/support/fake_graph.py @@ -195,6 +195,37 @@ def query(self, query: str, params: Dict[str, Any] | None = None, **kwargs: Any) } return FakeResult([[FakeNode(self.memories[memory_id])]]) + if ( + "MATCH (n)" in query + and "labels(n) AS labels" in query + and "properties(n) AS props" in query + ): + rows = [] + for memory in self.memories.values(): + rows.append([["Memory"], dict(memory)]) + return FakeResult(rows) + + if ( + "MATCH (a)-[r]->(b)" in query + and "properties(a).id AS source_id" in query + and "properties(r) AS props" in query + ): + rows = [] + for rel in self.relationships: + rows.append( + [ + rel.get("id1"), + rel.get("type"), + rel.get("id2"), + { + key: value + for key, value in rel.items() + if key not in {"id1", "id2", "type"} + }, + ] + ) + return FakeResult(rows) + # Memory update if "MATCH (m:Memory {id:" in query and "SET m.content" in query: memory_id = str(params["id"]) @@ -310,6 +341,41 @@ def query(self, query: str, params: Dict[str, Any] | None = None, **kwargs: Any) ) return FakeResult(rows) + if "MATCH (m:Memory)" in query and "RETURN count(m) as count" in query: + return FakeResult([[len(self.memories)]]) + + if "MATCH ()-[r]->() RETURN count(r)" in query: + return FakeResult([[len(self.relationships)]]) + + if "UNWIND coalesce(m.tags, []) AS tag" in query and "RETURN toLower(tag) AS tag" in query: + tag_counts: Dict[str, int] = {} + for memory in self.memories.values(): + for tag in memory.get("tags") or []: + normalized = str(tag).strip().lower() + if not normalized: + continue + tag_counts[normalized] = tag_counts.get(normalized, 0) + 1 + rows = [[tag, count] for tag, count in tag_counts.items()] + rows.sort(key=lambda row: (-int(row[1]), str(row[0]))) + return FakeResult(rows[:10]) + + if ( + "RETURN max(coalesce(m.updated_at, m.last_accessed, m.timestamp)) AS last_activity" + in query + ): + values = [] + for memory in self.memories.values(): + values.append( + str( + memory.get("updated_at") + or memory.get("last_accessed") + or memory.get("timestamp") + or "" + ) + ) + values = [value for value in values if value] + return FakeResult([[max(values) if values else None]]) + # Startup recall query patterns if "WHERE 'critical' IN m.tags OR 'lesson' IN m.tags OR 'ai-assistant' IN m.tags" in query: rows = [] diff --git a/tests/test_preseed_and_stats.py b/tests/test_preseed_and_stats.py new file mode 100644 index 0000000..3d60eaf --- /dev/null +++ b/tests/test_preseed_and_stats.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +import pytest + +import app +from tests.support.fake_graph import FakeGraph + + +def _make_state(): + state = app.ServiceState() + state.memory_graph = FakeGraph() + state.qdrant = None + return state + + +def _auth_headers() -> dict[str, str]: + return {"Authorization": "Bearer test-token"} + + +def _preseed_headers() -> dict[str, str]: + return { + "Authorization": "Bearer test-token", + "X-AutoMem-Internal": "preseed", + } + + +@pytest.fixture +def client(): + app.app.config["TESTING"] = True + return app.app.test_client() + + +def test_preseed_creates_memories_and_associations(client, monkeypatch): + state = _make_state() + monkeypatch.setattr(app, "state", state) + monkeypatch.setattr(app, "API_TOKEN", "test-token") + + response = client.post( + "/api/v1/preseed", + json={ + "memories": [ + { + "content": "User is a senior engineer at Acme", + "importance": 0.9, + "tags": ["role"], + "type": "fact", + }, + { + "content": "User prefers concise communication", + "importance": 0.85, + "tags": ["preferences"], + "type": "preference", + }, + ], + "associations": [ + { + "from_index": 0, + "to_index": 1, + "relationship": "RELATES_TO", + "strength": 0.7, + } + ], + }, + headers=_preseed_headers(), + ) + + assert response.status_code == 201 + body = response.get_json() + assert body["status"] == "success" + assert body["memories_created"] == 2 + assert body["associations_created"] == 1 + assert len(body["memory_ids"]) == 2 + + stored = list(state.memory_graph.memories.values()) + assert len(stored) == 2 + assert all("onboarding" in (memory.get("tags") or []) for memory in stored) + + relationship = state.memory_graph.relationships[0] + assert relationship["type"] == "RELATES_TO" + assert relationship["id1"] == body["memory_ids"][0] + assert relationship["id2"] == body["memory_ids"][1] + + +def test_preseed_requires_internal_header(client, monkeypatch): + state = _make_state() + monkeypatch.setattr(app, "state", state) + monkeypatch.setattr(app, "API_TOKEN", "test-token") + + response = client.post( + "/api/v1/preseed", + json={"memories": [{"content": "Test", "type": "fact"}]}, + headers=_auth_headers(), + ) + + assert response.status_code == 403 + assert "Internal preseed authorization required" in response.get_json()["message"] + + +def test_api_v1_stats_returns_dashboard_shape(client, monkeypatch): + state = _make_state() + state.memory_graph.memories["11111111-1111-1111-1111-111111111111"] = { + "id": "11111111-1111-1111-1111-111111111111", + "content": "Onboarding role memory", + "tags": ["onboarding", "role"], + "importance": 0.9, + "type": "fact", + "timestamp": "2026-04-01T10:00:00Z", + "updated_at": "2026-04-02T10:00:00Z", + } + state.memory_graph.memories["22222222-2222-2222-2222-222222222222"] = { + "id": "22222222-2222-2222-2222-222222222222", + "content": "Preference memory", + "tags": ["onboarding", "preferences"], + "importance": 0.8, + "type": "preference", + "timestamp": "2026-04-03T10:00:00Z", + "last_accessed": "2026-04-03T12:00:00Z", + } + state.memory_graph.relationships.append( + { + "id1": "11111111-1111-1111-1111-111111111111", + "id2": "22222222-2222-2222-2222-222222222222", + "type": "RELATES_TO", + "strength": 0.7, + } + ) + monkeypatch.setattr(app, "state", state) + monkeypatch.setattr(app, "API_TOKEN", "test-token") + + response = client.get("/api/v1/stats", headers=_auth_headers()) + + assert response.status_code == 200 + body = response.get_json() + assert body["status"] == "success" + assert body["memories_stored"] == 2 + assert body["associations"] == 1 + assert body["memory_types"] == {"fact": 1, "preference": 1} + assert body["top_tags"][0] == {"tag": "onboarding", "count": 2} + assert body["last_activity"] == "2026-04-03T12:00:00Z" + assert body["graph_density"] == 0.5 diff --git a/tests/test_service_profiles.py b/tests/test_service_profiles.py new file mode 100644 index 0000000..01e3363 --- /dev/null +++ b/tests/test_service_profiles.py @@ -0,0 +1,231 @@ +from __future__ import annotations + +import gzip +import json +from types import SimpleNamespace + +import pytest + +import app +from automem.config import resolve_runtime_profile +from tests.support.fake_graph import FakeGraph + + +class DummyEmbeddingProvider: + def provider_name(self) -> str: + return "openai" + + +class ExportQdrantClient: + def __init__(self) -> None: + self.points: dict[str, dict[str, object]] = {} + + def upsert(self, collection_name, points): + del collection_name + for point in points: + self.points[str(point.id)] = { + "payload": point.payload, + "vector": point.vector, + } + + def scroll(self, collection_name, limit=100, offset=None, with_payload=True, with_vectors=True): + del collection_name, limit, offset, with_payload, with_vectors + rows = [] + for point_id, point in self.points.items(): + rows.append( + SimpleNamespace( + id=point_id, + payload=point["payload"], + vector=point["vector"], + ) + ) + return rows, None + + +@pytest.fixture(autouse=True) +def reset_state(monkeypatch): + state = app.ServiceState() + state.memory_graph = FakeGraph() + state.qdrant = ExportQdrantClient() + state.embedding_provider = DummyEmbeddingProvider() + monkeypatch.setattr(app, "state", state) + monkeypatch.setattr(app, "init_falkordb", lambda: None) + monkeypatch.setattr(app, "init_qdrant", lambda: None) + monkeypatch.setattr(app, "API_TOKEN", "test-token") + monkeypatch.setattr(app, "ADMIN_TOKEN", "test-admin-token") + return state + + +@pytest.fixture +def client(): + app.app.config["TESTING"] = True + return app.app.test_client() + + +@pytest.fixture +def auth_headers(): + return {"Authorization": "Bearer test-token"} + + +@pytest.fixture +def admin_headers(): + return {"Authorization": "Bearer test-token", "X-Admin-Token": "test-admin-token"} + + +def test_resolve_runtime_profile_defaults_and_overrides() -> None: + trial = resolve_runtime_profile({"AUTOMEM_SERVICE_TIER": "trial"}) + assert trial.tier == "trial" + assert trial.mode == "active" + assert trial.writes_enabled is True + assert trial.self_service_export_enabled is False + + archived = resolve_runtime_profile( + { + "AUTOMEM_SERVICE_TIER": "ultimate", + "AUTOMEM_SERVICE_MODE": "archived", + "AUTOMEM_SELF_SERVICE_EXPORT_ENABLED": "false", + } + ) + assert archived.mode == "archived" + assert archived.writes_enabled is False + assert archived.admin_mutations_enabled is False + assert archived.self_service_export_enabled is False + + +def test_service_profile_endpoint_reports_runtime_details( + client, auth_headers, reset_state +) -> None: + reset_state.service_tier = "trial" + reset_state.service_mode = "active" + reset_state.service_profile = resolve_runtime_profile( + {"AUTOMEM_SERVICE_TIER": "trial"} + ).to_dict() + + response = client.get("/service/profile", headers=auth_headers) + + assert response.status_code == 200 + body = response.get_json() + assert body["status"] == "success" + assert body["service"]["tier"] == "trial" + assert body["service"]["mode"] == "active" + assert body["service"]["profile"]["capabilities"]["writes_enabled"] is True + assert body["service"]["vector_store"]["available"] is True + assert body["service"]["embedding"]["provider"] == "openai" + + +def test_read_only_mode_blocks_writes_but_allows_recall( + client, auth_headers, admin_headers, reset_state +) -> None: + reset_state.memory_graph.memories["11111111-1111-1111-1111-111111111111"] = { + "id": "11111111-1111-1111-1111-111111111111", + "content": "Read only memories can still be recalled", + "tags": ["locked"], + "importance": 0.7, + "type": "Context", + "timestamp": "2026-04-03T00:00:00Z", + "metadata": "{}", + "confidence": 0.8, + } + reset_state.service_mode = "read_only" + reset_state.service_profile = resolve_runtime_profile( + {"AUTOMEM_SERVICE_TIER": "pro", "AUTOMEM_SERVICE_MODE": "read_only"} + ).to_dict() + reset_state.qdrant = None + + write_response = client.post( + "/memory", + json={"content": "blocked"}, + headers=auth_headers, + ) + assert write_response.status_code == 423 + write_body = write_response.get_json() + assert write_body["error"] == "service_locked" + assert write_body["service_mode"] == "read_only" + + recall_response = client.get("/recall?query=recalled", headers=auth_headers) + assert recall_response.status_code == 200 + assert recall_response.get_json()["status"] == "success" + + admin_response = client.post("/admin/reembed", json={}, headers=admin_headers) + assert admin_response.status_code == 423 + + +def test_archived_mode_allows_admin_export_flow( + client, admin_headers, reset_state, monkeypatch, tmp_path +) -> None: + export_id = "trial-expiry-export" + reset_state.service_tier = "trial" + reset_state.service_mode = "archived" + reset_state.service_profile = resolve_runtime_profile( + {"AUTOMEM_SERVICE_TIER": "trial", "AUTOMEM_SERVICE_MODE": "archived"} + ).to_dict() + reset_state.memory_graph.memories["aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"] = { + "id": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", + "content": "Trial memory to archive", + "tags": ["trial"], + "importance": 0.9, + "type": "Context", + "timestamp": "2026-04-03T00:00:00Z", + "metadata": "{}", + "confidence": 0.9, + } + reset_state.memory_graph.relationships.append( + { + "id1": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", + "id2": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", + "type": "RELATES_TO", + "strength": 1.0, + } + ) + reset_state.qdrant.upsert( + "memories", + [ + SimpleNamespace( + id="aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", + vector=[0.1, 0.2], + payload={"content": "Trial memory to archive"}, + ) + ], + ) + monkeypatch.setenv("AUTOMEM_EXPORT_DIR", str(tmp_path)) + + create_response = client.post( + "/admin/exports", + json={"export_id": export_id, "reason": "trial_expired"}, + headers=admin_headers, + ) + assert create_response.status_code == 201 + created = create_response.get_json() + assert created["status"] == "complete" + assert created["export_id"] == export_id + assert created["service"]["mode"] == "archived" + assert created["graph"]["node_count"] == 1 + assert created["include_vectors"] is True + + status_response = client.get(f"/admin/exports/{export_id}", headers=admin_headers) + assert status_response.status_code == 200 + assert status_response.get_json()["export_id"] == export_id + + download_response = client.get(f"/admin/exports/{export_id}/download", headers=admin_headers) + assert download_response.status_code == 200 + bundle = json.loads(gzip.decompress(download_response.data).decode("utf-8")) + assert bundle["reason"] == "trial_expired" + assert bundle["service"]["mode"] == "archived" + assert bundle["graph"]["stats"]["node_count"] == 1 + + +def test_archived_trial_mode_returns_upgrade_payload(client, auth_headers, reset_state) -> None: + reset_state.service_tier = "trial" + reset_state.service_mode = "archived" + reset_state.service_profile = resolve_runtime_profile( + {"AUTOMEM_SERVICE_TIER": "trial", "AUTOMEM_SERVICE_MODE": "archived"} + ).to_dict() + + response = client.post("/memory", json={"content": "blocked"}, headers=auth_headers) + + assert response.status_code == 423 + body = response.get_json() + assert body["error"] == "trial_expired" + assert body["reason"] == "trial_expired" + assert body["upgrade_url"] == "https://automem.ai/subscribe" + assert "30-day trial" in body["message"]