diff --git a/internal/proxy/iolog_test.go b/internal/proxy/iolog_test.go index d91cc99..d6553b8 100644 --- a/internal/proxy/iolog_test.go +++ b/internal/proxy/iolog_test.go @@ -1,6 +1,7 @@ package proxy import ( + "bytes" "context" "io" "net/http" @@ -308,3 +309,86 @@ func TestIOLog_SinkReceivesRecordViaOnDone(t *testing.T) { t.Fatalf("opted-in request must hand exactly 1 record to the sink, got %d", got) } } + +// newIOLogServerWithLog is newIOLogServer but with a caller-supplied logger, so a +// test can capture WARN output (the request-body-truncation observability line). +func newIOLogServerWithLog(t *testing.T, upstream *url.URL, policy iolog.Policy, sink iolog.Sink, maxBody int, log *logging.Logger) *Server { + t.Helper() + s := &config.Settings{ListenAddr: ":0"} + resolver := registry.NewStatic(upstream) + return NewWithIOLog(s, log, resolver, &recordingEmitter{}, policy, sink, maxBody) +} + +// TestIOLog_RequestTruncationLogged verifies Hugo's D1 decision: hard truncation +// of the captured request body is fine, but it must be OBSERVABLE — a single WARN +// line fires (with request_id and original→cap sizes) ONLY when the body exceeds +// the cap, and stays silent for an under-cap body. +func TestIOLog_RequestTruncationLogged(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Drain so the upstream still sees the full forwarded body. + _, _ = io.Copy(io.Discard, r.Body) + _, _ = io.WriteString(w, `{"ok":true,"model":"m"}`) + })) + defer backend.Close() + upstream, _ := url.Parse(backend.URL) + + const bodyCap = 100 + + t.Run("over cap logs", func(t *testing.T) { + var buf bytes.Buffer + log := logging.New(logging.WARN) + log.Warn.SetOutput(&buf) + + policy := &spyPolicy{decision: true} + sink := &recordingSink{} + srv := newIOLogServerWithLog(t, upstream, policy, sink, bodyCap, log) + + big := strings.Repeat("A", 500) + rr := httptest.NewRecorder() + srv.Handler().ServeHTTP(rr, iologRequest(http.MethodPost, big)) + + if rr.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rr.Code) + } + // The record must mark the request body truncated... + recs := sink.all() + if len(recs) != 1 || !recs[0].RequestTruncated { + t.Fatalf("expected 1 record with RequestTruncated=true, got %+v", recs) + } + // ...and the WARN line must fire with request_id and original→cap sizes. + out := buf.String() + if !strings.Contains(out, "request body truncated") { + t.Fatalf("no truncation WARN logged; got %q", out) + } + if !strings.Contains(out, "request_id=req-iolog-1") { + t.Errorf("WARN missing request_id; got %q", out) + } + if !strings.Contains(out, "500 → 100 bytes") { + t.Errorf("WARN missing original→cap sizes (want \"500 → 100 bytes\"); got %q", out) + } + }) + + t.Run("under cap silent", func(t *testing.T) { + var buf bytes.Buffer + log := logging.New(logging.WARN) + log.Warn.SetOutput(&buf) + + policy := &spyPolicy{decision: true} + sink := &recordingSink{} + srv := newIOLogServerWithLog(t, upstream, policy, sink, bodyCap, log) + + rr := httptest.NewRecorder() + srv.Handler().ServeHTTP(rr, iologRequest(http.MethodPost, `{"model":"m"}`)) + + if rr.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rr.Code) + } + recs := sink.all() + if len(recs) != 1 || recs[0].RequestTruncated { + t.Fatalf("under-cap body must not be truncated; got %+v", recs) + } + if out := buf.String(); strings.Contains(out, "request body truncated") { + t.Errorf("under-cap body must NOT log truncation; got %q", out) + } + }) +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 42ec17e..7c71c01 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -251,12 +251,20 @@ func (s *Server) handleProxy(w http.ResponseWriter, r *http.Request) { // Cap the LOGGED request-body copy at the same bound as the response copy // (s.ioMaxBodyLen) — an uncapped body flows into to_tsvector and fails the // INSERT past ~1 MiB. The forwarded request keeps the full body. - reqBody, reqTruncated, err = captureRequestBody(r, s.ioMaxBodyLen) + var reqOrigLen int + reqBody, reqTruncated, reqOrigLen, err = captureRequestBody(r, s.ioMaxBodyLen) if err != nil { s.log.Error.Printf("capture request body: %v", err) http.Error(w, "bad request body", http.StatusBadRequest) return } + // Hard truncation is intentional (an uncapped body fails the to_tsvector + // INSERT), but make it OBSERVABLE: an operator should be able to see that + // a captured body was cut, and by how much. + if reqTruncated { + s.log.Warn.Printf("iolog: request body truncated for capture request_id=%s (%d → %d bytes)", + requestID, reqOrigLen, s.ioMaxBodyLen) + } } // Force streaming usage so we never under-bill a streamed response. diff --git a/internal/proxy/rewrite.go b/internal/proxy/rewrite.go index 5f59985..c7e22ab 100644 --- a/internal/proxy/rewrite.go +++ b/internal/proxy/rewrite.go @@ -24,15 +24,17 @@ import ( // the model sees — only the stored log copy. Truncation is at a rune boundary so // the stored TEXT is valid UTF-8. maxBodyBytes <= 0 means uncapped. // -// A nil body yields ("", false) with no error. -func captureRequestBody(r *http.Request, maxBodyBytes int) (body string, truncated bool, err error) { +// A nil body yields ("", false, 0) with no error. origLen is the size of the +// FULL body read off the wire (before capping), so a caller can log the +// original-vs-cap sizes when truncated is true. +func captureRequestBody(r *http.Request, maxBodyBytes int) (body string, truncated bool, origLen int, err error) { if r.Body == nil { - return "", false, nil + return "", false, 0, nil } raw, err := io.ReadAll(r.Body) _ = r.Body.Close() if err != nil { - return "", false, err + return "", false, 0, err } // Restore the FULL body so forceIncludeUsage (and the upstream) read every // byte — the cap only bounds the LOG copy, never the forwarded request. @@ -41,7 +43,7 @@ func captureRequestBody(r *http.Request, maxBodyBytes int) (body string, truncat r.Header.Set("Content-Length", strconv.Itoa(len(raw))) logCopy, truncated := truncateAtRuneBoundary(raw, maxBodyBytes) - return string(logCopy), truncated, nil + return string(logCopy), truncated, len(raw), nil } // forceIncludeUsage rewrites a request body so that streamed responses carry a diff --git a/internal/rating/store.go b/internal/rating/store.go index 4395d30..275cc4d 100644 --- a/internal/rating/store.go +++ b/internal/rating/store.go @@ -325,7 +325,7 @@ priced AS ( + cached_tokens * cached_price + completion_tokens * completion_price ) AS cost, - COUNT(*)::int AS event_count + COUNT(*)::bigint AS event_count FROM resolved WHERE prompt_price IS NOT NULL -- priced only AND auth_id IS NOT NULL -- attributable only diff --git a/internal/rating/store_integration_test.go b/internal/rating/store_integration_test.go index c799be4..8970ef3 100644 --- a/internal/rating/store_integration_test.go +++ b/internal/rating/store_integration_test.go @@ -96,7 +96,7 @@ CREATE TABLE rated_usage ( completion_tokens BIGINT NOT NULL, billable_prompt_tokens BIGINT NOT NULL, cost NUMERIC(20,9) NOT NULL, - event_count INTEGER NOT NULL, + event_count BIGINT NOT NULL, rated_at TIMESTAMPTZ NOT NULL DEFAULT now(), CONSTRAINT rated_usage_auth_model_window_uq UNIQUE (auth_id, model_id, window_start) );` @@ -472,6 +472,59 @@ func TestIntegration_RatingInstantIndexServesScan(t *testing.T) { } } +// TestIntegration_EventCountIsBigint pins D2: rated_usage.event_count must be a +// BIGINT, so a rollup counting more than 2^31 events INSERTs and round-trips +// instead of overflowing INTEGER. Asserts both the declared column type and an +// actual > 2^31 value surviving INSERT + SELECT. +func TestIntegration_EventCountIsBigint(t *testing.T) { + dsn := os.Getenv("PHOEBE_TEST_DATABASE_URL") + if dsn == "" { + t.Skip("PHOEBE_TEST_DATABASE_URL not set; skipping live-Postgres conformance") + } + db, err := sql.Open("pgx", dsn) + if err != nil { + t.Fatalf("open: %v", err) + } + defer db.Close() + db.SetMaxOpenConns(1) // SET/search_path must stick + + const sch = "phoebe_rating_bigint_it" + exec(t, db, "DROP SCHEMA IF EXISTS "+sch+" CASCADE") + exec(t, db, "CREATE SCHEMA "+sch) + exec(t, db, "SET search_path TO "+sch) + defer func() { exec(t, db, "DROP SCHEMA IF EXISTS "+sch+" CASCADE") }() + exec(t, db, schemaDDL) + + // The declared column type must be bigint, not integer. + var dataType string + if err := db.QueryRow( + `SELECT data_type FROM information_schema.columns + WHERE table_schema = $1 AND table_name = 'rated_usage' AND column_name = 'event_count'`, + sch).Scan(&dataType); err != nil { + t.Fatalf("read event_count column type: %v", err) + } + if dataType != "bigint" { + t.Fatalf("rated_usage.event_count data_type = %q, want \"bigint\"", dataType) + } + + // A value past INT32_MAX must INSERT and round-trip unharmed. + const bigCount int64 = 5_000_000_000 // > 2^31-1 (2_147_483_647) + exec(t, db, fmt.Sprintf(`INSERT INTO rated_usage + (id, auth_id, model_id, window_start, window_end, + prompt_tokens, cached_tokens, completion_tokens, billable_prompt_tokens, cost, event_count) + VALUES ('bc1','a','b','2026-06-08T10:00:00Z','2026-06-08T11:00:00Z', + 0,0,0,0,0, %d)`, bigCount)) + + var got int64 + if err := db.QueryRow( + `SELECT event_count FROM rated_usage WHERE id = 'bc1'`).Scan(&got); err != nil { + t.Fatalf("read back event_count: %v", err) + } + if got != bigCount { + t.Fatalf("event_count round-trip = %d, want %d", got, bigCount) + } +} + func exec(t *testing.T, db *sql.DB, q string) { t.Helper() if _, err := db.Exec(q); err != nil { diff --git a/migrations/0002_rating.sql b/migrations/0002_rating.sql index 26556a9..3fb9878 100644 --- a/migrations/0002_rating.sql +++ b/migrations/0002_rating.sql @@ -217,7 +217,9 @@ CREATE TABLE rated_usage ( -- The money, as exact NUMERIC. Computed and summed in SQL (never in Go). cost NUMERIC(20, 9) NOT NULL, - event_count INTEGER NOT NULL, + -- BIGINT, not INTEGER: a wide rollup window can count more than 2^31 events, + -- which the COUNT(*)::bigint cast in the rater already produces. + event_count BIGINT NOT NULL, rated_at TIMESTAMPTZ NOT NULL DEFAULT now(), diff --git a/migrations/atlas/c2f1a3b4d5e6_add_rating.py b/migrations/atlas/c2f1a3b4d5e6_add_rating.py index 00ca054..aa8b771 100644 --- a/migrations/atlas/c2f1a3b4d5e6_add_rating.py +++ b/migrations/atlas/c2f1a3b4d5e6_add_rating.py @@ -179,7 +179,10 @@ def upgrade(): sa.Column("billable_prompt_tokens", sa.BigInteger(), nullable=False), # The money, exact NUMERIC, computed and summed in SQL. sa.Column("cost", MONEY, nullable=False), - sa.Column("event_count", sa.Integer(), nullable=False), + # BigInteger, not Integer: a wide rollup window can count more than + # 2^31 events (matches the COUNT(*)::bigint cast in the rater). No data + # migration is needed — rated_usage is empty pre-prod. + sa.Column("event_count", sa.BigInteger(), nullable=False), sa.Column( "rated_at", sa.DateTime(timezone=True),