diff --git a/cmd/content-cache/main.go b/cmd/content-cache/main.go index c934dcc..d755c3a 100644 --- a/cmd/content-cache/main.go +++ b/cmd/content-cache/main.go @@ -37,6 +37,7 @@ type ServeCmd struct { Storage string `kong:"name='storage',default='./cache',env='CACHE_STORAGE',help='Storage directory path',group='Server'"` TLSCertFile string `kong:"name='tls-cert',env='TLS_CERT_FILE',type='existingfile',help='Path to TLS certificate file (enables HTTPS)',group='Server'"` TLSKeyFile string `kong:"name='tls-key',env='TLS_KEY_FILE',type='existingfile',help='Path to TLS private key file (enables HTTPS)',group='Server'"` + PublicBaseURL string `kong:"name='public-base-url',env='PUBLIC_BASE_URL',help='External base URL clients use to reach this cache (e.g. https://cache.example.com). Used to build served download links.',group='Server'"` AuthToken string `kong:"name='auth-token',env='AUTH_TOKEN',help='Bearer token for inbound authentication (mutually exclusive with --oidc-policies)',group='Auth'"` AuthTokenFile string `kong:"name='auth-token-file',env='AUTH_TOKEN_FILE',type='existingfile',help='Path to file containing auth token (for k8s secret mounts)',group='Auth'"` @@ -263,6 +264,7 @@ func (cmd *ServeCmd) Run() error { StoragePath: cmd.Storage, TLSCertFile: cmd.TLSCertFile, TLSKeyFile: cmd.TLSKeyFile, + PublicBaseURL: cmd.PublicBaseURL, AuthToken: authToken, OIDCValidator: oidcValidator, Credentials: creds, diff --git a/download/downloader.go b/download/downloader.go index 1f8eefe..752d5cf 100644 --- a/download/downloader.go +++ b/download/downloader.go @@ -5,11 +5,22 @@ package download import ( "context" + "errors" + "sync/atomic" + "time" contentcache "github.com/buildkite/content-cache" + "github.com/buildkite/content-cache/telemetry" "golang.org/x/sync/singleflight" ) +const ( + spoolRoleOrigin = "origin" + spoolRoleCoalesced = "coalesced" +) + +type spoolRecorder func(ctx context.Context, role, outcome string, duration time.Duration, bytesSaved int64) + // Result holds the outcome of a download operation. type Result struct { Hash contentcache.Hash @@ -25,12 +36,13 @@ type DownloadFunc func(ctx context.Context) (*Result, error) // using singleflight. It uses DoChan so each caller can respect its own // context deadline without cancelling the in-flight download for others. type Downloader struct { - group singleflight.Group + group singleflight.Group + recordSpool spoolRecorder } // New creates a new Downloader. func New() *Downloader { - return &Downloader{} + return &Downloader{recordSpool: telemetry.RecordSpoolRequest} } // Do deduplicates concurrent downloads for the same key. @@ -40,7 +52,10 @@ func New() *Downloader { // If the caller's context expires before the download completes, Do returns // the context error but the in-flight download continues for other waiters. func (d *Downloader) Do(ctx context.Context, key string, fn DownloadFunc) (*Result, bool, error) { + started := time.Now() + var executed atomic.Bool ch := d.group.DoChan(key, func() (any, error) { + executed.Store(true) // Use a detached context so that no single caller's cancellation // stops the download for everyone else. return fn(context.WithoutCancel(ctx)) @@ -48,15 +63,49 @@ func (d *Downloader) Do(ctx context.Context, key string, fn DownloadFunc) (*Resu select { case res := <-ch: + role := spoolRoleCoalesced + if executed.Load() { + role = spoolRoleOrigin + } + outcome := spoolOutcome(res.Err) + var bytesSaved int64 + if role == spoolRoleCoalesced && res.Err == nil { + bytesSaved = res.Val.(*Result).Size + } + d.record(ctx, role, outcome, time.Since(started), bytesSaved) if res.Err != nil { return nil, res.Shared, res.Err } return res.Val.(*Result), res.Shared, nil case <-ctx.Done(): + role := spoolRoleCoalesced + if executed.Load() { + role = spoolRoleOrigin + } + d.record(ctx, role, spoolOutcome(ctx.Err()), time.Since(started), 0) return nil, false, ctx.Err() } } +func (d *Downloader) record(ctx context.Context, role, outcome string, duration time.Duration, bytesSaved int64) { + if d.recordSpool != nil { + d.recordSpool(ctx, role, outcome, duration, bytesSaved) + } +} + +func spoolOutcome(err error) string { + switch { + case err == nil: + return "success" + case errors.Is(err, context.DeadlineExceeded): + return "timeout" + case errors.Is(err, context.Canceled): + return "canceled" + default: + return "error" + } +} + // Forget removes the key from the singleflight group, allowing a subsequent // call to retry. Typically called after a download error. func (d *Downloader) Forget(key string) { diff --git a/download/downloader_test.go b/download/downloader_test.go index 1d6c5dd..e520efc 100644 --- a/download/downloader_test.go +++ b/download/downloader_test.go @@ -32,6 +32,18 @@ func TestDo_SingleCall(t *testing.T) { func TestDo_ConcurrentDeduplication(t *testing.T) { d := New() + type spoolEvent struct { + role string + outcome string + bytesSaved int64 + } + var eventsMu sync.Mutex + var events []spoolEvent + d.recordSpool = func(_ context.Context, role, outcome string, _ time.Duration, bytesSaved int64) { + eventsMu.Lock() + defer eventsMu.Unlock() + events = append(events, spoolEvent{role: role, outcome: outcome, bytesSaved: bytesSaved}) + } var callCount atomic.Int32 expected := &Result{ @@ -63,6 +75,24 @@ func TestDo_ConcurrentDeduplication(t *testing.T) { require.NoError(t, errs[i]) require.Equal(t, expected.Hash, results[i].Hash) } + + eventsMu.Lock() + defer eventsMu.Unlock() + var origins, coalesced int + var bytesSaved int64 + for _, event := range events { + require.Equal(t, "success", event.outcome) + switch event.role { + case spoolRoleOrigin: + origins++ + case spoolRoleCoalesced: + coalesced++ + } + bytesSaved += event.bytesSaved + } + require.Equal(t, 1, origins) + require.Equal(t, 9, coalesced) + require.Equal(t, int64(36), bytesSaved) } func TestDo_CallerTimeout(t *testing.T) { diff --git a/protocol/git/handler.go b/protocol/git/handler.go index cbe6911..a506bd5 100644 --- a/protocol/git/handler.go +++ b/protocol/git/handler.go @@ -376,6 +376,35 @@ func (h *Handler) handleUploadPack(w http.ResponseWriter, r *http.Request, repo } } + // In the git v2 wire protocol `ls-refs` is passed to the git-upload-pack endpoint. + // This type of request cannot be cached, as it is backed by mutable refs upstream. + isLsRefs, lsRefsErr := isLsRefs(activeBodyFile) + if lsRefsErr != nil { + logger.Error("failed to check if git command includes ls-refs", "error", lsRefsErr) + } else if isLsRefs { + telemetry.SetCacheResult(r, telemetry.CacheBypass) + + upstream := h.router.Match(repo) + rc, err := upstream.FetchUploadPack(ctx, repo, gitProtocol, activeBodyFile) + _ = activeBodyFile.Close() + if errors.Is(err, ErrNotFound) { + http.Error(w, "repository not found", http.StatusNotFound) + return + } + if err != nil { + logger.Error("upstream git-upload-pack ls-refs failed", "error", err) + http.Error(w, "upstream error", http.StatusBadGateway) + return + } + defer func() { _ = rc.Close() }() + + w.Header().Set("Content-Type", ContentTypeUploadPackResult) + if _, err := io.Copy(w, rc); err != nil { + logger.Error("failed to stream git-upload-pack ls-refs response", "error", err) + } + return + } + // Check cache cached, err := h.index.GetCachedPack(ctx, cacheKey) if err == nil { @@ -616,3 +645,49 @@ func parseHexLen(s string) (int, error) { } return int(n), nil } + +// isLsRefs determines if a particular request is a ls-refs command. +// +// The output of ls-refs commands change as the upstream updates, +// and so they cannot be cached in the same way a normal git-upload-pack request can be. +func isLsRefs(activeBody io.ReadSeeker) (bool, error) { + if _, err := activeBody.Seek(0, io.SeekStart); err != nil { + return false, err + } + defer func() { _, _ = activeBody.Seek(0, io.SeekStart) }() + + body, err := io.ReadAll(io.LimitReader(activeBody, maxLsRefsScanSize)) + if err != nil { + return false, err + } + + pos := 0 + for pos+4 <= len(body) { + pktLen, err := parseHexLen(string(body[pos : pos+4])) + if err != nil { + return false, err + } + if pktLen == 0 || pktLen == 1 { + // pktLen == 0 -> flush packet + // pktLen == 1 -> delimiter packet + // in both cases, they are just special control packets + pos += 4 + continue + } + if pktLen < 4 { + return false, fmt.Errorf("git protocol malformed packet size") + } + if pos+pktLen > len(body) { + break + } + content := strings.TrimRight(string(body[pos+4:pos+pktLen]), "\n") + cmd, ok := strings.CutPrefix(content, "command=") + if !ok { + pos += pktLen + continue + } + return cmd == "ls-refs", nil + } + + return false, nil +} diff --git a/protocol/git/handler_test.go b/protocol/git/handler_test.go index 422e352..f90afae 100644 --- a/protocol/git/handler_test.go +++ b/protocol/git/handler_test.go @@ -5,6 +5,7 @@ import ( "compress/gzip" "context" "fmt" + "io" "net/http" "net/http/httptest" "os" @@ -646,6 +647,206 @@ func TestHandlerStaleCacheEviction(t *testing.T) { require.Equal(t, firstBody, w2.Body.String()) } +// pktLine encodes s as a git pkt-line: a 4-hex-digit length prefix covering the +// prefix itself plus the payload, followed by the payload. +func pktLine(s string) string { + return fmt.Sprintf("%04x%s", len(s)+4, s) +} + +const ( + flushPkt = "0000" + delimPkt = "0001" +) + +func TestIsLsRefs(t *testing.T) { + tests := []struct { + name string + body string + want bool + wantErr bool + }{ + { + name: "ls-refs command only", + body: pktLine("command=ls-refs\n"), + want: true, + }, + { + name: "ls-refs with capabilities, delimiter and args", + body: pktLine("command=ls-refs\n") + + pktLine("object-format=sha1\n") + + delimPkt + + pktLine("peel\n") + + pktLine("ref-prefix HEAD\n") + + flushPkt, + want: true, + }, + { + name: "ls-refs preceded by control packets", + body: flushPkt + delimPkt + pktLine("command=ls-refs\n"), + want: true, + }, + { + name: "fetch command is not ls-refs", + body: pktLine("command=fetch\n") + + pktLine("object-format=sha1\n") + + delimPkt + + pktLine("want 0123456789012345678901234567890123456789\n") + + flushPkt, + want: false, + }, + { + name: "v1-style want/done body has no command", + body: pktLine("want 0123456789012345678901234567890123456789\n") + + flushPkt + + pktLine("done\n"), + want: false, + }, + { + name: "empty body", + body: "", + want: false, + }, + { + name: "only a flush packet", + body: flushPkt, + want: false, + }, + { + name: "malformed packet length below minimum", + body: "0003", + wantErr: true, + }, + { + name: "invalid hex length", + body: "zzzzcommand=ls-refs\n", + wantErr: true, + }, + { + name: "truncated packet is treated as not ls-refs", + body: "0050command=ls-refs\n", // declares 0x50 bytes but the body is far shorter + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := isLsRefs(strings.NewReader(tt.body)) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tt.want, got) + }) + } +} + +func TestIsLsRefsResetsPosition(t *testing.T) { + body := pktLine("command=ls-refs\n") + flushPkt + r := strings.NewReader(body) + + got, err := isLsRefs(r) + require.NoError(t, err) + require.True(t, got) + + rest, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, body, string(rest), "body must be readable from the start after isLsRefs") +} + +func TestIsLsRefsSeeksToStartBeforeScanning(t *testing.T) { + body := pktLine("command=ls-refs\n") + r := strings.NewReader(body) + + _, err := r.Seek(4, io.SeekStart) + require.NoError(t, err) + + got, err := isLsRefs(r) + require.NoError(t, err) + require.True(t, got) +} + +// countingUploadPackUpstream returns a test server that records every +// git-upload-pack request body and responds with a fake pack. +func countingUploadPackUpstream(fetchCount *atomic.Int32, lastBody *[]byte) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/git-upload-pack") { + fetchCount.Add(1) + if b, err := io.ReadAll(r.Body); err == nil { + *lastBody = b + } + w.Header().Set("Content-Type", ContentTypeUploadPackResult) + _, _ = w.Write([]byte(fakeUploadPackBody)) + return + } + w.WriteHeader(http.StatusNotFound) + })) +} + +func TestHandlerUploadPackLsRefsBypassesCache(t *testing.T) { + var fetchCount atomic.Int32 + var lastBody []byte + upstreamSrv := countingUploadPackUpstream(&fetchCount, &lastBody) + defer upstreamSrv.Close() + + h, cleanup := newTestHandlerWithTransport(t, upstreamSrv) + defer cleanup() + + body := []byte(pktLine("command=ls-refs\n") + + pktLine("object-format=sha1\n") + + delimPkt + + pktLine("peel\n") + + pktLine("ref-prefix HEAD\n") + + flushPkt) + + for i := range 2 { + req := httptest.NewRequest(http.MethodPost, "/github.com/user/repo.git/git-upload-pack", bytes.NewReader(body)) + req.Header.Set("Content-Type", ContentTypeUploadPackRequest) + req.Header.Set("Git-Protocol", "version=2") + w := httptest.NewRecorder() + + h.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, "request %d", i) + require.Equal(t, ContentTypeUploadPackResult, w.Header().Get("Content-Type")) + require.Equal(t, fakeUploadPackBody, w.Body.String()) + } + + require.Equal(t, int32(2), fetchCount.Load(), "ls-refs must reach upstream on every request and never be cached") + require.Equal(t, body, lastBody, "the full ls-refs body must be forwarded to upstream") +} + +func TestHandlerUploadPackLsRefsGzip(t *testing.T) { + var fetchCount atomic.Int32 + var lastBody []byte + upstreamSrv := countingUploadPackUpstream(&fetchCount, &lastBody) + defer upstreamSrv.Close() + + h, cleanup := newTestHandlerWithTransport(t, upstreamSrv) + defer cleanup() + + plain := []byte(pktLine("command=ls-refs\n") + flushPkt) + + var gzBuf bytes.Buffer + gz := gzip.NewWriter(&gzBuf) + _, err := gz.Write(plain) + require.NoError(t, err) + require.NoError(t, gz.Close()) + + req := httptest.NewRequest(http.MethodPost, "/github.com/user/repo.git/git-upload-pack", bytes.NewReader(gzBuf.Bytes())) + req.Header.Set("Content-Type", ContentTypeUploadPackRequest) + req.Header.Set("Content-Encoding", "gzip") + req.Header.Set("Git-Protocol", "version=2") + w := httptest.NewRecorder() + + h.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, fakeUploadPackBody, w.Body.String()) + require.Equal(t, int32(1), fetchCount.Load()) + require.Equal(t, plain, lastBody, "the decompressed ls-refs body must be forwarded to upstream") +} + // newTestHandlerWithTransport creates a test handler that redirects all HTTPS // requests to the given test server via a custom transport. func newTestHandlerWithTransport(t *testing.T, targetServer *httptest.Server) (*Handler, func()) { diff --git a/protocol/git/types.go b/protocol/git/types.go index be8b424..a1c9251 100644 --- a/protocol/git/types.go +++ b/protocol/git/types.go @@ -23,6 +23,11 @@ const ( // DefaultMaxDecompressedBodySize is the maximum size of a decompressed git-upload-pack body (500MB). DefaultMaxDecompressedBodySize int64 = 500 * 1024 * 1024 + + // maxLsRefsScanSize bounds how much of an upload-pack request body is scanned + // to detect a protocol v2 "command=" pkt-line. The command appears at the very + // start of the body, so a small cap is sufficient. + maxLsRefsScanSize int64 = 64 * 1024 ) // ErrNotFound is returned when a cached pack is not found. diff --git a/protocol/npm/handler.go b/protocol/npm/handler.go index a070583..caa1dcc 100644 --- a/protocol/npm/handler.go +++ b/protocol/npm/handler.go @@ -40,6 +40,10 @@ type Handler struct { logger *slog.Logger downloader *download.Downloader + // publicBaseURL, when set, is the external base URL used to build served + // tarball links instead of the request scheme and Host header. + publicBaseURL string + // Lifecycle management for background goroutines wg sync.WaitGroup shutdownCtx context.Context @@ -79,6 +83,16 @@ func WithDownloader(dl *download.Downloader) HandlerOption { } } +// WithPublicBaseURL sets the external base URL (e.g. https://cache.example.com) +// used when rewriting tarball download links. When empty, the request scheme and +// Host header are used. Set this when TLS is terminated by an upstream load +// balancer, so the cache cannot infer https from the inbound request. +func WithPublicBaseURL(baseURL string) HandlerOption { + return func(h *Handler) { + h.publicBaseURL = strings.TrimSuffix(baseURL, "/") + } +} + // NewHandler creates a new NPM registry handler. func NewHandler(index *Index, store store.Store, opts ...HandlerOption) *Handler { shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) @@ -249,11 +263,14 @@ func (h *Handler) rewriteTarballURLs(r *http.Request, meta map[string]any) { } // Determine our base URL - scheme := "http" - if r.TLS != nil { - scheme = "https" + baseURL := h.publicBaseURL + if baseURL == "" { + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + baseURL = fmt.Sprintf("%s://%s", scheme, r.Host) } - baseURL := fmt.Sprintf("%s://%s", scheme, r.Host) for _, v := range versions { version, ok := v.(map[string]any) diff --git a/protocol/pypi/handler.go b/protocol/pypi/handler.go index 6623c2a..f539b69 100644 --- a/protocol/pypi/handler.go +++ b/protocol/pypi/handler.go @@ -39,6 +39,10 @@ type Handler struct { downloader *download.Downloader metadataTTL time.Duration + // publicBaseURL, when set, is the external base URL used to build served + // download links instead of the request scheme and Host header. + publicBaseURL string + // Lifecycle management for background goroutines wg sync.WaitGroup ctx context.Context @@ -76,6 +80,16 @@ func WithMetadataTTL(ttl time.Duration) HandlerOption { } } +// WithPublicBaseURL sets the external base URL (e.g. https://cache.example.com) +// used when rewriting file download links. When empty, the request scheme and +// Host header are used. Set this when TLS is terminated by an upstream load +// balancer, so the cache cannot infer https from the inbound request. +func WithPublicBaseURL(baseURL string) HandlerOption { + return func(h *Handler) { + h.publicBaseURL = strings.TrimSuffix(baseURL, "/") + } +} + // NewHandler creates a new PyPI Simple API handler. func NewHandler(index *Index, store store.Store, opts ...HandlerOption) *Handler { ctx, cancel := context.WithCancel(context.Background()) @@ -562,18 +576,28 @@ func (h *Handler) cacheFile(ctx context.Context, project, filename string, hash logger.Info("cached file", "filename", filename, "hash", hash.ShortString(), "size", size) } +// baseURL returns the external base URL used to build served download links. +// It prefers the configured public base URL and otherwise derives it from the +// request scheme and Host header. +func (h *Handler) baseURL(r *http.Request) string { + if h.publicBaseURL != "" { + return h.publicBaseURL + } + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + return fmt.Sprintf("%s://%s", scheme, r.Host) +} + // writeProjectResponse writes the project page response in HTML or JSON format. func (h *Handler) writeProjectResponse(w http.ResponseWriter, r *http.Request, cached *CachedProject, normalized string) { - // Build file list for response + // Rewrite file URLs to point to our proxy. We include the /pypi prefix + // since the server strips it before passing to this handler. + base := h.baseURL(r) var files []ProjectFile for _, f := range cached.Files { - // Rewrite URL to point to our proxy - // Note: We include /pypi prefix since the server strips it before passing to this handler - scheme := "http" - if r.TLS != nil { - scheme = "https" - } - proxyURL := fmt.Sprintf("%s://%s/pypi/packages/%s/%s", scheme, r.Host, normalized, f.Filename) + proxyURL := fmt.Sprintf("%s/pypi/packages/%s/%s", base, normalized, f.Filename) pf := ProjectFile{ Filename: f.Filename, diff --git a/protocol/pypi/handler_test.go b/protocol/pypi/handler_test.go index cc9bec1..c317dfe 100644 --- a/protocol/pypi/handler_test.go +++ b/protocol/pypi/handler_test.go @@ -27,7 +27,7 @@ func setupMetaDB(t *testing.T, tmpDir string) (*metadb.BoltDB, *Index, func()) { return db, idx, func() { _ = db.Close() } } -func newTestHandler(t *testing.T, upstreamServer *httptest.Server) (*Handler, func()) { +func newTestHandler(t *testing.T, upstreamServer *httptest.Server, extraOpts ...HandlerOption) (*Handler, func()) { t.Helper() // Use a manual temp dir instead of t.TempDir() to avoid race with async goroutines tmpDir, err := os.MkdirTemp("", "pypi-test-*") @@ -44,7 +44,7 @@ func newTestHandler(t *testing.T, upstreamServer *httptest.Server) (*Handler, fu } upstream := NewUpstream(opts...) - h := NewHandler(idx, cafs, WithUpstream(upstream)) + h := NewHandler(idx, cafs, append([]HandlerOption{WithUpstream(upstream)}, extraOpts...)...) return h, func() { h.Close() closeDB() @@ -121,6 +121,57 @@ func TestHandlerProject(t *testing.T) { }) } +func TestHandlerProjectPublicBaseURL(t *testing.T) { + upstreamHandler := func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/simple/requests/" { + w.Header().Set("Content-Type", "text/html") + _, _ = w.Write([]byte(` + +requests-2.31.0-py3-none-any.whl
+`)) + return + } + w.WriteHeader(http.StatusNotFound) + } + + t.Run("configured base URL overrides request scheme and host", func(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(upstreamHandler)) + defer upstream.Close() + + h, cleanup := newTestHandler(t, upstream, WithPublicBaseURL("https://cache.example.com/")) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/simple/requests/", nil) + req.Host = "localhost:8080" + w := httptest.NewRecorder() + + h.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + body := w.Body.String() + // Trailing slash trimmed, configured host used, request Host ignored. + require.Contains(t, body, "https://cache.example.com/pypi/packages/requests/") + require.NotContains(t, body, "localhost:8080") + }) + + t.Run("falls back to request scheme and host when unset", func(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(upstreamHandler)) + defer upstream.Close() + + h, cleanup := newTestHandler(t, upstream) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/simple/requests/", nil) + req.Host = "localhost:8080" + w := httptest.NewRecorder() + + h.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + require.Contains(t, w.Body.String(), "http://localhost:8080/pypi/packages/requests/") + }) +} + func TestHandlerRoot(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) diff --git a/protocol/pypi/upstream.go b/protocol/pypi/upstream.go index 45d5131..07bddf8 100644 --- a/protocol/pypi/upstream.go +++ b/protocol/pypi/upstream.go @@ -27,17 +27,36 @@ var ErrNotFound = errors.New("not found") // Upstream fetches packages from an upstream PyPI Simple API. type Upstream struct { - baseURL string - client *http.Client + baseURL string + username string + password string + client *http.Client } // UpstreamOption configures an Upstream. type UpstreamOption func(*Upstream) -// WithSimpleURL sets the upstream Simple API URL. -func WithSimpleURL(url string) UpstreamOption { +// WithSimpleURL sets the upstream Simple API URL. Any userinfo (user:pass@) is +// stripped from the URL and applied as Basic auth on every upstream request — +// including file downloads, which the index commonly serves from a different +// host (e.g. a CDN) that requires the same credentials. +func WithSimpleURL(rawURL string) UpstreamOption { return func(u *Upstream) { - u.baseURL = strings.TrimSuffix(url, "/") + "/" + trimmed := strings.TrimSuffix(rawURL, "/") + "/" + if parsed, err := url.Parse(trimmed); err == nil && parsed.User != nil { + u.username = parsed.User.Username() + u.password, _ = parsed.User.Password() + parsed.User = nil + trimmed = parsed.String() + } + u.baseURL = trimmed + } +} + +// setAuth applies the upstream Basic credentials to a request, if configured. +func (u *Upstream) setAuth(req *http.Request) { + if u.username != "" || u.password != "" { + req.SetBasicAuth(u.username, u.password) } } @@ -75,6 +94,7 @@ func (u *Upstream) FetchProjectPage(ctx context.Context, project string) ([]byte // Request JSON preferred, fallback to HTML req.Header.Set("Accept", ContentTypeJSON+", "+ContentTypeHTML+";q=0.9") + u.setAuth(req) resp, err := u.client.Do(req) //nolint:gosec // request targets operator-configured upstream, not user-controlled if err != nil { @@ -107,6 +127,7 @@ func (u *Upstream) FetchFile(ctx context.Context, fileURL string) (io.ReadCloser if err != nil { return nil, fmt.Errorf("creating request: %w", err) } + u.setAuth(req) resp, err := u.client.Do(req) if err != nil { diff --git a/protocol/pypi/upstream_test.go b/protocol/pypi/upstream_test.go new file mode 100644 index 0000000..6d080f5 --- /dev/null +++ b/protocol/pypi/upstream_test.go @@ -0,0 +1,66 @@ +package pypi + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUpstreamAppliesBasicAuth(t *testing.T) { + const ( + user = "pyx-user" + pass = "pyx-secret" + ) + + var gotIndexAuth, gotFileAuth bool + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + u, p, ok := r.BasicAuth() + if !ok || u != user || p != pass { + w.WriteHeader(http.StatusUnauthorized) + return + } + switch r.URL.Path { + case "/simple/requests/": + gotIndexAuth = true + w.Header().Set("Content-Type", "text/html") + _, _ = w.Write([]byte(``)) + case "/files/requests-2.31.0.whl": + gotFileAuth = true + _, _ = w.Write([]byte("wheel-bytes")) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + // Credentials are supplied via userinfo, as the deployment does. + up := NewUpstream(WithSimpleURL(srv.URL + "/simple/")) + up.username = user + up.password = pass + + ctx := context.Background() + + _, _, err := up.FetchProjectPage(ctx, "requests") + require.NoError(t, err) + require.True(t, gotIndexAuth, "index fetch should send Basic auth") + + rc, err := up.FetchFile(ctx, srv.URL+"/files/requests-2.31.0.whl") + require.NoError(t, err) + body, err := io.ReadAll(rc) + require.NoError(t, err) + require.NoError(t, rc.Close()) + require.Equal(t, "wheel-bytes", string(body)) + require.True(t, gotFileAuth, "file fetch should send Basic auth") +} + +func TestWithSimpleURLStripsUserinfo(t *testing.T) { + up := NewUpstream(WithSimpleURL("https://user:pw@api.example.com/simple/ramp/pypi/")) + require.Equal(t, "user", up.username) + require.Equal(t, "pw", up.password) + // Userinfo must not remain in the base URL (it would leak in logs/metrics). + require.Equal(t, "https://api.example.com/simple/ramp/pypi/", up.baseURL) +} diff --git a/server/http.go b/server/http.go index 2e7ed4a..b3dce81 100644 --- a/server/http.go +++ b/server/http.go @@ -171,6 +171,12 @@ type Config struct { // When both TLSCertFile and TLSKeyFile are set, the server starts with TLS. TLSCertFile string + // PublicBaseURL is the external base URL (scheme://host[:port]) + // clients use to reach this cache. + // When set, served download links (PyPI files, NPM tarballs) are built from it + // instead of the request scheme and Host header. + PublicBaseURL string + // TLSKeyFile is the path to the TLS private key file. TLSKeyFile string @@ -214,6 +220,7 @@ type Server struct { httpcacheIndex *httpcache.Index httpcache *httpcache.Handler metaDB metadb.MetaDB + metadataReapers *metadataReapers gcManager *gc.Manager s3fifoManager *s3fifo.Manager } @@ -298,6 +305,13 @@ func New(cfg Config) (*Server, error) { return nil, err } + expiryCheckInterval := cfg.ExpiryCheckInterval + if expiryCheckInterval <= 0 { + expiryCheckInterval = time.Hour + } + cfg.ExpiryCheckInterval = expiryCheckInterval + metadataReapers := newMetadataReapers(metaDB, expiryCheckInterval, cfg.Logger) + // Create a single shared EnvelopeCodec (zstd encoder/decoder) for all EnvelopeIndex instances. sharedCodec, err := metadb.NewEnvelopeCodec() if err != nil { @@ -467,6 +481,9 @@ func New(cfg Config) (*Server, error) { } npmHandlerOpts = append(npmHandlerOpts, npm.WithUpstream(npm.NewUpstream(npmUpstreamOpts...))) } + if cfg.PublicBaseURL != "" { + npmHandlerOpts = append(npmHandlerOpts, npm.WithPublicBaseURL(cfg.PublicBaseURL)) + } npmHandler := npm.NewHandler(npmIndex, cafsStore, npmHandlerOpts...) // Initialize OCI components using metadb EnvelopeIndex @@ -561,6 +578,9 @@ func New(cfg Config) (*Server, error) { if cfg.PyPIMetadataTTL > 0 { pypiHandlerOpts = append(pypiHandlerOpts, pypi.WithMetadataTTL(cfg.PyPIMetadataTTL)) } + if cfg.PublicBaseURL != "" { + pypiHandlerOpts = append(pypiHandlerOpts, pypi.WithPublicBaseURL(cfg.PublicBaseURL)) + } pypiHandler := pypi.NewHandler(pypiIndex, cafsStore, pypiHandlerOpts...) // Initialize Maven components using metadb EnvelopeIndex @@ -817,6 +837,7 @@ func New(cfg Config) (*Server, error) { httpcacheIndex: httpcacheIdx, httpcache: httpcacheHndlr, metaDB: metaDB, + metadataReapers: metadataReapers, gcManager: gcManager, s3fifoManager: s3fifoMgr, } @@ -1063,6 +1084,13 @@ func (s *Server) loggingMiddleware(next http.Handler) http.Handler { // Start starts the server. func (s *Server) Start() error { + if s.metadataReapers != nil { + s.logger.Info("starting metadata expiry reapers", + "interval", s.config.ExpiryCheckInterval, + ) + s.metadataReapers.Start(context.Background()) + } + if s.s3fifoManager != nil { s.logger.Info("starting S3-FIFO eviction manager") s.s3fifoManager.Start(context.Background()) @@ -1089,6 +1117,11 @@ func (s *Server) Start() error { func (s *Server) Shutdown(ctx context.Context) error { s.logger.Info("shutting down server") + // Stop metadata expiry reapers before closing the HTTP server. + if s.metadataReapers != nil { + s.metadataReapers.Stop() + } + // Stop S3-FIFO eviction manager if s.s3fifoManager != nil { s.s3fifoManager.Stop() diff --git a/server/reapers.go b/server/reapers.go new file mode 100644 index 0000000..2784179 --- /dev/null +++ b/server/reapers.go @@ -0,0 +1,57 @@ +package server + +import ( + "context" + "log/slog" + "sync" + "time" + + "github.com/buildkite/content-cache/store/metadb" +) + +// metadataReapers owns the background cleanup loops for both metadata formats. +// Reaping expired entries also decrements their blob reference counts, making +// those blobs eligible for S3-FIFO eviction and garbage collection. +type metadataReapers struct { + expiry *metadb.ExpiryReaper + envelope *metadb.EnvelopeReaper + + cancel context.CancelFunc + done sync.WaitGroup +} + +func newMetadataReapers(db metadb.MetaDB, interval time.Duration, logger *slog.Logger) *metadataReapers { + return &metadataReapers{ + expiry: metadb.NewExpiryReaper(db, + metadb.WithReaperInterval(interval), + metadb.WithReaperLogger(logger.With("component", "expiry-reaper")), + ), + envelope: metadb.NewEnvelopeReaper(db, + metadb.WithEnvelopeReaperInterval(interval), + metadb.WithEnvelopeReaperLogger(logger.With("component", "envelope-reaper")), + ), + } +} + +func (r *metadataReapers) Start(parent context.Context) { + ctx, cancel := context.WithCancel(parent) + r.cancel = cancel + + r.done.Add(2) + go func() { + defer r.done.Done() + r.expiry.Run(ctx) + }() + go func() { + defer r.done.Done() + r.envelope.Run(ctx) + }() +} + +func (r *metadataReapers) Stop() { + if r.cancel == nil { + return + } + r.cancel() + r.done.Wait() +} diff --git a/server/reapers_test.go b/server/reapers_test.go new file mode 100644 index 0000000..faedaf8 --- /dev/null +++ b/server/reapers_test.go @@ -0,0 +1,52 @@ +package server + +import ( + "context" + "errors" + "io" + "log/slog" + "testing" + "time" + + "github.com/buildkite/content-cache/store/metadb" + "github.com/stretchr/testify/require" +) + +func TestMetadataReapersReleaseExpiredEnvelopeBlobRefs(t *testing.T) { + ctx := context.Background() + db := metadb.NewBoltDB() + require.NoError(t, db.Open(t.TempDir()+"/metadata.db")) + t.Cleanup(func() { require.NoError(t, db.Close()) }) + + const hash = "sha256:1111111111111111111111111111111111111111111111111111111111111111" + require.NoError(t, db.PutBlob(ctx, &metadb.BlobEntry{ + Hash: hash, + Size: 100, + CachedAt: time.Now(), + LastAccess: time.Now(), + })) + require.NoError(t, db.PutEnvelope(ctx, "test", "artifact", "key", &metadb.MetadataEnvelope{ + EnvelopeVersion: metadb.CurrentEnvelopeVersion, + ExpiresAtUnixMs: time.Now().Add(-time.Minute).UnixMilli(), + BlobRefs: []string{hash}, + })) + + blob, err := db.GetBlob(ctx, hash) + require.NoError(t, err) + require.Equal(t, 1, blob.RefCount) + + reapers := newMetadataReapers( + db, + 10*time.Millisecond, + slog.New(slog.NewTextHandler(io.Discard, nil)), + ) + reapers.Start(ctx) + t.Cleanup(reapers.Stop) + + require.Eventually(t, func() bool { + _, envelopeErr := db.GetEnvelope(ctx, "test", "artifact", "key") + blob, blobErr := db.GetBlob(ctx, hash) + return errors.Is(envelopeErr, metadb.ErrNotFound) && + blobErr == nil && blob.RefCount == 0 + }, time.Second, 10*time.Millisecond) +} diff --git a/telemetry/metrics.go b/telemetry/metrics.go index 3f8b3e0..b3a332a 100644 --- a/telemetry/metrics.go +++ b/telemetry/metrics.go @@ -50,6 +50,9 @@ type Metrics struct { upstreamFetchBytesTotal metric.Int64Counter blobTouchesTotal metric.Int64Counter blobTouchMissesTotal metric.Int64Counter + spoolRequestsTotal metric.Int64Counter + spoolWaitDuration metric.Float64Histogram + spoolBytesSavedTotal metric.Int64Counter backendRequestDuration metric.Float64Histogram backendRequestsTotal metric.Int64Counter backendBytesTotal metric.Int64Counter @@ -233,6 +236,34 @@ func doInitMetrics(ctx context.Context, cfg MetricsConfig) error { return err } + spoolRequestsTotal, err := meter.Int64Counter( + "content_cache_spool_requests_total", + metric.WithDescription("Total requests entering download spooling, by caller role and outcome"), + metric.WithUnit("{request}"), + ) + if err != nil { + return err + } + + spoolWaitDuration, err := meter.Float64Histogram( + "content_cache_spool_wait_duration_seconds", + metric.WithDescription("Time coalesced requests waited for the origin request's download"), + metric.WithUnit("s"), + metric.WithExplicitBucketBoundaries(0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10, 20, 40, 60, 120, 300), + ) + if err != nil { + return err + } + + spoolBytesSavedTotal, err := meter.Int64Counter( + "content_cache_spool_bytes_saved_total", + metric.WithDescription("Estimated upstream payload bytes avoided by successful coalesced requests"), + metric.WithUnit("By"), + ) + if err != nil { + return err + } + blobTouchesTotal, err := meter.Int64Counter( "content_cache_blob_touches_total", metric.WithDescription("Total blob access count increments"), @@ -475,6 +506,9 @@ func doInitMetrics(ctx context.Context, cfg MetricsConfig) error { upstreamFetchBytesTotal: upstreamFetchBytesTotal, blobTouchesTotal: blobTouchesTotal, blobTouchMissesTotal: blobTouchMissesTotal, + spoolRequestsTotal: spoolRequestsTotal, + spoolWaitDuration: spoolWaitDuration, + spoolBytesSavedTotal: spoolBytesSavedTotal, backendRequestDuration: backendRequestDuration, backendRequestsTotal: backendRequestsTotal, backendBytesTotal: backendBytesTotal, @@ -721,6 +755,35 @@ func RecordUpstreamFetch(ctx context.Context, protocol string, duration time.Dur } } +// RecordSpoolRequest records one caller entering the request spool. Role is +// "origin" when the caller executed the download and "coalesced" when it +// joined an existing download. bytesSaved is the downloaded object size for a +// successful coalesced caller and zero otherwise. +func RecordSpoolRequest(ctx context.Context, role, outcome string, duration time.Duration, bytesSaved int64) { + if globalMetrics == nil { + return + } + + protocol := ProtocolFromContext(ctx) + if protocol == "" { + protocol = "unknown" + } + + attrs := []attribute.KeyValue{ + attribute.String("protocol", protocol), + attribute.String("role", role), + attribute.String("outcome", outcome), + } + globalMetrics.spoolRequestsTotal.Add(ctx, 1, metric.WithAttributes(attrs...)) + + if role == "coalesced" { + globalMetrics.spoolWaitDuration.Record(ctx, duration.Seconds(), metric.WithAttributes(attrs...)) + if bytesSaved > 0 { + globalMetrics.spoolBytesSavedTotal.Add(ctx, bytesSaved, metric.WithAttributes(attribute.String("protocol", protocol))) + } + } +} + // RecordBlobTouch records a blob access count increment. func RecordBlobTouch(ctx context.Context, protocol string, newAccessCount int) { if globalMetrics == nil { diff --git a/telemetry/metrics_test.go b/telemetry/metrics_test.go index d4004da..ff061c2 100644 --- a/telemetry/metrics_test.go +++ b/telemetry/metrics_test.go @@ -39,12 +39,24 @@ func setupTestMetrics(t *testing.T) *sdkmetric.ManualReader { authRequestsTotal, err := meter.Int64Counter("content_cache_auth_requests_total") require.NoError(t, err) + spoolRequestsTotal, err := meter.Int64Counter("content_cache_spool_requests_total") + require.NoError(t, err) + + spoolWaitDuration, err := meter.Float64Histogram("content_cache_spool_wait_duration_seconds") + require.NoError(t, err) + + spoolBytesSavedTotal, err := meter.Int64Counter("content_cache_spool_bytes_saved_total") + require.NoError(t, err) + globalMetrics = &Metrics{ requestsTotal: requestsTotal, responseBytesTotal: responseBytesTotal, requestDuration: requestDuration, requestsByEndpointTotal: requestsByEndpointTotal, authRequestsTotal: authRequestsTotal, + spoolRequestsTotal: spoolRequestsTotal, + spoolWaitDuration: spoolWaitDuration, + spoolBytesSavedTotal: spoolBytesSavedTotal, meterProvider: mp, } @@ -56,6 +68,31 @@ func setupTestMetrics(t *testing.T) *sdkmetric.ManualReader { return reader } +func TestRecordSpoolRequest(t *testing.T) { + reader := setupTestMetrics(t) + + ctx := WithProtocolContext(context.Background(), "npm") + RecordSpoolRequest(ctx, "origin", "success", 100*time.Millisecond, 0) + RecordSpoolRequest(ctx, "coalesced", "success", 75*time.Millisecond, 4096) + + rm := collectMetrics(t, reader) + + requestDps := findCounter(rm, "content_cache_spool_requests_total") + require.Len(t, requestDps, 2) + require.True(t, hasAttr(requestDps[0].Attributes, "protocol", "npm")) + require.True(t, hasAttr(requestDps[0].Attributes, "outcome", "success")) + + waitDps := findHistogram(rm, "content_cache_spool_wait_duration_seconds") + require.Len(t, waitDps, 1) + require.Equal(t, uint64(1), waitDps[0].Count) + require.True(t, hasAttr(waitDps[0].Attributes, "role", "coalesced")) + + bytesDps := findCounter(rm, "content_cache_spool_bytes_saved_total") + require.Len(t, bytesDps, 1) + require.EqualValues(t, 4096, bytesDps[0].Value) + require.True(t, hasAttr(bytesDps[0].Attributes, "protocol", "npm")) +} + // collectMetrics reads all metrics from the ManualReader. func collectMetrics(t *testing.T, reader *sdkmetric.ManualReader) metricdata.ResourceMetrics { t.Helper()