Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions internal/cloud/cloud.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,7 @@ func SendActivity(ctx context.Context, config *config.ConfigInfo, event *Activit
func SendRequestPackageInstallation(ctx context.Context, config *config.ConfigInfo, event *RequestPackageInstallationEvent) error {
return sendEvent(ctx, RequestPackageInstallationEndpoint, config, event)
}

func SendAiUsageStats(ctx context.Context, config *config.ConfigInfo, event *AiUsageStatsEvent) error {
Comment thread
bitterpanda63 marked this conversation as resolved.
return sendEvent(ctx, ReportAiStatsEndpoint, config, event)
}
66 changes: 66 additions & 0 deletions internal/cloud/cloud_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package cloud

import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/AikidoSec/safechain-internals/internal/config"
Expand Down Expand Up @@ -75,6 +78,69 @@ func TestUnauthorizedTokenIsSuppressedUntilTokenChanges(t *testing.T) {
}
}

func TestSendAiUsageStatsHitsCorrectEndpointWithExpectedShape(t *testing.T) {
withTempRunDir(t)
Init()

var (
gotPath string
gotBody []byte
)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
body, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("failed to read request body: %v", err)
}
gotBody = body
w.WriteHeader(http.StatusOK)
}))
defer server.Close()

cfg := &config.ConfigInfo{
Token: "good-token",
DeviceID: "device-1",
BaseURL: server.URL,
}

event := &AiUsageStatsEvent{
Models: []AiUsageModel{
{Provider: "anthropic", Model: "claude-opus-4-7"},
},
}

if err := SendAiUsageStats(context.Background(), cfg, event); err != nil {
t.Fatalf("expected SendAiUsageStats to succeed, got %v", err)
}

if want := "/" + ReportAiStatsEndpoint; gotPath != want {
t.Fatalf("expected POST to %q, got %q", want, gotPath)
}

var decoded AiUsageStatsEvent
if err := json.Unmarshal(gotBody, &decoded); err != nil {
t.Fatalf("failed to decode body as AiUsageStatsEvent: %v\nraw: %s", err, string(gotBody))
}
if len(decoded.Models) != 1 {
t.Fatalf("expected one model, got %d", len(decoded.Models))
}
got := decoded.Models[0]
if got.Provider != "anthropic" || got.Model != "claude-opus-4-7" {
t.Fatalf("expected anthropic/claude-opus-4-7, got %s/%s", got.Provider, got.Model)
}

// Wire format must match Wout's spec: `{models: [{provider, model}]}` β€”
// no agent-side timestamp until websockets/SSE.
if !strings.Contains(string(gotBody), `"models"`) {
t.Fatalf("expected payload to contain top-level `models` field, got: %s", string(gotBody))
}
for _, banned := range []string{`"last_seen_at"`, `"last_seen_ms"`, `"first_seen_ms"`, `"count"`, `"ts_ms"`} {
if strings.Contains(string(gotBody), banned) {
t.Fatalf("payload should not contain %s yet: %s", banned, string(gotBody))
}
}
}

// Simulates the race where a 401 for a token-in-flight arrives after the user
// has already replaced the token. The new token must not be marked unauthorized.
func TestStale401DoesNotPoisonReplacedToken(t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions internal/cloud/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ const (
ActivityEndpoint = "api/endpoint_protection/callbacks/reportActivity"
RequestPackageInstallationEndpoint = "api/endpoint_protection/callbacks/requestPackageInstallation"
UploadDeviceLogsEndpoint = "api/endpoint_protection/callbacks/uploadDeviceLogs"
ReportAiStatsEndpoint = "api/endpoint_protection/callbacks/reportAiStats"
)
14 changes: 14 additions & 0 deletions internal/cloud/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,17 @@ type RequestPackageInstallationEvent struct {
Ecosystems []EcosystemPackages `json:"ecosystems"`
} `json:"sbom"`
}

// AiUsageModel is one observed (provider, model) pair on this device. The
// cloud stamps `last_seen_at` server-side on receive β€” agent-side timestamps
Comment thread
reiniercriel marked this conversation as resolved.
Outdated
// are deferred until we have a streaming transport (websockets/SSE) where
// per-call accuracy actually pays off.
type AiUsageModel struct {
Provider string `json:"provider"`
Model string `json:"model"`
}

