Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions internal/proxy/iolog_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package proxy

import (
"bytes"
"context"
"io"
"net/http"
Expand Down Expand Up @@ -308,3 +309,86 @@ func TestIOLog_SinkReceivesRecordViaOnDone(t *testing.T) {
t.Fatalf("opted-in request must hand exactly 1 record to the sink, got %d", got)
}
}

// newIOLogServerWithLog is newIOLogServer but with a caller-supplied logger, so a
// test can capture WARN output (the request-body-truncation observability line).
func newIOLogServerWithLog(t *testing.T, upstream *url.URL, policy iolog.Policy, sink iolog.Sink, maxBody int, log *logging.Logger) *Server {
t.Helper()
s := &config.Settings{ListenAddr: ":0"}
resolver := registry.NewStatic(upstream)
return NewWithIOLog(s, log, resolver, &recordingEmitter{}, policy, sink, maxBody)
}

// TestIOLog_RequestTruncationLogged verifies Hugo's D1 decision: hard truncation
// of the captured request body is fine, but it must be OBSERVABLE — a single WARN
// line fires (with request_id and original→cap sizes) ONLY when the body exceeds
// the cap, and stays silent for an under-cap body.
func TestIOLog_RequestTruncationLogged(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Drain so the upstream still sees the full forwarded body.
_, _ = io.Copy(io.Discard, r.Body)
_, _ = io.WriteString(w, `{"ok":true,"model":"m"}`)
}))
defer backend.Close()
upstream, _ := url.Parse(backend.URL)

const bodyCap = 100

t.Run("over cap logs", func(t *testing.T) {
var buf bytes.Buffer
log := logging.New(logging.WARN)
log.Warn.SetOutput(&buf)

policy := &spyPolicy{decision: true}
sink := &recordingSink{}
srv := newIOLogServerWithLog(t, upstream, policy, sink, bodyCap, log)

big := strings.Repeat("A", 500)
rr := httptest.NewRecorder()
srv.Handler().ServeHTTP(rr, iologRequest(http.MethodPost, big))

if rr.Code != http.StatusOK {
t.Fatalf("status = %d, want 200", rr.Code)
}
// The record must mark the request body truncated...
recs := sink.all()
if len(recs) != 1 || !recs[0].RequestTruncated {
t.Fatalf("expected 1 record with RequestTruncated=true, got %+v", recs)
}
// ...and the WARN line must fire with request_id and original→cap sizes.
out := buf.String()
if !strings.Contains(out, "request body truncated") {
t.Fatalf("no truncation WARN logged; got %q", out)
}
if !strings.Contains(out, "request_id=req-iolog-1") {
t.Errorf("WARN missing request_id; got %q", out)
}
if !strings.Contains(out, "500 → 100 bytes") {
t.Errorf("WARN missing original→cap sizes (want \"500 → 100 bytes\"); got %q", out)
}
})

t.Run("under cap silent", func(t *testing.T) {
var buf bytes.Buffer
log := logging.New(logging.WARN)
log.Warn.SetOutput(&buf)

policy := &spyPolicy{decision: true}
sink := &recordingSink{}
srv := newIOLogServerWithLog(t, upstream, policy, sink, bodyCap, log)

rr := httptest.NewRecorder()
srv.Handler().ServeHTTP(rr, iologRequest(http.MethodPost, `{"model":"m"}`))

if rr.Code != http.StatusOK {
t.Fatalf("status = %d, want 200", rr.Code)
}
recs := sink.all()
if len(recs) != 1 || recs[0].RequestTruncated {
t.Fatalf("under-cap body must not be truncated; got %+v", recs)
}
if out := buf.String(); strings.Contains(out, "request body truncated") {
t.Errorf("under-cap body must NOT log truncation; got %q", out)
}
})
}
10 changes: 9 additions & 1 deletion internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,20 @@ func (s *Server) handleProxy(w http.ResponseWriter, r *http.Request) {
// Cap the LOGGED request-body copy at the same bound as the response copy
// (s.ioMaxBodyLen) — an uncapped body flows into to_tsvector and fails the
// INSERT past ~1 MiB. The forwarded request keeps the full body.
reqBody, reqTruncated, err = captureRequestBody(r, s.ioMaxBodyLen)
var reqOrigLen int
reqBody, reqTruncated, reqOrigLen, err = captureRequestBody(r, s.ioMaxBodyLen)
if err != nil {
s.log.Error.Printf("capture request body: %v", err)
http.Error(w, "bad request body", http.StatusBadRequest)
return
}
// Hard truncation is intentional (an uncapped body fails the to_tsvector
// INSERT), but make it OBSERVABLE: an operator should be able to see that
// a captured body was cut, and by how much.
if reqTruncated {
s.log.Warn.Printf("iolog: request body truncated for capture request_id=%s (%d → %d bytes)",
requestID, reqOrigLen, s.ioMaxBodyLen)
}
}

