diff --git a/internal/cloud/cloud.go b/internal/cloud/cloud.go index 2587882c..e0accbd7 100644 --- a/internal/cloud/cloud.go +++ b/internal/cloud/cloud.go @@ -114,3 +114,7 @@ func SendActivity(ctx context.Context, config *config.ConfigInfo, event *Activit func SendRequestPackageInstallation(ctx context.Context, config *config.ConfigInfo, event *RequestPackageInstallationEvent) error { return sendEvent(ctx, RequestPackageInstallationEndpoint, config, event) } + +func SendAiUsageStats(ctx context.Context, config *config.ConfigInfo, event *AiUsageStatsEvent) error { + return sendEvent(ctx, ReportAiStatsEndpoint, config, event) +} diff --git a/internal/cloud/cloud_test.go b/internal/cloud/cloud_test.go index de5083c5..1f611e47 100644 --- a/internal/cloud/cloud_test.go +++ b/internal/cloud/cloud_test.go @@ -2,8 +2,11 @@ package cloud import ( "context" + "encoding/json" + "io" "net/http" "net/http/httptest" + "strings" "testing" "github.com/AikidoSec/safechain-internals/internal/config" @@ -75,6 +78,69 @@ func TestUnauthorizedTokenIsSuppressedUntilTokenChanges(t *testing.T) { } } +func TestSendAiUsageStatsHitsCorrectEndpointWithExpectedShape(t *testing.T) { + withTempRunDir(t) + Init() + + var ( + gotPath string + gotBody []byte + ) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("failed to read request body: %v", err) + } + gotBody = body + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + cfg := &config.ConfigInfo{ + Token: "good-token", + DeviceID: "device-1", + BaseURL: server.URL, + } + + event := &AiUsageStatsEvent{ + Models: []AiUsageModel{ + {Provider: "anthropic", Model: "claude-opus-4-7"}, + }, + } + + if err := SendAiUsageStats(context.Background(), cfg, event); err != nil { + t.Fatalf("expected SendAiUsageStats to succeed, got %v", err) + } + + if want := "/" + ReportAiStatsEndpoint; gotPath != want { + t.Fatalf("expected POST to %q, got %q", want, gotPath) + } + + var decoded AiUsageStatsEvent + if err := json.Unmarshal(gotBody, &decoded); err != nil { + t.Fatalf("failed to decode body as AiUsageStatsEvent: %v\nraw: %s", err, string(gotBody)) + } + if len(decoded.Models) != 1 { + t.Fatalf("expected one model, got %d", len(decoded.Models)) + } + got := decoded.Models[0] + if got.Provider != "anthropic" || got.Model != "claude-opus-4-7" { + t.Fatalf("expected anthropic/claude-opus-4-7, got %s/%s", got.Provider, got.Model) + } + + // Wire format must match Wout's spec: `{models: [{provider, model}]}` — + // no agent-side timestamp until websockets/SSE. + if !strings.Contains(string(gotBody), `"models"`) { + t.Fatalf("expected payload to contain top-level `models` field, got: %s", string(gotBody)) + } + for _, banned := range []string{`"last_seen_at"`, `"last_seen_ms"`, `"first_seen_ms"`, `"count"`, `"ts_ms"`} { + if strings.Contains(string(gotBody), banned) { + t.Fatalf("payload should not contain %s yet: %s", banned, string(gotBody)) + } + } +} + // Simulates the race where a 401 for a token-in-flight arrives after the user // has already replaced the token. The new token must not be marked unauthorized. func TestStale401DoesNotPoisonReplacedToken(t *testing.T) { diff --git a/internal/cloud/constants.go b/internal/cloud/constants.go index 145d3415..b75011e0 100644 --- a/internal/cloud/constants.go +++ b/internal/cloud/constants.go @@ -7,4 +7,5 @@ const ( ActivityEndpoint = "api/endpoint_protection/callbacks/reportActivity" RequestPackageInstallationEndpoint = "api/endpoint_protection/callbacks/requestPackageInstallation" UploadDeviceLogsEndpoint = "api/endpoint_protection/callbacks/uploadDeviceLogs" + ReportAiStatsEndpoint = "api/endpoint_protection/callbacks/reportAiStats" ) diff --git a/internal/cloud/events.go b/internal/cloud/events.go index 3472859e..78e95dc1 100644 --- a/internal/cloud/events.go +++ b/internal/cloud/events.go @@ -54,3 +54,13 @@ type RequestPackageInstallationEvent struct { Ecosystems []EcosystemPackages `json:"ecosystems"` } `json:"sbom"` } + +type AiUsageModel struct { + Provider string `json:"provider"` + Model string `json:"model"` +} + +// AiUsageStatsEvent is the body sent to reportAiStats. +type AiUsageStatsEvent struct { + Models []AiUsageModel `json:"models"` +} diff --git a/internal/config/config.go b/internal/config/config.go index 678d45ec..6320394c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -21,6 +21,7 @@ type ConfigInfo struct { LastHeartbeatReportTime time.Time `json:"last_heartbeat_report_time"` LastSBOMReportTime time.Time `json:"last_sbom_report_time"` LastSetupWizardShownTime time.Time `json:"last_setup_wizard_shown_time"` + LastAiUsageReportTime time.Time `json:"last_ai_usage_report_time"` BaseURL string `json:"base_url,omitempty"` ProxyMode string `json:"proxy_mode,omitempty"` diff --git a/internal/constants/constants.go b/internal/constants/constants.go index dbe74dbb..b3895c42 100644 --- a/internal/constants/constants.go +++ b/internal/constants/constants.go @@ -10,5 +10,6 @@ const ( HeartbeatReportInterval = 3 * time.Minute SBOMReportInterval = 24 * time.Hour SetupWizardReshowInterval = 24 * time.Hour + AiUsageReportInterval = 10 * time.Minute ProxyStartMaxRetries = 100 ) diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 49af4e98..c26f9905 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -468,6 +468,35 @@ func (d *Daemon) heartbeat() error { } return nil }) + d.runIfIntervalExceeded(&d.config.LastAiUsageReportTime, constants.AiUsageReportInterval, d.reportAiUsage) + return nil +} + +func (d *Daemon) reportAiUsage() error { + if d.config.Token == "" { + return fmt.Errorf("Token is not set, skipping AI usage report") + } + + cutoffMs := d.config.LastAiUsageReportTime.UnixMilli() + snapshot := d.ingress.AiUsageEvents() + models := make([]cloud.AiUsageModel, 0, len(snapshot)) + for _, ev := range snapshot { + if ev.TsMs <= cutoffMs { + continue + } + models = append(models, cloud.AiUsageModel{ + Provider: ev.Provider, + Model: ev.Model, + }) + } + if len(models) == 0 { + return nil + } + event := &cloud.AiUsageStatsEvent{Models: models} + if err := cloud.SendAiUsageStats(d.ctx, d.config, event); err != nil { + return fmt.Errorf("Failed to report AI usage: %v", err) + } + log.Printf("AI usage report sent successfully (%d model(s))", len(models)) return nil } diff --git a/internal/ingress/ai_usage_handler.go b/internal/ingress/ai_usage_handler.go new file mode 100644 index 00000000..e6520f5f --- /dev/null +++ b/internal/ingress/ai_usage_handler.go @@ -0,0 +1,44 @@ +package ingress + +import ( + "encoding/json" + "log" + "net/http" +) + +func (s *Server) handleAiUsage(w http.ResponseWriter, r *http.Request) { + log.Printf("ai-usage: POST /events/ai-usage from %s", r.RemoteAddr) + + var event AiUsageEvent + if err := json.NewDecoder(r.Body).Decode(&event); err != nil { + log.Printf("ai-usage: invalid JSON body: %v", err) + http.Error(w, "invalid request body", http.StatusBadRequest) + return + } + + if event.Provider == "" || event.Model == "" { + log.Printf("ai-usage: rejecting event with missing fields: provider=%q model=%q", event.Provider, event.Model) + http.Error(w, "provider and model are required", http.StatusBadRequest) + return + } + + stored, isNew := s.aiUsageStore.Add(event) + if isNew { + log.Printf("ai-usage: first observation: provider=%s model=%s", stored.Provider, stored.Model) + } else { + log.Printf("ai-usage: repeat observation: provider=%s model=%s", stored.Provider, stored.Model) + } + + w.WriteHeader(http.StatusOK) +} + +func (s *Server) handleAiUsageEvents(w http.ResponseWriter, r *http.Request) { + if !s.validateUIToken(w, r) { + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(s.aiUsageStore.List()); err != nil { + log.Printf("failed to encode ai-usage events: %v", err) + } +} diff --git a/internal/ingress/ai_usage_store.go b/internal/ingress/ai_usage_store.go new file mode 100644 index 00000000..31f5cb68 --- /dev/null +++ b/internal/ingress/ai_usage_store.go @@ -0,0 +1,53 @@ +package ingress + +import ( + "fmt" + "sync" +) + +type aiUsageEventStore struct { + mu sync.RWMutex + events []AiUsageEvent +} + +func aiUsageEventID(provider, model string) string { + return fmt.Sprintf("ai-usage-%s-%s", provider, model) +} + +// Add records an observation for the given (provider, model) pair. Repeats +// just refresh the stored timestamp. Returns the stored row and `isNew=true` +// when this was the first observation of that (provider, model). +func (s *aiUsageEventStore) Add(ev AiUsageEvent) (AiUsageEvent, bool) { + s.mu.Lock() + defer s.mu.Unlock() + + id := aiUsageEventID(ev.Provider, ev.Model) + + for i := range s.events { + if s.events[i].ID == id { + s.events[i].TsMs = ev.TsMs + return s.events[i], false + } + } + + stored := AiUsageEvent{ + ID: id, + TsMs: ev.TsMs, + Provider: ev.Provider, + Model: ev.Model, + } + s.events = append(s.events, stored) + return stored, true +} + +func (s *aiUsageEventStore) List() []AiUsageEvent { + s.mu.RLock() + defer s.mu.RUnlock() + out := make([]AiUsageEvent, len(s.events)) + copy(out, s.events) + return out +} + +func (s *Server) AiUsageEvents() []AiUsageEvent { + return s.aiUsageStore.List() +} diff --git a/internal/ingress/ai_usage_store_test.go b/internal/ingress/ai_usage_store_test.go new file mode 100644 index 00000000..fae3b8cb --- /dev/null +++ b/internal/ingress/ai_usage_store_test.go @@ -0,0 +1,98 @@ +package ingress + +import "testing" + +func TestAiUsageEventStoreAddAssignsStableIDAndReportsNew(t *testing.T) { + store := &aiUsageEventStore{} + + stored, isNew := store.Add(AiUsageEvent{ + TsMs: 100, + Provider: "anthropic", + Model: "claude-3-5-sonnet-20241022", + }) + + want := "ai-usage-anthropic-claude-3-5-sonnet-20241022" + if stored.ID != want { + t.Fatalf("expected stable id %q, got %q", want, stored.ID) + } + if !isNew { + t.Fatalf("expected first observation to report isNew=true") + } + if stored.TsMs != 100 { + t.Fatalf("expected stored ts_ms=100, got %d", stored.TsMs) + } +} + +func TestAiUsageEventStoreAddCollapsesSameModelAndRefreshesTimestamp(t *testing.T) { + store := &aiUsageEventStore{} + + first, firstIsNew := store.Add(AiUsageEvent{TsMs: 100, Provider: "anthropic", Model: "claude-opus-4-7"}) + second, secondIsNew := store.Add(AiUsageEvent{TsMs: 250, Provider: "anthropic", Model: "claude-opus-4-7"}) + third, thirdIsNew := store.Add(AiUsageEvent{TsMs: 400, Provider: "anthropic", Model: "claude-opus-4-7"}) + + if !firstIsNew { + t.Fatalf("expected first call to report isNew=true") + } + if secondIsNew || thirdIsNew { + t.Fatalf("expected repeats to report isNew=false, got second=%v third=%v", secondIsNew, thirdIsNew) + } + if first.ID != second.ID || second.ID != third.ID { + t.Fatalf("expected stable aggregate id, got %q / %q / %q", first.ID, second.ID, third.ID) + } + if len(store.List()) != 1 { + t.Fatalf("expected one aggregate event, got %d", len(store.List())) + } + if third.TsMs != 400 { + t.Fatalf("expected ts_ms to refresh to 400, got %d", third.TsMs) + } +} + +func TestAiUsageEventStoreAddSeparatesEntriesPerModel(t *testing.T) { + store := &aiUsageEventStore{} + + store.Add(AiUsageEvent{TsMs: 100, Provider: "anthropic", Model: "claude-opus-4-7"}) + store.Add(AiUsageEvent{TsMs: 200, Provider: "anthropic", Model: "claude-haiku-4-5"}) + + if len(store.List()) != 2 { + t.Fatalf("expected one aggregate event per model, got %d", len(store.List())) + } +} + +func TestAiUsageEventStoreAddSeparatesEntriesPerProvider(t *testing.T) { + store := &aiUsageEventStore{} + + store.Add(AiUsageEvent{TsMs: 100, Provider: "anthropic", Model: "shared-name"}) + store.Add(AiUsageEvent{TsMs: 200, Provider: "openai", Model: "shared-name"}) + + if len(store.List()) != 2 { + t.Fatalf("expected one aggregate event per provider, got %d", len(store.List())) + } +} + +func TestServerAiUsageEventsReturnsCopyAndDoesNotClear(t *testing.T) { + s := &Server{aiUsageStore: &aiUsageEventStore{}} + + s.aiUsageStore.Add(AiUsageEvent{TsMs: 100, Provider: "anthropic", Model: "claude-opus-4-7"}) + s.aiUsageStore.Add(AiUsageEvent{TsMs: 200, Provider: "anthropic", Model: "claude-opus-4-7"}) + + snap1 := s.AiUsageEvents() + if len(snap1) != 1 { + t.Fatalf("expected one row in snapshot, got %d", len(snap1)) + } + if snap1[0].TsMs != 200 { + t.Fatalf("expected ts_ms=200 (latest), got %d", snap1[0].TsMs) + } + + // Mutating the returned slice must not affect the store. + snap1[0].TsMs = 999 + snap2 := s.AiUsageEvents() + if snap2[0].TsMs != 200 { + t.Fatalf("snapshot must be a copy; store was mutated to ts_ms=%d", snap2[0].TsMs) + } + + // A second call after a flush-equivalent call must still see the data — + // the store is intentionally not cleared. + if len(snap2) != 1 { + t.Fatalf("expected store to retain data, got %d rows", len(snap2)) + } +} diff --git a/internal/ingress/events.go b/internal/ingress/events.go index 9a08316d..6a719b06 100644 --- a/internal/ingress/events.go +++ b/internal/ingress/events.go @@ -40,6 +40,17 @@ type MinPackageAgeEvent struct { Message string `json:"message,omitempty"` } +// AiUsageEvent reports an observed call to a third-party AI provider. +// The proxy posts one event per call (with `ts_ms` populated); the daemon +// collapses repeats by (provider, model) and keeps only the most recent +// timestamp. +type AiUsageEvent struct { + ID string `json:"id,omitempty"` + TsMs int64 `json:"ts_ms"` + Provider string `json:"provider"` + Model string `json:"model"` +} + type EcosystemExceptions struct { AllowedPackages []string `json:"allowed_packages"` RejectedPackages []string `json:"rejected_packages"` diff --git a/internal/ingress/server.go b/internal/ingress/server.go index 99ffb228..e8d05e94 100644 --- a/internal/ingress/server.go +++ b/internal/ingress/server.go @@ -42,6 +42,7 @@ type Server struct { eventStore *eventStore tlsEventStore *tlsEventStore minAgeStore *minPackageAgeEventStore + aiUsageStore *aiUsageEventStore chromeNames *chromeExtensionNameResolver mu sync.RWMutex } @@ -54,6 +55,7 @@ func New(cfg *config.ConfigInfo, ui UIProvider, proxy proxy.ProxyManager) *Serve eventStore: &eventStore{}, tlsEventStore: &tlsEventStore{}, minAgeStore: &minPackageAgeEventStore{}, + aiUsageStore: &aiUsageEventStore{}, chromeNames: newChromeExtensionNameResolver(), } } @@ -70,6 +72,7 @@ func (s *Server) Start(ctx context.Context) error { mux.HandleFunc("POST /events/tls-termination-failed", s.handleTlsTerminationFailed) mux.HandleFunc("POST /events/min-package-age", s.handleMinPackageAge) mux.HandleFunc("POST /events/permissions", s.handlePermissionsUpdated) + mux.HandleFunc("POST /events/ai-usage", s.handleAiUsage) mux.HandleFunc("GET /ping", s.handlePing) mux.HandleFunc("POST /v1/events/{id}/request-access", s.handleRequestBypass) @@ -80,6 +83,7 @@ func (s *Server) Start(ctx context.Context) error { mux.HandleFunc("GET /v1/tls-events/{id}", s.handleGetTlsEventByID) mux.HandleFunc("GET /v1/min-package-age-events", s.handleMinPackageAgeEvents) mux.HandleFunc("GET /v1/min-package-age-events/{id}", s.handleGetMinPackageAgeEventByID) + mux.HandleFunc("GET /v1/ai-usage-events", s.handleAiUsageEvents) mux.HandleFunc("GET /v1/version", s.handleVersion) diff --git a/proxy-lib/src/http/firewall/events.rs b/proxy-lib/src/http/firewall/events.rs index 2b859428..3315ffd2 100644 --- a/proxy-lib/src/http/firewall/events.rs +++ b/proxy-lib/src/http/firewall/events.rs @@ -74,6 +74,17 @@ pub struct TlsTerminationFailedEvent { pub error: String, } +/// Reports a single observed call to a third-party AI provider. +/// Detection only — no blocking, no token counts (yet). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AiUsageEvent { + pub ts_ms: SystemTimestampMilliseconds, + /// Provider/company key, e.g. `"anthropic"`, `"openai"`, `"gemini"`. + pub provider: ArcStr, + /// Model identifier as sent by the client, e.g. `"claude-3-5-sonnet-20241022"`. + pub model: ArcStr, +} + #[cfg(test)] #[path = "events_tests.rs"] mod tests; diff --git a/proxy-lib/src/http/firewall/mod.rs b/proxy-lib/src/http/firewall/mod.rs index 5e62fcd2..811a30a0 100644 --- a/proxy-lib/src/http/firewall/mod.rs +++ b/proxy-lib/src/http/firewall/mod.rs @@ -257,6 +257,7 @@ impl Firewall { .await .context("create block rule: skills.sh")? .into_dyn(), + self::rule::anthropic::RuleAnthropic::new(notifier.clone()).into_dyn(), self::rule::hijack::RuleHijack::new().into_dyn(), ]), notifier, diff --git a/proxy-lib/src/http/firewall/notifier.rs b/proxy-lib/src/http/firewall/notifier.rs index f5ff5265..2e01060b 100644 --- a/proxy-lib/src/http/firewall/notifier.rs +++ b/proxy-lib/src/http/firewall/notifier.rs @@ -32,7 +32,7 @@ use tokio::sync::{Semaphore, SemaphorePermit}; use crate::{ endpoint_protection::types::EndpointConfig, - http::firewall::events::MinPackageAgeEvent, + http::firewall::events::{AiUsageEvent, MinPackageAgeEvent}, package::version::{PackageVersion, PackageVersionKey}, utils::env::{compute_concurrent_request_count, network_service_identifier}, }; @@ -143,6 +143,12 @@ impl EventNotifier { }); } + pub fn notify_ai_usage(&self, event: AiUsageEvent) { + self.spawn_event_task(|client, reporting_endpoint| { + send_ai_usage_event(client, reporting_endpoint, event) + }); + } + fn spawn_event_task(&self, f: F) where F: FnOnce(BoxService, String) -> Fut + Send + 'static, @@ -225,6 +231,22 @@ async fn send_tls_termination_failed_event( send_event(client, reporting_endpoint, event, &url).await; } +async fn send_ai_usage_event( + client: BoxService, + reporting_endpoint: String, + event: AiUsageEvent, +) { + tracing::debug!( + provider = %event.provider, + model = %event.model, + "sending ai-usage event notification" + ); + + let url = format!("{}/events/ai-usage", reporting_endpoint); + + send_event(client, reporting_endpoint, event, &url).await; +} + async fn send_permissions_updated_event( client: BoxService, reporting_endpoint: String, diff --git a/proxy-lib/src/http/firewall/rule/anthropic/mod.rs b/proxy-lib/src/http/firewall/rule/anthropic/mod.rs new file mode 100644 index 00000000..bb54ced0 --- /dev/null +++ b/proxy-lib/src/http/firewall/rule/anthropic/mod.rs @@ -0,0 +1,164 @@ +use std::fmt; + +use crate::{ + http::{ + KnownContentType, + firewall::{ + domain_matcher::DomainMatcher, + events::AiUsageEvent, + notifier::EventNotifier, + rule::{RequestAction, Rule}, + }, + }, + utils::time::SystemTimestampMilliseconds, +}; +use rama::{ + error::{BoxError, ErrorContext as _}, + http::{ + Body, Method, Request, + body::util::BodyExt as _, + headers::{ContentType, HeaderMapExt as _}, + }, + net::address::Domain, + telemetry::tracing, + utils::str::arcstr::{ArcStr, arcstr}, +}; +use serde::Deserialize; + +#[cfg(feature = "pac")] +use crate::http::firewall::pac::PacScriptGenerator; + +const ANTHROPIC_PROVIDER_KEY: ArcStr = arcstr!("anthropic"); + +/// Cap on how much of the request body we will buffer to extract the model name. +/// Anthropic's prompts can include base64 images and long contexts, but the JSON +/// preamble carrying `"model": "..."` is small. Requests above this cap pass +/// through untouched (no event emitted). +const MAX_BODY_BYTES: usize = 16 * 1024 * 1024; + +pub(in crate::http::firewall) struct RuleAnthropic { + target_domains: DomainMatcher, + notifier: Option, +} + +impl RuleAnthropic { + pub(in crate::http::firewall) fn new(notifier: Option) -> Self { + Self { + target_domains: ["api.anthropic.com"].into_iter().collect(), + notifier, + } + } +} + +impl fmt::Debug for RuleAnthropic { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RuleAnthropic").finish() + } +} + +impl Rule for RuleAnthropic { + #[inline(always)] + fn match_domain(&self, domain: &Domain) -> bool { + self.target_domains.is_match(domain) + } + + #[cfg(feature = "pac")] + #[inline(always)] + fn collect_pac_domains(&self, generator: &mut PacScriptGenerator) { + for domain in self.target_domains.iter() { + generator.write_domain(&domain); + } + } + + async fn evaluate_request(&self, req: Request) -> Result { + if req.method() != Method::POST { + return Ok(RequestAction::Allow(req)); + } + + if !path_carries_model(req.uri().path()) { + return Ok(RequestAction::Allow(req)); + } + + if req + .headers() + .typed_get::() + .and_then(KnownContentType::detect_from_content_type_header) + != Some(KnownContentType::Json) + { + return Ok(RequestAction::Allow(req)); + } + + let (parts, body) = req.into_parts(); + + let bytes = body + .collect() + .await + .context("collect anthropic request body")? + .to_bytes(); + + if bytes.len() > MAX_BODY_BYTES { + tracing::debug!( + body_size = bytes.len(), + "anthropic request body exceeds cap, skipping ai-usage detection" + ); + return Ok(RequestAction::Allow(Request::from_parts( + parts, + Body::from(bytes), + ))); + } + + if let Some(model) = parse_model_field(&bytes) { + tracing::debug!(model = %model, "anthropic request observed"); + if let Some(notifier) = self.notifier.as_ref() { + notifier.notify_ai_usage(AiUsageEvent { + ts_ms: SystemTimestampMilliseconds::now(), + provider: ANTHROPIC_PROVIDER_KEY, + model, + }); + } + } else { + tracing::debug!("anthropic request body did not contain a recognizable model field"); + } + + Ok(RequestAction::Allow(Request::from_parts( + parts, + Body::from(bytes), + ))) + } +} + +/// Endpoints whose request body carries a top-level `model` field. +fn path_carries_model(path: &str) -> bool { + matches!( + path, + "/v1/messages" | "/v1/messages/count_tokens" | "/v1/complete" + ) +} + +#[derive(Debug, Deserialize)] +struct AnthropicRequestModel { + model: ArcStr, +} + +fn parse_model_field(bytes: &[u8]) -> Option { + let parsed: AnthropicRequestModel = serde_json::from_slice(bytes) + .inspect_err(|err| { + tracing::debug!( + error = %err, + "failed to parse anthropic request body as JSON with `model`" + ); + }) + .ok()?; + let trimmed = parsed.model.trim(); + if trimmed.is_empty() { + return None; + } + if trimmed.len() == parsed.model.len() { + Some(parsed.model) + } else { + Some(ArcStr::from(trimmed)) + } +} + +#[cfg(test)] +mod tests; diff --git a/proxy-lib/src/http/firewall/rule/anthropic/tests.rs b/proxy-lib/src/http/firewall/rule/anthropic/tests.rs new file mode 100644 index 00000000..fd38c065 --- /dev/null +++ b/proxy-lib/src/http/firewall/rule/anthropic/tests.rs @@ -0,0 +1,95 @@ +use super::*; +use rama::utils::str::arcstr::ArcStr; + +// --- path_carries_model: Anthropic POST endpoints whose body has `model` --- + +#[test] +fn test_path_carries_model_messages() { + assert!(path_carries_model("/v1/messages")); +} + +#[test] +fn test_path_carries_model_count_tokens() { + assert!(path_carries_model("/v1/messages/count_tokens")); +} + +#[test] +fn test_path_carries_model_legacy_complete() { + assert!(path_carries_model("/v1/complete")); +} + +// --- path_carries_model: paths we should ignore --- + +#[test] +fn test_path_carries_model_rejects_models_listing() { + assert!(!path_carries_model("/v1/models")); +} + +#[test] +fn test_path_carries_model_rejects_root() { + assert!(!path_carries_model("/")); +} + +#[test] +fn test_path_carries_model_rejects_unknown_path() { + assert!(!path_carries_model("/v1/messages/something-else")); +} + +// --- parse_model_field: typical Anthropic Messages API request bodies --- + +#[test] +fn test_parse_model_basic() { + let body = br#"{"model":"claude-3-5-sonnet-20241022","messages":[]}"#; + assert_eq!( + parse_model_field(body), + Some(ArcStr::from("claude-3-5-sonnet-20241022")) + ); +} + +#[test] +fn test_parse_model_with_extra_fields() { + let body = br#"{ + "model": "claude-opus-4-7", + "max_tokens": 1024, + "messages": [{"role": "user", "content": "hi"}] + }"#; + assert_eq!( + parse_model_field(body), + Some(ArcStr::from("claude-opus-4-7")) + ); +} + +#[test] +fn test_parse_model_trims_whitespace() { + let body = br#"{"model":" claude-haiku-4-5 ","messages":[]}"#; + assert_eq!( + parse_model_field(body), + Some(ArcStr::from("claude-haiku-4-5")) + ); +} + +// --- parse_model_field: malformed inputs --- + +#[test] +fn test_parse_model_missing_field() { + let body = br#"{"messages":[]}"#; + assert!(parse_model_field(body).is_none()); +} + +#[test] +fn test_parse_model_empty_string() { + let body = br#"{"model":""}"#; + assert!(parse_model_field(body).is_none()); +} + +#[test] +fn test_parse_model_invalid_json() { + let body = b"not even json"; + assert!(parse_model_field(body).is_none()); +} + +#[test] +fn test_parse_model_wrong_type() { + let body = br#"{"model":42}"#; + assert!(parse_model_field(body).is_none()); +} diff --git a/proxy-lib/src/http/firewall/rule/mod.rs b/proxy-lib/src/http/firewall/rule/mod.rs index ecbbfd39..f00210d7 100644 --- a/proxy-lib/src/http/firewall/rule/mod.rs +++ b/proxy-lib/src/http/firewall/rule/mod.rs @@ -49,6 +49,7 @@ pub(crate) fn block_reason_for(decision: PackagePolicyDecision) -> BlockReason { #[cfg(feature = "pac")] pub use super::pac::PacScriptGenerator; +pub mod anthropic; pub mod chrome; pub mod golang; pub mod hijack;