// AiUsageStatsEvent is the body sent to reportAiStats.
type AiUsageStatsEvent struct {
Models []AiUsageModel `json:"models"`
}
1 change: 1 addition & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type ConfigInfo struct {
LastHeartbeatReportTime time.Time `json:"last_heartbeat_report_time"`
LastSBOMReportTime time.Time `json:"last_sbom_report_time"`
LastSetupWizardShownTime time.Time `json:"last_setup_wizard_shown_time"`
LastAiUsageReportTime time.Time `json:"last_ai_usage_report_time"`
BaseURL string `json:"base_url,omitempty"`
ProxyMode string `json:"proxy_mode,omitempty"`

Expand Down
1 change: 1 addition & 0 deletions internal/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ const (
HeartbeatReportInterval = 3 * time.Minute
SBOMReportInterval = 24 * time.Hour
SetupWizardReshowInterval = 24 * time.Hour
AiUsageReportInterval = 10 * time.Minute
ProxyStartMaxRetries = 100
)
29 changes: 29 additions & 0 deletions internal/daemon/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,35 @@ func (d *Daemon) heartbeat() error {
}
return nil
})
d.runIfIntervalExceeded(&d.config.LastAiUsageReportTime, constants.AiUsageReportInterval, d.reportAiUsage)
return nil
}

func (d *Daemon) reportAiUsage() error {
if d.config.Token == "" {
return fmt.Errorf("Token is not set, skipping AI usage report")
}

cutoffMs := d.config.LastAiUsageReportTime.UnixMilli()
snapshot := d.ingress.AiUsageEvents()
models := make([]cloud.AiUsageModel, 0, len(snapshot))
for _, ev := range snapshot {
if ev.TsMs <= cutoffMs {
continue
}
models = append(models, cloud.AiUsageModel{
Provider: ev.Provider,
Model: ev.Model,
})
}
if len(models) == 0 {
return nil
}
event := &cloud.AiUsageStatsEvent{Models: models}
if err := cloud.SendAiUsageStats(d.ctx, d.config, event); err != nil {
return fmt.Errorf("Failed to report AI usage: %v", err)
}
log.Printf("AI usage report sent successfully (%d model(s))", len(models))
return nil
}

Expand Down
44 changes: 44 additions & 0 deletions internal/ingress/ai_usage_handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package ingress

import (
"encoding/json"
"log"
"net/http"
)

func (s *Server) handleAiUsage(w http.ResponseWriter, r *http.Request) {
log.Printf("ai-usage: POST /events/ai-usage from %s", r.RemoteAddr)

var event AiUsageEvent
if err := json.NewDecoder(r.Body).Decode(&event); err != nil {
log.Printf("ai-usage: invalid JSON body: %v", err)
http.Error(w, "invalid request body", http.StatusBadRequest)
return
}

if event.Provider == "" || event.Model == "" {
log.Printf("ai-usage: rejecting event with missing fields: provider=%q model=%q", event.Provider, event.Model)
http.Error(w, "provider and model are required", http.StatusBadRequest)
return
}

stored, isNew := s.aiUsageStore.Add(event)
if isNew {
log.Printf("ai-usage: first observation: provider=%s model=%s", stored.Provider, stored.Model)
} else {
log.Printf("ai-usage: repeat observation: provider=%s model=%s", stored.Provider, stored.Model)
}

w.WriteHeader(http.StatusOK)
}