// Force streaming usage so we never under-bill a streamed response.
Expand Down
12 changes: 7 additions & 5 deletions internal/proxy/rewrite.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,17 @@ import (
// the model sees — only the stored log copy. Truncation is at a rune boundary so
// the stored TEXT is valid UTF-8. maxBodyBytes <= 0 means uncapped.
//
// A nil body yields ("", false) with no error.
func captureRequestBody(r *http.Request, maxBodyBytes int) (body string, truncated bool, err error) {
// A nil body yields ("", false, 0) with no error. origLen is the size of the
// FULL body read off the wire (before capping), so a caller can log the
// original-vs-cap sizes when truncated is true.
func captureRequestBody(r *http.Request, maxBodyBytes int) (body string, truncated bool, origLen int, err error) {
if r.Body == nil {
return "", false, nil
return "", false, 0, nil
}
raw, err := io.ReadAll(r.Body)
_ = r.Body.Close()
if err != nil {
return "", false, err
return "", false, 0, err
}
// Restore the FULL body so forceIncludeUsage (and the upstream) read every
// byte — the cap only bounds the LOG copy, never the forwarded request.
Expand All @@ -41,7 +43,7 @@ func captureRequestBody(r *http.Request, maxBodyBytes int) (body string, truncat
r.Header.Set("Content-Length", strconv.Itoa(len(raw)))

logCopy, truncated := truncateAtRuneBoundary(raw, maxBodyBytes)
return string(logCopy), truncated, nil
return string(logCopy), truncated, len(raw), nil
}

// forceIncludeUsage rewrites a request body so that streamed responses carry a
Expand Down
2 changes: 1 addition & 1 deletion internal/rating/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ priced AS (
+ cached_tokens * cached_price
+ completion_tokens * completion_price
) AS cost,
COUNT(*)::int AS event_count
COUNT(*)::bigint AS event_count
FROM resolved
WHERE prompt_price IS NOT NULL -- priced only
AND auth_id IS NOT NULL -- attributable only
Expand Down
55 changes: 54 additions & 1 deletion internal/rating/store_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ CREATE TABLE rated_usage (
completion_tokens BIGINT NOT NULL,
billable_prompt_tokens BIGINT NOT NULL,
cost NUMERIC(20,9) NOT NULL,
event_count INTEGER NOT NULL,
event_count BIGINT NOT NULL,
rated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
CONSTRAINT rated_usage_auth_model_window_uq UNIQUE (auth_id, model_id, window_start)
);`
Expand Down Expand Up @@ -472,6 +472,59 @@ func TestIntegration_RatingInstantIndexServesScan(t *testing.T) {
}
}

// TestIntegration_EventCountIsBigint pins D2: rated_usage.event_count must be a
// BIGINT, so a rollup counting more than 2^31 events INSERTs and round-trips
// instead of overflowing INTEGER. Asserts both the declared column type and an
// actual > 2^31 value surviving INSERT + SELECT.
func TestIntegration_EventCountIsBigint(t *testing.T) {
dsn := os.Getenv("PHOEBE_TEST_DATABASE_URL")
if dsn == "" {
t.Skip("PHOEBE_TEST_DATABASE_URL not set; skipping live-Postgres conformance")
}
db, err := sql.Open("pgx", dsn)
if err != nil {
t.Fatalf("open: %v", err)
}
defer db.Close()
db.SetMaxOpenConns(1) // SET/search_path must stick

const sch = "phoebe_rating_bigint_it"
exec(t, db, "DROP SCHEMA IF EXISTS "+sch+" CASCADE")
exec(t, db, "CREATE SCHEMA "+sch)
exec(t, db, "SET search_path TO "+sch)
defer func() { exec(t, db, "DROP SCHEMA IF EXISTS "+sch+" CASCADE") }()
exec(t, db, schemaDDL)

// The declared column type must be bigint, not integer.
var dataType string
if err := db.QueryRow(
`SELECT data_type FROM information_schema.columns
WHERE table_schema = $1 AND table_name = 'rated_usage' AND column_name = 'event_count'`,
sch).Scan(&dataType); err != nil {
t.Fatalf("read event_count column type: %v", err)
}
if dataType != "bigint" {
t.Fatalf("rated_usage.event_count data_type = %q, want \"bigint\"", dataType)
}

// A value past INT32_MAX must INSERT and round-trip unharmed.
const bigCount int64 = 5_000_000_000 // > 2^31-1 (2_147_483_647)
exec(t, db, fmt.Sprintf(`INSERT INTO rated_usage
(id, auth_id, model_id, window_start, window_end,
prompt_tokens, cached_tokens, completion_tokens, billable_prompt_tokens, cost, event_count)
VALUES ('bc1','a','b','2026-06-08T10:00:00Z','2026-06-08T11:00:00Z',
0,0,0,0,0, %d)`, bigCount))

var got int64
if err := db.QueryRow(
`SELECT event_count FROM rated_usage WHERE id = 'bc1'`).Scan(&got); err != nil {
t.Fatalf("read back event_count: %v", err)
}
if got != bigCount {
t.Fatalf("event_count round-trip = %d, want %d", got, bigCount)
}
}

func exec(t *testing.T, db *sql.DB, q string) {
t.Helper()
if _, err := db.Exec(q); err != nil {
Expand Down
4 changes: 3 additions & 1 deletion migrations/0002_rating.sql
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ CREATE TABLE rated_usage (

-- The money, as exact NUMERIC. Computed and summed in SQL (never in Go).
cost NUMERIC(20, 9) NOT NULL,
event_count INTEGER NOT NULL,
-- BIGINT, not INTEGER: a wide rollup window can count more than 2^31 events,
-- which the COUNT(*)::bigint cast in the rater already produces.
event_count BIGINT NOT NULL,

rated_at TIMESTAMPTZ NOT NULL DEFAULT now(),

Expand Down
5 changes: 4 additions & 1 deletion migrations/atlas/c2f1a3b4d5e6_add_rating.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,10 @@ def upgrade():
sa.Column("billable_prompt_tokens", sa.BigInteger(), nullable=False),
# The money, exact NUMERIC, computed and summed in SQL.
sa.Column("cost", MONEY, nullable=False),
sa.Column("event_count", sa.Integer(), nullable=False),
# BigInteger, not Integer: a wide rollup window can count more than
# 2^31 events (matches the COUNT(*)::bigint cast in the rater). No data
# migration is needed — rated_usage is empty pre-prod.
sa.Column("event_count", sa.BigInteger(), nullable=False),
sa.Column(
"rated_at",
sa.DateTime(timezone=True),
Expand Down
Loading