Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions internal/cloud/cloud.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,7 @@ func SendActivity(ctx context.Context, config *config.ConfigInfo, event *Activit
func SendRequestPackageInstallation(ctx context.Context, config *config.ConfigInfo, event *RequestPackageInstallationEvent) error {
return sendEvent(ctx, RequestPackageInstallationEndpoint, config, event)
}

func SendAiUsageStats(ctx context.Context, config *config.ConfigInfo, event *AiUsageStatsEvent) error {
Comment thread
bitterpanda63 marked this conversation as resolved.
return sendEvent(ctx, ReportAiStatsEndpoint, config, event)
}
66 changes: 66 additions & 0 deletions internal/cloud/cloud_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package cloud

import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/AikidoSec/safechain-internals/internal/config"
Expand Down Expand Up @@ -75,6 +78,69 @@ func TestUnauthorizedTokenIsSuppressedUntilTokenChanges(t *testing.T) {
}
}

func TestSendAiUsageStatsHitsCorrectEndpointWithExpectedShape(t *testing.T) {
withTempRunDir(t)
Init()

var (
gotPath string
gotBody []byte
)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
body, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("failed to read request body: %v", err)
}
gotBody = body
w.WriteHeader(http.StatusOK)
}))
defer server.Close()

cfg := &config.ConfigInfo{
Token: "good-token",
DeviceID: "device-1",
BaseURL: server.URL,
}

event := &AiUsageStatsEvent{
Models: []AiUsageModel{
{Provider: "anthropic", Model: "claude-opus-4-7"},
},
}

if err := SendAiUsageStats(context.Background(), cfg, event); err != nil {
t.Fatalf("expected SendAiUsageStats to succeed, got %v", err)
}

if want := "/" + ReportAiStatsEndpoint; gotPath != want {
t.Fatalf("expected POST to %q, got %q", want, gotPath)
}

var decoded AiUsageStatsEvent
if err := json.Unmarshal(gotBody, &decoded); err != nil {
t.Fatalf("failed to decode body as AiUsageStatsEvent: %v\nraw: %s", err, string(gotBody))
}
if len(decoded.Models) != 1 {
t.Fatalf("expected one model, got %d", len(decoded.Models))
}
got := decoded.Models[0]
if got.Provider != "anthropic" || got.Model != "claude-opus-4-7" {
t.Fatalf("expected anthropic/claude-opus-4-7, got %s/%s", got.Provider, got.Model)
}

// Wire format must match Wout's spec: `{models: [{provider, model}]}` β€”
// no agent-side timestamp until websockets/SSE.
if !strings.Contains(string(gotBody), `"models"`) {
t.Fatalf("expected payload to contain top-level `models` field, got: %s", string(gotBody))
}
for _, banned := range []string{`"last_seen_at"`, `"last_seen_ms"`, `"first_seen_ms"`, `"count"`, `"ts_ms"`} {
if strings.Contains(string(gotBody), banned) {
t.Fatalf("payload should not contain %s yet: %s", banned, string(gotBody))
}
}
}

// Simulates the race where a 401 for a token-in-flight arrives after the user
// has already replaced the token. The new token must not be marked unauthorized.
func TestStale401DoesNotPoisonReplacedToken(t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions internal/cloud/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ const (
ActivityEndpoint = "api/endpoint_protection/callbacks/reportActivity"
RequestPackageInstallationEndpoint = "api/endpoint_protection/callbacks/requestPackageInstallation"
UploadDeviceLogsEndpoint = "api/endpoint_protection/callbacks/uploadDeviceLogs"
ReportAiStatsEndpoint = "api/endpoint_protection/callbacks/reportAiStats"
)
10 changes: 10 additions & 0 deletions internal/cloud/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,13 @@ type RequestPackageInstallationEvent struct {
Ecosystems []EcosystemPackages `json:"ecosystems"`
} `json:"sbom"`
}

type AiUsageModel struct {
Provider string `json:"provider"`
Model string `json:"model"`
}

// AiUsageStatsEvent is the body sent to reportAiStats.
type AiUsageStatsEvent struct {
Models []AiUsageModel `json:"models"`
}
1 change: 1 addition & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type ConfigInfo struct {
LastHeartbeatReportTime time.Time `json:"last_heartbeat_report_time"`
LastSBOMReportTime time.Time `json:"last_sbom_report_time"`
LastSetupWizardShownTime time.Time `json:"last_setup_wizard_shown_time"`
LastAiUsageReportTime time.Time `json:"last_ai_usage_report_time"`
BaseURL string `json:"base_url,omitempty"`
ProxyMode string `json:"proxy_mode,omitempty"`

Expand Down
1 change: 1 addition & 0 deletions internal/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ const (
HeartbeatReportInterval = 3 * time.Minute
SBOMReportInterval = 24 * time.Hour
SetupWizardReshowInterval = 24 * time.Hour
AiUsageReportInterval = 10 * time.Minute
ProxyStartMaxRetries = 100
)
29 changes: 29 additions & 0 deletions internal/daemon/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,35 @@ func (d *Daemon) heartbeat() error {
}
return nil
})
d.runIfIntervalExceeded(&d.config.LastAiUsageReportTime, constants.AiUsageReportInterval, d.reportAiUsage)
return nil
}