func (s *Server) handleAiUsageEvents(w http.ResponseWriter, r *http.Request) {
if !s.validateUIToken(w, r) {
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(s.aiUsageStore.List()); err != nil {
log.Printf("failed to encode ai-usage events: %v", err)
}
}
53 changes: 53 additions & 0 deletions internal/ingress/ai_usage_store.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package ingress

import (
"fmt"
"sync"
)

type aiUsageEventStore struct {
mu sync.RWMutex
events []AiUsageEvent
}

func aiUsageEventID(provider, model string) string {
return fmt.Sprintf("ai-usage-%s-%s", provider, model)
}

// Add records an observation for the given (provider, model) pair. Repeats
// just refresh the stored timestamp. Returns the stored row and `isNew=true`
// when this was the first observation of that (provider, model).
func (s *aiUsageEventStore) Add(ev AiUsageEvent) (AiUsageEvent, bool) {
s.mu.Lock()
defer s.mu.Unlock()

id := aiUsageEventID(ev.Provider, ev.Model)

for i := range s.events {
if s.events[i].ID == id {
s.events[i].TsMs = ev.TsMs
return s.events[i], false
}
}

stored := AiUsageEvent{
ID: id,
TsMs: ev.TsMs,
Provider: ev.Provider,
Model: ev.Model,
}
s.events = append(s.events, stored)
return stored, true
}

func (s *aiUsageEventStore) List() []AiUsageEvent {
s.mu.RLock()
defer s.mu.RUnlock()
out := make([]AiUsageEvent, len(s.events))
copy(out, s.events)
return out
}

func (s *Server) AiUsageEvents() []AiUsageEvent {
return s.aiUsageStore.List()
}
98 changes: 98 additions & 0 deletions internal/ingress/ai_usage_store_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package ingress

import "testing"

func TestAiUsageEventStoreAddAssignsStableIDAndReportsNew(t *testing.T) {
store := &aiUsageEventStore{}

stored, isNew := store.Add(AiUsageEvent{
TsMs: 100,
Provider: "anthropic",
Model: "claude-3-5-sonnet-20241022",
})

want := "ai-usage-anthropic-claude-3-5-sonnet-20241022"
if stored.ID != want {
t.Fatalf("expected stable id %q, got %q", want, stored.ID)
}
if !isNew {
t.Fatalf("expected first observation to report isNew=true")
}
if stored.TsMs != 100 {
t.Fatalf("expected stored ts_ms=100, got %d", stored.TsMs)
}
}

func TestAiUsageEventStoreAddCollapsesSameModelAndRefreshesTimestamp(t *testing.T) {
store := &aiUsageEventStore{}

first, firstIsNew := store.Add(AiUsageEvent{TsMs: 100, Provider: "anthropic", Model: "claude-opus-4-7"})
second, secondIsNew := store.Add(AiUsageEvent{TsMs: 250, Provider: "anthropic", Model: "claude-opus-4-7"})
third, thirdIsNew := store.Add(AiUsageEvent{TsMs: 400, Provider: "anthropic", Model: "claude-opus-4-7"})

if !firstIsNew {
t.Fatalf("expected first call to report isNew=true")
}
if secondIsNew || thirdIsNew {
t.Fatalf("expected repeats to report isNew=false, got second=%v third=%v", secondIsNew, thirdIsNew)
}
if first.ID != second.ID || second.ID != third.ID {
t.Fatalf("expected stable aggregate id, got %q / %q / %q", first.ID, second.ID, third.ID)
}
if len(store.List()) != 1 {
t.Fatalf("expected one aggregate event, got %d", len(store.List()))
}
if third.TsMs != 400 {
t.Fatalf("expected ts_ms to refresh to 400, got %d", third.TsMs)
}
}

func TestAiUsageEventStoreAddSeparatesEntriesPerModel(t *testing.T) {
store := &aiUsageEventStore{}

store.Add(AiUsageEvent{TsMs: 100, Provider: "anthropic", Model: "claude-opus-4-7"})
store.Add(AiUsageEvent{TsMs: 200, Provider: "anthropic", Model: "claude-haiku-4-5"})

if len(store.List()) != 2 {
t.Fatalf("expected one aggregate event per model, got %d", len(store.List()))
}
}

func TestAiUsageEventStoreAddSeparatesEntriesPerProvider(t *testing.T) {
store := &aiUsageEventStore{}

store.Add(AiUsageEvent{TsMs: 100, Provider: "anthropic", Model: "shared-name"})
store.Add(AiUsageEvent{TsMs: 200, Provider: "openai", Model: "shared-name"})

if len(store.List()) != 2 {
t.Fatalf("expected one aggregate event per provider, got %d", len(store.List()))
}
}

func TestServerAiUsageEventsReturnsCopyAndDoesNotClear(t *testing.T) {
s := &Server{aiUsageStore: &aiUsageEventStore{}}

s.aiUsageStore.Add(AiUsageEvent{TsMs: 100, Provider: "anthropic", Model: "claude-opus-4-7"})
s.aiUsageStore.Add(AiUsageEvent{TsMs: 200, Provider: "anthropic", Model: "claude-opus-4-7"})

snap1 := s.AiUsageEvents()
if len(snap1) != 1 {
t.Fatalf("expected one row in snapshot, got %d", len(snap1))
}
if snap1[0].TsMs != 200 {
t.Fatalf("expected ts_ms=200 (latest), got %d", snap1[0].TsMs)
}

// Mutating the returned slice must not affect the store.
snap1[0].TsMs = 999
snap2 := s.AiUsageEvents()
if snap2[0].TsMs != 200 {
t.Fatalf("snapshot must be a copy; store was mutated to ts_ms=%d", snap2[0].TsMs)
}

// A second call after a flush-equivalent call must still see the data β€”
// the store is intentionally not cleared.
if len(snap2) != 1 {
t.Fatalf("expected store to retain data, got %d rows", len(snap2))
}
}
11 changes: 11 additions & 0 deletions internal/ingress/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ type MinPackageAgeEvent struct {
Message string `json:"message,omitempty"`
}

// AiUsageEvent reports an observed call to a third-party AI provider.
// The proxy posts one event per call (with `ts_ms` populated); the daemon
// collapses repeats by (provider, model) and keeps only the most recent
// timestamp.
type AiUsageEvent struct {
Comment thread
bitterpanda63 marked this conversation as resolved.
ID string `json:"id,omitempty"`
TsMs int64 `json:"ts_ms"`
Provider string `json:"provider"`
Model string `json:"model"`
}

type EcosystemExceptions struct {
AllowedPackages []string `json:"allowed_packages"`
RejectedPackages []string `json:"rejected_packages"`
Expand Down
Loading
Loading