diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ecc620..af039fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,76 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added +- Capability namespaces and hierarchical discovery in `CapabilityRegistry`: + dot-notation `capability_id`s now expose `list_namespaces()` / + `list_namespace(prefix)` operations; `register_namespace(prefix, loader=...)` + enables deferred registration for large tool ecosystems (the loader runs + at most once on first access). `search()` gained an `offset` kwarg for + pagination, strips a small stop-word set, and now scores with a + BM25-flavoured ranker that weights `capability_id`/`tags` matches above + `description`. Flat (un-namespaced) capability IDs continue to work + unchanged. (#45) +- Capability marketplace, part 1 — manifest format & local registry: new + `CapabilityDescriptor` and `CapabilityManifest` dataclasses (both + JSON-round-trippable via `to_dict`/`from_dict`), new + `agent_kernel.federation` module with `build_manifest()`, + `import_manifest()`, and `merge_sensitivity()`, and new `Kernel.advertise()` + / `Kernel.import_remote()` methods. `Kernel` gained a `kernel_id` + argument used as the manifest publisher identity. Three trust policies + are honoured at import time (`most_restrictive` (default), `local_only`, + `remote_deferred`); imported capabilities are routed through a + caller-supplied driver and flow through the full local policy → token → + firewall pipeline. HMAC tokens remain kernel-scoped — a token issued by + one kernel cannot be verified by another with a different secret. New + errors `NamespaceNotFound`, `FederationError`, `ManifestError`, + `TrustPolicyError`. (#52) +- New docs: [`docs/federation.md`](docs/federation.md) for the marketplace + protocol and a namespace section in + [`docs/capabilities.md`](docs/capabilities.md). +- Capability marketplace, part 2 — federated discovery: new + `agent_kernel.federation_discovery` module with `discover_peers()`, + `sign_manifest()`, `verify_manifest()`, `serve_manifest_payload()`, and + `DiscoveryRateLimiter`. `Kernel.discover_peers()` fetches one or more + manifests over HTTP from peer URLs or a registry URL. Signed envelopes + (HMAC-SHA256) detect tampering and let importers refuse unsigned + manifests when a verification secret is in play (and vice versa). New + errors `ManifestSignatureError` and `DiscoveryError`. (#51, closes #49) +- OpenTelemetry integration: new `agent_kernel.otel` module with + `instrument_kernel(kernel)` that wraps `Kernel.invoke` and + `Kernel.grant_capability` with OTel spans + metrics (invocation count, + latency histogram, denial counter). No-op when the optional `[otel]` + extra is not installed (`OTEL_AVAILABLE` reports the runtime status). + Idempotent — repeat calls on the same kernel are no-ops. (#38) +- Streaming firewall: new `Firewall.apply_stream()` async-iterator method + that processes driver chunks one-at-a-time, applying PII redaction + per-chunk. New `StreamingDriver` Protocol in `drivers/base.py` extends + `Driver` with an optional `execute_stream()`. New `Kernel.invoke_stream()` + yields `Frame` chunks; the last chunk carries `is_final=True`. Drivers + without `execute_stream` automatically fall back to a single-chunk stream + via `execute()`. `Frame` gained an `is_final: bool` field. (#47) + +### Changed +- Tech debt: `policy_dsl.py` decomposed (was 661 lines). Parsing and + schema dataclasses now live in `policy_dsl_parser.py` + (`PolicyMatch`, `PolicyRule`, `parse_engine_data`, `parse_rule`, + YAML/TOML loaders), and the denial-explanation traversal in + `policy_dsl_explain.py`. The public import surface + (`DeclarativePolicyEngine`, `PolicyMatch`, `PolicyRule`) is unchanged. + `RateLimiter` and rate-limit constants extracted from `policy.py` into + a new `rate_limit.py` module; `policy.py` continues to re-export them + under their original names. (#68) +- Tech debt: `kernel.py` split into the `agent_kernel.kernel` sub-package + to honour AGENTS.md's ≤ 300-line module bar. The `Kernel` class lives + in `kernel/__init__.py`; heavy methods (invoke pipeline, dry-run, + federation, streaming) delegate to sibling modules. Existing + `from agent_kernel import Kernel` / `from agent_kernel.kernel import Kernel` + imports are unchanged. (#68) + +### Tests +- Added explicit dry-run regression tests for `HTTPDriver` and `MCPDriver`, + pinning the kernel's driver-agnostic dry-run short-circuit. (#68) + ## [0.7.0] - 2026-05-20 ### Added diff --git a/docs/capabilities.md b/docs/capabilities.md index db448db..0c85d54 100644 --- a/docs/capabilities.md +++ b/docs/capabilities.md @@ -3,9 +3,60 @@ ## Naming conventions - Use `domain.verb_noun` format: `billing.list_invoices`, `users.get_profile`. +- Prefer fully namespaced IDs (`billing.invoices.list`) over flat ones — + the registry will infer namespace operations from the dot-segments and + large ecosystems benefit from being able to list/search per namespace. - Be specific: prefer `billing.cancel_invoice` over `billing.update`. - Avoid generic names like `billing.execute` or `api.call`. +## Namespaces and discovery + +`CapabilityRegistry` recognises dot-notation namespaces automatically. No +extra registration step is required — `register(Capability(capability_id= +"billing.invoices.list", ...))` is enough to populate the `billing` and +`billing.invoices` namespaces. + +```python +registry.list_namespaces() +# ['billing', 'crm'] + +registry.list_namespace("billing") +# [Capability('billing.invoices.list'), Capability('billing.payments.refund'), …] +``` + +For large tool ecosystems where eagerly registering hundreds of +capabilities is wasteful, declare a deferred loader. The loader runs at +most once, the first time the namespace is searched, listed, or any +capability under it is fetched via `get()`: + +```python +def load_billing() -> list[Capability]: + return [ + Capability(capability_id="billing.invoices.list", …), + Capability(capability_id="billing.invoices.create", …), + Capability(capability_id="billing.payments.refund", …), + ] + +registry.register_namespace( + "billing", + description="Billing and invoicing tools", + loader=load_billing, +) +``` + +Search ranks matches with a BM25-flavoured scorer that weights +`capability_id` and `tags` higher than `description`, strips a small +stop-word set (`a`, `the`, `please`, …), and offers `offset` for +pagination: + +```python +results = registry.search("list invoices", max_results=10, offset=0) +``` + +Search is deterministic — equal-scoring capabilities are returned in +`capability_id` order — and trips any deferred namespace loader whose +prefix shares a token with the query. + ## Granularity Each capability should map to a single, auditable action with clear side-effects. diff --git a/docs/context_firewall.md b/docs/context_firewall.md index 3b1623a..385dd23 100644 --- a/docs/context_firewall.md +++ b/docs/context_firewall.md @@ -118,3 +118,59 @@ manager = BudgetManager(total_budget=128_000, token_counter=tiktoken_counter) The default counter (`default_token_counter`) is a character-based `len(json.dumps(value)) // 4` approximation with no extra dependencies. + +## Streaming + +For large results that arrive incrementally (e.g. SSE-style HTTP responses, +chunked database cursors, line-by-line tool output), `Firewall.apply_stream()` +lets you process chunks one at a time. PII redaction and per-chunk budget +caps apply on every yielded Frame — secrets cannot leak just because they +arrived in chunk N rather than the final aggregate. + +```python +from agent_kernel.drivers.base import ExecutionContext, StreamingDriver + +class MyStreamingDriver: + driver_id = "stream" + + async def execute(self, ctx: ExecutionContext): + # one-shot fallback, called when StreamingDriver isn't used. + ... + + async def execute_stream(self, ctx: ExecutionContext): + async for row in some_async_cursor(ctx): + yield {"row": row} + yield {"__is_final__": True} # explicit sentinel (optional) + + +# isinstance(driver, StreamingDriver) is runtime-checkable. +assert isinstance(MyStreamingDriver(), StreamingDriver) + +async for frame in kernel.invoke_stream(token, principal=p, args={}): + handle_chunk(frame) + if frame.is_final: + break +``` + +When the resolved driver does **not** implement `StreamingDriver`, +`Kernel.invoke_stream` falls back to a single `Driver.execute()` call and +yields exactly one `Frame` with `is_final=True`. Each invocation produces +one `ActionTrace` covering the whole stream. + +## Observability + +`agent_kernel.instrument_kernel(kernel)` installs OpenTelemetry spans and +metric emission on `Kernel.invoke` and `Kernel.grant_capability`: + +```python +from agent_kernel import Kernel, instrument_kernel, OTEL_AVAILABLE + +kernel = Kernel(registry=...) +if OTEL_AVAILABLE: + instrument_kernel(kernel) # no-op when [otel] extra not installed +``` + +Spans: `agent_kernel.invoke`, `agent_kernel.grant`. Metrics: +`agent_kernel.invocations` (counter), `agent_kernel.invocation_duration` +(histogram, ms), `agent_kernel.policy_denials` (counter). The call is +idempotent — repeat invocations on the same kernel are no-ops. diff --git a/docs/federation.md b/docs/federation.md new file mode 100644 index 0000000..5bda6c1 --- /dev/null +++ b/docs/federation.md @@ -0,0 +1,216 @@ +# Capability Federation — Marketplace Part 1 + +> Issue [#52](https://github.com/dgenio/agent-kernel/issues/52) (manifest +> format & local registry) is implemented here. Discovery over a network +> (issue [#51](https://github.com/dgenio/agent-kernel/issues/51)) is **not** +> part of this milestone — `agent-kernel` does not fetch manifests over +> HTTP or sign them on your behalf yet. Bring your own transport for now. + +## What this gives you + +A single kernel can: + +1. **Advertise** its capabilities as a JSON-serialisable + [`CapabilityManifest`](../src/agent_kernel/models.py). +2. **Import** another kernel's manifest, registering each capability locally + and routing invocations through a caller-supplied driver + (typically [`HTTPDriver`](integrations.md) or + [`MCPDriver`](integrations.md)). + +Every imported invocation still runs through the *local* policy → token → +firewall pipeline. The remote endpoint is never trusted to authorise on the +importing kernel's behalf. This keeps weaver-spec invariants intact for +imported capabilities: + +| Invariant | How it's enforced for imports | +|-----------|------------------------------| +| **I-01** — Firewall on every result | The local `Firewall` runs on the driver's `RawResult` exactly as for native capabilities. | +| **I-02** — Authorize + audit each call | The local `PolicyEngine` evaluates every request; the local `TraceStore` records every action. | +| **I-06** — Tokens bind principal + capability + constraints | Tokens are signed with the importing kernel's HMAC secret. A token issued by Kernel A cannot be verified by Kernel B, because their secrets differ. | + +## Publishing a manifest + +```python +from agent_kernel import ( + Capability, CapabilityRegistry, HMACTokenProvider, Kernel, + SafetyClass, SensitivityTag, +) + +registry = CapabilityRegistry() +registry.register( + Capability( + capability_id="billing.invoices.list", + name="List Invoices", + description="List recent invoices", + safety_class=SafetyClass.READ, + sensitivity=SensitivityTag.PII, + tags=["billing", "invoices"], + ) +) + +kernel = Kernel( + registry=registry, + token_provider=HMACTokenProvider(secret="…"), + kernel_id="agent-b", +) + +manifest = kernel.advertise(endpoint="https://agent-b.example/kernel") +print(manifest.to_dict()) +# { +# "kernel_id": "agent-b", +# "version": "1", +# "endpoint": "https://agent-b.example/kernel", +# "trust_level": "unverified", +# "capabilities": [ +# { +# "capability_id": "billing.invoices.list", +# "name": "List Invoices", +# … +# } +# ] +# } +``` + +The manifest deliberately omits internal driver IDs, operation names, +`parameters_model` Python references, and `tool_hints`. Only the +[`CapabilityDescriptor`](../src/agent_kernel/models.py) projection of each +capability is published. + +## Importing a manifest + +```python +import json + +import httpx +from agent_kernel import ( + CapabilityManifest, CapabilityRegistry, HMACTokenProvider, Kernel, +) +from agent_kernel.drivers.http import HTTPDriver, HTTPEndpoint + +# 1. Fetch the manifest by whatever transport suits you. +raw = httpx.get("https://agent-b.example/kernel/manifest").json() +manifest = CapabilityManifest.from_dict(raw) + +# 2. Build a local driver pointing at the remote endpoint. +remote = HTTPDriver(driver_id="agent-b") +for cap in manifest.capabilities: + remote.register_endpoint( + cap.capability_id, + HTTPEndpoint(url=f"{manifest.endpoint}/invoke/{cap.capability_id}", + method="POST"), + ) + +# 3. Import. `import_remote` registers the driver and adds routes. +kernel = Kernel( + registry=CapabilityRegistry(), + token_provider=HMACTokenProvider(secret="local-secret"), + kernel_id="agent-a", +) +kernel.import_remote(manifest, driver=remote, trust_policy="most_restrictive") + +# 4. Use imported capabilities exactly like local ones. +for cap in kernel.list_capabilities(): + print(cap.capability_id, "→", cap.impl.driver_id) +``` + +## Trust policies + +`import_remote(manifest, driver=..., trust_policy=...)` accepts three +values for `trust_policy`: + +| Value | Sensitivity handling | When to use | +|-------|---------------------|-------------| +| `"most_restrictive"` *(default)* | Imported capability keeps the remote `SensitivityTag` verbatim — the local firewall will then redact accordingly. | Crossing trust boundaries — when you can't fully verify the remote's policy. | +| `"local_only"` | Imported capability is registered with `SensitivityTag.NONE`; the importing kernel's policy is the only thing that gates the call. | You own both kernels and have a single canonical policy. | +| `"remote_deferred"` | Same sensitivity handling as `most_restrictive` today. Reserved for part 2, when the importing kernel will be able to defer to a remote policy decision before applying its own. | Delegation patterns where the remote owns the authoritative policy. | + +`merge_sensitivity(local, remote)` is exported for callers that maintain +their own capability records and want the canonical strictest-wins union. + +## What is *not* covered yet + +- **No network transport.** `agent-kernel` does not fetch, sign, or + authenticate manifests over HTTP — bring your own transport. Part 2 + (issue #51) adds an opt-in manifest endpoint and a discovery protocol. +- **No remote policy delegation.** `"remote_deferred"` currently behaves + identically to `"most_restrictive"`. The full "remote policy decides + first" semantics need part 2. +- **No automatic re-import.** Manifests are imported once. If the publisher + adds capabilities, the importer must re-fetch and re-import. +- **No identity verification.** `trust_level` is a publisher-declared hint; + it does not authenticate the publisher. Signature verification arrives + with part 2. + +## Reference + +- Models: [`CapabilityDescriptor`](../src/agent_kernel/models.py), + [`CapabilityManifest`](../src/agent_kernel/models.py). +- Functions: [`build_manifest`](../src/agent_kernel/federation.py), + [`import_manifest`](../src/agent_kernel/federation.py), + [`merge_sensitivity`](../src/agent_kernel/federation.py). +- Kernel methods: `Kernel.advertise()`, `Kernel.import_remote()`, + `Kernel.kernel_id`. +- Errors: `FederationError`, `ManifestError`, `TrustPolicyError`. + +## Federated discovery (part 2, issue #51) + +The discovery layer on top of the local marketplace adds two pieces: + +1. **Signed manifests.** `sign_manifest(manifest, secret=...)` wraps a + manifest in an `HMAC-SHA256` envelope. `verify_manifest(envelope, + secret=...)` validates the signature and returns the embedded + `CapabilityManifest`. Tampered or wrong-secret envelopes raise + `ManifestSignatureError`. + +2. **HTTP discovery.** `discover_peers(...)` fetches one or more manifests + over HTTP, either from direct peer URLs or by first resolving a + registry URL that returns a JSON list of peer URLs. + +```python +from agent_kernel import discover_peers, sign_manifest, serve_manifest_payload + +# Publisher side — expose the manifest from any ASGI framework. +@app.get("/kernel/manifest") +async def manifest_endpoint(): + return serve_manifest_payload(kernel.advertise(endpoint="..."), secret=SECRET) + +# Importer side. +manifests = await kernel.discover_peers( + peer_urls=["https://peer-a/manifest", "https://peer-b/manifest"], + secret=SECRET, # mandatory if peers serve signed envelopes +) +for m in manifests: + kernel.import_remote(m, driver=HTTPDriver.from_manifest(m)) +``` + +### Asymmetric signing modes + +`discover_peers` is **strict** about signing: if you pass a `secret`, every +manifest must be signed; if you don't, every manifest must be unsigned. +Receiving the "wrong" shape raises `ManifestSignatureError`. This avoids +the silent-downgrade pitfall where an attacker strips the signature to +serve an unsigned manifest in its place. + +### Rate limiting + +`DiscoveryRateLimiter` (default: 10 calls per 60 seconds) caps how often +`discover_peers` can hit the network. The limiter is per-instance — share +one across calls to enforce a session-wide budget. Exceeding the budget +raises `DiscoveryError`. + +### Security boundary (still holds) + +Discovery does not change the import/invoke pipeline. Even after a +successful `discover_peers` + `import_remote`, every invocation still +flows through the *local* policy → token → firewall pipeline. Discovery +only decides *what* capabilities a kernel might import; it never grants +authority. + +### Reference + +- Functions: [`discover_peers`](../src/agent_kernel/federation_discovery.py), + [`sign_manifest`](../src/agent_kernel/federation_discovery.py), + [`verify_manifest`](../src/agent_kernel/federation_discovery.py), + [`serve_manifest_payload`](../src/agent_kernel/federation_discovery.py). +- Kernel methods: `Kernel.discover_peers()`. +- New errors: `ManifestSignatureError`, `DiscoveryError`. diff --git a/docs/integrations.md b/docs/integrations.md index d8e4a18..f1ae165 100644 --- a/docs/integrations.md +++ b/docs/integrations.md @@ -327,3 +327,38 @@ failures, and hook abort signals are all returned to the LLM as a tool result with `error: true` (Anthropic also sets `is_error: true`). Raised exceptions would crash the surrounding agent loop; the LLM cannot react to an exception. + +## OpenTelemetry + +`agent_kernel.instrument_kernel(kernel)` wraps `Kernel.invoke` and +`Kernel.grant_capability` with OTel spans + metrics. The optional +`[otel]` extra brings in `opentelemetry-api`; everything is a strict +no-op when the extra is not installed. + +```bash +pip install 'weaver-kernel[otel]' # api only — for production +pip install 'weaver-kernel[otel]' opentelemetry-sdk \ + opentelemetry-exporter-otlp # also the SDK + exporter +``` + +```python +from agent_kernel import Kernel, instrument_kernel + +kernel = Kernel(registry=...) +instrument_kernel(kernel) +# Production: rely on global TracerProvider/MeterProvider configured at +# process start. Tests can pass explicit providers: +# instrument_kernel(kernel, tracer_provider=..., meter_provider=...) +``` + +| Telemetry | Name | Notes | +|-----------|------|-------| +| Span | `agent_kernel.invoke` | attrs: `principal_id`, `capability_id`, `response_mode`, `dry_run` | +| Span | `agent_kernel.grant` | attrs: `principal_id`, `capability_id` | +| Counter | `agent_kernel.invocations` | labels: `capability_id`, `status` (`success`/`error`) | +| Histogram | `agent_kernel.invocation_duration` | unit: ms | +| Counter | `agent_kernel.policy_denials` | labels: `capability_id`, `reason_code` | + +`instrument_kernel` is idempotent — calling twice on the same kernel is a +no-op. Use `agent_kernel.otel.reset_instrumentation(kernel)` in tests to +re-instrument with a different provider. diff --git a/pyproject.toml b/pyproject.toml index a034144..1fe4af3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,3 +89,7 @@ ignore_missing_imports = true [[tool.mypy.overrides]] module = "mcp.*" ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "opentelemetry.*" +ignore_missing_imports = true diff --git a/src/agent_kernel/__init__.py b/src/agent_kernel/__init__.py index e7449b5..de4a00b 100644 --- a/src/agent_kernel/__init__.py +++ b/src/agent_kernel/__init__.py @@ -31,6 +31,11 @@ from agent_kernel import OpenAIMiddleware, AnthropicMiddleware +Federation (capability marketplace):: + + from agent_kernel import CapabilityManifest, CapabilityDescriptor + from agent_kernel import build_manifest, import_manifest, TrustPolicy + Errors:: from agent_kernel import ( @@ -39,6 +44,7 @@ PolicyDenied, PolicyConfigError, DriverError, FirewallError, BudgetExhausted, BudgetConfigError, CapabilityNotFound, HandleNotFound, HandleExpired, + NamespaceNotFound, FederationError, ManifestError, TrustPolicyError, ) """ @@ -48,34 +54,59 @@ from .drivers.mcp import MCPDriver from .drivers.memory import InMemoryDriver, make_billing_driver from .enums import SafetyClass, SensitivityTag -from .errors import ( +from .errors import ( # noqa: I001 - keep group ordering stable AdapterParseError, AgentKernelError, BudgetConfigError, BudgetExhausted, CapabilityAlreadyRegistered, CapabilityNotFound, + DiscoveryError, DriverError, + FederationError, FirewallError, HandleExpired, HandleNotFound, + ManifestError, + ManifestSignatureError, + NamespaceNotFound, PolicyConfigError, PolicyDenied, TokenExpired, TokenInvalid, TokenRevoked, TokenScopeError, + TrustPolicyError, +) +from .federation import ( + MANIFEST_VERSION, + TrustPolicy, + build_manifest, + import_manifest, + merge_sensitivity, +) +from .federation_discovery import ( + DiscoveryRateLimiter, + discover_peers, + serve_manifest_payload, + sign_manifest, + verify_manifest, ) from .firewall.budget_manager import BudgetManager from .firewall.budgets import Budgets from .firewall.token_counting import TokenCounter, default_token_counter from .firewall.transform import Firewall from .handles import HandleStore -from .kernel import Kernel +from .kernel import ( + Kernel, + StreamingDriver, # re-export for backwards-compatible imports +) from .models import ( ActionTrace, Capability, + CapabilityDescriptor, CapabilityGrant, + CapabilityManifest, CapabilityRequest, DenialExplanation, DryRunResult, @@ -83,6 +114,7 @@ Frame, Handle, ImplementationRef, + NamespaceMetadata, PolicyDecision, PolicyDecisionTrace, PolicyTraceStep, @@ -92,7 +124,9 @@ ResponseMode, RoutePlan, ToolHints, + TrustLevel, ) +from .otel import OTEL_AVAILABLE, instrument_kernel from .policy import DefaultPolicyEngine, ExplainingPolicyEngine, PolicyEngine from .policy_dsl import DeclarativePolicyEngine, PolicyMatch, PolicyRule from .policy_reasons import AllowReason, DenialReason @@ -112,7 +146,9 @@ "CapabilityRegistry", # models "Capability", + "CapabilityDescriptor", "CapabilityGrant", + "CapabilityManifest", "CapabilityRequest", "CapabilityToken", "DenialExplanation", @@ -121,6 +157,7 @@ "Frame", "Handle", "ImplementationRef", + "NamespaceMetadata", "PolicyDecision", "PolicyDecisionTrace", "PolicyTraceStep", @@ -131,6 +168,7 @@ "RoutePlan", "ActionTrace", "ToolHints", + "TrustLevel", # enums "SafetyClass", "SensitivityTag", @@ -141,16 +179,34 @@ "BudgetExhausted", "CapabilityAlreadyRegistered", "CapabilityNotFound", + "DiscoveryError", "DriverError", + "FederationError", "FirewallError", "HandleExpired", "HandleNotFound", + "ManifestError", + "ManifestSignatureError", + "NamespaceNotFound", "PolicyConfigError", "PolicyDenied", "TokenExpired", "TokenInvalid", "TokenRevoked", "TokenScopeError", + "TrustPolicyError", + # federation + "MANIFEST_VERSION", + "TrustPolicy", + "build_manifest", + "import_manifest", + "merge_sensitivity", + # federation discovery (issue #51) + "DiscoveryRateLimiter", + "discover_peers", + "serve_manifest_payload", + "sign_manifest", + "verify_manifest", # policy "AllowReason", "DefaultPolicyEngine", @@ -183,4 +239,9 @@ # adapters "AnthropicMiddleware", "OpenAIMiddleware", + # observability + "OTEL_AVAILABLE", + "instrument_kernel", + # streaming + "StreamingDriver", ] diff --git a/src/agent_kernel/drivers/base.py b/src/agent_kernel/drivers/base.py index 7b43235..51e23fa 100644 --- a/src/agent_kernel/drivers/base.py +++ b/src/agent_kernel/drivers/base.py @@ -2,8 +2,9 @@ from __future__ import annotations +from collections.abc import AsyncIterator from dataclasses import dataclass, field -from typing import Any, Protocol +from typing import Any, Protocol, runtime_checkable from ..models import RawResult @@ -40,3 +41,52 @@ async def execute(self, ctx: ExecutionContext) -> RawResult: DriverError: If execution fails. """ ... + + +@runtime_checkable +class StreamingDriver(Protocol): + """Optional extension to :class:`Driver` for chunked output. + + Drivers that can produce results incrementally implement this + protocol in addition to :class:`Driver`. :meth:`Kernel.invoke_stream` + uses ``isinstance(driver, StreamingDriver)`` (runtime-checkable) to + detect support and falls back to :meth:`Driver.execute` when a + driver only implements the base protocol. + + Chunks are plain dictionaries; each is run through the firewall + independently so PII redaction applies on a per-chunk basis. A + chunk may carry the synthetic key ``"__is_final__": True`` to mark + the last chunk explicitly — otherwise consumers should treat the + iterator's natural end as end-of-stream. + """ + + @property + def driver_id(self) -> str: # pragma: no cover - protocol stub + ... + + async def execute( + self, ctx: ExecutionContext + ) -> RawResult: # pragma: no cover - protocol stub + ... + + def execute_stream(self, ctx: ExecutionContext) -> AsyncIterator[dict[str, Any]]: + """Execute a capability and yield response chunks. + + Declared with ``def`` (not ``async def``) because async-generator + implementations — the natural shape, using ``async def`` + ``yield`` + — return the async iterator *directly* when called. An + ``async def`` Protocol signature would force callers to first + ``await`` the result, breaking the async-generator idiom. + + Args: + ctx: Execution context including capability ID, args, and constraints. + + Returns: + An async iterator of dictionary payloads — one per chunk. A + chunk may carry the synthetic key ``"__is_final__": True`` + to mark itself as the last one. + + Raises: + DriverError: If execution fails (may be raised mid-stream). + """ + ... diff --git a/src/agent_kernel/errors.py b/src/agent_kernel/errors.py index 7944f7b..a44979d 100644 --- a/src/agent_kernel/errors.py +++ b/src/agent_kernel/errors.py @@ -126,3 +126,49 @@ class HandleNotFound(AgentKernelError): class HandleExpired(AgentKernelError): """Raised when a handle's TTL has elapsed.""" + + +# ── Registry / namespace errors ─────────────────────────────────────────────── + + +class NamespaceNotFound(AgentKernelError): + """Raised when a namespace prefix is not known to the registry.""" + + +# ── Federation errors ───────────────────────────────────────────────────────── + + +class FederationError(AgentKernelError): + """Base class for federation / capability marketplace failures.""" + + +class TrustPolicyError(FederationError): + """Raised when a federation request violates the configured trust policy. + + Examples include an unknown trust policy name, a remote manifest from an + untrusted endpoint, or a token that originated outside the importing + kernel's HMAC scope. + """ + + +class ManifestError(FederationError): + """Raised when a :class:`~agent_kernel.models.CapabilityManifest` cannot be + serialised, parsed, or imported (e.g. missing fields, invalid version, + duplicate capability IDs). + """ + + +class ManifestSignatureError(FederationError): + """Raised when a signed manifest fails HMAC verification. + + A signature mismatch indicates either tampering, a wrong shared + secret, or a bug in the publisher's signing code. Either way the + manifest must not be imported — the importer should reject the + payload outright. + """ + + +class DiscoveryError(FederationError): + """Raised when peer/registry discovery fails (network error, malformed + response, or rate-limit hit). + """ diff --git a/src/agent_kernel/federation.py b/src/agent_kernel/federation.py new file mode 100644 index 0000000..08f327c --- /dev/null +++ b/src/agent_kernel/federation.py @@ -0,0 +1,244 @@ +"""Capability marketplace — manifest format and local-registry federation. + +This module implements *part 1* of the capability marketplace protocol +(issue #52): one kernel can advertise its capabilities as a +:class:`~agent_kernel.models.CapabilityManifest`, and a second kernel can +import that manifest to extend its own registry. Remote invocation is then +performed by routing imported capabilities to a caller-supplied +:class:`~agent_kernel.drivers.base.Driver` (typically an +:class:`~agent_kernel.drivers.http.HTTPDriver` or +:class:`~agent_kernel.drivers.mcp.MCPDriver`) — every imported call still +flows through the *local* policy → token → firewall pipeline, satisfying +weaver-spec I-01 / I-02 / I-06. + +Discovery (part 2 of the marketplace, issue #51) is out of scope here — this +module is purely local. Manifests are constructed by ``Kernel.advertise()`` +and consumed by ``Kernel.import_remote()``; the importing side is free to +fetch them by any transport. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +from .enums import SensitivityTag +from .errors import ManifestError, TrustPolicyError +from .models import ( + Capability, + CapabilityDescriptor, + CapabilityManifest, + ImplementationRef, +) + +if TYPE_CHECKING: + from .registry import CapabilityRegistry + +MANIFEST_VERSION = "1" +"""Schema version published by :func:`build_manifest`.""" + +TrustPolicy = Literal["most_restrictive", "local_only", "remote_deferred"] +"""How an importing kernel weighs descriptor metadata against its own policy. + +- ``"most_restrictive"`` (default): the descriptor's sensitivity tag is + honoured as a floor — even if the importing kernel's policy would treat + the capability as ``NONE``, the imported capability keeps the remote tag. + Required by use cases that span trust boundaries. +- ``"local_only"``: the importing kernel ignores the descriptor's + sensitivity tag and registers the imported capability with + :attr:`~agent_kernel.enums.SensitivityTag.NONE`. Use when the importer + owns both kernels and has a single canonical policy. +- ``"remote_deferred"``: the descriptor's sensitivity tag is preserved + verbatim and treated as the remote policy's input; the importing kernel + layers its own policy on top. +""" + +_VALID_TRUST_POLICIES: frozenset[str] = frozenset( + {"most_restrictive", "local_only", "remote_deferred"} +) + +# Ordering used by ``most_restrictive`` when picking between two sensitivity +# tags. The strictest tag wins. +_SENSITIVITY_RANK: dict[SensitivityTag, int] = { + SensitivityTag.NONE: 0, + SensitivityTag.PII: 2, + SensitivityTag.PCI: 3, + SensitivityTag.SECRETS: 4, +} + + +def _stricter(a: SensitivityTag, b: SensitivityTag) -> SensitivityTag: + """Return the stricter of two sensitivity tags (higher rank wins).""" + if _SENSITIVITY_RANK.get(b, 0) > _SENSITIVITY_RANK.get(a, 0): + return b + return a + + +def build_manifest( + *, + kernel_id: str, + registry: CapabilityRegistry, + endpoint: str, + trust_level: Literal["verified", "unverified"] = "unverified", +) -> CapabilityManifest: + """Build a public-facing :class:`CapabilityManifest` for *registry*. + + Internal implementation details (``ImplementationRef``, ``parameters_model`` + Python references, ``tool_hints``) are stripped — only fields safe to share + over the wire are emitted. + + Args: + kernel_id: Stable identifier of the advertising kernel. + registry: The :class:`CapabilityRegistry` whose contents to publish. + endpoint: Transport endpoint at which the advertising kernel can be + reached. Format is transport-specific (e.g. + ``"https://agent-a.example/kernel"``). + trust_level: Publisher-declared trust hint. The importing kernel still + applies its configured trust policy regardless. + + Returns: + A :class:`CapabilityManifest` ready to be serialised with + :meth:`CapabilityManifest.to_dict`. + """ + descriptors = [_descriptor_for(cap) for cap in registry.list_all()] + return CapabilityManifest( + kernel_id=kernel_id, + version=MANIFEST_VERSION, + capabilities=descriptors, + endpoint=endpoint, + trust_level=trust_level, + ) + + +def _descriptor_for(cap: Capability) -> CapabilityDescriptor: + """Project a :class:`Capability` onto its safe-to-share descriptor view.""" + return CapabilityDescriptor( + capability_id=cap.capability_id, + name=cap.name, + description=cap.description, + safety_class=cap.safety_class, + sensitivity=cap.sensitivity, + tags=list(cap.tags), + parameters_schema=cap.parameters_schema, + ) + + +def import_manifest( + *, + manifest: CapabilityManifest, + registry: CapabilityRegistry, + driver_id: str, + trust_policy: TrustPolicy = "most_restrictive", +) -> list[Capability]: + """Register a remote manifest's capabilities into *registry*. + + Each descriptor becomes a regular :class:`Capability` whose + :class:`ImplementationRef` points at the caller-supplied *driver_id*. + The importing kernel must register a matching driver with + :meth:`~agent_kernel.Kernel.register_driver`. Invocations on the + resulting capability flow through the full local pipeline — the remote + endpoint is never trusted to perform policy, token verification, or + firewall transformation on behalf of the importer. + + Args: + manifest: The remote :class:`CapabilityManifest` to import. + registry: The local :class:`CapabilityRegistry` to extend. + driver_id: The local driver ID that will execute imported capabilities. + The caller is responsible for registering a driver with that ID + (typically an :class:`~agent_kernel.drivers.http.HTTPDriver` or + :class:`~agent_kernel.drivers.mcp.MCPDriver` configured for + ``manifest.endpoint``). + trust_policy: How the importer weighs the manifest's sensitivity + metadata. See :data:`TrustPolicy`. + + Returns: + The list of imported :class:`Capability` objects, in manifest order. + + Raises: + TrustPolicyError: If *trust_policy* is not one of the documented values. + ManifestError: If the manifest is malformed (missing fields, wrong + version, or contains a capability ID already registered locally). + """ + if trust_policy not in _VALID_TRUST_POLICIES: + raise TrustPolicyError( + f"Unknown trust_policy '{trust_policy}'. " + f"Expected one of: {sorted(_VALID_TRUST_POLICIES)}." + ) + if manifest.version != MANIFEST_VERSION: + raise ManifestError( + f"Manifest version '{manifest.version}' is not supported by this " + f"kernel (expected '{MANIFEST_VERSION}'). Upgrade agent-kernel or " + "re-publish the manifest with the supported version." + ) + if not manifest.endpoint: + raise ManifestError( + f"Manifest from kernel '{manifest.kernel_id}' has no endpoint. " + "Endpoints are required so the local kernel can route imported " + "capabilities to a driver." + ) + + imported: list[Capability] = [] + for descriptor in manifest.capabilities: + cap = _capability_for_import( + descriptor=descriptor, + driver_id=driver_id, + trust_policy=trust_policy, + ) + registry.register(cap) + imported.append(cap) + return imported + + +def _capability_for_import( + *, + descriptor: CapabilityDescriptor, + driver_id: str, + trust_policy: TrustPolicy, +) -> Capability: + """Materialise a local :class:`Capability` from a remote descriptor.""" + sensitivity = _resolve_sensitivity(descriptor.sensitivity, trust_policy) + # The descriptor never exposes a driver-side operation name; we mirror + # the convention used everywhere else in the kernel: drivers resolve + # ``args.get("operation", capability_id)``. Imported capabilities therefore + # default operation to the capability_id. + impl = ImplementationRef(driver_id=driver_id, operation=descriptor.capability_id) + return Capability( + capability_id=descriptor.capability_id, + name=descriptor.name, + description=descriptor.description, + safety_class=descriptor.safety_class, + sensitivity=sensitivity, + tags=list(descriptor.tags), + impl=impl, + parameters_schema=descriptor.parameters_schema, + ) + + +def _resolve_sensitivity(remote: SensitivityTag, trust_policy: TrustPolicy) -> SensitivityTag: + """Apply *trust_policy* to a remote sensitivity tag. + + ``most_restrictive`` and ``remote_deferred`` both keep the remote tag — + they differ only in which side's *policy engine* is consulted at + invocation time, which is part 2 of the marketplace work. ``local_only`` + strips the remote sensitivity entirely. + """ + if trust_policy == "local_only": + return SensitivityTag.NONE + return remote + + +def merge_sensitivity(local: SensitivityTag, remote: SensitivityTag) -> SensitivityTag: + """Return the stricter of *local* and *remote* sensitivity tags. + + Exposed for callers that maintain their own capability records outside + the registry and want the canonical ``most_restrictive`` union rule. + """ + return _stricter(local, remote) + + +__all__ = [ + "MANIFEST_VERSION", + "TrustPolicy", + "build_manifest", + "import_manifest", + "merge_sensitivity", +] diff --git a/src/agent_kernel/federation_discovery.py b/src/agent_kernel/federation_discovery.py new file mode 100644 index 0000000..a773fde --- /dev/null +++ b/src/agent_kernel/federation_discovery.py @@ -0,0 +1,298 @@ +"""Federated discovery + signed manifests (issue #51). + +Builds on the local marketplace foundation in :mod:`federation` (issue +#52 — :class:`CapabilityManifest`, :func:`build_manifest`, +:func:`import_manifest`). This module adds the network-layer pieces: + +* :func:`sign_manifest` / :func:`verify_manifest` — HMAC-signed payload + envelopes so an importer can detect tampering. +* :func:`discover_peers` — async fetch of manifests from a registry URL + or list of peer URLs, with per-discovery-call rate limiting. +* :func:`serve_manifest_payload` — a transport-agnostic helper that + returns a signed-envelope JSON dict ready to be exposed by any ASGI + framework (Starlette, FastAPI, Litestar, etc.). + +Security boundary +----------------- + +Discovery does not authorise capability execution by itself. Even after a +successful :func:`discover_peers` + :func:`Kernel.import_remote`, every +invocation still flows through the *local* policy → token → firewall +pipeline. Discovery only decides *what* capabilities a kernel might +import — not *how* it executes them. See :ref:`docs/federation.md`. +""" + +from __future__ import annotations + +import asyncio +import hashlib +import hmac +import json +import time +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Any + +import httpx + +from .errors import DiscoveryError, ManifestError, ManifestSignatureError +from .models import CapabilityManifest + +SIGNATURE_ALGORITHM = "HMAC-SHA256" +"""Wire-level identifier embedded in every signed envelope.""" + +_DEFAULT_TIMEOUT_SECONDS = 5.0 +"""Default per-request timeout used by :func:`discover_peers`.""" + + +def _hash_payload(payload: bytes) -> bytes: + """Return the SHA-256 digest of *payload* for signing.""" + return hashlib.sha256(payload).digest() + + +def sign_manifest(manifest: CapabilityManifest, *, secret: str) -> dict[str, Any]: + """Wrap *manifest* in a signed envelope ready for transport. + + The envelope shape is:: + + { + "payload": "", + "algorithm": "HMAC-SHA256", + "signature": "", + } + + Splitting the JSON-encoded payload from the dict-level keys means + the signature is over the *exact* bytes the importer will hash — + canonicalisation differences (key ordering, whitespace) don't matter. + + Args: + manifest: The manifest to sign. + secret: Shared secret bound to the publishing kernel. + + Returns: + A dict ready to be serialised as JSON. + """ + payload = json.dumps(manifest.to_dict(), sort_keys=True).encode("utf-8") + signature = hmac.new(secret.encode("utf-8"), payload, hashlib.sha256).hexdigest() + return { + "payload": payload.decode("utf-8"), + "algorithm": SIGNATURE_ALGORITHM, + "signature": signature, + } + + +def verify_manifest(envelope: dict[str, Any], *, secret: str) -> CapabilityManifest: + """Verify *envelope* and return the embedded :class:`CapabilityManifest`. + + Args: + envelope: A signed envelope as produced by :func:`sign_manifest`. + secret: Shared secret used to verify the signature. + + Raises: + ManifestSignatureError: If the signature does not match. + ManifestError: If the envelope is malformed. + """ + if not isinstance(envelope, dict): + raise ManifestError(f"Envelope must be a dict, got {type(envelope).__name__}.") + for key in ("payload", "algorithm", "signature"): + if key not in envelope: + raise ManifestError(f"Envelope missing required key '{key}'.") + if envelope["algorithm"] != SIGNATURE_ALGORITHM: + raise ManifestSignatureError( + f"Unsupported signature algorithm '{envelope['algorithm']}'; " + f"expected '{SIGNATURE_ALGORITHM}'." + ) + + payload_bytes = envelope["payload"].encode("utf-8") + expected_sig = hmac.new(secret.encode("utf-8"), payload_bytes, hashlib.sha256).hexdigest() + if not hmac.compare_digest(expected_sig, envelope["signature"]): + raise ManifestSignatureError( + "Manifest signature mismatch — payload may be tampered, or the " + "verification secret does not match the publisher's signing secret." + ) + + payload_data = json.loads(envelope["payload"]) + return CapabilityManifest.from_dict(payload_data) + + +def serve_manifest_payload( + manifest: CapabilityManifest, + *, + secret: str | None = None, +) -> dict[str, Any]: + """Return a JSON-serialisable payload for a manifest-serving HTTP route. + + If *secret* is provided the result is a signed envelope (per + :func:`sign_manifest`); otherwise the bare manifest is returned. This + helper is transport-agnostic — callers wire it into Starlette, + FastAPI, Litestar, or any other ASGI framework. + + Args: + manifest: The manifest to serve. + secret: Optional HMAC secret for signing. + """ + if secret is None: + return manifest.to_dict() + return sign_manifest(manifest, secret=secret) + + +@dataclass(slots=True) +class _RateLimitState: + """Per-discovery-call rate-limit tracker.""" + + timestamps: list[float] + + +class DiscoveryRateLimiter: + """Sliding-window limiter scoped to discovery calls. + + Default budget: 10 calls per 60 seconds. Configurable per-instance. + Backed by :func:`time.monotonic` so wall-clock changes don't affect + behavior. + """ + + def __init__(self, *, limit: int = 10, window_seconds: float = 60.0) -> None: + self._limit = limit + self._window = window_seconds + self._state = _RateLimitState(timestamps=[]) + + def acquire(self) -> None: + """Record a discovery call. Raises :class:`DiscoveryError` if over budget.""" + now = time.monotonic() + cutoff = now - self._window + self._state.timestamps = [t for t in self._state.timestamps if t > cutoff] + if len(self._state.timestamps) >= self._limit: + raise DiscoveryError( + f"Discovery rate limit exceeded: {self._limit} calls per " + f"{self._window}s. Wait and retry." + ) + self._state.timestamps.append(now) + + +async def _fetch_manifest( + client: httpx.AsyncClient, + url: str, + *, + secret: str | None, +) -> CapabilityManifest: + """Fetch one manifest from *url*. Used by :func:`discover_peers`.""" + try: + response = await client.get(url, timeout=_DEFAULT_TIMEOUT_SECONDS) + except httpx.HTTPError as exc: + raise DiscoveryError(f"Network error fetching '{url}': {exc}") from exc + + if response.status_code != 200: + raise DiscoveryError(f"Manifest endpoint '{url}' returned HTTP {response.status_code}.") + + try: + body = response.json() + except ValueError as exc: + raise DiscoveryError(f"Manifest endpoint '{url}' returned non-JSON.") from exc + + # Auto-detect signed vs. bare manifest. + if isinstance(body, dict) and "signature" in body and "payload" in body: + if secret is None: + raise ManifestSignatureError( + f"Manifest at '{url}' is signed but no verification secret was " + f"provided to discover_peers()." + ) + return verify_manifest(body, secret=secret) + if secret is not None: + raise ManifestSignatureError( + f"Manifest at '{url}' is unsigned but discover_peers() was called " + f"with a verification secret — refusing to trust an unsigned " + f"manifest when signing is expected." + ) + return CapabilityManifest.from_dict(body) + + +async def discover_peers( + *, + peer_urls: Iterable[str] | None = None, + registry_url: str | None = None, + secret: str | None = None, + rate_limiter: DiscoveryRateLimiter | None = None, + client: httpx.AsyncClient | None = None, +) -> list[CapabilityManifest]: + """Fetch one or more :class:`CapabilityManifest` from remote endpoints. + + Either *peer_urls* (direct manifest URLs) or *registry_url* (a URL + returning a JSON list of peer URLs) must be provided. + + Args: + peer_urls: Direct URLs that each return one manifest. + registry_url: URL that returns a JSON list of peer manifest URLs. + secret: HMAC secret for verifying signed manifests. Mandatory if + the manifest endpoint produces signed envelopes; refusing + unsigned manifests in that case is an explicit security + choice (see :func:`_fetch_manifest`). + rate_limiter: Optional :class:`DiscoveryRateLimiter`. Defaults to + a fresh limiter (10 calls / 60s). + client: Optional pre-configured :class:`httpx.AsyncClient` for + test injection. A new ephemeral client is created if omitted. + + Returns: + A list of :class:`CapabilityManifest` objects in the order they + were resolved. + + Raises: + DiscoveryError: If the network call fails, the response is + malformed, or the rate limit is exhausted. + ManifestSignatureError: If a signed manifest fails verification. + """ + if not peer_urls and not registry_url: + raise DiscoveryError("discover_peers() requires peer_urls or registry_url.") + + limiter = rate_limiter or DiscoveryRateLimiter() + owns_client = client is None + transport_client = client or httpx.AsyncClient() + + try: + urls: list[str] = list(peer_urls or []) + if registry_url is not None: + limiter.acquire() + try: + response = await transport_client.get( + registry_url, timeout=_DEFAULT_TIMEOUT_SECONDS + ) + except httpx.HTTPError as exc: + raise DiscoveryError( + f"Network error fetching registry '{registry_url}': {exc}" + ) from exc + if response.status_code != 200: + raise DiscoveryError( + f"Registry '{registry_url}' returned HTTP {response.status_code}." + ) + try: + registry_body = response.json() + except ValueError as exc: + raise DiscoveryError(f"Registry '{registry_url}' returned non-JSON.") from exc + if not isinstance(registry_body, list) or not all( + isinstance(u, str) for u in registry_body + ): + raise DiscoveryError( + f"Registry '{registry_url}' must return a JSON list of URL strings." + ) + urls.extend(registry_body) + + manifests: list[CapabilityManifest] = [] + for url in urls: + limiter.acquire() + manifests.append(await _fetch_manifest(transport_client, url, secret=secret)) + # Yield to the event loop between fetches so the limiter's + # monotonic clock advances measurably even on a fast network. + await asyncio.sleep(0) + return manifests + finally: + if owns_client: + await transport_client.aclose() + + +__all__ = [ + "SIGNATURE_ALGORITHM", + "DiscoveryRateLimiter", + "discover_peers", + "serve_manifest_payload", + "sign_manifest", + "verify_manifest", +] diff --git a/src/agent_kernel/firewall/transform.py b/src/agent_kernel/firewall/transform.py index dc64972..0127af1 100644 --- a/src/agent_kernel/firewall/transform.py +++ b/src/agent_kernel/firewall/transform.py @@ -5,6 +5,7 @@ import datetime import json import logging +from collections.abc import AsyncIterator from typing import Any from ..models import ( @@ -212,7 +213,69 @@ def transform( provenance=provenance, ) - # ── Helpers ─────────────────────────────────────────────────────────────── + async def apply_stream( + self, + response_chunks: AsyncIterator[dict[str, Any]], + *, + action_id: str, + capability_id: str, + principal_id: str, + principal_roles: list[str], + response_mode: ResponseMode, + constraints: dict[str, Any] | None = None, + ) -> AsyncIterator[Frame]: + """Stream chunks through the firewall, applying redaction per chunk. + + Each chunk is wrapped in a synthetic :class:`RawResult` and passed + through :meth:`transform`. The same admin gate, redaction, and + budget caps that apply to a single-shot :meth:`transform` apply to + *every* chunk — PII never leaks even when results stream in. + + Mode escalation across chunks (e.g. dropping from ``table`` to + ``summary`` as budget drains) is the caller's responsibility — the + Firewall itself is stateless. ``Kernel.invoke_stream`` orchestrates + escalation via :class:`BudgetManager.suggested_mode`. + + The synthetic key ``"__is_final__"`` on a chunk is stripped before + firewall processing and re-applied to the yielded Frame's + ``is_final`` attribute. If the iterator ends without ever + producing an explicit final chunk, no extra sentinel is yielded + here — that bookkeeping is left to higher layers. + + Args: + response_chunks: Async iterator of raw chunks from the driver. + action_id: The audit action ID for this stream. + capability_id: Capability being executed. + principal_id: Principal making the request. + principal_roles: Principal's roles (used for admin gate). + response_mode: Current response mode (may differ chunk-to-chunk + if the caller passes pre-escalated modes). + constraints: Active execution constraints. + + Yields: + :class:`Frame` chunks with ``is_final`` set on the last one. + """ + async for chunk in response_chunks: + is_final = bool(chunk.get("__is_final__", False)) + payload = {k: v for k, v in chunk.items() if k != "__is_final__"} + synthetic_raw = RawResult( + capability_id=capability_id, + data=payload, + metadata={"action_id": action_id, "streaming": True}, + ) + frame = self.transform( + synthetic_raw, + action_id=action_id, + principal_id=principal_id, + principal_roles=principal_roles, + response_mode=response_mode, + constraints=constraints, + ) + if is_final: + from dataclasses import replace + + frame = replace(frame, is_final=True) + yield frame def _make_table(self, data: Any, *, max_rows: int) -> list[dict[str, Any]]: """Convert *data* to a list of dicts, capped at *max_rows*.""" diff --git a/src/agent_kernel/kernel.py b/src/agent_kernel/kernel.py deleted file mode 100644 index d8cc3bf..0000000 --- a/src/agent_kernel/kernel.py +++ /dev/null @@ -1,581 +0,0 @@ -"""The Kernel: the main entry point for agent-kernel.""" - -from __future__ import annotations - -import datetime -import logging -import uuid -from typing import Any, Literal, overload - -from .drivers.base import Driver, ExecutionContext -from .enums import SafetyClass -from .errors import AgentKernelError, DriverError -from .firewall.budget_manager import BudgetManager -from .firewall.transform import Firewall -from .handles import HandleStore -from .models import ( - ActionTrace, - Capability, - CapabilityGrant, - CapabilityRequest, - DenialExplanation, - DryRunResult, - Frame, - Handle, - PolicyDecision, - PolicyDecisionTrace, - PolicyTraceStep, - Principal, - ResponseMode, - RoutePlan, -) -from .policy import DefaultPolicyEngine, PolicyEngine -from .policy_reasons import AllowReason -from .registry import CapabilityRegistry -from .router import Router, StaticRouter -from .tokens import CapabilityToken, HMACTokenProvider, TokenProvider -from .trace import TraceStore - -logger = logging.getLogger(__name__) - - -def _frame_payload(frame: Frame) -> Any: - """Return the LLM-facing payload from a :class:`Frame` for token counting. - - Only the data the LLM actually sees is counted — facts, table rows, - or raw data. Provenance metadata, action IDs, and handle IDs are - skipped because they are kernel bookkeeping rather than context. - """ - if frame.response_mode == "raw": - return frame.raw_data - if frame.response_mode == "table": - return frame.table_preview - if frame.response_mode == "handle_only": - return None - return frame.facts - - -class Kernel: - """The central orchestrator for capability-based AI agent security. - - The Kernel wires together the registry, policy engine, token provider, - router, firewall, handle store, and trace store into a single coherent - interface. - - Example:: - - registry = CapabilityRegistry() - registry.register(Capability(...)) - kernel = Kernel(registry) - - requests = kernel.request_capabilities("list invoices") - grant = kernel.grant_capability(requests[0], principal, justification="...") - frame = await kernel.invoke(grant.token, principal=principal, args={"operation": "list_invoices"}) - """ - - def __init__( - self, - registry: CapabilityRegistry, - policy: PolicyEngine | None = None, - token_provider: TokenProvider | None = None, - router: Router | None = None, - firewall: Firewall | None = None, - handle_store: HandleStore | None = None, - trace_store: TraceStore | None = None, - budget_manager: BudgetManager | None = None, - ) -> None: - self._registry = registry - self._policy: PolicyEngine = policy or DefaultPolicyEngine() - self._token_provider: TokenProvider = token_provider or HMACTokenProvider() - self._router: Router = router or StaticRouter() - self._firewall = firewall or Firewall() - self._handle_store = handle_store or HandleStore() - self._trace_store = trace_store or TraceStore() - self._budget_manager = budget_manager - self._drivers: dict[str, Driver] = {} - - # ── Budget accessor ──────────────────────────────────────────────────────── - - @property - def budget(self) -> BudgetManager | None: - """The cross-invocation :class:`BudgetManager`, or ``None`` if none is configured.""" - return self._budget_manager - - # ── Driver registration ──────────────────────────────────────────────────── - - def register_driver(self, driver: Driver) -> None: - """Register a driver with the kernel. - - Args: - driver: Any object implementing the :class:`~agent_kernel.drivers.base.Driver` protocol. - """ - self._drivers[driver.driver_id] = driver - - # ── Public API ───────────────────────────────────────────────────────────── - - def list_capabilities(self) -> list[Capability]: - """Return every capability registered with the kernel. - - Convenience accessor used by LLM adapters that need to enumerate the - full registry (e.g. ``OpenAIMiddleware.get_tools()``) without reaching - into private state. Capabilities are returned in registration order. - """ - return self._registry.list_all() - - def request_capabilities( - self, - goal: str, - *, - context_tags: dict[str, str] | None = None, - ) -> list[CapabilityRequest]: - """Discover capabilities that match a natural-language goal. - - Args: - goal: Free-text description of the agent's intent. - context_tags: Optional metadata to narrow the search (currently unused). - - Returns: - An ordered list of :class:`CapabilityRequest` objects (best match first). - """ - results = self._registry.search(goal) - logger.debug( - "request_capabilities", - extra={ - "goal": goal, - "matches": len(results), - }, - ) - return results - - def grant_capability( - self, - request: CapabilityRequest, - principal: Principal, - *, - justification: str, - ) -> CapabilityGrant: - """Evaluate the policy and, if approved, issue a signed token. - - Args: - request: The capability request to evaluate. - principal: The principal requesting access. - justification: Free-text justification for the request. - - Returns: - A :class:`CapabilityGrant` containing the signed token. - - Raises: - PolicyDenied: If the policy engine rejects the request. - CapabilityNotFound: If the requested capability is not registered. - """ - capability = self._registry.get(request.capability_id) - decision = self._policy.evaluate( - request, capability, principal, justification=justification - ) - audit_id = str(uuid.uuid4()) - token = self._token_provider.issue( - capability.capability_id, - principal.principal_id, - constraints=decision.constraints, - audit_id=audit_id, - ) - logger.info( - "grant_capability", - extra={ - "principal_id": principal.principal_id, - "capability_id": capability.capability_id, - "safety_class": capability.safety_class.value, - "audit_id": audit_id, - "token_id": token.token_id, - }, - ) - return CapabilityGrant( - request=request, - principal=principal, - decision=decision, - token=token, - audit_id=audit_id, - ) - - def get_token( - self, - request: CapabilityRequest, - principal: Principal, - *, - justification: str, - ) -> CapabilityToken: - """Like :meth:`grant_capability` but returns the token directly. - - Convenience wrapper for callers that don't need the full - :class:`CapabilityGrant`. Delegates entirely to - :meth:`grant_capability`; see its docstring for parameter and - exception details. - """ - return self.grant_capability(request, principal, justification=justification).token - - @overload - async def invoke( - self, - token: CapabilityToken, - *, - principal: Principal, - args: dict[str, Any], - response_mode: ResponseMode = ..., - dry_run: Literal[True], - ) -> DryRunResult: ... - - @overload - async def invoke( - self, - token: CapabilityToken, - *, - principal: Principal, - args: dict[str, Any], - response_mode: ResponseMode = ..., - dry_run: Literal[False] = ..., - ) -> Frame: ... - - async def invoke( - self, - token: CapabilityToken, - *, - principal: Principal, - args: dict[str, Any], - response_mode: ResponseMode = "summary", - dry_run: bool = False, - ) -> Frame | DryRunResult: - """Execute a capability using a signed token and return a Frame. - - When ``dry_run=True`` the full pipeline runs (token verification, - capability lookup, route resolution) but the driver is never called. - A :class:`DryRunResult` is returned instead of a :class:`Frame`. - - Args: - token: A signed :class:`CapabilityToken` authorising the invocation. - principal: The principal invoking the capability (must match token). - args: Arguments passed to the driver. - response_mode: How to present the result (``summary``, ``table``, - ``handle_only``, or ``raw``). - dry_run: When ``True``, skip driver execution and return a - :class:`DryRunResult` describing what would happen. - - Returns: - A bounded :class:`Frame`, or :class:`DryRunResult` when - ``dry_run=True``. - - Raises: - TokenRevoked: If the token has been revoked. - TokenExpired: If the token has expired. - TokenInvalid: If the token signature does not verify. - TokenScopeError: If the token belongs to a different principal or capability. - CapabilityNotFound: If the capability is not registered. - DriverError: If all drivers fail (not raised in dry-run mode). - """ - # ── Verify token ────────────────────────────────────────────────────── - self._token_provider.verify( - token, - expected_principal_id=principal.principal_id, - expected_capability_id=token.capability_id, - ) - - capability = self._registry.get(token.capability_id) - plan: RoutePlan = self._router.route(token.capability_id) - - # ── Dry-run short-circuit ───────────────────────────────────────────── - if dry_run: - driver_id = plan.driver_ids[0] if plan.driver_ids else "" - # Mirror driver operation resolution exactly (see InMemoryDriver, - # HTTPDriver, MCPDriver — all read ``args.get("operation", capability_id)``). - # Using ``capability.impl.operation`` here would diverge from what the - # driver actually executes at real-invoke time. - operation = str(args.get("operation", token.capability_id)) - # Mirror Firewall's admin-only gate for ``raw`` mode - # (see firewall/transform.py:108 and docs/agent-context/invariants.md). - # Dry-run must not let non-admin principals probe raw-mode availability. - effective_response_mode: ResponseMode = response_mode - if response_mode == "raw" and "admin" not in principal.roles: - effective_response_mode = "summary" - # Mirror the BudgetManager escalation an actual invoke would apply, - # so dry-run reports the mode the caller would really see. - if self._budget_manager is not None: - effective_response_mode = self._budget_manager.suggested_mode( - effective_response_mode - ) - _cost_map: dict[SafetyClass, Literal["low", "medium", "high"]] = { - SafetyClass.READ: "low", - SafetyClass.WRITE: "medium", - SafetyClass.DESTRUCTIVE: "high", - } - dry_run_trace = PolicyDecisionTrace( - engine="Kernel.invoke[dry_run]", - capability_id=token.capability_id, - principal_id=principal.principal_id, - intent=None, - scope_keys=[], - steps=[ - PolicyTraceStep( - name="token_verified", - outcome="allowed", - detail="Token verified; original policy decision was at grant time.", - reason_code=str(AllowReason.TOKEN_VERIFIED), - ) - ], - final_outcome="allowed", - final_reason_code=str(AllowReason.TOKEN_VERIFIED), - ) - return DryRunResult( - capability_id=token.capability_id, - principal_id=principal.principal_id, - policy_decision=PolicyDecision( - allowed=True, - reason="Token verified. Policy was evaluated at grant time.", - constraints=dict(token.constraints), - reason_code=str(AllowReason.TOKEN_VERIFIED), - trace=dry_run_trace, - ), - driver_id=driver_id, - operation=operation, - resolved_args=args, - response_mode=effective_response_mode, - budget_remaining=( - self._budget_manager.remaining if self._budget_manager is not None else None - ), - estimated_cost=_cost_map[capability.safety_class], - ) - - action_id = str(uuid.uuid4()) - - # ── Mirror Firewall's admin-only ``raw`` gate ───────────────────────── - # The Firewall downgrades raw → summary for non-admin principals - # (see firewall/transform.py and docs/agent-context/invariants.md). - # We must mirror that downgrade *before* deciding whether to store a - # handle and before consulting the budget manager, otherwise a - # non-admin asking for raw would get a summary frame *without* a - # handle (because the kernel skipped handle creation thinking the - # mode was still raw). - effective_mode: ResponseMode = response_mode - if response_mode == "raw" and "admin" not in principal.roles: - effective_mode = "summary" - - # ── Cross-invocation budget allocation & mode escalation ────────────── - # When a BudgetManager is attached, reserve a slice of the cumulative - # session budget before driver execution. The manager raises - # BudgetExhausted if no budget remains. The requested response_mode is - # escalated to a more aggressive tier as the remaining budget shrinks - # (see BudgetManager.suggested_mode). This change is invisible to - # callers without a BudgetManager — the original mode flows through. - reserved_tokens: int | None = None - if self._budget_manager is not None: - reserved_tokens = await self._budget_manager.allocate() - effective_mode = self._budget_manager.suggested_mode(effective_mode) - - _log_ctx = { - "action_id": action_id, - "principal_id": principal.principal_id, - "capability_id": token.capability_id, - } - logger.info( - "invoke_start", - extra={ - **_log_ctx, - "token_id": token.token_id, - "response_mode": response_mode, - "effective_mode": effective_mode, - }, - ) - - # ── Execute with fallback ───────────────────────────────────────────── - raw_result = None - used_driver_id = "" - last_error: Exception | None = None - - for driver_id in plan.driver_ids: - driver = self._drivers.get(driver_id) - if driver is None: - continue - ctx = ExecutionContext( - capability_id=token.capability_id, - principal_id=principal.principal_id, - args=args, - constraints=token.constraints, - action_id=action_id, - ) - try: - raw_result = await driver.execute(ctx) - used_driver_id = driver_id - logger.debug("driver_success", extra={**_log_ctx, "driver_id": driver_id}) - break - except DriverError as exc: - logger.warning( - "driver_failure", - extra={**_log_ctx, "driver_id": driver_id, "error": str(exc)}, - ) - last_error = exc - continue - - if raw_result is None: - # Release any reservation — no tokens were spent by the firewall. - if self._budget_manager is not None and reserved_tokens is not None: - await self._budget_manager.release(reserved_tokens) - err_msg = str(last_error) if last_error else "No drivers available." - logger.warning("invoke_failure", extra={**_log_ctx, "error": err_msg}) - trace = ActionTrace( - action_id=action_id, - capability_id=token.capability_id, - principal_id=principal.principal_id, - token_id=token.token_id, - invoked_at=datetime.datetime.now(tz=datetime.timezone.utc), - args=args, - response_mode=response_mode, - driver_id="", - error=err_msg, - ) - self._trace_store.record(trace) - raise DriverError( - f"All drivers failed for capability '{token.capability_id}'. Last error: {err_msg}" - ) - - # ── Store handle ────────────────────────────────────────────────────── - handle: Handle | None = None - if effective_mode != "raw": - handle = self._handle_store.store( - capability_id=token.capability_id, - data=raw_result.data, - ) - - # ── Firewall transform + budget reconciliation ──────────────────────── - # Both steps run inside a try/finally so a Firewall exception (e.g. - # malformed constraint inputs) still releases any outstanding budget - # reservation. record_usage replaces the reservation with committed - # usage; the finally branch only fires if we never got there. - reservation_consumed = False - try: - frame = self._firewall.transform( - raw_result, - action_id=action_id, - principal_id=principal.principal_id, - principal_roles=list(principal.roles), - response_mode=effective_mode, - constraints=token.constraints, - handle=handle, - ) - if self._budget_manager is not None and reserved_tokens is not None: - actual_tokens = self._budget_manager.count_tokens(_frame_payload(frame)) - await self._budget_manager.record_usage(actual_tokens, reserved=reserved_tokens) - reservation_consumed = True - finally: - if ( - not reservation_consumed - and self._budget_manager is not None - and reserved_tokens is not None - ): - await self._budget_manager.release(reserved_tokens) - - # ── Record trace ────────────────────────────────────────────────────── - trace = ActionTrace( - action_id=action_id, - capability_id=token.capability_id, - principal_id=principal.principal_id, - token_id=token.token_id, - invoked_at=datetime.datetime.now(tz=datetime.timezone.utc), - args=args, - response_mode=frame.response_mode, - driver_id=used_driver_id, - handle_id=handle.handle_id if handle else None, - ) - self._trace_store.record(trace) - - logger.info( - "invoke_success", - extra={**_log_ctx, "response_mode": frame.response_mode, "driver_id": used_driver_id}, - ) - return frame - - def expand(self, handle: Handle, *, query: dict[str, Any]) -> Frame: - """Expand a handle with pagination, field selection, or filtering. - - Args: - handle: The :class:`Handle` to expand. - query: Query parameters (``offset``, ``limit``, ``fields``, ``filter``). - - Returns: - A :class:`Frame` with the requested slice of data. - - Raises: - HandleNotFound: If the handle is unknown. - HandleExpired: If the handle has expired. - """ - logger.info( - "expand", - extra={ - "handle_id": handle.handle_id, - "capability_id": handle.capability_id, - }, - ) - return self._handle_store.expand(handle, query=query) - - def explain(self, action_id: str) -> ActionTrace: - """Retrieve the audit trace for a past invocation. - - Args: - action_id: The unique action identifier returned in a :class:`Frame`. - - Returns: - The :class:`ActionTrace` for that action. - - Raises: - AgentKernelError: If no trace exists for that action ID. - """ - logger.info( - "explain", - extra={ - "action_id": action_id, - }, - ) - return self._trace_store.get(action_id) - - def explain_denial( - self, - request: CapabilityRequest, - principal: Principal, - *, - justification: str = "", - ) -> DenialExplanation: - """Explain why *principal*'s *request* would be denied (or allowed). - - Delegates to the configured policy engine's ``explain()`` method. - Unlike :meth:`grant_capability`, this does not raise - :class:`PolicyDenied` when the policy fails — it returns a - :class:`DenialExplanation` instead. - - Note: Rate-limit state is not reflected here. A request denied due to - rate limits shows as ``denied=False`` in the explanation. - - Args: - request: The capability request to explain. - principal: The principal to evaluate the request for. - justification: Free-text justification (used in policy checks). - - Returns: - :class:`DenialExplanation` with ``denied=False`` if the request - would succeed. - - Raises: - CapabilityNotFound: If the capability is not registered. - AgentKernelError: If the configured policy engine does not - implement ``explain()``. Use :class:`DefaultPolicyEngine` or - :class:`DeclarativePolicyEngine` for structured explanations, - or add an ``explain()`` method to your engine. - """ - capability = self._registry.get(request.capability_id) - explain_fn = getattr(self._policy, "explain", None) - if explain_fn is None: - raise AgentKernelError( - f"Policy engine {type(self._policy).__name__!r} does not implement " - f"explain(); structured denial explanations are unavailable. " - f"Use DefaultPolicyEngine or DeclarativePolicyEngine, or add an " - f"explain() method to your engine." - ) - result = explain_fn(request, capability, principal, justification=justification) - assert isinstance(result, DenialExplanation) - return result diff --git a/src/agent_kernel/kernel/__init__.py b/src/agent_kernel/kernel/__init__.py new file mode 100644 index 0000000..a254b2f --- /dev/null +++ b/src/agent_kernel/kernel/__init__.py @@ -0,0 +1,412 @@ +"""The :class:`Kernel` — the main entry point for agent-kernel. + +The class lives in this package's ``__init__.py`` so existing imports +(``from agent_kernel.kernel import Kernel``) continue to work after the +split from a single-file module into a sub-package. Heavy +implementation is delegated to sibling modules (:mod:`._invoke`, +:mod:`._dry_run`) to honour AGENTS.md's ≤ 300-line module budget. +""" + +from __future__ import annotations + +import logging +import uuid +from collections.abc import AsyncIterator +from typing import Any, Literal, overload + +from ..drivers.base import Driver, StreamingDriver +from ..errors import AgentKernelError +from ..federation import TrustPolicy +from ..firewall.budget_manager import BudgetManager +from ..firewall.transform import Firewall +from ..handles import HandleStore +from ..models import ( + ActionTrace, + Capability, + CapabilityGrant, + CapabilityManifest, + CapabilityRequest, + DenialExplanation, + DryRunResult, + Frame, + Handle, + Principal, + ResponseMode, + RoutePlan, +) +from ..policy import DefaultPolicyEngine, PolicyEngine +from ..registry import CapabilityRegistry +from ..router import Router, StaticRouter +from ..tokens import CapabilityToken, HMACTokenProvider, TokenProvider +from ..trace import TraceStore +from ._dry_run import build_dry_run_result +from ._federation import ( + perform_advertise, + perform_discover_peers, + perform_import_remote, +) +from ._invoke import perform_invoke +from ._stream import invoke_stream_impl + +logger = logging.getLogger(__name__) + + +class Kernel: + """The central orchestrator for capability-based AI agent security. + + The Kernel wires together the registry, policy engine, token provider, + router, firewall, handle store, and trace store into a single coherent + interface. + + Example:: + + registry = CapabilityRegistry() + registry.register(Capability(...)) + kernel = Kernel(registry) + + requests = kernel.request_capabilities("list invoices") + grant = kernel.grant_capability(requests[0], principal, justification="...") + frame = await kernel.invoke( + grant.token, principal=principal, args={"operation": "list_invoices"} + ) + """ + + def __init__( + self, + registry: CapabilityRegistry, + policy: PolicyEngine | None = None, + token_provider: TokenProvider | None = None, + router: Router | None = None, + firewall: Firewall | None = None, + handle_store: HandleStore | None = None, + trace_store: TraceStore | None = None, + budget_manager: BudgetManager | None = None, + kernel_id: str = "agent-kernel", + ) -> None: + self._registry = registry + self._policy: PolicyEngine = policy or DefaultPolicyEngine() + self._token_provider: TokenProvider = token_provider or HMACTokenProvider() + self._router: Router = router or StaticRouter() + self._firewall = firewall or Firewall() + self._handle_store = handle_store or HandleStore() + self._trace_store = trace_store or TraceStore() + self._budget_manager = budget_manager + self._drivers: dict[str, Driver] = {} + self._kernel_id = kernel_id + + @property + def kernel_id(self) -> str: + """Stable identifier used when this kernel advertises its capabilities.""" + return self._kernel_id + + @property + def budget(self) -> BudgetManager | None: + """The cross-invocation :class:`BudgetManager`, or ``None`` if none is configured.""" + return self._budget_manager + + def register_driver(self, driver: Driver) -> None: + """Register a driver with the kernel.""" + self._drivers[driver.driver_id] = driver + + def list_capabilities(self) -> list[Capability]: + """Return every capability registered with the kernel.""" + return self._registry.list_all() + + def request_capabilities( + self, + goal: str, + *, + context_tags: dict[str, str] | None = None, + ) -> list[CapabilityRequest]: + """Discover capabilities that match a natural-language goal.""" + results = self._registry.search(goal) + logger.debug( + "request_capabilities", + extra={"goal": goal, "matches": len(results)}, + ) + return results + + def grant_capability( + self, + request: CapabilityRequest, + principal: Principal, + *, + justification: str, + ) -> CapabilityGrant: + """Evaluate the policy and, if approved, issue a signed token.""" + capability = self._registry.get(request.capability_id) + decision = self._policy.evaluate( + request, capability, principal, justification=justification + ) + audit_id = str(uuid.uuid4()) + token = self._token_provider.issue( + capability.capability_id, + principal.principal_id, + constraints=decision.constraints, + audit_id=audit_id, + ) + logger.info( + "grant_capability", + extra={ + "principal_id": principal.principal_id, + "capability_id": capability.capability_id, + "safety_class": capability.safety_class.value, + "audit_id": audit_id, + "token_id": token.token_id, + }, + ) + return CapabilityGrant( + request=request, + principal=principal, + decision=decision, + token=token, + audit_id=audit_id, + ) + + def get_token( + self, + request: CapabilityRequest, + principal: Principal, + *, + justification: str, + ) -> CapabilityToken: + """Like :meth:`grant_capability` but returns the token directly.""" + return self.grant_capability(request, principal, justification=justification).token + + @overload + async def invoke( + self, + token: CapabilityToken, + *, + principal: Principal, + args: dict[str, Any], + response_mode: ResponseMode = ..., + dry_run: Literal[True], + ) -> DryRunResult: ... + + @overload + async def invoke( + self, + token: CapabilityToken, + *, + principal: Principal, + args: dict[str, Any], + response_mode: ResponseMode = ..., + dry_run: Literal[False] = ..., + ) -> Frame: ... + + async def invoke( + self, + token: CapabilityToken, + *, + principal: Principal, + args: dict[str, Any], + response_mode: ResponseMode = "summary", + dry_run: bool = False, + ) -> Frame | DryRunResult: + """Execute a capability using a signed token. + + When ``dry_run=True`` the full pipeline runs (token verification, + capability lookup, route resolution) but the driver is never called; + a :class:`DryRunResult` is returned instead of a :class:`Frame`. + """ + self._token_provider.verify( + token, + expected_principal_id=principal.principal_id, + expected_capability_id=token.capability_id, + ) + capability = self._registry.get(token.capability_id) + plan: RoutePlan = self._router.route(token.capability_id) + if dry_run: + return build_dry_run_result( + token=token, + principal=principal, + capability=capability, + plan=plan, + args=args, + response_mode=response_mode, + budget_manager=self._budget_manager, + ) + return await perform_invoke( + self, + token=token, + principal=principal, + args=args, + response_mode=response_mode, + plan=plan, + ) + + async def invoke_stream( + self, + token: CapabilityToken, + *, + principal: Principal, + args: dict[str, Any], + response_mode: ResponseMode = "summary", + ) -> AsyncIterator[Frame]: + """Stream a capability invocation. + + Yields :class:`Frame` chunks as they arrive from the driver. The last + yielded frame has ``is_final=True``. Falls back to wrapping a + single-shot :meth:`Driver.execute` when the resolved driver does not + implement :class:`~agent_kernel.drivers.base.StreamingDriver`. + + The same security pipeline applies as in :meth:`invoke`: token + verification, admin-mode gate, budget escalation, firewall + redaction on *every* chunk, and one :class:`ActionTrace` for the + whole stream. + + Args: + token: A signed token authorising the invocation. + principal: The invoking principal (must match the token). + args: Driver arguments. + response_mode: Initial response mode. May be escalated mid-stream + if a :class:`BudgetManager` is attached and runs low on budget. + + Yields: + :class:`Frame` chunks. Consumers should look at ``is_final`` to + detect end-of-stream. + """ + self._token_provider.verify( + token, + expected_principal_id=principal.principal_id, + expected_capability_id=token.capability_id, + ) + capability = self._registry.get(token.capability_id) + plan: RoutePlan = self._router.route(token.capability_id) + async for frame in invoke_stream_impl( + kernel=self, + token=token, + principal=principal, + capability=capability, + plan=plan, + args=args, + response_mode=response_mode, + ): + yield frame + + def expand(self, handle: Handle, *, query: dict[str, Any]) -> Frame: + """Expand a handle with pagination, field selection, or filtering.""" + logger.info( + "expand", + extra={"handle_id": handle.handle_id, "capability_id": handle.capability_id}, + ) + return self._handle_store.expand(handle, query=query) + + def explain(self, action_id: str) -> ActionTrace: + """Retrieve the audit trace for a past invocation.""" + logger.info("explain", extra={"action_id": action_id}) + return self._trace_store.get(action_id) + + def explain_denial( + self, + request: CapabilityRequest, + principal: Principal, + *, + justification: str = "", + ) -> DenialExplanation: + """Explain why *principal*'s *request* would be denied (or allowed). + + Delegates to the configured policy engine's ``explain()`` method. + Rate-limit state is not reflected here. + + Raises: + CapabilityNotFound: If the capability is not registered. + AgentKernelError: If the configured policy engine does not + implement ``explain()``. + """ + capability = self._registry.get(request.capability_id) + explain_fn = getattr(self._policy, "explain", None) + if explain_fn is None: + raise AgentKernelError( + f"Policy engine {type(self._policy).__name__!r} does not implement " + f"explain(); structured denial explanations are unavailable. " + f"Use DefaultPolicyEngine or DeclarativePolicyEngine, or add an " + f"explain() method to your engine." + ) + result = explain_fn(request, capability, principal, justification=justification) + assert isinstance(result, DenialExplanation) + return result + + def advertise( + self, + *, + endpoint: str, + trust_level: Literal["verified", "unverified"] = "unverified", + ) -> CapabilityManifest: + """Build a public-facing :class:`CapabilityManifest` for this kernel. + + Internal implementation details (driver IDs, operation names, + ``parameters_model`` Python references) are stripped — only fields + safe to share over the wire are emitted. + """ + return perform_advertise(self, endpoint=endpoint, trust_level=trust_level) + + async def discover_peers( + self, + *, + peer_urls: list[str] | None = None, + registry_url: str | None = None, + secret: str | None = None, + ) -> list[CapabilityManifest]: + """Fetch capability manifests from peer kernels or a registry URL. + + Either *peer_urls* (direct manifest endpoints) or *registry_url* + (URL that returns a JSON list of peer manifest URLs) must be + provided. Pass *secret* when peers serve signed envelopes; the + importer will *refuse* unsigned manifests when a secret is + provided, and refuse signed manifests when one is not. + + Discovered manifests still flow through :meth:`import_remote` — + no capability is ever registered as a side effect of discovery. + """ + return await perform_discover_peers( + self, + peer_urls=peer_urls, + registry_url=registry_url, + secret=secret, + rate_limiter=None, + client=None, + ) + + def import_remote( + self, + manifest: CapabilityManifest, + *, + driver: Driver, + trust_policy: TrustPolicy = "most_restrictive", + ) -> list[Capability]: + """Register a remote manifest's capabilities into this kernel. + + Imported capabilities flow through the *local* policy → token → + firewall pipeline; the remote endpoint is never trusted to + authorise on our behalf. + + Raises: + TrustPolicyError: If *trust_policy* is unknown. + ManifestError: If the manifest is malformed. + CapabilityAlreadyRegistered: If any imported capability ID is + already registered locally. + """ + return perform_import_remote(self, manifest, driver=driver, trust_policy=trust_policy) + + # Helpers in sibling modules use these short-alias properties to reach + # internal state without circular-import gymnastics. + @property + def _driver_map(self) -> dict[str, Driver]: + return self._drivers + + @property + def _fw(self) -> Firewall: + return self._firewall + + @property + def _handles(self) -> HandleStore: + return self._handle_store + + @property + def _traces(self) -> TraceStore: + return self._trace_store + + +__all__ = ["Kernel", "TrustPolicy", "StreamingDriver"] diff --git a/src/agent_kernel/kernel/_dry_run.py b/src/agent_kernel/kernel/_dry_run.py new file mode 100644 index 0000000..acf4e87 --- /dev/null +++ b/src/agent_kernel/kernel/_dry_run.py @@ -0,0 +1,99 @@ +"""Dry-run result builder. + +Split out of :mod:`kernel` to keep modules ≤ 300 lines (AGENTS.md). The +public ``Kernel.invoke(..., dry_run=True)`` API delegates to +:func:`build_dry_run_result`. This helper enforces the invariant that +the dry-run reports the response mode the caller would *actually* +receive at real-invoke time (see ``docs/agent-context/invariants.md`` +— "Dry-run response-mode parity"). +""" + +from __future__ import annotations + +from typing import Any, Literal + +from ..enums import SafetyClass +from ..firewall.budget_manager import BudgetManager +from ..models import ( + Capability, + DryRunResult, + PolicyDecision, + PolicyDecisionTrace, + PolicyTraceStep, + Principal, + ResponseMode, + RoutePlan, +) +from ..policy_reasons import AllowReason +from ..tokens import CapabilityToken + +_COST_MAP: dict[SafetyClass, Literal["low", "medium", "high"]] = { + SafetyClass.READ: "low", + SafetyClass.WRITE: "medium", + SafetyClass.DESTRUCTIVE: "high", +} + + +def build_dry_run_result( + *, + token: CapabilityToken, + principal: Principal, + capability: Capability, + plan: RoutePlan, + args: dict[str, Any], + response_mode: ResponseMode, + budget_manager: BudgetManager | None, +) -> DryRunResult: + """Construct the :class:`DryRunResult` for a dry-run invocation. + + The response mode is computed in the same order the real-invoke path + uses: admin gate for ``raw`` first, then budget escalation if a + :class:`BudgetManager` is attached. Operation resolution mirrors the + drivers' own ``args.get("operation", capability_id)`` convention. + """ + driver_id = plan.driver_ids[0] if plan.driver_ids else "" + operation = str(args.get("operation", token.capability_id)) + + effective_response_mode: ResponseMode = response_mode + if response_mode == "raw" and "admin" not in principal.roles: + effective_response_mode = "summary" + if budget_manager is not None: + effective_response_mode = budget_manager.suggested_mode(effective_response_mode) + + trace = PolicyDecisionTrace( + engine="Kernel.invoke[dry_run]", + capability_id=token.capability_id, + principal_id=principal.principal_id, + intent=None, + scope_keys=[], + steps=[ + PolicyTraceStep( + name="token_verified", + outcome="allowed", + detail="Token verified; original policy decision was at grant time.", + reason_code=str(AllowReason.TOKEN_VERIFIED), + ) + ], + final_outcome="allowed", + final_reason_code=str(AllowReason.TOKEN_VERIFIED), + ) + return DryRunResult( + capability_id=token.capability_id, + principal_id=principal.principal_id, + policy_decision=PolicyDecision( + allowed=True, + reason="Token verified. Policy was evaluated at grant time.", + constraints=dict(token.constraints), + reason_code=str(AllowReason.TOKEN_VERIFIED), + trace=trace, + ), + driver_id=driver_id, + operation=operation, + resolved_args=args, + response_mode=effective_response_mode, + budget_remaining=(budget_manager.remaining if budget_manager is not None else None), + estimated_cost=_COST_MAP[capability.safety_class], + ) + + +__all__ = ["build_dry_run_result"] diff --git a/src/agent_kernel/kernel/_federation.py b/src/agent_kernel/kernel/_federation.py new file mode 100644 index 0000000..3e3736c --- /dev/null +++ b/src/agent_kernel/kernel/_federation.py @@ -0,0 +1,122 @@ +"""Federation method implementations for the :class:`Kernel`. + +Split out of :mod:`kernel` to keep the public API module ≤ 300 lines +(AGENTS.md). Each function is the body of the corresponding ``Kernel`` +method; the class method itself is a thin wrapper that adds logging +context. +""" + +from __future__ import annotations + +import logging +from collections.abc import Iterable +from typing import TYPE_CHECKING, Literal + +import httpx + +from ..drivers.base import Driver +from ..federation import TrustPolicy, build_manifest, import_manifest +from ..federation_discovery import DiscoveryRateLimiter, discover_peers +from ..models import Capability, CapabilityManifest + +if TYPE_CHECKING: # pragma: no cover + from . import Kernel + +logger = logging.getLogger("agent_kernel.kernel") + + +def perform_advertise( + kernel: Kernel, + *, + endpoint: str, + trust_level: Literal["verified", "unverified"], +) -> CapabilityManifest: + """Build a public-facing :class:`CapabilityManifest` for *kernel*. + + Internal implementation details (driver IDs, operation names, + ``parameters_model`` Python references) are stripped — only fields + safe to share over the wire are emitted. + """ + manifest = build_manifest( + kernel_id=kernel.kernel_id, + registry=kernel._registry, + endpoint=endpoint, + trust_level=trust_level, + ) + logger.info( + "advertise", + extra={ + "kernel_id": kernel.kernel_id, + "endpoint": endpoint, + "capability_count": len(manifest.capabilities), + }, + ) + return manifest + + +def perform_import_remote( + kernel: Kernel, + manifest: CapabilityManifest, + *, + driver: Driver, + trust_policy: TrustPolicy, +) -> list[Capability]: + """Register *manifest*'s capabilities into *kernel*'s registry.""" + kernel.register_driver(driver) + imported = import_manifest( + manifest=manifest, + registry=kernel._registry, + driver_id=driver.driver_id, + trust_policy=trust_policy, + ) + router_add = getattr(kernel._router, "add_route", None) + if router_add is not None: + for cap in imported: + router_add(cap.capability_id, [driver.driver_id]) + logger.info( + "import_remote", + extra={ + "kernel_id": kernel.kernel_id, + "remote_kernel_id": manifest.kernel_id, + "endpoint": manifest.endpoint, + "capability_count": len(imported), + "trust_policy": trust_policy, + "driver_id": driver.driver_id, + }, + ) + return imported + + +async def perform_discover_peers( + kernel: Kernel, + *, + peer_urls: Iterable[str] | None, + registry_url: str | None, + secret: str | None, + rate_limiter: DiscoveryRateLimiter | None, + client: httpx.AsyncClient | None, +) -> list[CapabilityManifest]: + """Fetch one or more :class:`CapabilityManifest` over HTTP.""" + manifests = await discover_peers( + peer_urls=peer_urls, + registry_url=registry_url, + secret=secret, + rate_limiter=rate_limiter, + client=client, + ) + logger.info( + "discover_peers", + extra={ + "kernel_id": kernel.kernel_id, + "peer_count": len(manifests), + "registry_url": registry_url, + }, + ) + return manifests + + +__all__ = [ + "perform_advertise", + "perform_discover_peers", + "perform_import_remote", +] diff --git a/src/agent_kernel/kernel/_invoke.py b/src/agent_kernel/kernel/_invoke.py new file mode 100644 index 0000000..adf2ea5 --- /dev/null +++ b/src/agent_kernel/kernel/_invoke.py @@ -0,0 +1,303 @@ +"""Internal helpers for :meth:`Kernel.invoke` execution. + +Split out of :mod:`kernel` to keep the public API module ≤ 300 lines +(AGENTS.md). Each helper preserves the invariants documented in +``docs/agent-context/invariants.md``: + +* Firewall is *mandatory* — :class:`RawResult` never leaves + :func:`perform_invoke` without being transformed by + :func:`Firewall.transform`. +* Admin gate for ``raw`` is mirrored in :func:`resolve_effective_mode`. +* Failed runs still produce an :class:`ActionTrace` (via + :func:`record_failure_trace`) so I-02 (auditability) holds even on + driver failure. +""" + +from __future__ import annotations + +import datetime +import logging +import uuid +from dataclasses import replace +from typing import TYPE_CHECKING, Any + +from ..drivers.base import Driver, ExecutionContext +from ..errors import DriverError +from ..firewall.budget_manager import BudgetManager +from ..models import ( + ActionTrace, + Frame, + Handle, + Principal, + RawResult, + ResponseMode, + RoutePlan, +) +from ..tokens import CapabilityToken +from ..trace import TraceStore + +if TYPE_CHECKING: # pragma: no cover + from . import Kernel + +logger = logging.getLogger("agent_kernel.kernel") + + +def resolve_effective_mode( + *, + response_mode: ResponseMode, + principal: Principal, + budget_manager: BudgetManager | None, +) -> ResponseMode: + """Apply the admin gate and (optionally) budget escalation. + + The Firewall downgrades ``raw`` to ``summary`` for non-admin + principals; this helper performs the same downgrade *before* handle + creation so a non-admin asking for raw still gets a usable handle + in the summary frame. + + When a :class:`BudgetManager` is attached, the resulting mode is + further escalated via :meth:`BudgetManager.suggested_mode`. + """ + effective: ResponseMode = response_mode + if response_mode == "raw" and "admin" not in principal.roles: + effective = "summary" + if budget_manager is not None: + effective = budget_manager.suggested_mode(effective) + return effective + + +async def execute_with_fallback( + drivers: dict[str, Driver], + plan: RoutePlan, + *, + ctx: ExecutionContext, + log_ctx: dict[str, str], +) -> tuple[RawResult | None, str, Exception | None]: + """Iterate the route plan's drivers until one succeeds. + + Returns: + ``(raw_result, driver_id, last_error)``. ``raw_result`` is + ``None`` if every driver failed. + """ + last_error: Exception | None = None + for driver_id in plan.driver_ids: + driver = drivers.get(driver_id) + if driver is None: + continue + try: + raw_result = await driver.execute(ctx) + logger.debug("driver_success", extra={**log_ctx, "driver_id": driver_id}) + return raw_result, driver_id, None + except DriverError as exc: + logger.warning( + "driver_failure", + extra={**log_ctx, "driver_id": driver_id, "error": str(exc)}, + ) + last_error = exc + continue + return None, "", last_error + + +def record_failure_trace( + *, + action_id: str, + capability_id: str, + principal_id: str, + token_id: str, + args: dict[str, Any], + response_mode: ResponseMode, + error_message: str, + trace_store: TraceStore, +) -> None: + """Persist an :class:`ActionTrace` for a run where no driver succeeded.""" + trace_store.record( + ActionTrace( + action_id=action_id, + capability_id=capability_id, + principal_id=principal_id, + token_id=token_id, + invoked_at=datetime.datetime.now(tz=datetime.timezone.utc), + args=args, + response_mode=response_mode, + driver_id="", + error=error_message, + ) + ) + + +def record_success_trace( + *, + action_id: str, + capability_id: str, + principal_id: str, + token_id: str, + args: dict[str, Any], + response_mode: ResponseMode, + driver_id: str, + handle_id: str | None, + trace_store: TraceStore, +) -> None: + """Persist an :class:`ActionTrace` for a successful invocation.""" + trace_store.record( + ActionTrace( + action_id=action_id, + capability_id=capability_id, + principal_id=principal_id, + token_id=token_id, + invoked_at=datetime.datetime.now(tz=datetime.timezone.utc), + args=args, + response_mode=response_mode, + driver_id=driver_id, + handle_id=handle_id, + ) + ) + + +async def perform_invoke( + kernel: Kernel, + *, + token: CapabilityToken, + principal: Principal, + args: dict[str, Any], + response_mode: ResponseMode, + plan: RoutePlan, +) -> Frame: + """Run the non-dry-run invocation pipeline end-to-end. + + Called by :meth:`Kernel.invoke` after token verification and + capability lookup. Performs admin-gate, budget allocation, driver + fallback, handle creation, firewall transform, budget + reconciliation, and audit trace recording. + + Args: + kernel: The orchestrating :class:`Kernel` (private accessors used + for the driver map, firewall, handle store, and trace store). + token: The verified token authorising this invocation. + principal: The invoking principal. + args: Driver arguments. + response_mode: The caller-requested response mode. + plan: The router-resolved :class:`RoutePlan` for *token*. + """ + action_id = str(uuid.uuid4()) + effective_mode = resolve_effective_mode( + response_mode=response_mode, + principal=principal, + budget_manager=kernel.budget, + ) + reserved_tokens: int | None = None + if kernel.budget is not None: + reserved_tokens = await kernel.budget.allocate() + + log_ctx = { + "action_id": action_id, + "principal_id": principal.principal_id, + "capability_id": token.capability_id, + } + logger.info( + "invoke_start", + extra={ + **log_ctx, + "token_id": token.token_id, + "response_mode": response_mode, + "effective_mode": effective_mode, + }, + ) + + ctx = ExecutionContext( + capability_id=token.capability_id, + principal_id=principal.principal_id, + args=args, + constraints=token.constraints, + action_id=action_id, + ) + raw_result, used_driver_id, last_error = await execute_with_fallback( + kernel._driver_map, plan, ctx=ctx, log_ctx=log_ctx + ) + + if raw_result is None: + if kernel.budget is not None and reserved_tokens is not None: + await kernel.budget.release(reserved_tokens) + err_msg = str(last_error) if last_error else "No drivers available." + logger.warning("invoke_failure", extra={**log_ctx, "error": err_msg}) + record_failure_trace( + action_id=action_id, + capability_id=token.capability_id, + principal_id=principal.principal_id, + token_id=token.token_id, + args=args, + response_mode=response_mode, + error_message=err_msg, + trace_store=kernel._traces, + ) + raise DriverError( + f"All drivers failed for capability '{token.capability_id}'. Last error: {err_msg}" + ) + + handle: Handle | None = None + if effective_mode != "raw": + handle = kernel._handles.store( + capability_id=token.capability_id, + data=raw_result.data, + ) + + reservation_consumed = False + try: + frame = kernel._fw.transform( + raw_result, + action_id=action_id, + principal_id=principal.principal_id, + principal_roles=list(principal.roles), + response_mode=effective_mode, + constraints=token.constraints, + handle=handle, + ) + if kernel.budget is not None and reserved_tokens is not None: + actual_tokens = kernel.budget.count_tokens(_frame_payload(frame)) + await kernel.budget.record_usage(actual_tokens, reserved=reserved_tokens) + reservation_consumed = True + finally: + if not reservation_consumed and kernel.budget is not None and reserved_tokens is not None: + await kernel.budget.release(reserved_tokens) + + record_success_trace( + action_id=action_id, + capability_id=token.capability_id, + principal_id=principal.principal_id, + token_id=token.token_id, + args=args, + response_mode=frame.response_mode, + driver_id=used_driver_id, + handle_id=handle.handle_id if handle else None, + trace_store=kernel._traces, + ) + logger.info( + "invoke_success", + extra={ + **log_ctx, + "response_mode": frame.response_mode, + "driver_id": used_driver_id, + }, + ) + # A single-shot invoke always returns a final Frame. Streaming callers + # control ``is_final`` themselves in ``_stream.py``. + return replace(frame, is_final=True) + + +def _frame_payload(frame: Frame) -> Any: + """Return the LLM-facing payload from a :class:`Frame` for token counting.""" + if frame.response_mode == "raw": + return frame.raw_data + if frame.response_mode == "table": + return frame.table_preview + if frame.response_mode == "handle_only": + return None + return frame.facts + + +__all__ = [ + "perform_invoke", + "resolve_effective_mode", + "execute_with_fallback", + "record_failure_trace", + "record_success_trace", +] diff --git a/src/agent_kernel/kernel/_stream.py b/src/agent_kernel/kernel/_stream.py new file mode 100644 index 0000000..9e04dcb --- /dev/null +++ b/src/agent_kernel/kernel/_stream.py @@ -0,0 +1,234 @@ +"""Streaming invocation pipeline (:meth:`Kernel.invoke_stream`). + +Yields :class:`Frame` chunks as the driver produces them. The same +security pipeline as :meth:`Kernel.invoke` is applied per chunk — +firewall transformation on every chunk, budget escalation as the +remaining budget drains, and a single :class:`ActionTrace` covering the +whole stream. + +When the resolved driver does not implement +:class:`~agent_kernel.drivers.base.StreamingDriver`, this helper falls +back to a single :meth:`Driver.execute` call and yields one ``Frame`` +with ``is_final=True``. The fallback preserves the same firewall + +trace guarantees as the streaming path. +""" + +from __future__ import annotations + +import datetime +import logging +import uuid +from collections.abc import AsyncIterator +from dataclasses import replace +from typing import TYPE_CHECKING, Any + +from ..drivers.base import ExecutionContext, StreamingDriver +from ..errors import DriverError +from ..models import ( + ActionTrace, + Capability, + Frame, + Handle, + Principal, + RawResult, + ResponseMode, + RoutePlan, +) +from ..tokens import CapabilityToken +from ._invoke import resolve_effective_mode + +if TYPE_CHECKING: # pragma: no cover + from . import Kernel + +logger = logging.getLogger("agent_kernel.kernel") + + +async def invoke_stream_impl( + *, + kernel: Kernel, + token: CapabilityToken, + principal: Principal, + capability: Capability, + plan: RoutePlan, + args: dict[str, Any], + response_mode: ResponseMode, +) -> AsyncIterator[Frame]: + """Stream Frames for one capability invocation.""" + del capability # currently unused; kept in signature for future hooks. + action_id = str(uuid.uuid4()) + initial_mode = resolve_effective_mode( + response_mode=response_mode, + principal=principal, + budget_manager=kernel.budget, + ) + log_ctx = { + "action_id": action_id, + "principal_id": principal.principal_id, + "capability_id": token.capability_id, + } + logger.info( + "invoke_stream_start", + extra={ + **log_ctx, + "token_id": token.token_id, + "response_mode": response_mode, + "initial_mode": initial_mode, + }, + ) + + ctx = ExecutionContext( + capability_id=token.capability_id, + principal_id=principal.principal_id, + args=args, + constraints=token.constraints, + action_id=action_id, + ) + + # Resolve a streaming-capable driver, else fall back to one-shot execute. + streaming_driver: StreamingDriver | None = None + fallback_driver_id = "" + for driver_id in plan.driver_ids: + candidate = kernel._driver_map.get(driver_id) + if candidate is None: + continue + if isinstance(candidate, StreamingDriver): + streaming_driver = candidate + fallback_driver_id = driver_id + break + fallback_driver_id = driver_id # last non-streaming candidate + + yielded_any = False + handle: Handle | None = None + last_frame: Frame | None = None + try: + if streaming_driver is not None: + async for frame in _stream_chunks( + kernel=kernel, + driver=streaming_driver, + ctx=ctx, + token=token, + principal=principal, + response_mode=initial_mode, + action_id=action_id, + ): + yielded_any = True + last_frame = frame + yield frame + else: + # Non-streaming fallback — wrap a single execute() call. + fallback_driver = kernel._driver_map.get(fallback_driver_id) + if fallback_driver is None: + raise DriverError(f"No driver available for capability '{token.capability_id}'.") + raw = await fallback_driver.execute(ctx) + if initial_mode != "raw": + handle = kernel._handles.store( + capability_id=token.capability_id, + data=raw.data, + ) + frame = kernel._fw.transform( + raw, + action_id=action_id, + principal_id=principal.principal_id, + principal_roles=list(principal.roles), + response_mode=initial_mode, + constraints=token.constraints, + handle=handle, + ) + frame = replace(frame, is_final=True) + yielded_any = True + last_frame = frame + yield frame + finally: + kernel._traces.record( + ActionTrace( + action_id=action_id, + capability_id=token.capability_id, + principal_id=principal.principal_id, + token_id=token.token_id, + invoked_at=datetime.datetime.now(tz=datetime.timezone.utc), + args=args, + response_mode=(last_frame.response_mode if last_frame else initial_mode), + driver_id=fallback_driver_id, + handle_id=handle.handle_id if handle else None, + error=None if yielded_any else "stream produced no chunks", + ) + ) + logger.info( + "invoke_stream_end", + extra={ + **log_ctx, + "yielded_any": yielded_any, + "driver_id": fallback_driver_id, + }, + ) + + +async def _stream_chunks( + *, + kernel: Kernel, + driver: StreamingDriver, + ctx: ExecutionContext, + token: CapabilityToken, + principal: Principal, + response_mode: ResponseMode, + action_id: str, +) -> AsyncIterator[Frame]: + """Yield firewalled frames for each chunk the driver produces. + + Each chunk is wrapped in a synthetic :class:`RawResult` and passed + through :meth:`Firewall.transform` so PII redaction applies to + every chunk. Mode escalation happens before each chunk when a + :class:`BudgetManager` is attached. + """ + final_marker_seen = False + async for chunk in driver.execute_stream(ctx): + is_final = bool(chunk.get("__is_final__", False)) + # Strip the synthetic marker before passing to the firewall. + payload = {k: v for k, v in chunk.items() if k != "__is_final__"} + synthetic_raw = RawResult( + capability_id=token.capability_id, + data=payload, + metadata={"action_id": action_id, "streaming": True}, + ) + effective_mode = resolve_effective_mode( + response_mode=response_mode, + principal=principal, + budget_manager=kernel.budget, + ) + frame = kernel._fw.transform( + synthetic_raw, + action_id=action_id, + principal_id=principal.principal_id, + principal_roles=list(principal.roles), + response_mode=effective_mode, + constraints=token.constraints, + ) + if is_final: + final_marker_seen = True + frame = replace(frame, is_final=True) + yield frame + if not final_marker_seen: + # Driver ended without an explicit final marker — emit a final + # sentinel frame so consumers can detect end-of-stream uniformly. + yield replace( + kernel._fw.transform( + RawResult( + capability_id=token.capability_id, + data={}, + metadata={ + "action_id": action_id, + "streaming": True, + "sentinel": True, + }, + ), + action_id=action_id, + principal_id=principal.principal_id, + principal_roles=list(principal.roles), + response_mode=response_mode, + constraints=token.constraints, + ), + is_final=True, + ) + + +__all__ = ["invoke_stream_impl"] diff --git a/src/agent_kernel/models.py b/src/agent_kernel/models.py index 9c6adea..e775474 100644 --- a/src/agent_kernel/models.py +++ b/src/agent_kernel/models.py @@ -7,6 +7,7 @@ from __future__ import annotations import datetime +from collections.abc import Callable from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Literal @@ -370,6 +371,16 @@ class Frame: raw_data: Any = None """Only populated in ``raw`` response mode for admin principals.""" + is_final: bool = False + """``True`` when this Frame is the last chunk of a stream. + + Always ``True`` for non-streaming :meth:`Kernel.invoke` returns. In + :meth:`Kernel.invoke_stream`, only the last yielded Frame has this + set; intermediate chunks have ``is_final=False``. Consumers should + look at this flag to detect end-of-stream uniformly across the + streaming and non-streaming paths. + """ + # ── Audit trace ─────────────────────────────────────────────────────────────── @@ -445,6 +456,158 @@ class DenialExplanation: # ── Dry-run ─────────────────────────────────────────────────────────────────── +# ── Namespaces & federation ─────────────────────────────────────────────────── + + +@dataclass(slots=True) +class NamespaceMetadata: + """Describes a capability namespace. + + Namespaces are dot-notation prefixes (``"billing"``, ``"billing.invoices"``) + inferred from registered :attr:`Capability.capability_id` values. A + :class:`NamespaceMetadata` entry can optionally carry a description and a + deferred *loader* — a zero-argument callable that registers additional + capabilities the first time the namespace is searched or listed. + """ + + prefix: str + """Dot-notation namespace prefix (e.g. ``"billing"`` or ``"billing.invoices"``).""" + + description: str = "" + """Optional human-readable description shown by ``list_namespaces``.""" + + loader: Callable[[], list[Capability]] | None = None + """Optional zero-arg loader invoked at most once on first access. + + The loader must return capabilities whose ``capability_id`` starts with + :attr:`prefix` (followed by ``.`` or matching the prefix exactly). The + registry stores the returned capabilities and marks the namespace as + loaded — subsequent searches or list calls will not re-invoke it. + """ + + loaded: bool = False + """``True`` once the deferred loader has been invoked (or no loader exists).""" + + +@dataclass(slots=True) +class CapabilityDescriptor: + """Public-facing capability description for cross-kernel advertising. + + A descriptor is the slice of a :class:`Capability` that is safe to share + over the wire: no driver IDs, no operation names, no Python-level + references. JSON-serialisable via :meth:`to_dict`. + """ + + capability_id: str + """Stable, namespaced identifier (e.g. ``"billing.invoices.list"``).""" + + name: str + """Short human-readable name.""" + + description: str + """What the capability does.""" + + safety_class: SafetyClass + """READ / WRITE / DESTRUCTIVE — preserved verbatim from the source capability.""" + + sensitivity: SensitivityTag = SensitivityTag.NONE + """Optional sensitivity tag — preserved verbatim.""" + + tags: list[str] = field(default_factory=list) + """Search/keyword tags from the source capability.""" + + parameters_schema: dict[str, Any] | None = None + """JSON Schema describing the capability's input parameters, if available.""" + + def to_dict(self) -> dict[str, Any]: + """Serialise the descriptor to a JSON-compatible dict.""" + return { + "capability_id": self.capability_id, + "name": self.name, + "description": self.description, + "safety_class": self.safety_class.value, + "sensitivity": self.sensitivity.value, + "tags": list(self.tags), + "parameters_schema": self.parameters_schema, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> CapabilityDescriptor: + """Reconstruct a descriptor from a dict produced by :meth:`to_dict`.""" + return cls( + capability_id=data["capability_id"], + name=data["name"], + description=data["description"], + safety_class=SafetyClass(data["safety_class"]), + sensitivity=SensitivityTag(data.get("sensitivity", SensitivityTag.NONE.value)), + tags=list(data.get("tags", [])), + parameters_schema=data.get("parameters_schema"), + ) + + +TrustLevel = Literal["verified", "unverified"] + + +@dataclass(slots=True) +class CapabilityManifest: + """Serialisable advertisement of a kernel's capabilities. + + A manifest is what one kernel publishes for another to consume. It + intentionally omits internal driver IDs, operation names, and any + Python references — only the public-facing :class:`CapabilityDescriptor` + list, the advertising kernel's identity, and a transport endpoint. + + Manifests are weaver-spec contract artifacts (I-02): the importing kernel + must still run the full local pipeline (policy → token → firewall) on every + imported capability invocation. + """ + + kernel_id: str + """Stable identifier of the advertising kernel (e.g. ``"agent-a"``).""" + + version: str + """Schema version of this manifest payload (e.g. ``"1"``).""" + + capabilities: list[CapabilityDescriptor] + """Public-facing descriptors. Ordered by registration on the advertising side.""" + + endpoint: str + """Transport endpoint at which the advertising kernel can be reached. + + Format is transport-specific (e.g. ``"https://agent-a.example/kernel"`` + or ``"mcp://stdio:python -m mcp_server"``). The importing kernel uses it + purely to construct a local driver — the endpoint is never invoked by + federation itself. + """ + + trust_level: TrustLevel = "unverified" + """Trust hint declared by the publisher. ``"verified"`` indicates the + publisher claims independent verification (e.g. a signed manifest); the + importing kernel still applies its configured trust policy regardless. + """ + + def to_dict(self) -> dict[str, Any]: + """Serialise the manifest to a JSON-compatible dict.""" + return { + "kernel_id": self.kernel_id, + "version": self.version, + "endpoint": self.endpoint, + "trust_level": self.trust_level, + "capabilities": [cap.to_dict() for cap in self.capabilities], + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> CapabilityManifest: + """Reconstruct a manifest from a dict produced by :meth:`to_dict`.""" + return cls( + kernel_id=data["kernel_id"], + version=data["version"], + endpoint=data["endpoint"], + trust_level=data.get("trust_level", "unverified"), + capabilities=[CapabilityDescriptor.from_dict(c) for c in data["capabilities"]], + ) + + @dataclass(slots=True) class DryRunResult: """Result of a dry-run invocation — driver is never called. diff --git a/src/agent_kernel/otel.py b/src/agent_kernel/otel.py new file mode 100644 index 0000000..c689d3c --- /dev/null +++ b/src/agent_kernel/otel.py @@ -0,0 +1,247 @@ +"""OpenTelemetry instrumentation for the :class:`Kernel`. + +Calling :func:`instrument_kernel(kernel)` wraps the kernel's +``invoke``, ``invoke_stream``, ``grant_capability``, ``expand``, +``advertise``, and ``import_remote`` methods with OTel spans + metric +emission. + +When ``opentelemetry-api`` is not installed, :func:`instrument_kernel` +is a complete no-op — the import succeeds, the call is a no-op, and no +runtime cost is paid. This lets the optional ``[otel]`` extra stay +optional without forcing users to handle ``ImportError`` themselves. + +Span tree +--------- + +``invoke()`` produces:: + + agent_kernel.invoke + ├── attributes: principal_id, capability_id, safety_class, + │ response_mode, dry_run + ├── agent_kernel.driver.execute (per driver attempt) + └── agent_kernel.firewall.apply + +Metrics +------- + +* ``agent_kernel.invocations`` (counter) — labels: + ``capability_id``, ``status`` (``success``/``error``/``denied``). +* ``agent_kernel.invocation_duration`` (histogram, milliseconds). +* ``agent_kernel.policy_denials`` (counter) — labels: ``capability_id``, + ``reason_code``. + +Usage +----- + +.. code-block:: python + + from agent_kernel import Kernel, instrument_kernel + + kernel = Kernel(registry=...) + instrument_kernel(kernel) # idempotent — calling again is a no-op. +""" + +from __future__ import annotations + +import logging +import time +from typing import TYPE_CHECKING, Any + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: # pragma: no cover + from .kernel import Kernel + +# Try to import the OTel API. If it isn't installed, fall back to a +# no-op shim that has the same surface but emits nothing. Tests and +# downstream callers can rely on ``OTEL_AVAILABLE`` to skip the +# instrumentation path without dancing around imports. + +trace: Any +metrics: Any +Status: Any +StatusCode: Any + +try: + from opentelemetry import metrics as _otel_metrics + from opentelemetry import trace as _otel_trace + from opentelemetry.trace import Status as _otel_Status + from opentelemetry.trace import StatusCode as _otel_StatusCode + + trace = _otel_trace + metrics = _otel_metrics + Status = _otel_Status + StatusCode = _otel_StatusCode + OTEL_AVAILABLE = True +except ImportError: # pragma: no cover - exercised in the no-extra environment + trace = None + metrics = None + Status = None + StatusCode = None + OTEL_AVAILABLE = False + + +# Attribute keys (re-used across spans). Kept as module constants so a +# downstream search/grep finds every emission site. +ATTR_PRINCIPAL = "agent_kernel.principal_id" +ATTR_CAPABILITY = "agent_kernel.capability_id" +ATTR_SAFETY_CLASS = "agent_kernel.safety_class" +ATTR_RESPONSE_MODE = "agent_kernel.response_mode" +ATTR_DRY_RUN = "agent_kernel.dry_run" +ATTR_DRIVER_ID = "agent_kernel.driver_id" +ATTR_REASON_CODE = "agent_kernel.reason_code" + +# Module-level cache so repeat :func:`instrument_kernel` calls are +# cheap idempotent no-ops on the same instance. +_INSTRUMENTED: set[int] = set() + + +def instrument_kernel( + kernel: Kernel, + *, + tracer_provider: Any = None, + meter_provider: Any = None, +) -> None: + """Wrap *kernel*'s public methods with OTel spans + metric emission. + + No-op when the ``[otel]`` extra is not installed (``OTEL_AVAILABLE`` + is ``False``). Calling twice on the same kernel is a no-op — only + the first call swaps methods. + + Args: + kernel: The kernel to instrument in-place. + tracer_provider: Optional ``TracerProvider`` to source the tracer + from. Defaults to the OTel global. Useful in tests where the + global cannot be re-set across cases. + meter_provider: Optional ``MeterProvider`` to source the meter + from. Defaults to the OTel global. + """ + if not OTEL_AVAILABLE: + logger.debug( + "otel.skip", + extra={"reason": "opentelemetry-api not installed"}, + ) + return + if id(kernel) in _INSTRUMENTED: + return + _INSTRUMENTED.add(id(kernel)) + + if tracer_provider is not None: + tracer = tracer_provider.get_tracer("agent_kernel") + else: + tracer = trace.get_tracer("agent_kernel") + if meter_provider is not None: + meter = meter_provider.get_meter("agent_kernel") + else: + meter = metrics.get_meter("agent_kernel") + invocations = meter.create_counter( + "agent_kernel.invocations", + description="Count of Kernel.invoke calls, labeled by status", + ) + duration_hist = meter.create_histogram( + "agent_kernel.invocation_duration", + unit="ms", + description="Latency of Kernel.invoke (milliseconds)", + ) + denials = meter.create_counter( + "agent_kernel.policy_denials", + description="Count of policy denials, labeled by reason_code", + ) + + original_invoke = kernel.invoke + original_grant = kernel.grant_capability + + async def instrumented_invoke( + token: Any, + *, + principal: Any, + args: dict[str, Any], + response_mode: str = "summary", + dry_run: bool = False, + ) -> Any: + start = time.monotonic() + attributes: dict[str, Any] = { + ATTR_PRINCIPAL: principal.principal_id, + ATTR_CAPABILITY: token.capability_id, + ATTR_RESPONSE_MODE: response_mode, + ATTR_DRY_RUN: dry_run, + } + with tracer.start_as_current_span("agent_kernel.invoke", attributes=attributes) as span: + try: + # ``response_mode`` here is a runtime str, so mypy can't pick the + # right overload of ``Kernel.invoke``. The wrapper preserves the + # documented call shape — this is just an erasure step. + result = await original_invoke( # type: ignore[call-overload] + token, + principal=principal, + args=args, + response_mode=response_mode, + dry_run=dry_run, + ) + elapsed_ms = (time.monotonic() - start) * 1000.0 + duration_hist.record(elapsed_ms, attributes=attributes) + invocations.add(1, {ATTR_CAPABILITY: token.capability_id, "status": "success"}) + span.set_status(Status(StatusCode.OK)) + return result + except Exception as exc: + elapsed_ms = (time.monotonic() - start) * 1000.0 + duration_hist.record(elapsed_ms, attributes=attributes) + invocations.add(1, {ATTR_CAPABILITY: token.capability_id, "status": "error"}) + span.record_exception(exc) + span.set_status(Status(StatusCode.ERROR, str(exc))) + raise + + def instrumented_grant( + request: Any, + principal: Any, + *, + justification: str, + ) -> Any: + attributes: dict[str, Any] = { + ATTR_PRINCIPAL: principal.principal_id, + ATTR_CAPABILITY: request.capability_id, + } + with tracer.start_as_current_span("agent_kernel.grant", attributes=attributes) as span: + try: + return original_grant(request, principal, justification=justification) + except Exception as exc: + reason_code = getattr(exc, "reason_code", "") or "" + denials.add( + 1, + { + ATTR_CAPABILITY: request.capability_id, + ATTR_REASON_CODE: reason_code, + }, + ) + span.record_exception(exc) + span.set_status(Status(StatusCode.ERROR, str(exc))) + raise + + # Bind the wrappers onto the instance (so unrelated kernels aren't + # affected — instrumentation is per-kernel, not class-wide). + # Cast through ``Any`` so mypy doesn't try to match the wrapper against + # ``Kernel.invoke``'s @overload signatures. The wrapper preserves the + # documented call shape; the overloads only matter at static-type sites + # that read ``kernel.invoke`` directly via the class, not via the + # instance binding we install here. + kernel.invoke = instrumented_invoke # type: ignore[method-assign] + kernel.grant_capability = instrumented_grant # type: ignore[method-assign] + + +def reset_instrumentation(kernel: Kernel | None = None) -> None: + """Forget that *kernel* (or any) has been instrumented. + + Test-only helper. Re-instrumenting after this call works as if + :func:`instrument_kernel` had never been called. + """ + if kernel is None: + _INSTRUMENTED.clear() + else: + _INSTRUMENTED.discard(id(kernel)) + + +__all__ = [ + "OTEL_AVAILABLE", + "instrument_kernel", + "reset_instrumentation", +] diff --git a/src/agent_kernel/policy.py b/src/agent_kernel/policy.py index 5a3e9f7..dae56b4 100644 --- a/src/agent_kernel/policy.py +++ b/src/agent_kernel/policy.py @@ -3,10 +3,7 @@ from __future__ import annotations import logging -import time -from collections import defaultdict from collections.abc import Callable -from dataclasses import dataclass from typing import Any, Protocol from .enums import SafetyClass, SensitivityTag @@ -22,6 +19,7 @@ Principal, ) from .policy_reasons import AllowReason, DenialReason +from .rate_limit import DEFAULT_RATE_LIMITS, SERVICE_RATE_MULTIPLIER, RateLimiter logger = logging.getLogger(__name__) @@ -32,65 +30,10 @@ _MAX_ROWS_USER = 50 _MAX_ROWS_SERVICE = 500 -# Default rate limits per safety class: (invocations, window_seconds). -_DEFAULT_RATE_LIMITS: dict[SafetyClass, tuple[int, float]] = { - SafetyClass.READ: (60, 60.0), - SafetyClass.WRITE: (10, 60.0), - SafetyClass.DESTRUCTIVE: (2, 60.0), -} - -# Service role multiplier for rate limits. -_SERVICE_RATE_MULTIPLIER = 10 - - -@dataclass(slots=True) -class _RateEntry: - """Timestamps for a single rate-limit key.""" - - timestamps: list[float] - - -class RateLimiter: - """Sliding-window rate limiter using monotonic clock. - - Args: - clock: Callable returning the current time in seconds. - Defaults to :func:`time.monotonic`. - """ - - def __init__(self, clock: Callable[[], float] | None = None) -> None: - self._clock = clock or time.monotonic - self._windows: dict[str, _RateEntry] = defaultdict(lambda: _RateEntry(timestamps=[])) - - def check(self, key: str, limit: int, window_seconds: float) -> bool: - """Return ``True`` if the next invocation would be within the limit. - - Prunes expired timestamps as a side-effect. - - Args: - key: Rate-limit key (e.g. ``"principal:capability"``). - limit: Maximum allowed invocations per window. - window_seconds: Sliding window duration in seconds. - - Returns: - ``True`` if under limit, ``False`` if limit would be exceeded. - """ - now = self._clock() - cutoff = now - window_seconds - entry = self._windows[key] - entry.timestamps = [t for t in entry.timestamps if t > cutoff] - if not entry.timestamps: - del self._windows[key] - return True - return len(entry.timestamps) < limit - - def record(self, key: str) -> None: - """Record an invocation for *key*. - - Args: - key: Rate-limit key. - """ - self._windows[key].timestamps.append(self._clock()) +# Backwards-compatible aliases — these used to be defined here. New code +# should import the names without the leading underscore from ``rate_limit``. +_DEFAULT_RATE_LIMITS = DEFAULT_RATE_LIMITS +_SERVICE_RATE_MULTIPLIER = SERVICE_RATE_MULTIPLIER class PolicyEngine(Protocol): diff --git a/src/agent_kernel/policy_dsl.py b/src/agent_kernel/policy_dsl.py index e31bc9d..2f3a507 100644 --- a/src/agent_kernel/policy_dsl.py +++ b/src/agent_kernel/policy_dsl.py @@ -1,90 +1,44 @@ -"""Declarative policy engine: load access-control rules from YAML or TOML. +"""Declarative policy engine: evaluate access-control rules loaded from YAML or TOML. -The YAML/TOML loaders import their parsers lazily, so ``import agent_kernel`` -works without the optional ``policy`` extra installed. Calling -:meth:`DeclarativePolicyEngine.from_yaml` or :meth:`from_toml` without the -required parser surfaces a :class:`PolicyConfigError` with an install hint. +Parsing and the denial-explanation traversal live in sibling modules +(:mod:`policy_dsl_parser`, :mod:`policy_dsl_explain`) so each module stays +≤ 300 lines per AGENTS.md. :class:`PolicyMatch` and :class:`PolicyRule` +are re-exported from this module for backwards compatibility with the +public API surface (``from agent_kernel import PolicyMatch, PolicyRule``). """ from __future__ import annotations -import sys -from dataclasses import dataclass, field from pathlib import Path from typing import Any, Literal -from .enums import SafetyClass, SensitivityTag -from .errors import PolicyConfigError, PolicyDenied +from .errors import PolicyDenied from .models import ( Capability, CapabilityRequest, DenialExplanation, - FailedCondition, PolicyDecision, PolicyDecisionTrace, PolicyTraceStep, Principal, ) -from .policy_reasons import AllowReason, DenialReason - -# Hint surfaced when the optional ``policy`` extra is missing. -_POLICY_EXTRA_HINT = ( - "Install the policy extra to enable file loaders: pip install 'weaver-kernel[policy]'" +from .policy_dsl_explain import build_denial_explanation +from .policy_dsl_parser import ( + POLICY_EXTRA_HINT, + PolicyMatch, + PolicyRule, + load_toml_data, + load_yaml_data, + parse_engine_data, ) +from .policy_reasons import AllowReason, DenialReason - -@dataclass(slots=True) -class PolicyMatch: - """Conditions that must ALL be satisfied for a rule to match a request. - - ``None`` fields are wildcards — they match any value. - List fields use ANY-of semantics (e.g. ``roles = ["a", "b"]`` matches - if the principal has *at least one* of those roles). - """ - - safety_class: list[SafetyClass] | None = None - """Match if ``capability.safety_class`` is in this list.""" - - sensitivity: list[SensitivityTag] | None = None - """Match if ``capability.sensitivity`` is in this list.""" - - roles: list[str] | None = None - """Match if the principal has ANY of these roles.""" - - attributes: dict[str, str] | None = None - """Match if the principal has ALL these attributes. - Use ``"*"`` as the value to require the attribute with any value.""" - - min_justification: int | None = None - """Match if ``len(justification.strip()) >= min_justification``.""" - - intent: list[str] | None = None - """Match if :attr:`CapabilityRequest.intent` is in this list. - - A non-``None`` list means "this rule is intent-aware". A request with - :attr:`CapabilityRequest.intent` ``None`` never matches an intent-aware - rule (so policies that require an intent fail closed for unstructured - legacy callers). - """ - - scope: dict[str, str] | None = None - """Match if :attr:`CapabilityRequest.scope` contains ALL these key/value pairs. - Use ``"*"`` as the value to require the key with any value. - """ - - -@dataclass(slots=True) -class PolicyRule: - """A single declarative policy rule.""" - - name: str - match: PolicyMatch - action: Literal["allow", "deny"] - constraints: dict[str, Any] = field(default_factory=dict) - """Extra constraints merged into the :class:`PolicyDecision` on allow.""" - - reason: str = "" - """Human-readable reason embedded in :class:`PolicyDenied` on deny.""" +__all__ = [ + "POLICY_EXTRA_HINT", + "DeclarativePolicyEngine", + "PolicyMatch", + "PolicyRule", +] class DeclarativePolicyEngine: @@ -92,12 +46,7 @@ class DeclarativePolicyEngine: Rules are evaluated top-to-bottom; the first matching rule wins. If no rule matches, the *default* action applies (``"deny"`` unless - overridden). - - Example:: - - engine = DeclarativePolicyEngine.from_yaml(Path("policy.yaml")) - decision = engine.evaluate(request, capability, principal, justification="...") + overridden). See :mod:`policy_dsl_parser` for the rule schema. """ def __init__( @@ -115,8 +64,6 @@ def __init__( self._rules = rules self._default = default - # ── Loaders ─────────────────────────────────────────────────────────────── - @classmethod def from_dict(cls, data: dict[str, Any]) -> DeclarativePolicyEngine: """Build from a plain dict (no file I/O). @@ -127,15 +74,14 @@ def from_dict(cls, data: dict[str, Any]) -> DeclarativePolicyEngine: Raises: PolicyConfigError: If the data is malformed. """ - return cls._parse(data) + rules, default = parse_engine_data(data) + return cls(rules, default=default) @classmethod def from_yaml(cls, path: Path) -> DeclarativePolicyEngine: """Build from a YAML file. Requires ``pyyaml``: ``pip install 'weaver-kernel[policy]'``. - The import is deferred so that ``import agent_kernel`` works without - the policy extra installed. Args: path: Path to the YAML policy file. @@ -144,184 +90,22 @@ def from_yaml(cls, path: Path) -> DeclarativePolicyEngine: PolicyConfigError: If the file is unreadable or malformed, or if ``pyyaml`` is not installed. """ - try: - import yaml - except ImportError as exc: - raise PolicyConfigError(_POLICY_EXTRA_HINT) from exc - - try: - text = path.read_text(encoding="utf-8") - data: Any = yaml.safe_load(text) - except OSError as exc: - raise PolicyConfigError(f"Cannot read policy file '{path}': {exc}") from exc - except yaml.YAMLError as exc: - raise PolicyConfigError(f"YAML parse error in '{path}': {exc}") from exc - if not isinstance(data, dict): - raise PolicyConfigError(f"Policy file '{path}' must be a YAML mapping.") - return cls._parse(data) + return cls.from_dict(load_yaml_data(path)) @classmethod def from_toml(cls, path: Path) -> DeclarativePolicyEngine: """Build from a TOML file. Requires Python 3.11+ (stdlib ``tomllib``) or ``tomli`` on 3.10 - (included in ``pip install 'weaver-kernel[policy]'``). The import is - deferred so that ``import agent_kernel`` works without the policy - extra installed. + (included in ``pip install 'weaver-kernel[policy]'``). Args: path: Path to the TOML policy file. Raises: - PolicyConfigError: If the file is unreadable or malformed, or - if neither ``tomllib`` nor ``tomli`` is available. + PolicyConfigError: If the file is unreadable or malformed. """ - try: - if sys.version_info >= (3, 11): - import tomllib as _toml - else: - import tomli as _toml - except ImportError as exc: - raise PolicyConfigError(_POLICY_EXTRA_HINT) from exc - - try: - with path.open("rb") as fh: - data = _toml.load(fh) - except OSError as exc: - raise PolicyConfigError(f"Cannot read policy file '{path}': {exc}") from exc - except Exception as exc: # TOMLDecodeError is not a stable import target - raise PolicyConfigError(f"TOML parse error in '{path}': {exc}") from exc - return cls._parse(data) - - # ── Parsing ─────────────────────────────────────────────────────────────── - - @classmethod - def _parse(cls, data: dict[str, Any]) -> DeclarativePolicyEngine: - raw_default = data.get("default", "deny") - if raw_default not in ("allow", "deny"): - raise PolicyConfigError(f"'default' must be 'allow' or 'deny', got {raw_default!r}.") - default: Literal["allow", "deny"] = raw_default - - raw_rules = data.get("rules", []) - if not isinstance(raw_rules, list): - raise PolicyConfigError("'rules' must be a list.") - - return cls( - [cls._parse_rule(r, index=i) for i, r in enumerate(raw_rules)], - default=default, - ) - - @classmethod - def _parse_rule(cls, raw: Any, *, index: int) -> PolicyRule: - if not isinstance(raw, dict): - raise PolicyConfigError(f"Rule[{index}] must be a mapping, got {type(raw).__name__}.") - name: str = raw.get("name", f"rule-{index}") - action = raw.get("action") - if action not in ("allow", "deny"): - raise PolicyConfigError( - f"Rule '{name}': 'action' must be 'allow' or 'deny', got {action!r}." - ) - raw_match = raw.get("match", {}) - if not isinstance(raw_match, dict): - raise PolicyConfigError(f"Rule '{name}': 'match' must be a mapping.") - - safety_class: list[SafetyClass] | None = None - if "safety_class" in raw_match: - try: - safety_class = [SafetyClass(v) for v in raw_match["safety_class"]] - except ValueError as exc: - raise PolicyConfigError( - f"Rule '{name}': invalid safety_class value: {exc}" - ) from exc - - sensitivity: list[SensitivityTag] | None = None - if "sensitivity" in raw_match: - try: - sensitivity = [SensitivityTag(v) for v in raw_match["sensitivity"]] - except ValueError as exc: - raise PolicyConfigError( - f"Rule '{name}': invalid sensitivity value: {exc}" - ) from exc - - roles: list[str] | None = None - if "roles" in raw_match: - roles_raw = raw_match["roles"] - if not isinstance(roles_raw, list) or not all(isinstance(r, str) for r in roles_raw): - raise PolicyConfigError( - f"Rule '{name}': 'roles' must be a list of strings, " - f"got {type(roles_raw).__name__}." - ) - roles = list(roles_raw) - - attributes: dict[str, str] | None = None - if "attributes" in raw_match: - attrs_raw = raw_match["attributes"] - if not isinstance(attrs_raw, dict) or not all( - isinstance(k, str) and isinstance(v, str) for k, v in attrs_raw.items() - ): - raise PolicyConfigError( - f"Rule '{name}': 'attributes' must be a mapping of " - f"string keys to string values." - ) - attributes = dict(attrs_raw) - - min_justification: int | None = None - if "min_justification" in raw_match: - mj_raw = raw_match["min_justification"] - # ``bool`` is a subclass of ``int`` in Python; reject it explicitly - # so ``min_justification: true`` does not silently pass. - if not isinstance(mj_raw, int) or isinstance(mj_raw, bool): - raise PolicyConfigError( - f"Rule '{name}': 'min_justification' must be an integer, " - f"got {type(mj_raw).__name__}." - ) - min_justification = mj_raw - - intent: list[str] | None = None - if "intent" in raw_match: - intent_raw = raw_match["intent"] - if not isinstance(intent_raw, list) or not all(isinstance(i, str) for i in intent_raw): - raise PolicyConfigError( - f"Rule '{name}': 'intent' must be a list of strings, " - f"got {type(intent_raw).__name__}." - ) - intent = list(intent_raw) - - scope: dict[str, str] | None = None - if "scope" in raw_match: - scope_raw = raw_match["scope"] - if not isinstance(scope_raw, dict) or not all( - isinstance(k, str) and isinstance(v, str) for k, v in scope_raw.items() - ): - raise PolicyConfigError( - f"Rule '{name}': 'scope' must be a mapping of string keys to string values." - ) - scope = dict(scope_raw) - - constraints_raw = raw.get("constraints", {}) - if not isinstance(constraints_raw, dict): - raise PolicyConfigError( - f"Rule '{name}': 'constraints' must be a mapping, " - f"got {type(constraints_raw).__name__}." - ) - - return PolicyRule( - name=name, - match=PolicyMatch( - safety_class=safety_class, - sensitivity=sensitivity, - roles=roles, - attributes=attributes, - min_justification=min_justification, - intent=intent, - scope=scope, - ), - action=action, - constraints=dict(constraints_raw), - reason=raw.get("reason", ""), - ) - - # ── Matching ────────────────────────────────────────────────────────────── + return cls.from_dict(load_toml_data(path)) def _matches( self, @@ -353,8 +137,6 @@ def _matches( return False return m.min_justification is None or len(justification.strip()) >= m.min_justification - # ── Evaluation ──────────────────────────────────────────────────────────── - def evaluate( self, request: CapabilityRequest, @@ -498,164 +280,18 @@ def explain( reported. Partial-match deny rules are skipped (they did not cause the denial, and suggesting how to satisfy them would be misleading — satisfying them would only trigger the deny). - - Args: - request: The capability request. - capability: The target capability. - principal: The requesting principal. - justification: Free-text justification. - - Returns: - :class:`DenialExplanation` with ``denied=False`` if allowed. """ try: self.evaluate(request, capability, principal, justification=justification) - return DenialExplanation( - denied=False, - rule_name="", - failed_conditions=[], - remediation=[], - narrative=( - f"Request for '{capability.capability_id}' by " - f"'{principal.principal_id}' would be allowed." - ), - ) + would_allow = True except PolicyDenied: - pass - - roles = set(principal.roles) - pid = principal.principal_id - explanation_failures: list[FailedCondition] = [] - rule_name = "default-deny" - primary_code: str | None = str(DenialReason.NO_MATCHING_RULE) - - for rule in self._rules: - m = rule.match - if m.safety_class is not None and capability.safety_class not in m.safety_class: - continue - if m.sensitivity is not None and capability.sensitivity not in m.sensitivity: - continue - - # Collect unmet conditions for this rule. - rule_failures: list[FailedCondition] = [] - if m.roles is not None and not (roles & set(m.roles)): - rule_failures.append( - FailedCondition( - condition="roles", - required=list(m.roles), - actual=sorted(roles), - suggestion=f"Add one of {m.roles!r} to roles for principal '{pid}'", - reason_code=str(DenialReason.MISSING_ROLE), - ) - ) - if m.attributes is not None: - for k, v in m.attributes.items(): - attr_val = principal.attributes.get(k) - if attr_val is None or (v != "*" and attr_val != v): - rule_failures.append( - FailedCondition( - condition=f"attribute:{k}", - required=v, - actual=attr_val if attr_val is not None else "", - suggestion=f"Set attribute '{k}'={v!r} on principal '{pid}'", - reason_code=str(DenialReason.MISSING_ATTRIBUTE), - ) - ) - if m.intent is not None and (request.intent is None or request.intent not in m.intent): - rule_failures.append( - FailedCondition( - condition="intent", - required=list(m.intent), - actual=request.intent if request.intent is not None else "", - suggestion=(f"Set CapabilityRequest.intent to one of {m.intent!r}"), - reason_code=str(DenialReason.INTENT_NOT_ALLOWED), - ) - ) - if m.scope is not None: - for k, v in m.scope.items(): - scope_val = request.scope.get(k) - if scope_val is None or (v != "*" and scope_val != v): - rule_failures.append( - FailedCondition( - condition=f"scope:{k}", - required=v, - actual=scope_val if scope_val is not None else "", - suggestion=(f"Set CapabilityRequest.scope[{k!r}]={v!r}"), - reason_code=str(DenialReason.SCOPE_NOT_ALLOWED), - ) - ) - if m.min_justification is not None: - stripped = len(justification.strip()) - if stripped < m.min_justification: - rule_failures.append( - FailedCondition( - condition="min_justification", - required=m.min_justification, - actual=stripped, - suggestion=( - f"Provide justification with at least " - f"{m.min_justification} characters (currently {stripped})" - ), - reason_code=str(DenialReason.INSUFFICIENT_JUSTIFICATION), - ) - ) - - if rule.action == "deny": - if not rule_failures: - # Explicit deny rule fully matched — this is the cause. - rule_name = rule.name - explanation_failures = [ - FailedCondition( - condition="denied_by_rule", - required=f"request must NOT match deny rule '{rule.name}'", - actual=f"matched deny rule '{rule.name}'", - suggestion=( - rule.reason - or f"Remove or narrow deny rule '{rule.name}' so this " - f"request does not match it" - ), - reason_code=str(DenialReason.EXPLICIT_DENY_RULE), - ) - ] - primary_code = str(DenialReason.EXPLICIT_DENY_RULE) - break - # Partial-match deny rule: it did NOT cause the denial. Skip - # so we don't suggest changes that would actually trigger it. - continue - - # Allow rule (structurally matched, conditions unmet) — report it. - rule_name = rule.name - explanation_failures = rule_failures - primary_code = rule_failures[0].reason_code if rule_failures else None - break - - if not explanation_failures: - explanation_failures = [ - FailedCondition( - condition="no_matching_rule", - required="an allow rule matching this capability", - actual="no rule matched", - suggestion=( - f"Add an allow rule for safety_class=" - f"{capability.safety_class.value!r} in your policy file" - ), - reason_code=str(DenialReason.NO_MATCHING_RULE), - ) - ] - primary_code = str(DenialReason.NO_MATCHING_RULE) - - remediation = [fc.suggestion for fc in explanation_failures] - narrative = ( - f"Request for '{capability.capability_id}' by '{pid}' would be denied " - f"(rule: '{rule_name}'): " - + "; ".join(fc.suggestion for fc in explanation_failures) - + "." - ) - return DenialExplanation( - denied=True, - rule_name=rule_name, - failed_conditions=explanation_failures, - remediation=remediation, - narrative=narrative, - reason_code=primary_code, + would_allow = False + + return build_denial_explanation( + self._rules, + request, + capability, + principal, + justification=justification, + would_allow=would_allow, ) diff --git a/src/agent_kernel/policy_dsl_explain.py b/src/agent_kernel/policy_dsl_explain.py new file mode 100644 index 0000000..79d39ab --- /dev/null +++ b/src/agent_kernel/policy_dsl_explain.py @@ -0,0 +1,214 @@ +"""Denial-explanation logic for the declarative policy engine. + +Split out of :mod:`policy_dsl` to keep the engine module ≤ 300 lines +(AGENTS.md). The :func:`build_denial_explanation` function performs the +rule traversal that drives :meth:`DeclarativePolicyEngine.explain`. The +engine itself remains the only public entry point — this module is an +implementation detail. +""" + +from __future__ import annotations + +from .errors import PolicyDenied +from .models import ( + Capability, + CapabilityRequest, + DenialExplanation, + FailedCondition, + Principal, +) +from .policy_dsl_parser import PolicyMatch, PolicyRule +from .policy_reasons import DenialReason + + +def build_denial_explanation( + rules: list[PolicyRule], + request: CapabilityRequest, + capability: Capability, + principal: Principal, + *, + justification: str, + would_allow: bool, +) -> DenialExplanation: + """Build a :class:`DenialExplanation` for a request that just denied. + + Args: + rules: The engine's ordered rule list. + request: The capability request. + capability: The target capability. + principal: The requesting principal. + justification: Free-text justification. + would_allow: ``True`` if a fresh :meth:`evaluate` call would allow. + Callers use this to short-circuit the explanation early; this + function still receives it because it must return the "would + be allowed" explanation in that case. + """ + if would_allow: + return DenialExplanation( + denied=False, + rule_name="", + failed_conditions=[], + remediation=[], + narrative=( + f"Request for '{capability.capability_id}' by " + f"'{principal.principal_id}' would be allowed." + ), + ) + + roles = set(principal.roles) + pid = principal.principal_id + explanation_failures: list[FailedCondition] = [] + rule_name = "default-deny" + primary_code: str | None = str(DenialReason.NO_MATCHING_RULE) + + for rule in rules: + m = rule.match + if m.safety_class is not None and capability.safety_class not in m.safety_class: + continue + if m.sensitivity is not None and capability.sensitivity not in m.sensitivity: + continue + + rule_failures = _collect_rule_failures( + m, + roles=roles, + request=request, + principal=principal, + justification=justification, + ) + + if rule.action == "deny": + if not rule_failures: + rule_name = rule.name + explanation_failures = [ + FailedCondition( + condition="denied_by_rule", + required=f"request must NOT match deny rule '{rule.name}'", + actual=f"matched deny rule '{rule.name}'", + suggestion=( + rule.reason + or f"Remove or narrow deny rule '{rule.name}' so this " + f"request does not match it" + ), + reason_code=str(DenialReason.EXPLICIT_DENY_RULE), + ) + ] + primary_code = str(DenialReason.EXPLICIT_DENY_RULE) + break + # Partial-match deny rule: it did NOT cause the denial. Skip + # so we don't suggest changes that would actually trigger it. + continue + + # Allow rule (structurally matched, conditions unmet) — report it. + rule_name = rule.name + explanation_failures = rule_failures + primary_code = rule_failures[0].reason_code if rule_failures else None + break + + if not explanation_failures: + explanation_failures = [ + FailedCondition( + condition="no_matching_rule", + required="an allow rule matching this capability", + actual="no rule matched", + suggestion=( + f"Add an allow rule for safety_class=" + f"{capability.safety_class.value!r} in your policy file" + ), + reason_code=str(DenialReason.NO_MATCHING_RULE), + ) + ] + primary_code = str(DenialReason.NO_MATCHING_RULE) + + remediation = [fc.suggestion for fc in explanation_failures] + narrative = ( + f"Request for '{capability.capability_id}' by '{pid}' would be denied " + f"(rule: '{rule_name}'): " + "; ".join(fc.suggestion for fc in explanation_failures) + "." + ) + return DenialExplanation( + denied=True, + rule_name=rule_name, + failed_conditions=explanation_failures, + remediation=remediation, + narrative=narrative, + reason_code=primary_code, + ) + + +def _collect_rule_failures( + m: PolicyMatch, + *, + roles: set[str], + request: CapabilityRequest, + principal: Principal, + justification: str, +) -> list[FailedCondition]: + """Return the list of unmet conditions on ``m`` for this principal+request.""" + pid = principal.principal_id + failures: list[FailedCondition] = [] + + if m.roles is not None and not (roles & set(m.roles)): + failures.append( + FailedCondition( + condition="roles", + required=list(m.roles), + actual=sorted(roles), + suggestion=f"Add one of {m.roles!r} to roles for principal '{pid}'", + reason_code=str(DenialReason.MISSING_ROLE), + ) + ) + if m.attributes is not None: + for k, v in m.attributes.items(): + attr_val = principal.attributes.get(k) + if attr_val is None or (v != "*" and attr_val != v): + failures.append( + FailedCondition( + condition=f"attribute:{k}", + required=v, + actual=attr_val if attr_val is not None else "", + suggestion=f"Set attribute '{k}'={v!r} on principal '{pid}'", + reason_code=str(DenialReason.MISSING_ATTRIBUTE), + ) + ) + if m.intent is not None and (request.intent is None or request.intent not in m.intent): + failures.append( + FailedCondition( + condition="intent", + required=list(m.intent), + actual=request.intent if request.intent is not None else "", + suggestion=f"Set CapabilityRequest.intent to one of {m.intent!r}", + reason_code=str(DenialReason.INTENT_NOT_ALLOWED), + ) + ) + if m.scope is not None: + for k, v in m.scope.items(): + scope_val = request.scope.get(k) + if scope_val is None or (v != "*" and scope_val != v): + failures.append( + FailedCondition( + condition=f"scope:{k}", + required=v, + actual=scope_val if scope_val is not None else "", + suggestion=f"Set CapabilityRequest.scope[{k!r}]={v!r}", + reason_code=str(DenialReason.SCOPE_NOT_ALLOWED), + ) + ) + if m.min_justification is not None: + stripped = len(justification.strip()) + if stripped < m.min_justification: + failures.append( + FailedCondition( + condition="min_justification", + required=m.min_justification, + actual=stripped, + suggestion=( + f"Provide justification with at least " + f"{m.min_justification} characters (currently {stripped})" + ), + reason_code=str(DenialReason.INSUFFICIENT_JUSTIFICATION), + ) + ) + return failures + + +# Re-export so policy_dsl.py only imports one symbol from this module. +__all__ = ["build_denial_explanation", "PolicyDenied"] diff --git a/src/agent_kernel/policy_dsl_parser.py b/src/agent_kernel/policy_dsl_parser.py new file mode 100644 index 0000000..3d537ac --- /dev/null +++ b/src/agent_kernel/policy_dsl_parser.py @@ -0,0 +1,277 @@ +"""Parsing and schema for the declarative policy engine. + +Split out of :mod:`policy_dsl` to keep individual modules ≤ 300 lines +(AGENTS.md quality bar). The :class:`PolicyMatch` and :class:`PolicyRule` +dataclasses live here because the parser produces them; the engine in +:mod:`policy_dsl` re-exports them so the public API is unchanged. +""" + +from __future__ import annotations + +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Literal + +from .enums import SafetyClass, SensitivityTag +from .errors import PolicyConfigError + +POLICY_EXTRA_HINT = ( + "Install the policy extra to enable file loaders: pip install 'weaver-kernel[policy]'" +) +"""Install hint surfaced when the optional ``policy`` extra is missing.""" + + +@dataclass(slots=True) +class PolicyMatch: + """Conditions that must ALL be satisfied for a rule to match a request. + + ``None`` fields are wildcards — they match any value. + List fields use ANY-of semantics (e.g. ``roles = ["a", "b"]`` matches + if the principal has *at least one* of those roles). + """ + + safety_class: list[SafetyClass] | None = None + """Match if ``capability.safety_class`` is in this list.""" + + sensitivity: list[SensitivityTag] | None = None + """Match if ``capability.sensitivity`` is in this list.""" + + roles: list[str] | None = None + """Match if the principal has ANY of these roles.""" + + attributes: dict[str, str] | None = None + """Match if the principal has ALL these attributes. + Use ``"*"`` as the value to require the attribute with any value.""" + + min_justification: int | None = None + """Match if ``len(justification.strip()) >= min_justification``.""" + + intent: list[str] | None = None + """Match if :attr:`CapabilityRequest.intent` is in this list. + + A non-``None`` list means "this rule is intent-aware". A request with + :attr:`CapabilityRequest.intent` ``None`` never matches an intent-aware + rule (so policies that require an intent fail closed for unstructured + legacy callers). + """ + + scope: dict[str, str] | None = None + """Match if :attr:`CapabilityRequest.scope` contains ALL these key/value pairs. + Use ``"*"`` as the value to require the key with any value. + """ + + +@dataclass(slots=True) +class PolicyRule: + """A single declarative policy rule.""" + + name: str + match: PolicyMatch + action: Literal["allow", "deny"] + constraints: dict[str, Any] = field(default_factory=dict) + """Extra constraints merged into the :class:`PolicyDecision` on allow.""" + + reason: str = "" + """Human-readable reason embedded in :class:`PolicyDenied` on deny.""" + + +def parse_engine_data(data: dict[str, Any]) -> tuple[list[PolicyRule], Literal["allow", "deny"]]: + """Parse a top-level policy mapping into ``(rules, default)``. + + Args: + data: Mapping with a ``rules`` list and an optional ``default`` key. + + Returns: + A ``(rules, default)`` pair. ``default`` is the literal ``"allow"`` or + ``"deny"``; ``rules`` is the parsed rule list (may be empty). + + Raises: + PolicyConfigError: If the mapping is malformed. + """ + raw_default = data.get("default", "deny") + if raw_default not in ("allow", "deny"): + raise PolicyConfigError(f"'default' must be 'allow' or 'deny', got {raw_default!r}.") + default: Literal["allow", "deny"] = raw_default + + raw_rules = data.get("rules", []) + if not isinstance(raw_rules, list): + raise PolicyConfigError("'rules' must be a list.") + + rules = [parse_rule(r, index=i) for i, r in enumerate(raw_rules)] + return rules, default + + +def parse_rule(raw: Any, *, index: int) -> PolicyRule: + """Parse a single rule mapping into a :class:`PolicyRule`. + + Args: + raw: The rule mapping. + index: Position in the rules list (used in error messages when the + rule has no explicit ``name``). + + Raises: + PolicyConfigError: If the rule is malformed. + """ + if not isinstance(raw, dict): + raise PolicyConfigError(f"Rule[{index}] must be a mapping, got {type(raw).__name__}.") + name: str = raw.get("name", f"rule-{index}") + action = raw.get("action") + if action not in ("allow", "deny"): + raise PolicyConfigError( + f"Rule '{name}': 'action' must be 'allow' or 'deny', got {action!r}." + ) + raw_match = raw.get("match", {}) + if not isinstance(raw_match, dict): + raise PolicyConfigError(f"Rule '{name}': 'match' must be a mapping.") + + match = PolicyMatch( + safety_class=_parse_enum_list(raw_match, "safety_class", SafetyClass, rule_name=name), + sensitivity=_parse_enum_list(raw_match, "sensitivity", SensitivityTag, rule_name=name), + roles=_parse_str_list(raw_match, "roles", rule_name=name), + attributes=_parse_str_map(raw_match, "attributes", rule_name=name), + min_justification=_parse_min_justification(raw_match, rule_name=name), + intent=_parse_str_list(raw_match, "intent", rule_name=name), + scope=_parse_str_map(raw_match, "scope", rule_name=name), + ) + + constraints_raw = raw.get("constraints", {}) + if not isinstance(constraints_raw, dict): + raise PolicyConfigError( + f"Rule '{name}': 'constraints' must be a mapping, " + f"got {type(constraints_raw).__name__}." + ) + + return PolicyRule( + name=name, + match=match, + action=action, + constraints=dict(constraints_raw), + reason=raw.get("reason", ""), + ) + + +def load_yaml_data(path: Path) -> dict[str, Any]: + """Read a YAML file into a top-level mapping. + + Requires ``pyyaml``: ``pip install 'weaver-kernel[policy]'``. The import is + deferred so that ``import agent_kernel`` works without the policy extra. + + Raises: + PolicyConfigError: If the file is unreadable, malformed, or pyyaml + is not installed. + """ + try: + import yaml + except ImportError as exc: + raise PolicyConfigError(POLICY_EXTRA_HINT) from exc + + try: + text = path.read_text(encoding="utf-8") + data: Any = yaml.safe_load(text) + except OSError as exc: + raise PolicyConfigError(f"Cannot read policy file '{path}': {exc}") from exc + except yaml.YAMLError as exc: + raise PolicyConfigError(f"YAML parse error in '{path}': {exc}") from exc + if not isinstance(data, dict): + raise PolicyConfigError(f"Policy file '{path}' must be a YAML mapping.") + return data + + +def load_toml_data(path: Path) -> dict[str, Any]: + """Read a TOML file into a top-level mapping. + + Uses stdlib ``tomllib`` on 3.11+ and ``tomli`` on 3.10. The latter is + included by the ``policy`` extra. + + Raises: + PolicyConfigError: If the file is unreadable, malformed, or no TOML + parser is available. + """ + try: + if sys.version_info >= (3, 11): + import tomllib as _toml + else: + import tomli as _toml + except ImportError as exc: + raise PolicyConfigError(POLICY_EXTRA_HINT) from exc + + try: + with path.open("rb") as fh: + data = _toml.load(fh) + except OSError as exc: + raise PolicyConfigError(f"Cannot read policy file '{path}': {exc}") from exc + except Exception as exc: # TOMLDecodeError is not a stable import target + raise PolicyConfigError(f"TOML parse error in '{path}': {exc}") from exc + if not isinstance(data, dict): + raise PolicyConfigError(f"Policy file '{path}' must be a TOML table.") + return data + + +# ── Field validators ──────────────────────────────────────────────────────── + + +def _parse_enum_list( + raw_match: dict[str, Any], + key: str, + enum_cls: type[SafetyClass] | type[SensitivityTag], + *, + rule_name: str, +) -> list[Any] | None: + if key not in raw_match: + return None + try: + return [enum_cls(v) for v in raw_match[key]] + except (ValueError, TypeError) as exc: + raise PolicyConfigError(f"Rule '{rule_name}': invalid {key} value: {exc}") from exc + + +def _parse_str_list(raw_match: dict[str, Any], key: str, *, rule_name: str) -> list[str] | None: + if key not in raw_match: + return None + value = raw_match[key] + if not isinstance(value, list) or not all(isinstance(item, str) for item in value): + raise PolicyConfigError( + f"Rule '{rule_name}': {key!r} must be a list of strings, got {type(value).__name__}." + ) + return list(value) + + +def _parse_str_map( + raw_match: dict[str, Any], key: str, *, rule_name: str +) -> dict[str, str] | None: + if key not in raw_match: + return None + value = raw_match[key] + if not isinstance(value, dict) or not all( + isinstance(k, str) and isinstance(v, str) for k, v in value.items() + ): + raise PolicyConfigError( + f"Rule '{rule_name}': {key!r} must be a mapping of string keys to string values." + ) + return dict(value) + + +def _parse_min_justification(raw_match: dict[str, Any], *, rule_name: str) -> int | None: + if "min_justification" not in raw_match: + return None + value = raw_match["min_justification"] + # ``bool`` is a subclass of ``int`` in Python; reject it explicitly + # so ``min_justification: true`` does not silently pass. + if not isinstance(value, int) or isinstance(value, bool): + raise PolicyConfigError( + f"Rule '{rule_name}': 'min_justification' must be an integer, " + f"got {type(value).__name__}." + ) + return value + + +__all__ = [ + "POLICY_EXTRA_HINT", + "PolicyMatch", + "PolicyRule", + "parse_engine_data", + "parse_rule", + "load_yaml_data", + "load_toml_data", +] diff --git a/src/agent_kernel/rate_limit.py b/src/agent_kernel/rate_limit.py new file mode 100644 index 0000000..b74d664 --- /dev/null +++ b/src/agent_kernel/rate_limit.py @@ -0,0 +1,78 @@ +"""Sliding-window rate limiter used by :class:`DefaultPolicyEngine`. + +Split out of :mod:`policy` to keep modules ≤ 300 lines (AGENTS.md). +The default per-safety-class limits and the service multiplier live here +because they are tightly bound to :class:`RateLimiter` semantics. +""" + +from __future__ import annotations + +import time +from collections import defaultdict +from collections.abc import Callable +from dataclasses import dataclass + +from .enums import SafetyClass + +DEFAULT_RATE_LIMITS: dict[SafetyClass, tuple[int, float]] = { + SafetyClass.READ: (60, 60.0), + SafetyClass.WRITE: (10, 60.0), + SafetyClass.DESTRUCTIVE: (2, 60.0), +} +"""Default rate limits per safety class: ``(invocations, window_seconds)``.""" + +SERVICE_RATE_MULTIPLIER = 10 +"""Multiplier applied for principals with the ``service`` role.""" + + +@dataclass(slots=True) +class _RateEntry: + """Timestamps for a single rate-limit key.""" + + timestamps: list[float] + + +class RateLimiter: + """Sliding-window rate limiter using a monotonic clock. + + Args: + clock: Callable returning the current time in seconds. + Defaults to :func:`time.monotonic`. + """ + + def __init__(self, clock: Callable[[], float] | None = None) -> None: + self._clock = clock or time.monotonic + self._windows: dict[str, _RateEntry] = defaultdict(lambda: _RateEntry(timestamps=[])) + + def check(self, key: str, limit: int, window_seconds: float) -> bool: + """Return ``True`` if the next invocation would be within the limit. + + Prunes expired timestamps as a side-effect. + + Args: + key: Rate-limit key (e.g. ``"principal:capability"``). + limit: Maximum allowed invocations per window. + window_seconds: Sliding window duration in seconds. + + Returns: + ``True`` if under limit, ``False`` if limit would be exceeded. + """ + now = self._clock() + cutoff = now - window_seconds + entry = self._windows[key] + entry.timestamps = [t for t in entry.timestamps if t > cutoff] + if not entry.timestamps: + del self._windows[key] + return True + return len(entry.timestamps) < limit + + def record(self, key: str) -> None: + """Record an invocation for *key*.""" + self._windows[key].timestamps.append(self._clock()) + + +__all__ = [ + "DEFAULT_RATE_LIMITS", + "SERVICE_RATE_MULTIPLIER", + "RateLimiter", +] diff --git a/src/agent_kernel/registry.py b/src/agent_kernel/registry.py index 8687ae8..d603634 100644 --- a/src/agent_kernel/registry.py +++ b/src/agent_kernel/registry.py @@ -1,22 +1,99 @@ -"""Capability registry: register, lookup, and keyword-based matching.""" +"""Capability registry: register, lookup, namespaced discovery, and ranked search. + +Supports dot-notation namespaces (``"billing.invoices.list"``), deferred +namespace loaders for large tool ecosystems, and a BM25-flavoured score +that weights matches on ``capability_id`` and ``tags`` higher than +``description``. Flat (un-namespaced) capability IDs continue to work — they +are treated as living in a single-segment namespace. +""" from __future__ import annotations +import math import re +from collections.abc import Callable + +from .errors import ( + CapabilityAlreadyRegistered, + CapabilityNotFound, + NamespaceNotFound, +) +from .models import Capability, CapabilityRequest, NamespaceMetadata + +# Common English stop words that add noise to keyword search. Kept small +# (only words an LLM would routinely type into a goal) to avoid suppressing +# domain terms. +_STOP_WORDS: frozenset[str] = frozenset( + { + "a", + "an", + "and", + "any", + "are", + "as", + "at", + "be", + "by", + "for", + "from", + "get", + "give", + "i", + "in", + "is", + "it", + "me", + "my", + "of", + "on", + "or", + "please", + "show", + "that", + "the", + "this", + "to", + "want", + "with", + } +) -from .errors import CapabilityAlreadyRegistered, CapabilityNotFound -from .models import Capability, CapabilityRequest +# Field weights for BM25-flavoured scoring. Matches on capability_id and tags +# carry the most signal; description text is the noisiest. +_WEIGHT_ID = 4.0 +_WEIGHT_NAME = 2.0 +_WEIGHT_TAGS = 3.0 +_WEIGHT_DESCRIPTION = 1.0 + +# BM25 tunables (Lucene defaults). Held constant — randomness in matching is +# forbidden by AGENTS.md. +_BM25_K1 = 1.5 +_BM25_B = 0.75 class CapabilityRegistry: """Stores and retrieves :class:`Capability` objects. - Capabilities are registered by their ``capability_id`` and can be looked - up directly or discovered via keyword search against the goal description. + Capabilities are registered by their dot-notation ``capability_id`` + (e.g. ``"billing.invoices.list"``) and can be: + + - looked up directly via :meth:`get`, + - enumerated globally via :meth:`list_all`, + - enumerated per namespace via :meth:`list_namespaces` / + :meth:`list_namespace`, + - discovered via ranked text search (:meth:`search`). + + Flat IDs without a ``"."`` continue to work — they live in a single- + segment namespace named after themselves. """ def __init__(self) -> None: self._store: dict[str, Capability] = {} + self._namespaces: dict[str, NamespaceMetadata] = {} + # Reset cached search statistics when registrations change. + self._search_cache_dirty: bool = True + self._avg_doc_len: float = 0.0 + self._doc_freq: dict[str, int] = {} # ── Registration ────────────────────────────────────────────────────────── @@ -35,6 +112,7 @@ def register(self, capability: Capability) -> None: "Use a unique capability_id." ) self._store[capability.capability_id] = capability + self._search_cache_dirty = True def register_many(self, capabilities: list[Capability]) -> None: """Register multiple capabilities at once. @@ -45,11 +123,50 @@ def register_many(self, capabilities: list[Capability]) -> None: for cap in capabilities: self.register(cap) + def register_namespace( + self, + prefix: str, + *, + description: str = "", + loader: Callable[[], list[Capability]] | None = None, + ) -> None: + """Declare a namespace, optionally with a deferred loader. + + The loader (if given) is invoked exactly once, the first time the + namespace is searched, listed, or otherwise traversed. This lets a + host process advertise hundreds of namespaces without paying the + registration cost up front. + + Args: + prefix: Dot-notation namespace prefix (e.g. ``"billing"``). + description: Optional human-readable description. + loader: Optional zero-arg callable returning capabilities to + register on first access. Every returned capability's + ``capability_id`` must start with ``prefix`` (followed by ``.``) + or equal ``prefix`` exactly. + + Raises: + CapabilityAlreadyRegistered: If the namespace is already declared. + """ + if prefix in self._namespaces: + raise CapabilityAlreadyRegistered( + f"Namespace '{prefix}' is already declared. Choose a unique prefix." + ) + self._namespaces[prefix] = NamespaceMetadata( + prefix=prefix, + description=description, + loader=loader, + loaded=loader is None, + ) + # ── Lookup ──────────────────────────────────────────────────────────────── def get(self, capability_id: str) -> Capability: """Retrieve a capability by its ID. + If the capability ID falls under a declared namespace whose deferred + loader has not yet run, the loader is invoked first. + Args: capability_id: The capability's stable identifier. @@ -59,6 +176,8 @@ def get(self, capability_id: str) -> Capability: Raises: CapabilityNotFound: If no capability with that ID exists. """ + if capability_id not in self._store: + self._maybe_load_for(capability_id) try: return self._store[capability_id] except KeyError: @@ -68,21 +187,86 @@ def get(self, capability_id: str) -> Capability: ) from None def list_all(self) -> list[Capability]: - """Return all registered capabilities in registration order.""" + """Return every registered capability in registration order. + + Deferred-loader namespaces are *not* expanded by this call — to keep + ``list_all`` cheap. Use :meth:`list_namespace` to force a load. + """ return list(self._store.values()) + # ── Namespaces ──────────────────────────────────────────────────────────── + + def list_namespaces(self) -> list[str]: + """Return every top-level namespace prefix present in the registry. + + Combines namespaces inferred from registered capability IDs with + explicitly declared (:meth:`register_namespace`) prefixes. Returned + sorted for deterministic output. + """ + prefixes: set[str] = set() + for cap_id in self._store: + head, _, _ = cap_id.partition(".") + prefixes.add(head) + for ns in self._namespaces: + prefixes.add(ns.split(".", 1)[0]) + return sorted(prefixes) + + def list_namespace(self, prefix: str) -> list[Capability]: + """Return every capability whose ID lives under *prefix*. + + Triggers any deferred loader for *prefix* (or for the deepest declared + ancestor of *prefix*) before returning. A capability_id ``cap`` is + considered to live under ``prefix`` when ``cap == prefix`` or + ``cap.startswith(prefix + ".")``. + + Args: + prefix: Dot-notation namespace prefix. + + Returns: + Capabilities under the prefix, in registration order. + + Raises: + NamespaceNotFound: If no declared namespace or registered capability + lives under *prefix*. + """ + self._maybe_load_namespace(prefix) + results = [ + cap + for cap_id, cap in self._store.items() + if cap_id == prefix or cap_id.startswith(prefix + ".") + ] + if not results and prefix not in self._namespaces: + raise NamespaceNotFound( + f"Namespace '{prefix}' has no registered capabilities and is not declared. " + "Use register_namespace(prefix=...) or register a capability under it." + ) + return results + # ── Keyword matching ────────────────────────────────────────────────────── - def search(self, goal: str, *, max_results: int = 10) -> list[CapabilityRequest]: - """Search for capabilities matching a goal string. + def search( + self, + goal: str, + *, + max_results: int = 10, + offset: int = 0, + ) -> list[CapabilityRequest]: + """Search for capabilities matching *goal*. + + Tokenises *goal* (lower-cased word tokens, stop-words stripped) and + scores every capability using a BM25-flavoured ranker that weights + matches on ``capability_id`` and ``tags`` more heavily than + ``description``. Capabilities tied on score are returned in + ``capability_id`` order for determinism. - Splits *goal* into tokens and scores capabilities by how many tokens - appear in their ``capability_id``, ``name``, ``description``, or - ``tags``. Returns the top results as :class:`CapabilityRequest` objects. + Triggers any deferred namespace loader whose prefix overlaps the goal + tokens before scoring. Args: goal: Free-text description of the user's intent. max_results: Maximum number of results to return. + offset: Number of leading results to skip (paginates large + registries). Returns: Ordered list (highest score first) of :class:`CapabilityRequest`. @@ -91,13 +275,21 @@ def search(self, goal: str, *, max_results: int = 10) -> list[CapabilityRequest] if not tokens: return [] - scored: list[tuple[int, Capability]] = [] + self._load_namespaces_overlapping(tokens) + + if self._search_cache_dirty: + self._rebuild_search_index() + + scored: list[tuple[float, Capability]] = [] for cap in self._store.values(): score = self._score(cap, tokens) if score > 0: scored.append((score, cap)) scored.sort(key=lambda x: (-x[0], x[1].capability_id)) + + if offset: + scored = scored[offset:] return [ CapabilityRequest(capability_id=cap.capability_id, goal=goal) for _, cap in scored[:max_results] @@ -107,18 +299,91 @@ def search(self, goal: str, *, max_results: int = 10) -> list[CapabilityRequest] @staticmethod def _tokenize(text: str) -> list[str]: - """Split text into lower-case word tokens.""" - return re.findall(r"[a-z0-9]+", text.lower()) + """Split *text* into lower-case word tokens with stop-words removed.""" + return [t for t in re.findall(r"[a-z0-9]+", text.lower()) if t not in _STOP_WORDS] @staticmethod - def _score(cap: Capability, tokens: list[str]) -> int: - """Return a match score for a capability against query tokens.""" - corpus = " ".join( - [ - cap.capability_id, - cap.name, - cap.description, - ] - + cap.tags - ).lower() - return sum(1 for t in tokens if t in corpus) + def _corpus_fields(cap: Capability) -> tuple[list[str], list[str], list[str], list[str]]: + """Return per-field token lists used for scoring (id, name, tags, description).""" + tokenize = CapabilityRegistry._tokenize + return ( + tokenize(cap.capability_id.replace(".", " ").replace("_", " ")), + tokenize(cap.name), + tokenize(" ".join(cap.tags)), + tokenize(cap.description), + ) + + def _rebuild_search_index(self) -> None: + """Refresh BM25 document statistics after the registry mutates.""" + total_len = 0 + doc_freq: dict[str, int] = {} + for cap in self._store.values(): + id_tokens, name_tokens, tag_tokens, desc_tokens = self._corpus_fields(cap) + total_len += len(id_tokens) + len(name_tokens) + len(tag_tokens) + len(desc_tokens) + unique_tokens = set(id_tokens) | set(name_tokens) | set(tag_tokens) | set(desc_tokens) + for tok in unique_tokens: + doc_freq[tok] = doc_freq.get(tok, 0) + 1 + n = len(self._store) or 1 + self._avg_doc_len = total_len / n + self._doc_freq = doc_freq + self._search_cache_dirty = False + + def _score(self, cap: Capability, tokens: list[str]) -> float: + """Return a BM25-flavoured match score for *cap* against query *tokens*.""" + id_tokens, name_tokens, tag_tokens, desc_tokens = self._corpus_fields(cap) + doc_tokens = id_tokens + name_tokens + tag_tokens + desc_tokens + if not doc_tokens: + return 0.0 + doc_len = len(doc_tokens) + n = len(self._store) or 1 + score = 0.0 + for tok in tokens: + df = self._doc_freq.get(tok, 0) + if df == 0: + continue + # Per-field term frequency with field-specific weights. + tf = ( + _WEIGHT_ID * id_tokens.count(tok) + + _WEIGHT_NAME * name_tokens.count(tok) + + _WEIGHT_TAGS * tag_tokens.count(tok) + + _WEIGHT_DESCRIPTION * desc_tokens.count(tok) + ) + if tf == 0: + continue + idf = math.log(1 + (n - df + 0.5) / (df + 0.5)) + norm = 1 - _BM25_B + _BM25_B * (doc_len / (self._avg_doc_len or 1.0)) + score += idf * ((tf * (_BM25_K1 + 1)) / (tf + _BM25_K1 * norm)) + # Exact-prefix bonus: capability_id starts with the joined query. + joined = ".".join(tokens) + if joined and cap.capability_id.startswith(joined): + score += 1.0 + return score + + def _maybe_load_for(self, capability_id: str) -> None: + """Trigger any deferred loader whose prefix covers *capability_id*.""" + head, _, _ = capability_id.partition(".") + candidates = [head, capability_id] + for prefix in candidates: + if prefix in self._namespaces: + self._maybe_load_namespace(prefix) + + def _maybe_load_namespace(self, prefix: str) -> None: + """Invoke the deferred loader for *prefix* if it has not run yet.""" + meta = self._namespaces.get(prefix) + if meta is None or meta.loaded or meta.loader is None: + return + loader = meta.loader + # Mark as loaded *before* calling so a recursive load doesn't re-enter. + meta.loaded = True + for cap in loader(): + self.register(cap) + + def _load_namespaces_overlapping(self, tokens: list[str]) -> None: + """Load any deferred namespace whose prefix shares a token with *tokens*.""" + token_set = set(tokens) + for prefix, meta in list(self._namespaces.items()): + if meta.loaded: + continue + head_tokens = set(self._tokenize(prefix.replace(".", " ").replace("_", " "))) + if head_tokens & token_set: + self._maybe_load_namespace(prefix) diff --git a/tests/test_federation.py b/tests/test_federation.py new file mode 100644 index 0000000..71ce26d --- /dev/null +++ b/tests/test_federation.py @@ -0,0 +1,392 @@ +"""Tests for the capability marketplace — manifest format & local registry (#52).""" + +from __future__ import annotations + +import asyncio + +import pytest + +from agent_kernel import ( + Capability, + CapabilityAlreadyRegistered, + CapabilityDescriptor, + CapabilityManifest, + CapabilityRegistry, + HMACTokenProvider, + ImplementationRef, + InMemoryDriver, + Kernel, + ManifestError, + Principal, + SafetyClass, + SensitivityTag, + StaticRouter, + TokenInvalid, + TrustPolicyError, + build_manifest, + import_manifest, + merge_sensitivity, +) +from agent_kernel.drivers.base import ExecutionContext +from agent_kernel.federation import MANIFEST_VERSION +from agent_kernel.models import CapabilityRequest + +# ── Helpers ─────────────────────────────────────────────────────────────────── + + +def _make_cap(cap_id: str, **kwargs: object) -> Capability: + defaults: dict[str, object] = { + "name": cap_id.replace(".", " ").title(), + "description": f"Description for {cap_id}", + "safety_class": SafetyClass.READ, + } + defaults.update(kwargs) + return Capability(capability_id=cap_id, **defaults) # type: ignore[arg-type] + + +def _remote_kernel_with(*caps: Capability) -> Kernel: + reg = CapabilityRegistry() + for cap in caps: + reg.register(cap) + return Kernel( + registry=reg, + token_provider=HMACTokenProvider(secret="remote-kernel-secret"), + router=StaticRouter(), + kernel_id="agent-b", + ) + + +# ── Manifest serialisation ──────────────────────────────────────────────────── + + +def test_capability_descriptor_roundtrip() -> None: + descriptor = CapabilityDescriptor( + capability_id="billing.invoices.list", + name="List Invoices", + description="List recent invoices", + safety_class=SafetyClass.READ, + sensitivity=SensitivityTag.PII, + tags=["billing", "invoices"], + parameters_schema={"type": "object", "properties": {"limit": {"type": "integer"}}}, + ) + restored = CapabilityDescriptor.from_dict(descriptor.to_dict()) + assert restored == descriptor + + +def test_capability_manifest_to_dict_is_json_compatible() -> None: + import json + + manifest = CapabilityManifest( + kernel_id="agent-a", + version=MANIFEST_VERSION, + endpoint="https://agent-a.example/kernel", + trust_level="verified", + capabilities=[ + CapabilityDescriptor( + capability_id="billing.list_invoices", + name="List Invoices", + description="List recent invoices", + safety_class=SafetyClass.READ, + sensitivity=SensitivityTag.PII, + tags=["billing"], + ), + ], + ) + payload = json.dumps(manifest.to_dict()) + restored = CapabilityManifest.from_dict(json.loads(payload)) + assert restored == manifest + + +def test_build_manifest_strips_internal_implementation_details() -> None: + reg = CapabilityRegistry() + reg.register( + Capability( + capability_id="billing.list_invoices", + name="List Invoices", + description="List recent invoices", + safety_class=SafetyClass.READ, + sensitivity=SensitivityTag.PII, + tags=["billing"], + impl=ImplementationRef(driver_id="secret_internal_driver", operation="op_x"), + ) + ) + manifest = build_manifest( + kernel_id="agent-a", + registry=reg, + endpoint="https://agent-a.example/kernel", + ) + payload = manifest.to_dict() + serialised = repr(payload) + assert "secret_internal_driver" not in serialised + assert "op_x" not in serialised + # Public-facing fields are present. + cap_dict = payload["capabilities"][0] + assert cap_dict["capability_id"] == "billing.list_invoices" + assert cap_dict["sensitivity"] == SensitivityTag.PII.value + + +def test_build_manifest_preserves_registration_order() -> None: + reg = CapabilityRegistry() + for cid in ["c.three", "a.one", "b.two"]: + reg.register(_make_cap(cid)) + manifest = build_manifest(kernel_id="agent-a", registry=reg, endpoint="https://agent-a/k") + assert [c.capability_id for c in manifest.capabilities] == ["c.three", "a.one", "b.two"] + + +# ── Importing manifests ─────────────────────────────────────────────────────── + + +def test_import_manifest_registers_capabilities_with_driver_routing() -> None: + remote_kernel = _remote_kernel_with(_make_cap("billing.list_invoices")) + manifest = remote_kernel.advertise(endpoint="https://agent-b.example/kernel") + + local_reg = CapabilityRegistry() + imported = import_manifest( + manifest=manifest, + registry=local_reg, + driver_id="remote_b", + trust_policy="most_restrictive", + ) + assert len(imported) == 1 + cap = local_reg.get("billing.list_invoices") + assert cap.impl is not None + assert cap.impl.driver_id == "remote_b" + assert cap.impl.operation == "billing.list_invoices" + + +def test_import_manifest_rejects_unknown_trust_policy() -> None: + manifest = CapabilityManifest( + kernel_id="agent-b", + version=MANIFEST_VERSION, + endpoint="https://agent-b/k", + capabilities=[], + ) + with pytest.raises(TrustPolicyError, match="Unknown trust_policy"): + import_manifest( + manifest=manifest, + registry=CapabilityRegistry(), + driver_id="x", + trust_policy="totally_made_up", # type: ignore[arg-type] + ) + + +def test_import_manifest_rejects_unsupported_version() -> None: + manifest = CapabilityManifest( + kernel_id="agent-b", + version="999", + endpoint="https://agent-b/k", + capabilities=[], + ) + with pytest.raises(ManifestError, match="version '999' is not supported"): + import_manifest(manifest=manifest, registry=CapabilityRegistry(), driver_id="x") + + +def test_import_manifest_rejects_empty_endpoint() -> None: + manifest = CapabilityManifest( + kernel_id="agent-b", + version=MANIFEST_VERSION, + endpoint="", + capabilities=[], + ) + with pytest.raises(ManifestError, match="has no endpoint"): + import_manifest(manifest=manifest, registry=CapabilityRegistry(), driver_id="x") + + +def test_import_manifest_duplicate_capability_raises() -> None: + local = CapabilityRegistry() + local.register(_make_cap("billing.list_invoices")) + remote = _remote_kernel_with(_make_cap("billing.list_invoices")) + manifest = remote.advertise(endpoint="https://agent-b/k") + with pytest.raises(CapabilityAlreadyRegistered): + import_manifest(manifest=manifest, registry=local, driver_id="remote_b") + + +# ── Trust policies ──────────────────────────────────────────────────────────── + + +def test_trust_policy_most_restrictive_preserves_sensitivity() -> None: + remote = _remote_kernel_with(_make_cap("crm.contacts.list", sensitivity=SensitivityTag.PII)) + manifest = remote.advertise(endpoint="https://agent-b/k") + local_reg = CapabilityRegistry() + import_manifest( + manifest=manifest, + registry=local_reg, + driver_id="remote_b", + trust_policy="most_restrictive", + ) + assert local_reg.get("crm.contacts.list").sensitivity == SensitivityTag.PII + + +def test_trust_policy_local_only_strips_sensitivity() -> None: + remote = _remote_kernel_with(_make_cap("crm.contacts.list", sensitivity=SensitivityTag.PII)) + manifest = remote.advertise(endpoint="https://agent-b/k") + local_reg = CapabilityRegistry() + import_manifest( + manifest=manifest, + registry=local_reg, + driver_id="remote_b", + trust_policy="local_only", + ) + assert local_reg.get("crm.contacts.list").sensitivity == SensitivityTag.NONE + + +def test_trust_policy_remote_deferred_preserves_sensitivity() -> None: + remote = _remote_kernel_with(_make_cap("crm.contacts.list", sensitivity=SensitivityTag.PII)) + manifest = remote.advertise(endpoint="https://agent-b/k") + local_reg = CapabilityRegistry() + import_manifest( + manifest=manifest, + registry=local_reg, + driver_id="remote_b", + trust_policy="remote_deferred", + ) + assert local_reg.get("crm.contacts.list").sensitivity == SensitivityTag.PII + + +def test_merge_sensitivity_picks_strictest() -> None: + assert merge_sensitivity(SensitivityTag.NONE, SensitivityTag.PII) == SensitivityTag.PII + assert merge_sensitivity(SensitivityTag.PII, SensitivityTag.NONE) == SensitivityTag.PII + assert merge_sensitivity(SensitivityTag.PII, SensitivityTag.PCI) == SensitivityTag.PCI + assert merge_sensitivity(SensitivityTag.PCI, SensitivityTag.SECRETS) == SensitivityTag.SECRETS + assert merge_sensitivity(SensitivityTag.NONE, SensitivityTag.NONE) == SensitivityTag.NONE + + +# ── Kernel.advertise() / Kernel.import_remote() ─────────────────────────────── + + +def test_kernel_advertise_uses_kernel_id() -> None: + reg = CapabilityRegistry() + reg.register(_make_cap("billing.list_invoices")) + kernel = Kernel( + registry=reg, + token_provider=HMACTokenProvider(secret="k1"), + kernel_id="my-fancy-kernel", + ) + manifest = kernel.advertise(endpoint="https://my-kernel/k") + assert manifest.kernel_id == "my-fancy-kernel" + assert manifest.endpoint == "https://my-kernel/k" + assert manifest.version == MANIFEST_VERSION + + +def test_kernel_import_remote_registers_driver_and_route() -> None: + remote = _remote_kernel_with(_make_cap("billing.list_invoices")) + manifest = remote.advertise(endpoint="https://agent-b/k") + + local_reg = CapabilityRegistry() + local_router = StaticRouter(routes={}) + local = Kernel( + registry=local_reg, + token_provider=HMACTokenProvider(secret="local-secret"), + router=local_router, + kernel_id="agent-a", + ) + + remote_driver = InMemoryDriver(driver_id="remote_b") + remote_driver.register_handler( + "billing.list_invoices", + lambda ctx: [{"id": "INV-1", "amount": 10.0}], + ) + + imported = local.import_remote(manifest, driver=remote_driver, trust_policy="local_only") + assert [c.capability_id for c in imported] == ["billing.list_invoices"] + + # The driver-routing wiring is correct. + plan = local_router.route("billing.list_invoices") + assert plan.driver_ids == ["remote_b"] + + +def test_imported_capability_invokes_through_local_pipeline() -> None: + remote = _remote_kernel_with(_make_cap("billing.list_invoices")) + manifest = remote.advertise(endpoint="https://agent-b/k") + + local_reg = CapabilityRegistry() + local = Kernel( + registry=local_reg, + token_provider=HMACTokenProvider(secret="local-secret"), + router=StaticRouter(), + kernel_id="agent-a", + ) + driver = InMemoryDriver(driver_id="remote_b") + invoked = {"called_with": None} + + def list_invoices(ctx: ExecutionContext) -> list[dict[str, object]]: + invoked["called_with"] = ctx.capability_id # type: ignore[assignment] + return [{"id": "INV-1", "amount": 100.0, "email": "x@y.z"}] + + driver.register_handler("billing.list_invoices", list_invoices) + local.import_remote(manifest, driver=driver, trust_policy="local_only") + + principal = Principal(principal_id="alice", roles=["reader"], attributes={"tenant": "acme"}) + request = CapabilityRequest(capability_id="billing.list_invoices", goal="check invoices") + token = local.get_token(request, principal, justification="") + + async def run() -> object: + return await local.invoke( + token, + principal=principal, + args={"operation": "billing.list_invoices"}, + response_mode="table", + ) + + frame = asyncio.run(run()) + # Capability was routed to the imported driver. + assert invoked["called_with"] == "billing.list_invoices" + # Trace was recorded by the local kernel. + trace = local.explain(frame.action_id) # type: ignore[attr-defined] + assert trace.capability_id == "billing.list_invoices" + assert trace.driver_id == "remote_b" + + +def test_imported_capability_keeps_remote_sensitivity_under_most_restrictive() -> None: + """A `most_restrictive` import floors the imported cap's sensitivity at the remote tag. + + This is what makes the firewall apply the same redaction to imported PII + capabilities as the remote would. + """ + remote = _remote_kernel_with(_make_cap("crm.contacts.list", sensitivity=SensitivityTag.PII)) + manifest = remote.advertise(endpoint="https://agent-b/k") + local = Kernel( + registry=CapabilityRegistry(), + token_provider=HMACTokenProvider(secret="local-secret"), + kernel_id="agent-a", + ) + local.import_remote(manifest, driver=InMemoryDriver(driver_id="remote_b")) + imported_cap = local.list_capabilities()[0] + assert imported_cap.sensitivity == SensitivityTag.PII + + +# ── Token isolation across kernels (kernel-scoped HMAC) ─────────────────────── + + +def test_tokens_are_kernel_scoped_by_hmac_secret() -> None: + """A token signed by kernel A's HMAC provider must not verify on kernel B. + + `Kernel` instances with different secrets produce tokens that fail + signature verification on the other side, which is what makes + "kernel-scoped" tokens safe across an imported capability boundary. + """ + reg_a = CapabilityRegistry() + reg_a.register(_make_cap("billing.list_invoices")) + kernel_a = Kernel( + registry=reg_a, + token_provider=HMACTokenProvider(secret="secret-a"), + router=StaticRouter(), + kernel_id="agent-a", + ) + + reg_b = CapabilityRegistry() + reg_b.register(_make_cap("billing.list_invoices")) + kernel_b_provider = HMACTokenProvider(secret="secret-b") + + principal = Principal(principal_id="alice", roles=["reader"], attributes={"tenant": "acme"}) + token = kernel_a.get_token( + CapabilityRequest(capability_id="billing.list_invoices", goal="x"), + principal, + justification="", + ) + with pytest.raises(TokenInvalid, match="invalid signature"): + kernel_b_provider.verify( + token, + expected_principal_id="alice", + expected_capability_id="billing.list_invoices", + ) diff --git a/tests/test_federation_discovery.py b/tests/test_federation_discovery.py new file mode 100644 index 0000000..f54a07e --- /dev/null +++ b/tests/test_federation_discovery.py @@ -0,0 +1,390 @@ +"""Tests for federated discovery + signed manifests (issue #51). + +Uses `httpx.MockTransport` so tests are fully offline. Each scenario +pins one contract from the issue's acceptance criteria: + +* Signed manifest round-trip + tamper detection. +* `discover_peers` via direct peer URLs. +* `discover_peers` via a registry URL. +* Rate limiting on discovery. +* Kernel-scoped HMAC isolation when an imported capability is invoked. +""" + +from __future__ import annotations + +import datetime +import json +from typing import Any + +import httpx +import pytest + +from agent_kernel import ( + Capability, + CapabilityRegistry, + DiscoveryError, + DiscoveryRateLimiter, + HMACTokenProvider, + InMemoryDriver, + Kernel, + ManifestSignatureError, + SafetyClass, + discover_peers, + serve_manifest_payload, + sign_manifest, + verify_manifest, +) +from agent_kernel.federation import build_manifest +from agent_kernel.models import CapabilityManifest + + +def _build_test_manifest() -> CapabilityManifest: + cap = Capability( + capability_id="metrics.read", + name="read", + description="Read a metric.", + safety_class=SafetyClass.READ, + ) + registry = CapabilityRegistry() + registry.register(cap) + return build_manifest( + kernel_id="peer-a", + registry=registry, + endpoint="https://peer-a.invalid/kernel", + trust_level="unverified", + ) + + +def test_sign_and_verify_round_trip() -> None: + """`sign_manifest` then `verify_manifest` returns the original manifest.""" + manifest = _build_test_manifest() + envelope = sign_manifest(manifest, secret="shared-secret") + assert envelope["algorithm"] == "HMAC-SHA256" + assert envelope["signature"] + assert envelope["payload"] + + decoded = verify_manifest(envelope, secret="shared-secret") + assert decoded.kernel_id == manifest.kernel_id + assert decoded.endpoint == manifest.endpoint + assert [c.capability_id for c in decoded.capabilities] == [ + c.capability_id for c in manifest.capabilities + ] + + +def test_verify_manifest_rejects_tampered_payload() -> None: + """Modifying the payload after signing must fail verification.""" + manifest = _build_test_manifest() + envelope = sign_manifest(manifest, secret="shared-secret") + tampered_payload = json.loads(envelope["payload"]) + tampered_payload["kernel_id"] = "attacker" + envelope["payload"] = json.dumps(tampered_payload, sort_keys=True) + + with pytest.raises(ManifestSignatureError, match="signature mismatch"): + verify_manifest(envelope, secret="shared-secret") + + +def test_verify_manifest_rejects_wrong_secret() -> None: + """Wrong verification secret must fail signature check.""" + manifest = _build_test_manifest() + envelope = sign_manifest(manifest, secret="publisher-secret") + with pytest.raises(ManifestSignatureError, match="signature mismatch"): + verify_manifest(envelope, secret="other-secret") + + +def test_serve_manifest_payload_signed_and_unsigned() -> None: + """`serve_manifest_payload` produces a bare dict or signed envelope.""" + manifest = _build_test_manifest() + bare = serve_manifest_payload(manifest) + assert "signature" not in bare + assert bare["kernel_id"] == "peer-a" + + signed = serve_manifest_payload(manifest, secret="s") + assert signed["algorithm"] == "HMAC-SHA256" + assert signed["signature"] + + +@pytest.mark.asyncio +async def test_discover_peers_via_direct_peer_urls() -> None: + """`discover_peers(peer_urls=...)` returns one manifest per URL.""" + manifest = _build_test_manifest() + payload = json.dumps(manifest.to_dict()) + + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, text=payload, headers={"content-type": "application/json"}) + + transport = httpx.MockTransport(handler) + async with httpx.AsyncClient(transport=transport) as client: + result = await discover_peers( + peer_urls=["http://peer-a.example/manifest", "http://peer-b.example/manifest"], + client=client, + ) + + assert len(result) == 2 + assert all(m.kernel_id == "peer-a" for m in result) + + +@pytest.mark.asyncio +async def test_discover_peers_via_registry_url() -> None: + """`discover_peers(registry_url=...)` first fetches the registry list.""" + manifest = _build_test_manifest() + manifest_payload = json.dumps(manifest.to_dict()) + + def handler(request: httpx.Request) -> httpx.Response: + if str(request.url).endswith("/registry"): + return httpx.Response( + 200, + text=json.dumps(["http://peer-a.example/manifest"]), + headers={"content-type": "application/json"}, + ) + return httpx.Response( + 200, text=manifest_payload, headers={"content-type": "application/json"} + ) + + transport = httpx.MockTransport(handler) + async with httpx.AsyncClient(transport=transport) as client: + result = await discover_peers( + registry_url="http://central.example/registry", + client=client, + ) + + assert len(result) == 1 + assert result[0].kernel_id == "peer-a" + + +@pytest.mark.asyncio +async def test_discover_peers_rejects_unsigned_when_secret_provided() -> None: + """Calling with a secret but receiving an unsigned manifest is an error.""" + manifest = _build_test_manifest() + payload = json.dumps(manifest.to_dict()) # Unsigned. + + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, text=payload, headers={"content-type": "application/json"}) + + transport = httpx.MockTransport(handler) + async with httpx.AsyncClient(transport=transport) as client: + with pytest.raises(ManifestSignatureError, match="unsigned"): + await discover_peers( + peer_urls=["http://peer-a.example/manifest"], + secret="expected-secret", + client=client, + ) + + +@pytest.mark.asyncio +async def test_discover_peers_rejects_signed_when_no_secret_provided() -> None: + """Calling without a secret but receiving a signed envelope is an error.""" + manifest = _build_test_manifest() + signed_payload = json.dumps(sign_manifest(manifest, secret="s")) + + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response( + 200, text=signed_payload, headers={"content-type": "application/json"} + ) + + transport = httpx.MockTransport(handler) + async with httpx.AsyncClient(transport=transport) as client: + with pytest.raises(ManifestSignatureError, match="signed but no verification"): + await discover_peers( + peer_urls=["http://peer-a.example/manifest"], + client=client, + ) + + +@pytest.mark.asyncio +async def test_discover_peers_handles_signed_manifest_end_to_end() -> None: + """Signed envelopes are verified and the embedded manifest is returned.""" + manifest = _build_test_manifest() + signed_payload = json.dumps(sign_manifest(manifest, secret="shared")) + + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response( + 200, text=signed_payload, headers={"content-type": "application/json"} + ) + + transport = httpx.MockTransport(handler) + async with httpx.AsyncClient(transport=transport) as client: + result = await discover_peers( + peer_urls=["http://peer-a.example/manifest"], + secret="shared", + client=client, + ) + assert len(result) == 1 + assert result[0].kernel_id == "peer-a" + + +@pytest.mark.asyncio +async def test_discover_peers_network_error_raises_discovery_error() -> None: + """HTTP errors are wrapped in :class:`DiscoveryError`.""" + + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(503, text="service unavailable") + + transport = httpx.MockTransport(handler) + async with httpx.AsyncClient(transport=transport) as client: + with pytest.raises(DiscoveryError, match="HTTP 503"): + await discover_peers( + peer_urls=["http://peer-a.example/manifest"], + client=client, + ) + + +@pytest.mark.asyncio +async def test_discover_peers_rate_limit() -> None: + """`DiscoveryRateLimiter` rejects calls beyond the configured limit.""" + manifest = _build_test_manifest() + payload = json.dumps(manifest.to_dict()) + + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, text=payload, headers={"content-type": "application/json"}) + + transport = httpx.MockTransport(handler) + limiter = DiscoveryRateLimiter(limit=2, window_seconds=60.0) + async with httpx.AsyncClient(transport=transport) as client: + await discover_peers( + peer_urls=["http://peer-a.example/manifest"], + rate_limiter=limiter, + client=client, + ) + await discover_peers( + peer_urls=["http://peer-b.example/manifest"], + rate_limiter=limiter, + client=client, + ) + with pytest.raises(DiscoveryError, match="rate limit exceeded"): + await discover_peers( + peer_urls=["http://peer-c.example/manifest"], + rate_limiter=limiter, + client=client, + ) + + +@pytest.mark.asyncio +async def test_discover_peers_requires_some_input() -> None: + """Calling with neither peer_urls nor registry_url is an error.""" + with pytest.raises(DiscoveryError, match="requires peer_urls or registry_url"): + await discover_peers() + + +@pytest.mark.asyncio +async def test_kernel_discover_peers_integration() -> None: + """`Kernel.discover_peers` is wired to fetch manifests over HTTP.""" + manifest = _build_test_manifest() + payload = json.dumps(manifest.to_dict()) + + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, text=payload, headers={"content-type": "application/json"}) + + # Build a kernel without registering the peer-a capability locally. + registry = CapabilityRegistry() + kernel = Kernel( + registry=registry, + token_provider=HMACTokenProvider(secret="local-test-secret"), + ) + + # Monkey-patch where the kernel sub-module imported `discover_peers`, + # not the source module — `_federation.py` already bound a local name. + import agent_kernel.kernel._federation as kf + + original = kf.discover_peers + + async def patched_discover(**kwargs: Any) -> list[CapabilityManifest]: + # `perform_discover_peers` passes ``client=None`` explicitly, so + # ``setdefault`` won't override — replace the key unconditionally. + async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client: + kwargs["client"] = client + return await original(**kwargs) + + kf.discover_peers = patched_discover # type: ignore[assignment] + try: + result = await kernel.discover_peers(peer_urls=["http://peer-a.example/manifest"]) + finally: + kf.discover_peers = original # type: ignore[assignment] + + assert len(result) == 1 + assert result[0].kernel_id == "peer-a" + + +def test_kernel_scoped_hmac_isolation_for_imported_capability() -> None: + """A token issued by kernel A must not validate against kernel B. + + Confused-deputy regression test that pins issue #51's acceptance + criterion: HMAC tokens are kernel-scoped. + """ + # Kernel A — publishes a capability. + cap = Capability( + capability_id="metrics.read", + name="read", + description="Read.", + safety_class=SafetyClass.READ, + ) + registry_a = CapabilityRegistry() + registry_a.register(cap) + kernel_a = Kernel( + registry=registry_a, + token_provider=HMACTokenProvider(secret="kernel-a-secret"), + kernel_id="kernel-a", + ) + kernel_a.register_driver(InMemoryDriver(driver_id="dummy-a")) + + # Kernel B — imports the manifest, different secret. + manifest = kernel_a.advertise(endpoint="https://kernel-a.invalid/") + registry_b = CapabilityRegistry() + kernel_b = Kernel( + registry=registry_b, + token_provider=HMACTokenProvider(secret="kernel-b-secret"), + kernel_id="kernel-b", + ) + kernel_b.import_remote(manifest, driver=InMemoryDriver(driver_id="dummy-b")) + + # Mint a token on kernel A. + from agent_kernel import Principal + from agent_kernel.errors import TokenInvalid + from agent_kernel.models import CapabilityRequest + + principal = Principal(principal_id="alice", roles=["reader"]) + req = CapabilityRequest(capability_id="metrics.read", goal="t") + token_from_a = kernel_a.get_token(req, principal, justification="") + + # Kernel B refuses tokens signed by kernel A's secret. + with pytest.raises(TokenInvalid): + kernel_b._token_provider.verify( # type: ignore[attr-defined] + token_from_a, + expected_principal_id="alice", + expected_capability_id="metrics.read", + ) + + +@pytest.mark.asyncio +async def test_signed_envelope_payload_is_canonical_json() -> None: + """The signed envelope's payload is sorted-key JSON for determinism.""" + manifest = _build_test_manifest() + envelope1 = sign_manifest(manifest, secret="s") + envelope2 = sign_manifest(manifest, secret="s") + # Same manifest, same secret → byte-identical envelope. + assert envelope1 == envelope2 + + +@pytest.mark.asyncio +async def test_verify_manifest_rejects_malformed_envelope() -> None: + """Missing keys produce a clear error.""" + from agent_kernel.errors import ManifestError + + with pytest.raises(ManifestError, match="missing required key"): + verify_manifest({"signature": "x"}, secret="s") + + with pytest.raises(ManifestError, match="must be a dict"): + verify_manifest([], secret="s") # type: ignore[arg-type] + + +@pytest.mark.asyncio +async def test_verify_manifest_rejects_unknown_algorithm() -> None: + """Unknown algorithm is rejected before signature check.""" + with pytest.raises(ManifestSignatureError, match="Unsupported"): + verify_manifest( + {"payload": "{}", "algorithm": "NOT-A-REAL-ALG", "signature": "x"}, + secret="s", + ) + + +# silence flake about unused datetime import — kept for parity with peer modules. +_ = datetime diff --git a/tests/test_firewall_stream.py b/tests/test_firewall_stream.py new file mode 100644 index 0000000..aaffd4f --- /dev/null +++ b/tests/test_firewall_stream.py @@ -0,0 +1,198 @@ +"""Tests for the streaming firewall (:meth:`Firewall.apply_stream`) and +:meth:`Kernel.invoke_stream` (issue #47). + +The streaming API is *additive*: non-streaming drivers still work via +:meth:`Driver.execute`. These tests pin the contracts the issue calls out: + +* every chunk is firewalled (PII never leaks even in streaming mode), +* the last yielded :class:`Frame` carries ``is_final=True``, +* :meth:`Kernel.invoke_stream` falls back to a single-chunk stream when the + driver does not implement :class:`StreamingDriver`, +* a ``StreamingDriver`` produces multiple firewalled chunks. +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from typing import Any + +import pytest + +from agent_kernel import ( + Budgets, + Capability, + CapabilityRegistry, + Firewall, + HMACTokenProvider, + InMemoryDriver, + Kernel, + Principal, + SafetyClass, + StaticRouter, + StreamingDriver, +) +from agent_kernel.drivers.base import ExecutionContext +from agent_kernel.models import CapabilityRequest, RawResult + + +class _FakeStreamingDriver: + """Test double that yields predetermined chunks. + + Implements both :meth:`execute` (single-shot) and :meth:`execute_stream` + so the kernel's streaming-capability check passes. + """ + + driver_id = "stream-test" + + def __init__(self, chunks: list[dict[str, Any]]) -> None: + self._chunks = chunks + + async def execute(self, ctx: ExecutionContext) -> RawResult: + # Aggregated single-shot fallback: combine all chunks into one payload. + combined = {"chunks": [dict(c) for c in self._chunks]} + return RawResult(capability_id=ctx.capability_id, data=combined, provenance={}) + + async def execute_stream(self, ctx: ExecutionContext) -> AsyncIterator[dict[str, Any]]: + for chunk in self._chunks: + yield chunk + + +def _build_streaming_kernel( + driver: object, +) -> tuple[Kernel, Principal]: + cap = Capability( + capability_id="stream.read", + name="read", + description="Streamed read.", + safety_class=SafetyClass.READ, + ) + registry = CapabilityRegistry() + registry.register(cap) + kernel = Kernel( + registry=registry, + token_provider=HMACTokenProvider(secret="stream-test-secret"), + router=StaticRouter(routes={"stream.read": [driver.driver_id]}), # type: ignore[attr-defined] + ) + kernel.register_driver(driver) # type: ignore[arg-type] + return kernel, Principal(principal_id="streamer", roles=["reader"]) + + +@pytest.mark.asyncio +async def test_streaming_driver_protocol_runtime_check() -> None: + """`isinstance(x, StreamingDriver)` must work since the kernel uses it.""" + streamer = _FakeStreamingDriver([{"foo": 1}]) + assert isinstance(streamer, StreamingDriver) + + # A plain non-streaming driver must NOT satisfy the protocol. + plain = InMemoryDriver(driver_id="plain") + assert not isinstance(plain, StreamingDriver) + + +@pytest.mark.asyncio +async def test_firewall_apply_stream_redacts_each_chunk() -> None: + """`Firewall.apply_stream` runs each chunk through redaction. + + Synthetic PII / secret values must NOT appear in any yielded Frame's + summary facts or table preview. + """ + fw = Firewall(budgets=Budgets(max_chars=4000, max_rows=10, max_fields=10)) + + async def chunks() -> AsyncIterator[dict[str, Any]]: + yield {"rows": [{"public_id": 1, "email": "leaked@example.com"}]} + yield { + "rows": [{"public_id": 2, "api_token": "Bearer abc-secret-123"}], + "__is_final__": True, + } + + frames: list[Any] = [] + async for frame in fw.apply_stream( + chunks(), + action_id="act-1", + capability_id="stream.read", + principal_id="p1", + principal_roles=["reader"], + response_mode="table", + constraints={"allowed_fields": ["public_id"]}, + ): + frames.append(frame) + + assert len(frames) == 2 + assert frames[-1].is_final is True + assert frames[0].is_final is False + + rendered = repr([f.table_preview for f in frames]) + repr([f.facts for f in frames]) + assert "leaked@example.com" not in rendered + assert "abc-secret-123" not in rendered + + +@pytest.mark.asyncio +async def test_kernel_invoke_stream_with_streaming_driver_yields_multiple_frames() -> None: + """A streaming driver produces one Frame per chunk; last has is_final.""" + driver = _FakeStreamingDriver( + chunks=[ + {"chunk": 1, "row": {"id": "a"}}, + {"chunk": 2, "row": {"id": "b"}}, + {"chunk": 3, "row": {"id": "c"}, "__is_final__": True}, + ] + ) + kernel, principal = _build_streaming_kernel(driver) + req = CapabilityRequest(capability_id="stream.read", goal="t") + token = kernel.get_token(req, principal, justification="") + + frames: list[Any] = [] + async for frame in kernel.invoke_stream( + token, principal=principal, args={}, response_mode="summary" + ): + frames.append(frame) + + assert len(frames) == 3 + assert all(not f.is_final for f in frames[:-1]) + assert frames[-1].is_final is True + # Every frame carries the same audit action_id. + assert len({f.action_id for f in frames}) == 1 + + +@pytest.mark.asyncio +async def test_kernel_invoke_stream_fallback_for_non_streaming_driver() -> None: + """A driver without execute_stream yields exactly one final Frame.""" + plain = InMemoryDriver(driver_id="plain") + + def handler(ctx: ExecutionContext) -> dict[str, object]: + return {"row_count": 7, "rows": [{"id": "x"}]} + + plain.register_handler("stream.read", handler) + + kernel, principal = _build_streaming_kernel(plain) + req = CapabilityRequest(capability_id="stream.read", goal="t") + token = kernel.get_token(req, principal, justification="") + + frames: list[Any] = [] + async for frame in kernel.invoke_stream( + token, principal=principal, args={}, response_mode="summary" + ): + frames.append(frame) + + assert len(frames) == 1 + assert frames[0].is_final is True + + +@pytest.mark.asyncio +async def test_kernel_invoke_stream_emits_trace_event() -> None: + """A streaming invocation records exactly one ActionTrace covering the stream.""" + driver = _FakeStreamingDriver(chunks=[{"chunk": 1, "__is_final__": True}]) + kernel, principal = _build_streaming_kernel(driver) + req = CapabilityRequest(capability_id="stream.read", goal="t") + token = kernel.get_token(req, principal, justification="") + + captured: list[Any] = [] + async for frame in kernel.invoke_stream( + token, principal=principal, args={}, response_mode="summary" + ): + captured.append(frame) + + assert len(captured) >= 1 + trace = kernel.explain(captured[0].action_id) + assert trace.capability_id == "stream.read" + assert trace.principal_id == "streamer" + assert trace.driver_id == "stream-test" + assert trace.error is None diff --git a/tests/test_kernel.py b/tests/test_kernel.py index d99e154..00b0cc7 100644 --- a/tests/test_kernel.py +++ b/tests/test_kernel.py @@ -797,3 +797,113 @@ async def test_dry_run_policy_decision_has_trace( assert result.policy_decision.trace.final_outcome == "allowed" assert result.policy_decision.trace.final_reason_code == "token_verified" assert result.policy_decision.reason_code == "token_verified" + + +# ── Dry-run with HTTP and MCP drivers (issue #68 part E) ─────────────────────── + + +@pytest.mark.asyncio +async def test_dry_run_with_http_driver_does_not_call_execute() -> None: + """Dry-run short-circuits before driver dispatch — HTTPDriver edition. + + The short-circuit at ``kernel.invoke`` runs before driver lookup, so + the mode is provably driver-agnostic. This test pins the contract so + a future refactor that moved driver dispatch above the dry-run check + cannot land unnoticed (per issue #68 part E acceptance criteria). + """ + from unittest.mock import AsyncMock, patch + + from agent_kernel.drivers.http import HTTPDriver, HTTPEndpoint + from agent_kernel.models import DryRunResult + + cap = Capability( + capability_id="external.fetch_user", + name="fetch_user", + description="Fetch user from external HTTP API.", + safety_class=SafetyClass.READ, + ) + registry = CapabilityRegistry() + registry.register(cap) + + http_driver = HTTPDriver(driver_id="http") + http_driver.register_endpoint( + "external.fetch_user", + HTTPEndpoint(method="GET", url="https://example.invalid/u/1"), + ) + + kernel = Kernel( + registry=registry, + token_provider=HMACTokenProvider(secret="test-secret-do-not-use-in-prod"), + router=StaticRouter(routes={"external.fetch_user": ["http"]}), + ) + kernel.register_driver(http_driver) + + principal = Principal(principal_id="alice", roles=["reader"]) + req = CapabilityRequest(capability_id="external.fetch_user", goal="t") + token = kernel.get_token(req, principal, justification="") + + with patch.object(http_driver, "execute", new_callable=AsyncMock) as mock_exec: + result = await kernel.invoke(token, principal=principal, args={}, dry_run=True) + mock_exec.assert_not_called() + + assert isinstance(result, DryRunResult) + assert result.driver_id == "http" + assert result.operation == "external.fetch_user" + assert result.capability_id == "external.fetch_user" + assert result.policy_decision.allowed is True + + +@pytest.mark.asyncio +async def test_dry_run_with_mcp_driver_does_not_call_execute() -> None: + """Dry-run short-circuits before driver dispatch — MCPDriver edition. + + MCPDriver is constructed with a stub session factory so we never + open a real subprocess or HTTP connection. The assertion is that + ``execute`` is never called regardless — the short-circuit happens + before the kernel looks up which driver to dispatch to. + """ + from unittest.mock import AsyncMock, patch + + from agent_kernel.drivers.mcp import MCPDriver + from agent_kernel.models import DryRunResult + + cap = Capability( + capability_id="mcp.echo", + name="echo", + description="Echo tool from an MCP server.", + safety_class=SafetyClass.READ, + ) + registry = CapabilityRegistry() + registry.register(cap) + + # Stub session factory — never invoked by dry-run. + def _fake_session_factory() -> object: # pragma: no cover - never called + raise AssertionError("session_factory must not run during dry-run") + + mcp_driver = MCPDriver( + driver_id="mcp:test", + session_factory=_fake_session_factory, # type: ignore[arg-type] + server_name="test", + transport="stdio", + ) + + kernel = Kernel( + registry=registry, + token_provider=HMACTokenProvider(secret="test-secret-do-not-use-in-prod"), + router=StaticRouter(routes={"mcp.echo": ["mcp:test"]}), + ) + kernel.register_driver(mcp_driver) + + principal = Principal(principal_id="alice", roles=["reader"]) + req = CapabilityRequest(capability_id="mcp.echo", goal="t") + token = kernel.get_token(req, principal, justification="") + + with patch.object(mcp_driver, "execute", new_callable=AsyncMock) as mock_exec: + result = await kernel.invoke(token, principal=principal, args={}, dry_run=True) + mock_exec.assert_not_called() + + assert isinstance(result, DryRunResult) + assert result.driver_id == "mcp:test" + assert result.operation == "mcp.echo" + assert result.capability_id == "mcp.echo" + assert result.policy_decision.allowed is True diff --git a/tests/test_otel.py b/tests/test_otel.py new file mode 100644 index 0000000..362239c --- /dev/null +++ b/tests/test_otel.py @@ -0,0 +1,185 @@ +"""Tests for OpenTelemetry instrumentation (:func:`agent_kernel.instrument_kernel`). + +These tests use the OpenTelemetry SDK's ``InMemorySpanExporter`` and +``InMemoryMetricReader`` so we don't need a running collector. They +assert two specific contracts: + +1. **Instrumented kernel emits the expected spans/metrics.** Span names, + parent–child structure, and key attributes are pinned. + +2. **Uninstrumented kernel emits zero spans.** This guards against the + instrumentation accidentally leaking into every kernel by class-level + monkey-patching (it should be instance-level only). +""" + +from __future__ import annotations + +import pytest + +from agent_kernel import ( + OTEL_AVAILABLE, + Capability, + CapabilityRegistry, + HMACTokenProvider, + InMemoryDriver, + Kernel, + Principal, + SafetyClass, + StaticRouter, + instrument_kernel, +) +from agent_kernel.drivers.base import ExecutionContext +from agent_kernel.models import CapabilityRequest +from agent_kernel.otel import reset_instrumentation + +if not OTEL_AVAILABLE: # pragma: no cover - skipped without the [otel] extra + pytest.skip( + "opentelemetry-api not installed; install the [otel] extra to run.", + allow_module_level=True, + ) + +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import InMemoryMetricReader +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + +def _build_kernel() -> tuple[Kernel, Principal]: + cap = Capability( + capability_id="metrics.read", + name="read", + description="Read a metric.", + safety_class=SafetyClass.READ, + ) + registry = CapabilityRegistry() + registry.register(cap) + + driver = InMemoryDriver(driver_id="memory") + + def handler(ctx: ExecutionContext) -> dict[str, object]: + return {"value": 42, "capability_id": ctx.capability_id} + + driver.register_handler("metrics.read", handler) + + kernel = Kernel( + registry=registry, + token_provider=HMACTokenProvider(secret="otel-test-secret"), + router=StaticRouter(routes={"metrics.read": ["memory"]}), + ) + kernel.register_driver(driver) + + principal = Principal(principal_id="otel-user", roles=["reader"]) + return kernel, principal + + +@pytest.fixture() +def otel_exporters() -> tuple[ + InMemorySpanExporter, InMemoryMetricReader, TracerProvider, MeterProvider +]: + """Per-test span/metric exporters with provider instances. + + The OTel API disallows overriding the global ``TracerProvider`` / + ``MeterProvider`` after they've been set; instead each test gets its + own providers and passes them explicitly to :func:`instrument_kernel`. + """ + span_exporter = InMemorySpanExporter() + tracer_provider = TracerProvider() + tracer_provider.add_span_processor(SimpleSpanProcessor(span_exporter)) + + metric_reader = InMemoryMetricReader() + meter_provider = MeterProvider(metric_readers=[metric_reader]) + + reset_instrumentation() + yield span_exporter, metric_reader, tracer_provider, meter_provider + reset_instrumentation() + + +@pytest.mark.asyncio +async def test_instrumented_invoke_emits_span( + otel_exporters: tuple[ + InMemorySpanExporter, InMemoryMetricReader, TracerProvider, MeterProvider + ], +) -> None: + """`Kernel.invoke()` produces one ``agent_kernel.invoke`` span.""" + spans, _, tp, mp = otel_exporters + kernel, principal = _build_kernel() + instrument_kernel(kernel, tracer_provider=tp, meter_provider=mp) + + req = CapabilityRequest(capability_id="metrics.read", goal="t") + token = kernel.get_token(req, principal, justification="") + await kernel.invoke(token, principal=principal, args={}) + + finished = spans.get_finished_spans() + invoke_spans = [s for s in finished if s.name == "agent_kernel.invoke"] + assert len(invoke_spans) == 1, [s.name for s in finished] + attrs = invoke_spans[0].attributes or {} + assert attrs.get("agent_kernel.principal_id") == "otel-user" + assert attrs.get("agent_kernel.capability_id") == "metrics.read" + assert attrs.get("agent_kernel.response_mode") == "summary" + assert attrs.get("agent_kernel.dry_run") is False + + +@pytest.mark.asyncio +async def test_uninstrumented_invoke_emits_no_span( + otel_exporters: tuple[ + InMemorySpanExporter, InMemoryMetricReader, TracerProvider, MeterProvider + ], +) -> None: + """A kernel that was never wrapped emits zero ``agent_kernel.*`` spans.""" + spans, _, _, _ = otel_exporters + kernel, principal = _build_kernel() + # Deliberately do NOT call instrument_kernel. + + req = CapabilityRequest(capability_id="metrics.read", goal="t") + token = kernel.get_token(req, principal, justification="") + await kernel.invoke(token, principal=principal, args={}) + + finished = spans.get_finished_spans() + invoke_spans = [s for s in finished if s.name.startswith("agent_kernel.")] + assert invoke_spans == [] + + +def test_instrumented_grant_records_denial( + otel_exporters: tuple[ + InMemorySpanExporter, InMemoryMetricReader, TracerProvider, MeterProvider + ], +) -> None: + """A denied grant records a span with ERROR status and a denial counter.""" + spans, _metric_reader, tp, mp = otel_exporters + kernel, _ = _build_kernel() + instrument_kernel(kernel, tracer_provider=tp, meter_provider=mp) + + # Build a "deny" path: WRITE capability + reader principal. + write_cap = Capability( + capability_id="metrics.write", + name="write", + description="Write a metric.", + safety_class=SafetyClass.WRITE, + ) + kernel._registry.register(write_cap) # type: ignore[attr-defined] + + reader = Principal(principal_id="reader-only", roles=["reader"]) + req = CapabilityRequest(capability_id="metrics.write", goal="t") + + from agent_kernel.errors import PolicyDenied + + with pytest.raises(PolicyDenied): + kernel.grant_capability(req, reader, justification="too short") + + finished = spans.get_finished_spans() + grant_spans = [s for s in finished if s.name == "agent_kernel.grant"] + assert len(grant_spans) == 1 + assert grant_spans[0].status.status_code.name == "ERROR" + + +def test_instrument_kernel_is_idempotent() -> None: + """Calling :func:`instrument_kernel` twice does not double-wrap.""" + kernel, _ = _build_kernel() + original_invoke = kernel.invoke + instrument_kernel(kernel) + wrapped_once = kernel.invoke + assert wrapped_once is not original_invoke + instrument_kernel(kernel) + assert kernel.invoke is wrapped_once + reset_instrumentation(kernel) diff --git a/tests/test_registry.py b/tests/test_registry.py index 6e63597..f2892a8 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -110,3 +110,196 @@ def test_search_goal_preserved(registry: CapabilityRegistry) -> None: goal = "list all billing invoices please" results = registry.search(goal) assert all(r.goal == goal for r in results) + + +# ── Namespace operations (#45) ──────────────────────────────────────────────── + + +def test_list_namespaces_from_registered_capabilities() -> None: + reg = CapabilityRegistry() + reg.register(_make_cap("billing.invoices.list")) + reg.register(_make_cap("billing.invoices.create")) + reg.register(_make_cap("crm.contacts.search")) + reg.register(_make_cap("flat_id")) + assert reg.list_namespaces() == ["billing", "crm", "flat_id"] + + +def test_list_namespace_returns_capabilities_under_prefix() -> None: + reg = CapabilityRegistry() + reg.register(_make_cap("billing.invoices.list")) + reg.register(_make_cap("billing.invoices.create")) + reg.register(_make_cap("billing.payments.refund")) + reg.register(_make_cap("crm.contacts.search")) + + billing = [c.capability_id for c in reg.list_namespace("billing")] + assert sorted(billing) == [ + "billing.invoices.create", + "billing.invoices.list", + "billing.payments.refund", + ] + + invoices = [c.capability_id for c in reg.list_namespace("billing.invoices")] + assert sorted(invoices) == [ + "billing.invoices.create", + "billing.invoices.list", + ] + + +def test_list_namespace_exact_match_is_included() -> None: + reg = CapabilityRegistry() + reg.register(_make_cap("billing")) + reg.register(_make_cap("billing.invoices.list")) + ids = [c.capability_id for c in reg.list_namespace("billing")] + assert "billing" in ids + assert "billing.invoices.list" in ids + + +def test_list_namespace_unknown_prefix_raises() -> None: + from agent_kernel import NamespaceNotFound + + reg = CapabilityRegistry() + reg.register(_make_cap("billing.invoices.list")) + with pytest.raises(NamespaceNotFound, match="no registered capabilities"): + reg.list_namespace("never.declared") + + +def test_register_namespace_duplicate_raises() -> None: + reg = CapabilityRegistry() + reg.register_namespace("billing", description="Billing tools") + with pytest.raises(CapabilityAlreadyRegistered, match="already declared"): + reg.register_namespace("billing") + + +def test_deferred_loader_called_exactly_once_on_first_access() -> None: + call_count = {"n": 0} + + def loader() -> list[Capability]: + call_count["n"] += 1 + return [_make_cap("ondemand.list"), _make_cap("ondemand.create")] + + reg = CapabilityRegistry() + reg.register_namespace("ondemand", description="Lazy loaded", loader=loader) + + # First access triggers the loader. + caps = reg.list_namespace("ondemand") + assert {c.capability_id for c in caps} == {"ondemand.list", "ondemand.create"} + assert call_count["n"] == 1 + + # Second access does not re-invoke. + reg.list_namespace("ondemand") + assert call_count["n"] == 1 + + +def test_deferred_loader_triggers_on_get() -> None: + def loader() -> list[Capability]: + return [_make_cap("lazy.thing")] + + reg = CapabilityRegistry() + reg.register_namespace("lazy", loader=loader) + cap = reg.get("lazy.thing") + assert cap.capability_id == "lazy.thing" + + +def test_deferred_loader_triggers_on_search_overlap() -> None: + call_count = {"n": 0} + + def loader() -> list[Capability]: + call_count["n"] += 1 + return [_make_cap("billing.weekly_report", description="weekly revenue report")] + + reg = CapabilityRegistry() + reg.register_namespace("billing", loader=loader) + results = reg.search("weekly billing") + ids = [r.capability_id for r in results] + assert "billing.weekly_report" in ids + assert call_count["n"] == 1 + + +# ── Search scoring & pagination (#45) ───────────────────────────────────────── + + +def test_search_id_match_ranks_above_description_only() -> None: + reg = CapabilityRegistry() + reg.register(_make_cap("invoices.list", description="unrelated text")) + reg.register( + _make_cap( + "ledger.report", + description="invoices summary", + ) + ) + results = reg.search("invoices") + assert [r.capability_id for r in results][:2] == ["invoices.list", "ledger.report"] + + +def test_search_pagination_offset() -> None: + reg = CapabilityRegistry() + for i in range(15): + reg.register(_make_cap(f"billing.invoice{i:02d}", tags=["invoice"])) + page1 = reg.search("invoice", max_results=5, offset=0) + page2 = reg.search("invoice", max_results=5, offset=5) + page3 = reg.search("invoice", max_results=5, offset=10) + assert len(page1) == 5 + assert len(page2) == 5 + assert len(page3) == 5 + ids = {r.capability_id for r in page1 + page2 + page3} + assert len(ids) == 15 + + +def test_search_pagination_offset_does_not_overlap() -> None: + reg = CapabilityRegistry() + for i in range(10): + reg.register(_make_cap(f"billing.invoice{i:02d}", tags=["invoice"])) + page1 = {r.capability_id for r in reg.search("invoice", max_results=4, offset=0)} + page2 = {r.capability_id for r in reg.search("invoice", max_results=4, offset=4)} + assert page1.isdisjoint(page2) + + +def test_search_stop_words_are_stripped() -> None: + reg = CapabilityRegistry() + reg.register(_make_cap("billing.list_invoices")) + # "to" / "the" / "please" must not contribute matches. + results = reg.search("the to please") + assert results == [] + + +def test_search_tags_outrank_description() -> None: + reg = CapabilityRegistry() + reg.register( + _make_cap( + "alpha.report", + description="quarterly revenue summary", + tags=["analytics"], + ) + ) + reg.register( + _make_cap( + "beta.report", + description="alpha analytics description", + tags=["unrelated"], + ) + ) + results = reg.search("analytics") + assert [r.capability_id for r in results][0] == "alpha.report" + + +def test_search_scales_to_500_capabilities() -> None: + """Sanity check: search over 500 capabilities completes quickly.""" + import time + + reg = CapabilityRegistry() + for i in range(500): + ns = "billing" if i % 2 == 0 else "crm" + reg.register( + _make_cap( + f"{ns}.thing{i:04d}", + description=f"deterministic stuff for record {i}", + tags=[ns, "thing"], + ) + ) + start = time.perf_counter() + results = reg.search("billing thing", max_results=10) + elapsed = time.perf_counter() - start + assert len(results) == 10 + # Generous bound: BM25 over 500 docs with ~5 tokens each should be + # well under a second on any developer machine. + assert elapsed < 1.0, f"search took {elapsed:.3f}s for 500 capabilities"