func (d *Daemon) reportAiUsage() error {
if d.config.Token == "" {
return fmt.Errorf("Token is not set, skipping AI usage report")
}

cutoffMs := d.config.LastAiUsageReportTime.UnixMilli()
snapshot := d.ingress.AiUsageEvents()
models := make([]cloud.AiUsageModel, 0, len(snapshot))
for _, ev := range snapshot {
if ev.TsMs <= cutoffMs {
continue
}
models = append(models, cloud.AiUsageModel{
Provider: ev.Provider,
Model: ev.Model,
})
}
if len(models) == 0 {
return nil
}
event := &cloud.AiUsageStatsEvent{Models: models}
if err := cloud.SendAiUsageStats(d.ctx, d.config, event); err != nil {
return fmt.Errorf("Failed to report AI usage: %v", err)
}
log.Printf("AI usage report sent successfully (%d model(s))", len(models))
return nil
}

Expand Down
44 changes: 44 additions & 0 deletions internal/ingress/ai_usage_handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package ingress

import (
"encoding/json"
"log"
"net/http"
)

func (s *Server) handleAiUsage(w http.ResponseWriter, r *http.Request) {
log.Printf("ai-usage: POST /events/ai-usage from %s", r.RemoteAddr)

var event AiUsageEvent
if err := json.NewDecoder(r.Body).Decode(&event); err != nil {
log.Printf("ai-usage: invalid JSON body: %v", err)
http.Error(w, "invalid request body", http.StatusBadRequest)
return
}

if event.Provider == "" || event.Model == "" {
log.Printf("ai-usage: rejecting event with missing fields: provider=%q model=%q", event.Provider, event.Model)
http.Error(w, "provider and model are required", http.StatusBadRequest)
return
}

stored, isNew := s.aiUsageStore.Add(event)
if isNew {
log.Printf("ai-usage: first observation: provider=%s model=%s", stored.Provider, stored.Model)
} else {
log.Printf("ai-usage: repeat observation: provider=%s model=%s", stored.Provider, stored.Model)
}

w.WriteHeader(http.StatusOK)
}

