diff --git a/.gitignore b/.gitignore index c2ee7c98..32523fe3 100644 --- a/.gitignore +++ b/.gitignore @@ -43,3 +43,4 @@ node_modules/ /benchmarks/results/ benchmarks/baselines/locomo_baseline.json data/bm25_index.db +.venv-test/ diff --git a/.railway/falkordb.Dockerfile b/.railway/falkordb.Dockerfile index 751bca60..9c95c1e5 100644 --- a/.railway/falkordb.Dockerfile +++ b/.railway/falkordb.Dockerfile @@ -12,7 +12,7 @@ ENV REDIS_ARGS="--save 900 1 --save 300 10 --save 60 10000 --appendonly yes --di EXPOSE 6379 # Health check -HEALTHCHECK --interval=30s --timeout=3s --start-period=30s --retries=3 \ +HEALTHCHECK --interval=30s --timeout=3s --start-period=30s --retries=10 \ CMD redis-cli ping || exit 1 # Volume for persistent data diff --git a/app.py b/app.py index 0e14a505..9bfbe36b 100644 --- a/app.py +++ b/app.py @@ -144,6 +144,8 @@ def _parse_viewer_allowed_origins() -> Any: CONSOLIDATION_TASK_FIELDS, CONSOLIDATION_TICK_SECONDS, DEFAULT_EXPAND_RELATIONS, + DOCUMENT_MAX_BYTES, + DOCUMENT_PRESIGNED_EXPIRES, EMBEDDING_MODEL, ENRICHMENT_ENABLE_SUMMARIES, ENRICHMENT_FAILURE_BACKOFF_SECONDS, @@ -182,6 +184,7 @@ def _parse_viewer_allowed_origins() -> Any: configure_recall_helpers, ) from automem.search.runtime_relations import fetch_relations as _fetch_relations_runtime +from automem.stores.bucket_store import build_bucket_store_from_config from automem.stores.graph_store import _build_graph_tag_predicate from automem.stores.vector_store import _build_qdrant_tag_filter from automem.sync.runtime_bindings import create_sync_runtime @@ -224,6 +227,17 @@ def _parse_viewer_allowed_origins() -> Any: "default_expand": DEFAULT_EXPAND_RELATIONS, } +# Optional S3-compatible bucket store for document originals. Returns None +# unless all S3_* env vars are configured, in which case /documents endpoints +# return 503 with a clear message. +try: + bucket_store = build_bucket_store_from_config() + if bucket_store is not None: + logger.info("Bucket store initialized for document originals") +except Exception: + logger.exception("Bucket store init failed; /documents endpoints disabled") + bucket_store = None + # Search weights are imported from automem.config # Maximum number of results returned by /recall diff --git a/automem/api/documents.py b/automem/api/documents.py new file mode 100644 index 00000000..463027f4 --- /dev/null +++ b/automem/api/documents.py @@ -0,0 +1,574 @@ +"""Flask blueprint for agent-driven document storage. + +Design (Stage 1 — "lean"): +- The file stays in the bucket as an opaque original. +- AutoMem never parses it — no pdfplumber, no trafilatura, no OCR. +- The uploading agent is expected to have read the file already and provided + a human-quality ``title`` + ``summary``; those become the Memory's content + so existing vector/keyword recall surfaces the doc. +- When an agent later wants the raw bytes, it calls ``GET + /documents/:id/download`` to get a short-lived presigned URL and fetches + the file itself (Claude can read PDFs directly from such URLs). + +The "gate" requested by the user is the hard 422 on missing title/summary at +the HTTP boundary — the MCP tool definitions enforce it client-side too via +required parameters. +""" + +from __future__ import annotations + +import json +import logging +import re +import time +import uuid +from typing import Any, Callable, Dict, List, Optional + +from flask import Blueprint, abort, current_app, jsonify, request +from flask.typing import ResponseReturnValue + +logger = logging.getLogger(__name__) + + +# Filename sanitization: keep alphanumerics, dot, dash, underscore. Everything +# else becomes "_" so we never build S3 keys with control characters or path +# traversal surprises. The memory_id prefix keeps keys globally unique. +_SAFE_FILENAME_RE = re.compile(r"[^A-Za-z0-9._-]+") + + +def _safe_filename(name: str, *, fallback: str = "file") -> str: + cleaned = _SAFE_FILENAME_RE.sub("_", name or "").strip("._-") + return cleaned or fallback + + +def _parse_json_field(value: Optional[str], field_name: str) -> Any: + """Parse an optional JSON-encoded form field. Aborts 400 if malformed.""" + if value is None or value == "": + return None + try: + return json.loads(value) + except (TypeError, ValueError) as exc: + abort(400, description=f"'{field_name}' must be valid JSON ({exc})") + + +def create_documents_blueprint( + *, + bucket_store: Any, + get_memory_graph_fn: Callable[[], Any], + get_qdrant_client_fn: Callable[[], Any], + normalize_tags_fn: Callable[[Any], List[str]], + compute_tag_prefixes_fn: Callable[[List[str]], List[str]], + coerce_importance_fn: Callable[[Any], float], + enqueue_enrichment_fn: Callable[[str], None], + enqueue_embedding_fn: Callable[[str, str], None], + collection_name: str, + utc_now_fn: Callable[[], str], + state: Any, + qdrant_models_obj: Any, + max_bytes: int, + presigned_expires: int, +) -> Blueprint: + """Create the ``/documents`` blueprint. ``bucket_store`` may be None; + when None, all endpoints return 503 with a clear message.""" + + bp = Blueprint("documents", __name__) + + def _require_bucket() -> Any: + if bucket_store is None: + abort( + 503, + description=( + "Document storage is not configured. Set S3_ENDPOINT, " + "S3_BUCKET, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY " + "(Railway Buckets provide all four)." + ), + ) + return bucket_store + + def _validate_memory_id(memory_id: str) -> None: + try: + uuid.UUID(memory_id) + except (ValueError, TypeError): + abort(400, description="memory_id must be a valid UUID") + + def _find_document(memory_id: str) -> Dict[str, Any]: + """Return the Memory node + parsed metadata, or 404.""" + graph = get_memory_graph_fn() + if graph is None: + abort(503, description="FalkorDB is unavailable") + + result = graph.query("MATCH (m:Memory {id: $id}) RETURN m", {"id": memory_id}) + if not getattr(result, "result_set", None): + abort(404, description="Memory not found") + + node = result.result_set[0][0] + props = dict(getattr(node, "properties", {}) or {}) + mem_type = props.get("type") + if mem_type != "Document": + abort( + 404, + description=( + "Memory exists but is not of type=Document " + f"(got {mem_type!r}); use /memory/:id for other types." + ), + ) + raw_meta = props.get("metadata") + if isinstance(raw_meta, str): + try: + meta = json.loads(raw_meta) if raw_meta else {} + except (TypeError, ValueError): + meta = {} + elif isinstance(raw_meta, dict): + meta = raw_meta + else: + meta = {} + props["metadata"] = meta + return props + + # ------------------------------------------------------------------ POST + @bp.route("/documents", methods=["POST"]) + def upload_document() -> ResponseReturnValue: # noqa: WPS430 - Flask view + """Multipart upload: stores file in bucket + creates Document Memory. + + Required form fields: + file: the binary file payload + title: agent-generated short title + summary: agent-generated 1–3 sentence summary + + Optional: + tags: JSON array of strings OR comma-separated string + importance: float 0-1 (default 0.5) + metadata: JSON object merged into Memory metadata + memory_id: UUID to use (otherwise server-generated) + """ + store = _require_bucket() + query_start = time.perf_counter() + + # --- multipart validation -------------------------------------------- + uploaded = request.files.get("file") + if uploaded is None or not uploaded.filename: + abort( + 400, + description=( + "Missing 'file' in multipart form data. Upload with " + "'Content-Type: multipart/form-data' and a 'file' field." + ), + ) + + title = (request.form.get("title") or "").strip() + summary = (request.form.get("summary") or "").strip() + if not title or not summary: + # THE GATE: force agents to read the file and describe it before + # upload. We never extract text server-side, so the agent's + # description IS the searchable content. + abort( + 422, + description=( + "Both 'title' and 'summary' are required. The agent must " + "read the file and generate an accurate title (< 200 chars)" + " and a 1-3 sentence summary BEFORE calling this endpoint." + " AutoMem does not parse file content; the agent's summary" + " is the indexed, searchable representation." + ), + ) + if len(title) > 300: + abort(400, description="'title' must be 300 characters or fewer") + # The summary becomes Memory content, which has its own hard limit; + # but we don't want to accept a 20KB "summary" either. + if len(summary) > 4000: + abort(400, description="'summary' must be 4000 characters or fewer") + + # --- memory_id -------------------------------------------------------- + raw_memory_id = (request.form.get("memory_id") or "").strip() + memory_id = raw_memory_id or str(uuid.uuid4()) + try: + uuid.UUID(memory_id) + except (ValueError, TypeError): + abort(400, description="'memory_id' must be a valid UUID") + + # --- tags / importance / metadata ------------------------------------ + tags_raw: Any = request.form.get("tags") + parsed_tags = _parse_json_field(tags_raw, "tags") + if parsed_tags is None and tags_raw: + # Accept comma-separated shorthand as well + parsed_tags = [t.strip() for t in tags_raw.split(",") if t.strip()] + tags = normalize_tags_fn(parsed_tags) + # Always add the "document" tag so list/filter queries are trivial. + if "document" not in {t.lower() for t in tags}: + tags.append("document") + tags_lower = [t.strip().lower() for t in tags if isinstance(t, str) and t.strip()] + tag_prefixes = compute_tag_prefixes_fn(tags_lower) + + importance = coerce_importance_fn(request.form.get("importance")) + + user_metadata = _parse_json_field(request.form.get("metadata"), "metadata") + if user_metadata is None: + user_metadata = {} + elif not isinstance(user_metadata, dict): + abort(400, description="'metadata' must be a JSON object") + + # --- size guard (cheap pre-check via Content-Length) ------------------ + content_length = request.content_length or 0 + if content_length and content_length > max_bytes: + abort( + 413, + description=( + f"Upload exceeds DOCUMENT_MAX_BYTES={max_bytes}; got " + f"{content_length} bytes." + ), + ) + + # --- bucket upload ---------------------------------------------------- + filename = _safe_filename(uploaded.filename) + mime = uploaded.mimetype or "application/octet-stream" + bucket_key = f"documents/{memory_id}/{filename}" + + try: + upload_info = store.upload( + bucket_key, + uploaded.stream, + mime=mime, + metadata={ + "memory_id": memory_id, + "original_filename": uploaded.filename[:200], + }, + ) + except Exception: + logger.exception( + "Bucket upload failed for memory_id=%s key=%s", + memory_id, + bucket_key, + ) + abort(502, description="Upload to object store failed") + + # Size-after-the-fact check (Content-Length may have been absent on + # chunked uploads). If over the cap, roll the object back. + if upload_info["size"] > max_bytes: + try: + store.delete(bucket_key) + except Exception: + logger.exception("Failed to clean up oversized upload at %s", bucket_key) + abort( + 413, + description=( + f"Uploaded file {upload_info['size']} bytes exceeds " + f"DOCUMENT_MAX_BYTES={max_bytes}." + ), + ) + + # --- assemble Memory node -------------------------------------------- + content = f"{title}\n\n{summary}".strip() + created_at = utc_now_fn() + + storage_metadata: Dict[str, Any] = { + "document": { + "title": title, + "filename": uploaded.filename, + "safe_filename": filename, + "mime": upload_info["content_type"], + "size": upload_info["size"], + "sha256": upload_info["sha256"], + "etag": upload_info["etag"], + "bucket_key": bucket_key, + "source": "upload", + "uploaded_at": created_at, + } + } + # User metadata wins over storage_metadata only if they collide + # outside the "document" key. We keep storage_metadata.document as + # authoritative to avoid clients corrupting our bookkeeping. + merged_metadata: Dict[str, Any] = {**user_metadata, **storage_metadata} + metadata_json = json.dumps(merged_metadata, default=str) + + # --- persist to FalkorDB --------------------------------------------- + graph = get_memory_graph_fn() + if graph is None: + # We already uploaded to the bucket; roll back. + try: + store.delete(bucket_key) + except Exception: + logger.exception("Cleanup failed after FalkorDB unavailable") + abort(503, description="FalkorDB is unavailable") + + try: + graph.query( + """ + MERGE (m:Memory {id: $id}) + SET m.content = $content, + m.timestamp = $timestamp, + m.importance = $importance, + m.tags = $tags, + m.tag_prefixes = $tag_prefixes, + m.type = $type, + m.confidence = $confidence, + m.updated_at = $updated_at, + m.last_accessed = $last_accessed, + m.metadata = $metadata, + m.processed = false + RETURN m + """, + { + "id": memory_id, + "content": content, + "timestamp": created_at, + "importance": importance, + "tags": tags, + "tag_prefixes": tag_prefixes, + "type": "Document", + "confidence": 1.0, # Agent-provided classification + "updated_at": created_at, + "last_accessed": created_at, + "metadata": metadata_json, + }, + ) + except Exception: + logger.exception( + "Failed to persist Document memory in FalkorDB; rolling back " + "bucket upload for memory_id=%s", + memory_id, + ) + try: + store.delete(bucket_key) + except Exception: + logger.exception("Bucket rollback failed for %s", bucket_key) + abort(500, description="Failed to store document memory") + + # --- queue embedding + enrichment ------------------------------------ + qdrant_client = get_qdrant_client_fn() + if qdrant_client is not None: + enqueue_embedding_fn(memory_id, content) + embedding_status = "queued" + else: + embedding_status = "unconfigured" + + try: + enqueue_enrichment_fn(memory_id) + enrichment_status = "queued" if state.enrichment_queue else "disabled" + except Exception: + logger.exception("Failed to enqueue enrichment for %s", memory_id) + enrichment_status = "failed" + + # --- presigned URL for immediate agent use --------------------------- + try: + download_url = store.presigned_url(bucket_key, expires_in=presigned_expires) + except Exception: + logger.exception("Failed to generate presigned URL for %s", bucket_key) + download_url = None + + response = { + "status": "success", + "memory_id": memory_id, + "type": "Document", + "title": title, + "summary": summary, + "tags": tags, + "importance": importance, + "document": merged_metadata["document"], + "download_url": download_url, + "download_url_expires_in": presigned_expires if download_url else None, + "embedding_status": embedding_status, + "enrichment": enrichment_status, + "stored_at": created_at, + "query_time_ms": round((time.perf_counter() - query_start) * 1000, 2), + } + + logger.info( + "document_uploaded", + extra={ + "memory_id": memory_id, + "bucket_key": bucket_key, + "size": upload_info["size"], + "mime": upload_info["content_type"], + "latency_ms": response["query_time_ms"], + }, + ) + return jsonify(response), 201 + + # --------------------------------------------------------- GET download + @bp.route("/documents//download", methods=["GET"]) + def document_download(memory_id: str) -> ResponseReturnValue: + store = _require_bucket() + _validate_memory_id(memory_id) + doc = _find_document(memory_id) + + doc_meta = doc.get("metadata", {}).get("document", {}) or {} + bucket_key = doc_meta.get("bucket_key") + if not bucket_key: + abort( + 500, + description=( + "Document memory is missing document.bucket_key in " + "metadata; cannot generate a download URL." + ), + ) + + try: + expires_in = int(request.args.get("expires_in", presigned_expires)) + except (TypeError, ValueError): + abort(400, description="'expires_in' must be an integer") + expires_in = max(30, min(expires_in, 3600)) # 30s .. 1h + + disposition = request.args.get("disposition") + filename = doc_meta.get("filename") or doc_meta.get("safe_filename") + rcd = None + if disposition == "attachment" and filename: + # Quote filename per RFC 6266; boto3 will pass this through verbatim. + rcd = f'attachment; filename="{_safe_filename(filename)}"' + + try: + url = store.presigned_url( + bucket_key, + expires_in=expires_in, + response_content_disposition=rcd, + ) + except Exception: + logger.exception("Presign failed for %s", bucket_key) + abort(502, description="Failed to generate presigned URL") + + return jsonify( + { + "status": "success", + "memory_id": memory_id, + "download_url": url, + "expires_in": expires_in, + "filename": filename, + "mime": doc_meta.get("mime"), + "size": doc_meta.get("size"), + } + ) + + # -------------------------------------------------------------- GET list + @bp.route("/documents", methods=["GET"]) + def list_documents() -> ResponseReturnValue: + """Paged list of Document memories, filter by tag (any/all).""" + _require_bucket() # returns 503 with a clear message if unconfigured + graph = get_memory_graph_fn() + if graph is None: + abort(503, description="FalkorDB is unavailable") + + raw_tags = request.args.getlist("tags") or request.args.get("tags") + tag_filter: List[str] = [] + if isinstance(raw_tags, list): + tag_filter = [t.strip().lower() for t in raw_tags if t and t.strip()] + elif isinstance(raw_tags, str): + tag_filter = [t.strip().lower() for t in raw_tags.split(",") if t.strip()] + + try: + limit = int(request.args.get("limit", 25)) + except (TypeError, ValueError): + limit = 25 + limit = max(1, min(limit, 200)) + + if tag_filter: + query = """ + MATCH (m:Memory {type: 'Document'}) + WHERE ANY(tag IN coalesce(m.tags, []) WHERE toLower(tag) IN $tags) + RETURN m + ORDER BY m.importance DESC, m.timestamp DESC + LIMIT $limit + """ + params: Dict[str, Any] = {"tags": tag_filter, "limit": limit} + else: + query = """ + MATCH (m:Memory {type: 'Document'}) + RETURN m + ORDER BY m.importance DESC, m.timestamp DESC + LIMIT $limit + """ + params = {"limit": limit} + + try: + result = graph.query(query, params) + except Exception: + logger.exception("Document list query failed") + abort(500, description="Failed to list documents") + + docs: List[Dict[str, Any]] = [] + for row in getattr(result, "result_set", []) or []: + node = row[0] + props = dict(getattr(node, "properties", {}) or {}) + raw_meta = props.get("metadata") + if isinstance(raw_meta, str): + try: + meta = json.loads(raw_meta) if raw_meta else {} + except (TypeError, ValueError): + meta = {} + elif isinstance(raw_meta, dict): + meta = raw_meta + else: + meta = {} + doc_meta = meta.get("document") if isinstance(meta, dict) else None + docs.append( + { + "memory_id": props.get("id"), + "title": (doc_meta or {}).get("title"), + "content": props.get("content"), + "tags": props.get("tags") or [], + "importance": props.get("importance"), + "timestamp": props.get("timestamp"), + "updated_at": props.get("updated_at"), + "document": doc_meta, + } + ) + + return jsonify( + { + "status": "success", + "tags": tag_filter, + "count": len(docs), + "documents": docs, + } + ) + + # ---------------------------------------------------------------- DELETE + @bp.route("/documents/", methods=["DELETE"]) + def delete_document(memory_id: str) -> ResponseReturnValue: + store = _require_bucket() + _validate_memory_id(memory_id) + doc = _find_document(memory_id) + + bucket_key = doc.get("metadata", {}).get("document", {}).get("bucket_key") + + graph = get_memory_graph_fn() + if graph is None: + abort(503, description="FalkorDB is unavailable") + + # Drop the graph node first (so subsequent recall can't hand out stale + # presigned URLs); then the bucket object. + try: + graph.query("MATCH (m:Memory {id: $id}) DETACH DELETE m", {"id": memory_id}) + except Exception: + logger.exception("Graph delete failed for %s", memory_id) + abort(500, description="Failed to delete document memory") + + # Qdrant vector cleanup + qdrant_client = get_qdrant_client_fn() + if qdrant_client is not None: + try: + if qdrant_models_obj is not None: + selector = qdrant_models_obj.PointIdsList(points=[memory_id]) + else: + selector = {"points": [memory_id]} + qdrant_client.delete(collection_name=collection_name, points_selector=selector) + except Exception: + logger.exception("Qdrant vector delete failed for %s", memory_id) + + bucket_result = "skipped" + if bucket_key: + try: + store.delete(bucket_key) + bucket_result = "deleted" + except Exception: + logger.exception("Bucket delete failed for %s", bucket_key) + bucket_result = "failed" + + return jsonify( + { + "status": "success", + "memory_id": memory_id, + "graph": "deleted", + "bucket": bucket_result, + } + ) + + return bp diff --git a/automem/api/runtime_bootstrap.py b/automem/api/runtime_bootstrap.py index d99495c1..c26e4615 100644 --- a/automem/api/runtime_bootstrap.py +++ b/automem/api/runtime_bootstrap.py @@ -4,6 +4,7 @@ from automem.api.admin import create_admin_blueprint_full from automem.api.consolidation import create_consolidation_blueprint_full +from automem.api.documents import create_documents_blueprint from automem.api.enrichment import create_enrichment_blueprint from automem.api.graph import create_graph_blueprint from automem.api.health import create_health_blueprint @@ -66,6 +67,10 @@ def register_blueprints( consolidation_tick_seconds: int, consolidation_history_limit: int, require_api_token_fn: Callable[[], None], + bucket_store: Any = None, + qdrant_models_obj: Any = None, + document_max_bytes: int = 100 * 1024 * 1024, + document_presigned_expires: int = 300, ) -> None: health_bp = create_health_blueprint( get_memory_graph_fn, @@ -172,6 +177,23 @@ def register_blueprints( require_api_token=require_api_token_fn, ) + documents_bp = create_documents_blueprint( + bucket_store=bucket_store, + get_memory_graph_fn=get_memory_graph_fn, + get_qdrant_client_fn=get_qdrant_client_fn, + normalize_tags_fn=normalize_tags_fn, + compute_tag_prefixes_fn=compute_tag_prefixes_fn, + coerce_importance_fn=coerce_importance_fn, + enqueue_enrichment_fn=enqueue_enrichment_fn, + enqueue_embedding_fn=enqueue_embedding_fn, + collection_name=collection_name, + utc_now_fn=utc_now_fn, + state=state, + qdrant_models_obj=qdrant_models_obj, + max_bytes=document_max_bytes, + presigned_expires=document_presigned_expires, + ) + app.register_blueprint(health_bp) app.register_blueprint(enrichment_bp) app.register_blueprint(memory_bp) @@ -180,6 +202,11 @@ def register_blueprints( app.register_blueprint(consolidation_bp) app.register_blueprint(graph_bp) app.register_blueprint(stream_bp) + app.register_blueprint(documents_bp) + logger.info( + "Documents blueprint registered (bucket=%s)", + "configured" if bucket_store is not None else "UNCONFIGURED", + ) if is_viewer_enabled(): viewer_bp = create_viewer_blueprint() diff --git a/automem/config.py b/automem/config.py index cbeb5089..9bd23559 100644 --- a/automem/config.py +++ b/automem/config.py @@ -78,7 +78,11 @@ # Sync configuration (background sync worker) SYNC_CHECK_INTERVAL_SECONDS = int(os.getenv("SYNC_CHECK_INTERVAL_SECONDS", "3600")) # 1 hour -SYNC_AUTO_REPAIR = os.getenv("SYNC_AUTO_REPAIR", "true").lower() not in {"0", "false", "no"} +SYNC_AUTO_REPAIR = os.getenv("SYNC_AUTO_REPAIR", "true").lower() not in { + "0", + "false", + "no", +} # Enrichment configuration ENRICHMENT_MAX_ATTEMPTS = int(os.getenv("ENRICHMENT_MAX_ATTEMPTS", "3")) @@ -115,7 +119,11 @@ RECALL_RELATION_LIMIT = int(os.getenv("RECALL_RELATION_LIMIT", "5")) RECALL_EXPANSION_LIMIT = int(os.getenv("RECALL_EXPANSION_LIMIT", "25")) RECALL_MIN_SCORE = float(os.getenv("RECALL_MIN_SCORE", "0.0")) -RECALL_ADAPTIVE_FLOOR = os.getenv("RECALL_ADAPTIVE_FLOOR", "true").lower() in ("true", "1", "yes") +RECALL_ADAPTIVE_FLOOR = os.getenv("RECALL_ADAPTIVE_FLOOR", "true").lower() in ( + "true", + "1", + "yes", +) # Memory content size limits (governs auto-summarization on store) # Soft limit: Content above this triggers auto-summarization @@ -131,8 +139,56 @@ # Target length for summarized content MEMORY_SUMMARY_TARGET_LENGTH = int(os.getenv("MEMORY_SUMMARY_TARGET_LENGTH", "300")) +# ----------------------------------------------------------------------------- +# Document storage (S3-compatible bucket, e.g. Railway Buckets) +# ----------------------------------------------------------------------------- +# When all S3_* vars are configured, /documents endpoints are enabled for +# agent-driven document uploads. Originals live in the bucket; a Memory node of +# type=Document is created with agent-provided title + summary so existing +# vector/keyword recall can find the doc. The agent fetches the original via a +# short-lived presigned URL when it needs to read the actual content. +S3_ENDPOINT: str | None = os.getenv("S3_ENDPOINT") or None +S3_BUCKET: str | None = os.getenv("S3_BUCKET") or None +S3_REGION: str = os.getenv("S3_REGION", "auto") +S3_ACCESS_KEY_ID: str | None = os.getenv("S3_ACCESS_KEY_ID") or None +S3_SECRET_ACCESS_KEY: str | None = os.getenv("S3_SECRET_ACCESS_KEY") or None +# virtual-host (bucket.endpoint) vs path (endpoint/bucket) +S3_URL_STYLE: str = os.getenv("S3_URL_STYLE", "virtual-host") +S3_FORCE_PATH_STYLE: bool = os.getenv("S3_FORCE_PATH_STYLE", "false").lower() in { + "1", + "true", + "yes", +} +# Hard cap on document uploads (bytes). Default 100 MB. +DOCUMENT_MAX_BYTES: int = int(os.getenv("DOCUMENT_MAX_BYTES", str(100 * 1024 * 1024))) +# Default expiry for GET /documents/:id/download presigned URLs (seconds). +DOCUMENT_PRESIGNED_EXPIRES: int = int(os.getenv("DOCUMENT_PRESIGNED_EXPIRES", "300")) + + +def is_bucket_configured() -> bool: + """True when all required S3 env vars are present so bucket upload works.""" + return all( + v + for v in ( + S3_ENDPOINT, + S3_BUCKET, + S3_ACCESS_KEY_ID, + S3_SECRET_ACCESS_KEY, + ) + ) + + # Memory types for classification -MEMORY_TYPES = {"Decision", "Pattern", "Preference", "Style", "Habit", "Insight", "Context"} +MEMORY_TYPES = { + "Decision", + "Pattern", + "Preference", + "Style", + "Habit", + "Insight", + "Context", + "Document", +} # Type aliases for normalization (lowercase and legacy types → canonical) # Non-canonical types are auto-mapped to canonical types on store @@ -145,12 +201,12 @@ "habit": "Habit", "insight": "Insight", "context": "Context", + "document": "Document", # Legacy/alternative types "memory": "Context", "milestone": "Context", "analysis": "Insight", "observation": "Insight", - "document": "Context", "meeting_notes": "Context", "template": "Pattern", "project": "Context", diff --git a/automem/runtime_wiring.py b/automem/runtime_wiring.py index 672b7d5b..10758bc7 100644 --- a/automem/runtime_wiring.py +++ b/automem/runtime_wiring.py @@ -53,7 +53,7 @@ def wire_recall_and_blueprints( serialize_node_fn=module._serialize_node, summarize_relation_node_fn=module._summarize_relation_node, update_last_accessed_fn=module.update_last_accessed, - jit_enrich_fn=module.jit_enrich_lightweight if module.JIT_ENRICHMENT_ENABLED else None, + jit_enrich_fn=(module.jit_enrich_lightweight if module.JIT_ENRICHMENT_ENABLED else None), normalize_tags_fn=module._normalize_tags, compute_tag_prefixes_fn=module._compute_tag_prefixes, coerce_importance_fn=module._coerce_importance, @@ -77,6 +77,10 @@ def wire_recall_and_blueprints( consolidation_tick_seconds=module.CONSOLIDATION_TICK_SECONDS, consolidation_history_limit=module.CONSOLIDATION_HISTORY_LIMIT, require_api_token_fn=module.require_api_token, + bucket_store=getattr(module, "bucket_store", None), + qdrant_models_obj=getattr(module, "qdrant_models", None), + document_max_bytes=getattr(module, "DOCUMENT_MAX_BYTES", 100 * 1024 * 1024), + document_presigned_expires=getattr(module, "DOCUMENT_PRESIGNED_EXPIRES", 300), ) diff --git a/automem/stores/bucket_store.py b/automem/stores/bucket_store.py new file mode 100644 index 00000000..35a365b0 --- /dev/null +++ b/automem/stores/bucket_store.py @@ -0,0 +1,199 @@ +"""S3-compatible object store for document originals. + +Used by the ``/documents`` endpoints to persist uploaded file bytes alongside a +Memory node that carries the agent-supplied title + summary. Originals are +fetched lazily via short-lived presigned URLs — AutoMem never parses file +content, it just stores and hands out signed URLs to agents who know what to +do with the bytes (e.g. Claude can read PDFs natively). + +Compatible with Railway Buckets, AWS S3, MinIO, Cloudflare R2, Backblaze B2, +Wasabi — anything that speaks the S3 API. +""" + +from __future__ import annotations + +import hashlib +import logging +from dataclasses import dataclass +from typing import Any, BinaryIO, Dict, Optional + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class BucketConfig: + """Immutable connection config for the bucket store.""" + + endpoint: str + bucket: str + region: str + access_key_id: str + secret_access_key: str + url_style: str = "virtual-host" + force_path_style: bool = False + + +class BucketStore: + """Thin wrapper around boto3's S3 client that exposes only the ops we use. + + Methods raise the underlying ``botocore.exceptions.ClientError`` for callers + to classify; this keeps the wrapper dependency-free (no custom exception + hierarchy to maintain). + """ + + def __init__(self, config: BucketConfig) -> None: + # boto3 is an optional dependency at import-time so that test suites + # not touching document storage don't need to install it. The client + # is instantiated lazily on first operation. + import boto3 # noqa: WPS433 - deliberate lazy import + from botocore.config import Config as BotoConfig # noqa: WPS433 + + self._config = config + addressing = ( + "path" if (config.force_path_style or config.url_style == "path") else "virtual" + ) + self._client = boto3.client( + "s3", + endpoint_url=config.endpoint, + region_name=config.region, + aws_access_key_id=config.access_key_id, + aws_secret_access_key=config.secret_access_key, + config=BotoConfig( + s3={"addressing_style": addressing}, + signature_version="s3v4", + retries={"max_attempts": 3, "mode": "standard"}, + ), + ) + + # ------------------------------------------------------------------ upload + def upload( + self, + key: str, + fileobj: BinaryIO, + *, + mime: str = "application/octet-stream", + metadata: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: + """Stream ``fileobj`` into the bucket at ``key``. + + The file pointer is consumed from its current position to EOF; the + caller is responsible for positioning it (typically at 0). Returns a + dict with ``{key, size, sha256, etag, content_type}``. + """ + # We tee through a hashing wrapper so we capture SHA-256 without + # loading the whole file into memory. boto3 handles chunked multipart + # upload automatically for large files via the managed upload API. + hasher = hashlib.sha256() + size = 0 + + class _HashingReader: + """Wrap fileobj so boto3's managed upload streams through our hasher.""" + + def __init__(self, inner: BinaryIO) -> None: + self._inner = inner + + def read(self, n: int = -1) -> bytes: + nonlocal size + chunk = self._inner.read(n) + if chunk: + hasher.update(chunk) + size += len(chunk) + return chunk + + extra_args: Dict[str, Any] = {"ContentType": mime} + if metadata: + # S3 metadata keys must be ASCII, and values must be str. Let + # boto3 URL-encode non-ASCII values via its ``Metadata`` handling. + extra_args["Metadata"] = {str(k): str(v) for k, v in metadata.items() if v is not None} + + reader = _HashingReader(fileobj) + self._client.upload_fileobj( + Fileobj=reader, + Bucket=self._config.bucket, + Key=key, + ExtraArgs=extra_args, + ) + + head = self._client.head_object(Bucket=self._config.bucket, Key=key) + return { + "key": key, + "size": size or int(head.get("ContentLength", 0)), + "sha256": hasher.hexdigest(), + "etag": (head.get("ETag") or "").strip('"'), + "content_type": head.get("ContentType", mime), + } + + # ----------------------------------------------------------- presigned_url + def presigned_url( + self, + key: str, + *, + expires_in: int = 300, + response_content_disposition: Optional[str] = None, + ) -> str: + """Return a time-limited download URL for ``key`` (default 5 min).""" + params: Dict[str, Any] = {"Bucket": self._config.bucket, "Key": key} + if response_content_disposition: + params["ResponseContentDisposition"] = response_content_disposition + return self._client.generate_presigned_url( + ClientMethod="get_object", + Params=params, + ExpiresIn=int(expires_in), + ) + + # ---------------------------------------------------------------- delete + def delete(self, key: str) -> None: + """Delete the object at ``key``. Idempotent: succeeds if already gone.""" + self._client.delete_object(Bucket=self._config.bucket, Key=key) + + # ----------------------------------------------------------------- head + def head(self, key: str) -> Optional[Dict[str, Any]]: + """Return metadata for ``key`` or None if the object does not exist.""" + try: + response = self._client.head_object(Bucket=self._config.bucket, Key=key) + except Exception as exc: # pragma: no cover - boto3 error classes are dynamic + code = getattr(getattr(exc, "response", {}), "get", lambda *_: None)("Error", {}) + status = getattr(exc, "response", {}).get("ResponseMetadata", {}).get("HTTPStatusCode") + if code and code.get("Code") in {"404", "NoSuchKey", "NotFound"}: + return None + if status == 404: + return None + raise + return { + "key": key, + "size": int(response.get("ContentLength", 0)), + "etag": (response.get("ETag") or "").strip('"'), + "content_type": response.get("ContentType"), + "last_modified": response.get("LastModified"), + "metadata": response.get("Metadata") or {}, + } + + +def build_bucket_store_from_config() -> Optional[BucketStore]: + """Construct a :class:`BucketStore` from env vars, or None if unconfigured.""" + from automem.config import ( + S3_ACCESS_KEY_ID, + S3_BUCKET, + S3_ENDPOINT, + S3_FORCE_PATH_STYLE, + S3_REGION, + S3_SECRET_ACCESS_KEY, + S3_URL_STYLE, + is_bucket_configured, + ) + + if not is_bucket_configured(): + return None + + assert S3_ENDPOINT and S3_BUCKET and S3_ACCESS_KEY_ID and S3_SECRET_ACCESS_KEY + return BucketStore( + BucketConfig( + endpoint=S3_ENDPOINT, + bucket=S3_BUCKET, + region=S3_REGION, + access_key_id=S3_ACCESS_KEY_ID, + secret_access_key=S3_SECRET_ACCESS_KEY, + url_style=S3_URL_STYLE, + force_path_style=S3_FORCE_PATH_STYLE, + ) + ) diff --git a/mcp-sse-server/package-lock.json b/mcp-sse-server/package-lock.json index ac2750b2..634eebc0 100644 --- a/mcp-sse-server/package-lock.json +++ b/mcp-sse-server/package-lock.json @@ -1,12 +1,12 @@ { "name": "automem-mcp-sse-server", - "version": "0.1.0", + "version": "0.2.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "automem-mcp-sse-server", - "version": "0.1.0", + "version": "0.2.0", "dependencies": { "@modelcontextprotocol/sdk": "^1.20.0", "express": "^4.19.2" diff --git a/mcp-sse-server/server.js b/mcp-sse-server/server.js index 7f453f97..67c52626 100644 --- a/mcp-sse-server/server.js +++ b/mcp-sse-server/server.js @@ -240,6 +240,36 @@ class InMemoryEventStore { } } +// Some MCP client transports JSON-encode nested object/array arguments as +// strings before calling tools, even when the input schema declares them as +// `object` or `array`. The upstream AutoMem service rejects those with +// `'metadata' must be an object` (etc.) because its strict JSON body parse +// expects native types. Coerce known offenders back to their native form +// before forwarding. Exported for unit tests. +export function coerceJsonFields(obj, fields) { + if (!obj || typeof obj !== 'object') return obj; + const out = { ...obj }; + for (const field of fields) { + const value = out[field]; + if (typeof value === 'string' && value.length > 0) { + try { + const parsed = JSON.parse(value); + // Only accept parses that match the expected native shape (object + // for metadata, array for embedding/tags). This avoids accidentally + // turning a plain string value into a number or other primitive. + if (field === 'metadata' && parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { + out[field] = parsed; + } else if ((field === 'embedding' || field === 'tags') && Array.isArray(parsed)) { + out[field] = parsed; + } + } catch (_) { + // Leave as-is — the upstream service will return a clear 400. + } + } + } + return out; +} + // Simple AutoMem HTTP client (mirrors the npm package behavior but inline to avoid version conflicts) export class AutoMemClient { constructor(config) { @@ -264,20 +294,21 @@ export class AutoMemClient { }); } async storeMemory(args, options) { + const coerced = coerceJsonFields(args || {}, ['metadata', 'embedding', 'tags']); const body = { - content: args.content, - tags: args.tags || [], - importance: args.importance, - embedding: args.embedding, - metadata: args.metadata, - timestamp: args.timestamp, - type: args.type, - confidence: args.confidence, - id: args.id, - t_valid: args.t_valid, - t_invalid: args.t_invalid, - updated_at: args.updated_at, - last_accessed: args.last_accessed + content: coerced.content, + tags: coerced.tags || [], + importance: coerced.importance, + embedding: coerced.embedding, + metadata: coerced.metadata, + timestamp: coerced.timestamp, + type: coerced.type, + confidence: coerced.confidence, + id: coerced.id, + t_valid: coerced.t_valid, + t_invalid: coerced.t_invalid, + updated_at: coerced.updated_at, + last_accessed: coerced.last_accessed }; const r = await this._request('POST', 'memory', body, options); return { memory_id: r.memory_id || r.id, message: r.message || 'Memory stored successfully' }; @@ -327,7 +358,8 @@ export class AutoMemClient { return { success: true, message: r.message || 'Association created successfully' }; } async updateMemory(args, options) { - const { memory_id, ...updates } = args; + const coerced = coerceJsonFields(args || {}, ['metadata', 'embedding', 'tags']); + const { memory_id, ...updates } = coerced; const r = await this._request('PATCH', `memory/${memory_id}`, updates, options); return { memory_id: r.memory_id || memory_id, message: r.message || 'Memory updated successfully' }; } @@ -342,6 +374,118 @@ export class AutoMemClient { maxRetries: options.maxRetries ?? 0, }); } + + // -------------------- Document storage (Stage 1) ----------------------- + // Agent-driven uploads: the caller must read the file and provide an accurate + // ``title`` + ``summary`` BEFORE calling. AutoMem never parses file content. + + async uploadDocument(args, options = {}) { + const { file_base64, filename, mime, title, summary } = args; + if (!file_base64 || !filename) { + throw new Error('upload_document requires file_base64 and filename'); + } + if (!title || !summary) { + throw new Error( + 'upload_document requires title and summary. Read the file first and generate an accurate title (<200 chars) and 1-3 sentence summary.' + ); + } + + const bytes = Buffer.from(file_base64, 'base64'); + const blob = new Blob([bytes], { type: mime || 'application/octet-stream' }); + const form = new FormData(); + form.append('file', blob, filename); + form.append('title', title); + form.append('summary', summary); + if (args.tags !== undefined) { + form.append('tags', Array.isArray(args.tags) ? JSON.stringify(args.tags) : String(args.tags)); + } + if (args.importance !== undefined) form.append('importance', String(args.importance)); + if (args.metadata !== undefined) { + form.append( + 'metadata', + typeof args.metadata === 'string' ? args.metadata : JSON.stringify(args.metadata) + ); + } + if (args.memory_id) form.append('memory_id', String(args.memory_id)); + + const url = `${this.config.endpoint.replace(/\/$/, '')}/documents`; + const headers = {}; + if (this.config.apiKey) headers['Authorization'] = `Bearer ${this.config.apiKey}`; + const requestId = options.requestId || randomUUID(); + const timeoutMs = options.timeoutMs ?? readIntEnv('UPSTREAM_TIMEOUT_MS', DEFAULT_UPSTREAM_TIMEOUT_MS); + + // No retries for multipart — uploads are expensive to repeat and the + // single attempt gives clearer error semantics to the caller. + const controller = new AbortController(); + const timer = setTimeout(() => controller.abort(), timeoutMs); + try { + log('info', 'document_upload', { + reqId: requestId, + url: sanitizeUrlForLog(url), + filename, + size: bytes.length, + mime, + }); + const res = await fetch(url, { + method: 'POST', + headers, + body: form, + signal: controller.signal, + }); + const data = await parseResponseBody(res); + if (!res.ok) { + const message = summarizeUpstreamErrorBody(res.status, data); + log('error', 'document_upload_failed', { + reqId: requestId, + status: res.status, + message, + }); + throw new UpstreamRequestError(message, { + status: res.status, + requestId, + kind: 'http', + retryable: false, + endpoint: url, + }); + } + return data; + } catch (error) { + if (error instanceof UpstreamRequestError) throw error; + throw new UpstreamRequestError( + error?.name === 'AbortError' + ? `document upload timed out after ${timeoutMs}ms` + : `document upload failed: ${error?.message || error}`, + { requestId, kind: error?.name === 'AbortError' ? 'timeout' : 'network', endpoint: url, cause: error } + ); + } finally { + clearTimeout(timer); + } + } + + async getDocumentUrl(args, options) { + const id = args.memory_id; + if (!id) throw new Error('get_document_url requires memory_id'); + const p = new URLSearchParams(); + if (args.expires_in !== undefined) p.set('expires_in', String(args.expires_in)); + if (args.disposition) p.set('disposition', String(args.disposition)); + const path = p.toString() ? `documents/${id}/download?${p.toString()}` : `documents/${id}/download`; + return this._request('GET', path, undefined, options); + } + + async listDocuments(args = {}, options) { + const p = new URLSearchParams(); + if (Array.isArray(args.tags)) args.tags.forEach(t => p.append('tags', t)); + else if (typeof args.tags === 'string' && args.tags.trim()) p.set('tags', args.tags.trim()); + if (args.limit !== undefined) p.set('limit', String(args.limit)); + const path = p.toString() ? `documents?${p.toString()}` : 'documents'; + return this._request('GET', path, undefined, options); + } + + async deleteDocument(args, options) { + const id = args.memory_id; + if (!id) throw new Error('delete_document requires memory_id'); + return this._request('DELETE', `documents/${id}`, undefined, options); + } } export function formatRecallAsItems(results, { detailed = false } = {}) { @@ -543,7 +687,102 @@ export function buildMcpServer(client) { description: 'Check AutoMem service health (FalkorDB, Qdrant, embedding provider)', annotations: { readOnlyHint: true, destructiveHint: false }, inputSchema: { type: 'object', properties: {} } - } + }, + { + name: 'upload_document', + description: + 'Upload a document (PDF, DOCX, image, text, etc.) to AutoMem. IMPORTANT: BEFORE calling this tool you MUST read the file yourself and provide an accurate human-quality title AND a 1-3 sentence summary of what it contains. AutoMem does not parse file content; your title and summary become the searchable representation. The file bytes are stored as an opaque original; later you (or another agent) can retrieve a short-lived download URL via get_document_url when you need the raw bytes. Reject the task if the user has not given you enough context to write a faithful summary.', + annotations: { readOnlyHint: false, destructiveHint: false }, + inputSchema: { + type: 'object', + properties: { + file_base64: { type: 'string', description: 'Base64-encoded file bytes (binary-safe).' }, + filename: { type: 'string', description: 'Original filename, e.g. "report-2026-q2.pdf"' }, + mime: { + type: 'string', + description: 'MIME type, e.g. "application/pdf". Defaults to application/octet-stream.', + }, + title: { + type: 'string', + description: + 'A concise human-readable title (<200 chars). You must have read the file before writing this; do not echo the filename.', + }, + summary: { + type: 'string', + description: + '1-3 sentence summary capturing what this document is, who it is for, and what is inside. Becomes the vector+keyword search representation — write it well.', + }, + tags: { + type: 'array', + items: { type: 'string' }, + description: 'Optional tags for filtering (the tag "document" is added automatically).', + }, + importance: { + type: 'number', + minimum: 0, + maximum: 1, + description: 'Importance score 0-1 (default 0.5).', + }, + metadata: { + type: 'object', + description: 'Arbitrary extra metadata merged into the Memory node.', + }, + memory_id: { + type: 'string', + description: 'Optional UUID for the memory. Auto-generated if omitted.', + }, + }, + required: ['file_base64', 'filename', 'title', 'summary'], + }, + }, + { + name: 'get_document_url', + description: + 'Return a short-lived (default 5 min) presigned HTTPS URL to download the original file for a Document memory. Use this when you need to read the raw bytes (e.g. open the PDF).', + annotations: { readOnlyHint: true, destructiveHint: false }, + inputSchema: { + type: 'object', + properties: { + memory_id: { type: 'string', description: 'UUID of the Document memory.' }, + expires_in: { + type: 'integer', + minimum: 30, + maximum: 3600, + description: 'Seconds the URL remains valid (30-3600, default 300).', + }, + disposition: { + type: 'string', + enum: ['inline', 'attachment'], + description: 'How downstream browsers should handle the response.', + }, + }, + required: ['memory_id'], + }, + }, + { + name: 'list_documents', + description: 'List Document memories, optionally filtered by tags.', + annotations: { readOnlyHint: true, destructiveHint: false }, + inputSchema: { + type: 'object', + properties: { + tags: { type: 'array', items: { type: 'string' } }, + limit: { type: 'integer', minimum: 1, maximum: 200, default: 25 }, + }, + }, + }, + { + name: 'delete_document', + description: 'Delete a Document memory and remove its original file from the bucket.', + annotations: { readOnlyHint: false, destructiveHint: true }, + inputSchema: { + type: 'object', + properties: { + memory_id: { type: 'string', description: 'UUID of the Document memory to remove.' }, + }, + required: ['memory_id'], + }, + }, ]; server.setRequestHandler(ListToolsRequestSchema, async () => ({ tools })); @@ -643,6 +882,53 @@ export function buildMcpServer(client) { const r = await client.checkHealth({ requestId }); return { content: [{ type: 'text', text: JSON.stringify(r) }] }; } + case 'upload_document': { + const r = await client.uploadDocument(args || {}, { requestId }); + const doc = r?.document || {}; + const lines = [ + `Document stored: ${r.memory_id}`, + r.title ? `Title: ${r.title}` : null, + doc.filename ? `Filename: ${doc.filename}` : null, + doc.size !== undefined ? `Size: ${doc.size} bytes` : null, + doc.mime ? `MIME: ${doc.mime}` : null, + doc.sha256 ? `SHA-256: ${doc.sha256}` : null, + r.download_url ? `Download URL (${r.download_url_expires_in}s): ${r.download_url}` : null, + ].filter(Boolean); + return { content: [{ type: 'text', text: lines.join('\n') }] }; + } + case 'get_document_url': { + const r = await client.getDocumentUrl(args || {}, { requestId }); + const lines = [ + `Download URL (expires in ${r.expires_in}s): ${r.download_url}`, + r.filename ? `Filename: ${r.filename}` : null, + r.mime ? `MIME: ${r.mime}` : null, + r.size !== undefined ? `Size: ${r.size} bytes` : null, + ].filter(Boolean); + return { content: [{ type: 'text', text: lines.join('\n') }] }; + } + case 'list_documents': { + const r = await client.listDocuments(args || {}, { requestId }); + const docs = Array.isArray(r.documents) ? r.documents : []; + if (!docs.length) { + return { content: [{ type: 'text', text: 'No documents found.' }] }; + } + const lines = docs.map((d, i) => { + const meta = d.document || {}; + return [ + `${i + 1}. ${meta.title || '(untitled)'}`, + ` ID: ${d.memory_id}`, + meta.filename ? ` File: ${meta.filename} (${meta.size || '?'} bytes, ${meta.mime || 'unknown'})` : null, + Array.isArray(d.tags) && d.tags.length ? ` Tags: ${d.tags.join(', ')}` : null, + ].filter(Boolean).join('\n'); + }); + return { + content: [{ type: 'text', text: `${docs.length} document(s):\n\n${lines.join('\n\n')}` }], + }; + } + case 'delete_document': { + const r = await client.deleteDocument(args || {}, { requestId }); + return { content: [{ type: 'text', text: `Deleted ${r.memory_id} (graph=${r.graph}, bucket=${r.bucket})` }] }; + } default: throw new Error(`Unknown tool: ${name}`); } diff --git a/mcp-sse-server/test/server.test.js b/mcp-sse-server/test/server.test.js index 227c31e9..19caed23 100644 --- a/mcp-sse-server/test/server.test.js +++ b/mcp-sse-server/test/server.test.js @@ -1,6 +1,11 @@ import test from "node:test"; import assert from "node:assert/strict"; -import { AutoMemClient, createApp, formatRecallAsItems } from "../server.js"; +import { + AutoMemClient, + coerceJsonFields, + createApp, + formatRecallAsItems, +} from "../server.js"; async function withServer(app, fn) { const server = await new Promise((resolve) => { @@ -497,3 +502,81 @@ test("GET /mcp/sse returns an SSE stream and endpoint event", async () => { process.env.AUTOMEM_API_URL = prevEndpoint; } }); + +// ----------------------------------------------------------------------------- +// coerceJsonFields — the MCP transport sometimes JSON-encodes nested args as +// strings. Ensure the client parses them back to native shape before sending +// to AutoMem, which otherwise returns "'metadata' must be an object". +// ----------------------------------------------------------------------------- + +test("coerceJsonFields parses stringified metadata objects", () => { + const input = { + content: "hello", + metadata: '{"project":"automem","k":1}', + tags: ["a", "b"], + }; + const out = coerceJsonFields(input, ["metadata", "embedding", "tags"]); + assert.deepEqual(out.metadata, { project: "automem", k: 1 }); + assert.equal(out.content, "hello"); + assert.deepEqual(out.tags, ["a", "b"]); +}); + +test("coerceJsonFields parses stringified tags arrays", () => { + const out = coerceJsonFields( + { tags: '["finance","q2"]' }, + ["metadata", "embedding", "tags"] + ); + assert.deepEqual(out.tags, ["finance", "q2"]); +}); + +test("coerceJsonFields leaves already-native values alone", () => { + const meta = { nested: { deep: true } }; + const out = coerceJsonFields( + { metadata: meta, tags: ["x"] }, + ["metadata", "embedding", "tags"] + ); + assert.strictEqual(out.metadata, meta); + assert.deepEqual(out.tags, ["x"]); +}); + +test("coerceJsonFields leaves malformed JSON strings as-is", () => { + const out = coerceJsonFields( + { metadata: "not json" }, + ["metadata", "embedding", "tags"] + ); + assert.equal(out.metadata, "not json"); +}); + +test("coerceJsonFields rejects shape mismatches (object expected, array given)", () => { + const out = coerceJsonFields( + { metadata: '["not","an","object"]' }, + ["metadata", "embedding", "tags"] + ); + assert.equal(out.metadata, '["not","an","object"]'); +}); + +test("AutoMemClient.storeMemory coerces stringified metadata before POST", async () => { + const client = new AutoMemClient({ + endpoint: "http://example.test", + apiKey: "k", + }); + + let capturedBody = null; + client._request = async (_method, _path, body) => { + capturedBody = body; + return { memory_id: "abc" }; + }; + + await client.storeMemory({ + content: "hi", + metadata: '{"project":"automem","priority":"high"}', + tags: '["a","b"]', + }); + + assert.deepEqual(capturedBody.metadata, { + project: "automem", + priority: "high", + }); + assert.deepEqual(capturedBody.tags, ["a", "b"]); + assert.equal(capturedBody.content, "hi"); +}); diff --git a/requirements.txt b/requirements.txt index 5280e6eb..a283502a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ fastembed==0.4.2 onnxruntime<1.20 # Pin to avoid issues with fastembed 0.4.2 en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0.tar.gz httpx>=0.27.0 +boto3>=1.35.0 # S3-compatible bucket store for document originals (Railway Buckets) diff --git a/tests/contracts/test_routes_contract.py b/tests/contracts/test_routes_contract.py index db6e9565..4bb98a4d 100644 --- a/tests/contracts/test_routes_contract.py +++ b/tests/contracts/test_routes_contract.py @@ -31,6 +31,10 @@ ("GET", "/viewer/"), ("GET", "/stream"), ("GET", "/stream/status"), + ("POST", "/documents"), + ("GET", "/documents"), + ("DELETE", "/documents/"), + ("GET", "/documents//download"), } diff --git a/tests/support/fake_bucket.py b/tests/support/fake_bucket.py new file mode 100644 index 00000000..cb2c8c56 --- /dev/null +++ b/tests/support/fake_bucket.py @@ -0,0 +1,84 @@ +"""In-memory fake for ``automem.stores.bucket_store.BucketStore``. + +Keeps the test suite self-contained — no moto, no real boto3 — while +preserving the BucketStore contract used by the documents blueprint. +""" + +from __future__ import annotations + +import hashlib +from typing import Any, BinaryIO, Dict, Optional + + +class FakeBucketStore: + """Tracks uploads in a dict keyed by bucket key. + + ``upload`` reads the entire stream (to compute sha256) and stores bytes + verbatim. ``presigned_url`` returns a deterministic URL; tests can parse + it to assert the key. ``delete`` removes the entry. + """ + + def __init__(self, *, bucket: str = "test-bucket") -> None: + self.bucket = bucket + self.objects: Dict[str, Dict[str, Any]] = {} + # Recording knobs for assertions + self.upload_calls = 0 + self.delete_calls = 0 + self.presign_calls = 0 + + # ------------------------------------------------------------------ API + def upload( + self, + key: str, + fileobj: BinaryIO, + *, + mime: str = "application/octet-stream", + metadata: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: + data = fileobj.read() + size = len(data) + sha = hashlib.sha256(data).hexdigest() + self.objects[key] = { + "data": data, + "mime": mime, + "metadata": dict(metadata or {}), + } + self.upload_calls += 1 + return { + "key": key, + "size": size, + "sha256": sha, + "etag": sha[:16], + "content_type": mime, + } + + def presigned_url( + self, + key: str, + *, + expires_in: int = 300, + response_content_disposition: Optional[str] = None, + ) -> str: + self.presign_calls += 1 + qs = f"expires_in={int(expires_in)}" + if response_content_disposition: + # URL-safe enough for assertions; we don't actually call a server + qs += f"&disposition={response_content_disposition.replace(' ', '+')}" + return f"https://fake.bucket.invalid/{self.bucket}/{key}?{qs}" + + def delete(self, key: str) -> None: + self.delete_calls += 1 + self.objects.pop(key, None) + + def head(self, key: str) -> Optional[Dict[str, Any]]: + obj = self.objects.get(key) + if not obj: + return None + return { + "key": key, + "size": len(obj["data"]), + "etag": hashlib.sha256(obj["data"]).hexdigest()[:16], + "content_type": obj["mime"], + "last_modified": None, + "metadata": obj["metadata"], + } diff --git a/tests/test_documents.py b/tests/test_documents.py new file mode 100644 index 00000000..b01ddef6 --- /dev/null +++ b/tests/test_documents.py @@ -0,0 +1,325 @@ +"""Tests for the ``/documents`` blueprint. + +Uses a self-contained Flask app wired only with the documents blueprint and a +FakeBucketStore — independent of the bigger ``app.py`` bootstrap so we're +testing the blueprint contract directly. The bigger integration test suite +covers end-to-end wiring. +""" + +from __future__ import annotations + +import io +import json +import logging +from types import SimpleNamespace +from typing import List + +import pytest +from flask import Flask + +from automem.api.documents import create_documents_blueprint +from tests.support.fake_bucket import FakeBucketStore +from tests.support.fake_graph import FakeGraph + + +def _normalize_tags(value) -> List[str]: + if value is None: + return [] + if isinstance(value, str): + return [t.strip() for t in value.split(",") if t.strip()] + if isinstance(value, list): + return [str(t).strip() for t in value if str(t).strip()] + return [] + + +def _compute_tag_prefixes(tags_lower: List[str]) -> List[str]: + prefixes = [] + for tag in tags_lower: + for i in range(1, len(tag) + 1): + prefixes.append(tag[:i]) + return prefixes + + +def _coerce_importance(value) -> float: + if value is None or value == "": + return 0.5 + try: + f = float(value) + except (TypeError, ValueError): + return 0.5 + return max(0.0, min(1.0, f)) + + +def _utc_now() -> str: + from datetime import datetime, timezone + + return datetime.now(timezone.utc).isoformat() + + +# ------------------------------------------------------------------ fixtures +@pytest.fixture +def fake_bucket() -> FakeBucketStore: + return FakeBucketStore(bucket="test-docs") + + +@pytest.fixture +def fake_graph() -> FakeGraph: + return FakeGraph() + + +@pytest.fixture +def flask_app(fake_bucket: FakeBucketStore, fake_graph: FakeGraph): + """Fresh Flask app with only the documents blueprint wired to the fakes.""" + app = Flask(__name__) + state = SimpleNamespace( + enrichment_queue=SimpleNamespace(qsize=lambda: 0), + memory_graph=fake_graph, + ) + + bp = create_documents_blueprint( + bucket_store=fake_bucket, + get_memory_graph_fn=lambda: fake_graph, + get_qdrant_client_fn=lambda: None, + normalize_tags_fn=_normalize_tags, + compute_tag_prefixes_fn=_compute_tag_prefixes, + coerce_importance_fn=_coerce_importance, + enqueue_enrichment_fn=lambda *_: None, + enqueue_embedding_fn=lambda *_: None, + collection_name="memories", + utc_now_fn=_utc_now, + state=state, + qdrant_models_obj=None, + max_bytes=10 * 1024 * 1024, + presigned_expires=300, + ) + app.register_blueprint(bp) + app.logger.setLevel(logging.CRITICAL) + return app + + +@pytest.fixture +def client(flask_app): + return flask_app.test_client() + + +@pytest.fixture +def noop_bucket_app(fake_graph: FakeGraph): + """App with bucket_store=None, for the unconfigured path.""" + app = Flask(__name__) + state = SimpleNamespace( + enrichment_queue=SimpleNamespace(qsize=lambda: 0), + memory_graph=fake_graph, + ) + bp = create_documents_blueprint( + bucket_store=None, + get_memory_graph_fn=lambda: fake_graph, + get_qdrant_client_fn=lambda: None, + normalize_tags_fn=_normalize_tags, + compute_tag_prefixes_fn=_compute_tag_prefixes, + coerce_importance_fn=_coerce_importance, + enqueue_enrichment_fn=lambda *_: None, + enqueue_embedding_fn=lambda *_: None, + collection_name="memories", + utc_now_fn=_utc_now, + state=state, + qdrant_models_obj=None, + max_bytes=10 * 1024 * 1024, + presigned_expires=300, + ) + app.register_blueprint(bp) + app.logger.setLevel(logging.CRITICAL) + return app.test_client() + + +# ------------------------------------------------------------------- helpers +def _multipart( + *, + title="Test Doc", + summary="A test document containing example content.", + data=b"hello world", + filename="test.txt", + mime="text/plain", + **extra, +): + form = {"file": (io.BytesIO(data), filename, mime)} + if title is not None: + form["title"] = title + if summary is not None: + form["summary"] = summary + form.update(extra) + return form + + +# -------------------------------------------------------------- tests: gate +def test_upload_without_title_returns_422(client, fake_bucket): + response = client.post( + "/documents", + data=_multipart(title=""), + content_type="multipart/form-data", + ) + assert response.status_code == 422 + assert fake_bucket.upload_calls == 0 + + +def test_upload_without_summary_returns_422(client, fake_bucket): + response = client.post( + "/documents", + data=_multipart(summary=""), + content_type="multipart/form-data", + ) + assert response.status_code == 422 + assert fake_bucket.upload_calls == 0 + + +def test_upload_without_file_returns_400(client, fake_bucket): + response = client.post( + "/documents", + data={"title": "t", "summary": "s"}, + content_type="multipart/form-data", + ) + assert response.status_code == 400 + assert fake_bucket.upload_calls == 0 + + +def test_upload_title_too_long_returns_400(client, fake_bucket): + response = client.post( + "/documents", + data=_multipart(title="x" * 301), + content_type="multipart/form-data", + ) + assert response.status_code == 400 + assert fake_bucket.upload_calls == 0 + + +# ------------------------------------------------------- tests: happy path +def test_upload_creates_memory_and_bucket_object(client, fake_bucket, fake_graph): + response = client.post( + "/documents", + data=_multipart( + title="Q2 Budget Report", + summary="Quarterly budget numbers for marketing.", + data=b"PDF pretend bytes", + filename="q2.pdf", + mime="application/pdf", + tags=json.dumps(["finance", "q2"]), + importance="0.8", + ), + content_type="multipart/form-data", + ) + assert response.status_code == 201, response.get_data(as_text=True) + body = response.get_json() + + assert body["status"] == "success" + assert body["type"] == "Document" + assert body["title"] == "Q2 Budget Report" + assert body["document"]["mime"] == "application/pdf" + assert body["document"]["size"] == len(b"PDF pretend bytes") + assert body["document"]["sha256"] + assert body["download_url"].startswith("https://fake.bucket.invalid/") + tags_lower = [t.lower() for t in body["tags"]] + assert "document" in tags_lower # auto-added + assert "finance" in tags_lower + + assert fake_bucket.upload_calls == 1 + key = body["document"]["bucket_key"] + assert key in fake_bucket.objects + assert fake_bucket.objects[key]["data"] == b"PDF pretend bytes" + assert fake_bucket.objects[key]["mime"] == "application/pdf" + + # Memory node created in graph with type=Document + mem = fake_graph.memories.get(body["memory_id"]) + assert mem is not None + assert mem["type"] == "Document" + assert mem["content"].startswith("Q2 Budget Report") + + +def test_upload_uses_provided_memory_id(client, fake_bucket): + custom_id = "11111111-1111-1111-1111-111111111111" + response = client.post( + "/documents", + data=_multipart(memory_id=custom_id), + content_type="multipart/form-data", + ) + assert response.status_code == 201 + assert response.get_json()["memory_id"] == custom_id + + +def test_upload_rejects_invalid_memory_id(client): + response = client.post( + "/documents", + data=_multipart(memory_id="not-a-uuid"), + content_type="multipart/form-data", + ) + assert response.status_code == 400 + + +# ----------------------------------------------------- tests: download URL +def test_download_returns_presigned_url(client, fake_bucket): + upload = client.post( + "/documents", + data=_multipart(), + content_type="multipart/form-data", + ) + memory_id = upload.get_json()["memory_id"] + + response = client.get(f"/documents/{memory_id}/download") + assert response.status_code == 200 + body = response.get_json() + assert body["download_url"].startswith("https://fake.bucket.invalid/") + assert "expires_in=" in body["download_url"] + assert body["mime"] == "text/plain" + assert fake_bucket.presign_calls >= 2 # upload returned one, download one more + + +def test_download_missing_doc_returns_404(client): + missing = "22222222-2222-2222-2222-222222222222" + response = client.get(f"/documents/{missing}/download") + assert response.status_code == 404 + + +def test_download_invalid_uuid_returns_400(client): + response = client.get("/documents/not-a-uuid/download") + assert response.status_code == 400 + + +# -------------------------------------------------------------- tests: delete +def test_delete_removes_memory_and_bucket(client, fake_bucket, fake_graph): + upload = client.post( + "/documents", + data=_multipart(), + content_type="multipart/form-data", + ) + memory_id = upload.get_json()["memory_id"] + key = upload.get_json()["document"]["bucket_key"] + assert key in fake_bucket.objects + + response = client.delete(f"/documents/{memory_id}") + assert response.status_code == 200 + body = response.get_json() + assert body["graph"] == "deleted" + assert body["bucket"] == "deleted" + assert key not in fake_bucket.objects + assert fake_bucket.delete_calls == 1 + assert memory_id not in fake_graph.memories + + +def test_delete_missing_doc_returns_404(client): + response = client.delete("/documents/33333333-3333-3333-3333-333333333333") + assert response.status_code == 404 + + +# --------------------------------------------------- tests: unconfigured +def test_upload_returns_503_when_bucket_unconfigured(noop_bucket_app): + response = noop_bucket_app.post( + "/documents", + data=_multipart(), + content_type="multipart/form-data", + ) + assert response.status_code == 503 + text = response.get_data(as_text=True) + assert "S3" in text or "bucket" in text.lower() + + +def test_download_returns_503_when_bucket_unconfigured(noop_bucket_app): + response = noop_bucket_app.get("/documents/11111111-1111-1111-1111-111111111111/download") + assert response.status_code == 503