func (s *Server) handleAiUsageEvents(w http.ResponseWriter, r *http.Request) {
if !s.validateUIToken(w, r) {
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(s.aiUsageStore.List()); err != nil {
log.Printf("failed to encode ai-usage events: %v", err)
}
}
53 changes: 53 additions & 0 deletions internal/ingress/ai_usage_store.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package ingress

import (
"fmt"
"sync"
)

type aiUsageEventStore struct {
mu sync.RWMutex
events []AiUsageEvent
}

func aiUsageEventID(provider, model string) string {
return fmt.Sprintf("ai-usage-%s-%s", provider, model)
}

// Add records an observation for the given (provider, model) pair. Repeats
// just refresh the stored timestamp. Returns the stored row and `isNew=true`
// when this was the first observation of that (provider, model).
func (s *aiUsageEventStore) Add(ev AiUsageEvent) (AiUsageEvent, bool) {
s.mu.Lock()
defer s.mu.Unlock()

id := aiUsageEventID(ev.Provider, ev.Model)

for i := range s.events {
if s.events[i].ID == id {
s.events[i].TsMs = ev.TsMs
return s.events[i], false
}
}

stored := AiUsageEvent{
ID: id,
TsMs: ev.TsMs,
Provider: ev.Provider,
Model: ev.Model,
}
s.events = append(s.events, stored)
return stored, true
}

func (s *aiUsageEventStore) List() []AiUsageEvent {
s.mu.RLock()
defer s.mu.RUnlock()
out := make([]AiUsageEvent, len(s.events))
copy(out, s.events)
return out
}

func (s *Server) AiUsageEvents() []AiUsageEvent {
return s.aiUsageStore.List()
}
98 changes: 98 additions & 0 deletions internal/ingress/ai_usage_store_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package ingress

import "testing"

func TestAiUsageEventStoreAddAssignsStableIDAndReportsNew(t *testing.T) {
store := &aiUsageEventStore{}

stored, isNew := store.Add(AiUsageEvent{
TsMs: 100,
Provider: "anthropic",
Model: "claude-3-5-sonnet-20241022",
})

want := "ai-usage-anthropic-claude-3-5-sonnet-20241022"
if stored.ID != want {
t.Fatalf("expected stable id %q, got %q", want, stored.ID)
}
if !isNew {
t.Fatalf("expected first observation to report isNew=true")
}
if stored.TsMs != 100 {
t.Fatalf("expected stored ts_ms=100, got %d", stored.TsMs)
}
}

func TestAiUsageEventStoreAddCollapsesSameModelAndRefreshesTimestamp(t *testing.T) {
store := &aiUsageEventStore{}

first, firstIsNew := store.Add(AiUsageEvent{TsMs: 100, Provider: "anthropic", Model: "claude-opus-4-7"})
second, secondIsNew := store.Add(AiUsageEvent{TsMs: 250, Provider: "anthropic", Model: "claude-opus-4-7"})
third, thirdIsNew := store.Add(AiUsageEvent{TsMs: 400, Provider: "anthropic", Model: "claude-opus-4-7"})

if !firstIsNew {
t.Fatalf("expected first call to report isNew=true")
}
if secondIsNew || thirdIsNew {
t.Fatalf("expected repeats to report isNew=false, got second=%v third=%v", secondIsNew, thirdIsNew)
}
if first.ID != second.ID || second.ID != third.ID {
t.Fatalf("expected stable aggregate id, got %q / %q / %q", first.ID, second.ID, third.ID)
}
if len(store.List()) != 1 {
t.Fatalf("expected one aggregate event, got %d", len(store.List()))
}
if third.TsMs != 400 {
t.Fatalf("expected ts_ms to refresh to 400, got %d", third.TsMs)
}
}

func TestAiUsageEventStoreAddSeparatesEntriesPerModel(t *testing.T) {
store := &aiUsageEventStore{}

store.Add(AiUsageEvent{TsMs: 100, Provider: "anthropic", Model: "claude-opus-4-7"})
store.Add(AiUsageEvent{TsMs: 200, Provider: "anthropic", Model: "claude-haiku-4-5"})

if len(store.List()) != 2 {
t.Fatalf("expected one aggregate event per model, got %d", len(store.List()))
}
}

func TestAiUsageEventStoreAddSeparatesEntriesPerProvider(t *testing.T) {
store := &aiUsageEventStore{}

store.Add(AiUsageEvent{TsMs: 100, Provider: "anthropic", Model: "shared-name"})
store.Add(AiUsageEvent{TsMs: 200, Provider: "openai", Model: "shared-name"})

if len(store.List()) != 2 {
t.Fatalf("expected one aggregate event per provider, got %d", len(store.List()))
}
}

func TestServerAiUsageEventsReturnsCopyAndDoesNotClear(t *testing.T) {
s := &Server{aiUsageStore: &aiUsageEventStore{}}

s.aiUsageStore.Add(AiUsageEvent{TsMs: 100, Provider: "anthropic", Model: "claude-opus-4-7"})
s.aiUsageStore.Add(AiUsageEvent{TsMs: 200, Provider: "anthropic", Model: "claude-opus-4-7"})

snap1 := s.AiUsageEvents()
if len(snap1) != 1 {
t.Fatalf("expected one row in snapshot, got %d", len(snap1))
}
if snap1[0].TsMs != 200 {
t.Fatalf("expected ts_ms=200 (latest), got %d", snap1[0].TsMs)
}

// Mutating the returned slice must not affect the store.
snap1[0].TsMs = 999
snap2 := s.AiUsageEvents()
if snap2[0].TsMs != 200 {
t.Fatalf("snapshot must be a copy; store was mutated to ts_ms=%d", snap2[0].TsMs)
}

// A second call after a flush-equivalent call must still see the data β€”
// the store is intentionally not cleared.
if len(snap2) != 1 {
t.Fatalf("expected store to retain data, got %d rows", len(snap2))
}
}
11 changes: 11 additions & 0 deletions internal/ingress/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ type MinPackageAgeEvent struct {
Message string `json:"message,omitempty"`
}

// AiUsageEvent reports an observed call to a third-party AI provider.
// The proxy posts one event per call (with `ts_ms` populated); the daemon
// collapses repeats by (provider, model) and keeps only the most recent
// timestamp.
type AiUsageEvent struct {
Comment thread
bitterpanda63 marked this conversation as resolved.
ID string `json:"id,omitempty"`
TsMs int64 `json:"ts_ms"`
Provider string `json:"provider"`
Model string `json:"model"`
}

type EcosystemExceptions struct {
AllowedPackages []string `json:"allowed_packages"`
RejectedPackages []string `json:"rejected_packages"`
Expand Down
Loading
Loading