diff --git a/cmd/content-cache/main.go b/cmd/content-cache/main.go index c934dcc..d755c3a 100644 --- a/cmd/content-cache/main.go +++ b/cmd/content-cache/main.go @@ -37,6 +37,7 @@ type ServeCmd struct { Storage string `kong:"name='storage',default='./cache',env='CACHE_STORAGE',help='Storage directory path',group='Server'"` TLSCertFile string `kong:"name='tls-cert',env='TLS_CERT_FILE',type='existingfile',help='Path to TLS certificate file (enables HTTPS)',group='Server'"` TLSKeyFile string `kong:"name='tls-key',env='TLS_KEY_FILE',type='existingfile',help='Path to TLS private key file (enables HTTPS)',group='Server'"` + PublicBaseURL string `kong:"name='public-base-url',env='PUBLIC_BASE_URL',help='External base URL clients use to reach this cache (e.g. https://cache.example.com). Used to build served download links.',group='Server'"` AuthToken string `kong:"name='auth-token',env='AUTH_TOKEN',help='Bearer token for inbound authentication (mutually exclusive with --oidc-policies)',group='Auth'"` AuthTokenFile string `kong:"name='auth-token-file',env='AUTH_TOKEN_FILE',type='existingfile',help='Path to file containing auth token (for k8s secret mounts)',group='Auth'"` @@ -263,6 +264,7 @@ func (cmd *ServeCmd) Run() error { StoragePath: cmd.Storage, TLSCertFile: cmd.TLSCertFile, TLSKeyFile: cmd.TLSKeyFile, + PublicBaseURL: cmd.PublicBaseURL, AuthToken: authToken, OIDCValidator: oidcValidator, Credentials: creds, diff --git a/download/downloader.go b/download/downloader.go index 1f8eefe..752d5cf 100644 --- a/download/downloader.go +++ b/download/downloader.go @@ -5,11 +5,22 @@ package download import ( "context" + "errors" + "sync/atomic" + "time" contentcache "github.com/buildkite/content-cache" + "github.com/buildkite/content-cache/telemetry" "golang.org/x/sync/singleflight" ) +const ( + spoolRoleOrigin = "origin" + spoolRoleCoalesced = "coalesced" +) + +type spoolRecorder func(ctx context.Context, role, outcome string, duration time.Duration, bytesSaved int64) + // Result holds the outcome of a download operation. type Result struct { Hash contentcache.Hash @@ -25,12 +36,13 @@ type DownloadFunc func(ctx context.Context) (*Result, error) // using singleflight. It uses DoChan so each caller can respect its own // context deadline without cancelling the in-flight download for others. type Downloader struct { - group singleflight.Group + group singleflight.Group + recordSpool spoolRecorder } // New creates a new Downloader. func New() *Downloader { - return &Downloader{} + return &Downloader{recordSpool: telemetry.RecordSpoolRequest} } // Do deduplicates concurrent downloads for the same key. @@ -40,7 +52,10 @@ func New() *Downloader { // If the caller's context expires before the download completes, Do returns // the context error but the in-flight download continues for other waiters. func (d *Downloader) Do(ctx context.Context, key string, fn DownloadFunc) (*Result, bool, error) { + started := time.Now() + var executed atomic.Bool ch := d.group.DoChan(key, func() (any, error) { + executed.Store(true) // Use a detached context so that no single caller's cancellation // stops the download for everyone else. return fn(context.WithoutCancel(ctx)) @@ -48,15 +63,49 @@ func (d *Downloader) Do(ctx context.Context, key string, fn DownloadFunc) (*Resu select { case res := <-ch: + role := spoolRoleCoalesced + if executed.Load() { + role = spoolRoleOrigin + } + outcome := spoolOutcome(res.Err) + var bytesSaved int64 + if role == spoolRoleCoalesced && res.Err == nil { + bytesSaved = res.Val.(*Result).Size + } + d.record(ctx, role, outcome, time.Since(started), bytesSaved) if res.Err != nil { return nil, res.Shared, res.Err } return res.Val.(*Result), res.Shared, nil case <-ctx.Done(): + role := spoolRoleCoalesced + if executed.Load() { + role = spoolRoleOrigin + } + d.record(ctx, role, spoolOutcome(ctx.Err()), time.Since(started), 0) return nil, false, ctx.Err() } } +func (d *Downloader) record(ctx context.Context, role, outcome string, duration time.Duration, bytesSaved int64) { + if d.recordSpool != nil { + d.recordSpool(ctx, role, outcome, duration, bytesSaved) + } +} + +func spoolOutcome(err error) string { + switch { + case err == nil: + return "success" + case errors.Is(err, context.DeadlineExceeded): + return "timeout" + case errors.Is(err, context.Canceled): + return "canceled" + default: + return "error" + } +} + // Forget removes the key from the singleflight group, allowing a subsequent // call to retry. Typically called after a download error. func (d *Downloader) Forget(key string) { diff --git a/download/downloader_test.go b/download/downloader_test.go index 1d6c5dd..e520efc 100644 --- a/download/downloader_test.go +++ b/download/downloader_test.go @@ -32,6 +32,18 @@ func TestDo_SingleCall(t *testing.T) { func TestDo_ConcurrentDeduplication(t *testing.T) { d := New() + type spoolEvent struct { + role string + outcome string + bytesSaved int64 + } + var eventsMu sync.Mutex + var events []spoolEvent + d.recordSpool = func(_ context.Context, role, outcome string, _ time.Duration, bytesSaved int64) { + eventsMu.Lock() + defer eventsMu.Unlock() + events = append(events, spoolEvent{role: role, outcome: outcome, bytesSaved: bytesSaved}) + } var callCount atomic.Int32 expected := &Result{ @@ -63,6 +75,24 @@ func TestDo_ConcurrentDeduplication(t *testing.T) { require.NoError(t, errs[i]) require.Equal(t, expected.Hash, results[i].Hash) } + + eventsMu.Lock() + defer eventsMu.Unlock() + var origins, coalesced int + var bytesSaved int64 + for _, event := range events { + require.Equal(t, "success", event.outcome) + switch event.role { + case spoolRoleOrigin: + origins++ + case spoolRoleCoalesced: + coalesced++ + } + bytesSaved += event.bytesSaved + } + require.Equal(t, 1, origins) + require.Equal(t, 9, coalesced) + require.Equal(t, int64(36), bytesSaved) } func TestDo_CallerTimeout(t *testing.T) { diff --git a/protocol/git/handler.go b/protocol/git/handler.go index cbe6911..a506bd5 100644 --- a/protocol/git/handler.go +++ b/protocol/git/handler.go @@ -376,6 +376,35 @@ func (h *Handler) handleUploadPack(w http.ResponseWriter, r *http.Request, repo } } + // In the git v2 wire protocol `ls-refs` is passed to the git-upload-pack endpoint. + // This type of request cannot be cached, as it is backed by mutable refs upstream. + isLsRefs, lsRefsErr := isLsRefs(activeBodyFile) + if lsRefsErr != nil { + logger.Error("failed to check if git command includes ls-refs", "error", lsRefsErr) + } else if isLsRefs { + telemetry.SetCacheResult(r, telemetry.CacheBypass) + + upstream := h.router.Match(repo) + rc, err := upstream.FetchUploadPack(ctx, repo, gitProtocol, activeBodyFile) + _ = activeBodyFile.Close() + if errors.Is(err, ErrNotFound) { + http.Error(w, "repository not found", http.StatusNotFound) + return + } + if err != nil { + logger.Error("upstream git-upload-pack ls-refs failed", "error", err) + http.Error(w, "upstream error", http.StatusBadGateway) + return + } + defer func() { _ = rc.Close() }() + + w.Header().Set("Content-Type", ContentTypeUploadPackResult) + if _, err := io.Copy(w, rc); err != nil { + logger.Error("failed to stream git-upload-pack ls-refs response", "error", err) + } + return + } + // Check cache cached, err := h.index.GetCachedPack(ctx, cacheKey) if err == nil { @@ -616,3 +645,49 @@ func parseHexLen(s string) (int, error) { } return int(n), nil } + +// isLsRefs determines if a particular request is a ls-refs command. +// +// The output of ls-refs commands change as the upstream updates, +// and so they cannot be cached in the same way a normal git-upload-pack request can be. +func isLsRefs(activeBody io.ReadSeeker) (bool, error) { + if _, err := activeBody.Seek(0, io.SeekStart); err != nil { + return false, err + } + defer func() { _, _ = activeBody.Seek(0, io.SeekStart) }() + + body, err := io.ReadAll(io.LimitReader(activeBody, maxLsRefsScanSize)) + if err != nil { + return false, err + } + + pos := 0 + for pos+4 <= len(body) { + pktLen, err := parseHexLen(string(body[pos : pos+4])) + if err != nil { + return false, err + } + if pktLen == 0 || pktLen == 1 { + // pktLen == 0 -> flush packet + // pktLen == 1 -> delimiter packet + // in both cases, they are just special control packets + pos += 4 + continue + } + if pktLen < 4 { + return false, fmt.Errorf("git protocol malformed packet size") + } + if pos+pktLen > len(body) { + break + } + content := strings.TrimRight(string(body[pos+4:pos+pktLen]), "\n") + cmd, ok := strings.CutPrefix(content, "command=") + if !ok { + pos += pktLen + continue + } + return cmd == "ls-refs", nil + } + + return false, nil +} diff --git a/protocol/git/handler_test.go b/protocol/git/handler_test.go index 422e352..f90afae 100644 --- a/protocol/git/handler_test.go +++ b/protocol/git/handler_test.go @@ -5,6 +5,7 @@ import ( "compress/gzip" "context" "fmt" + "io" "net/http" "net/http/httptest" "os" @@ -646,6 +647,206 @@ func TestHandlerStaleCacheEviction(t *testing.T) { require.Equal(t, firstBody, w2.Body.String()) } +// pktLine encodes s as a git pkt-line: a 4-hex-digit length prefix covering the +// prefix itself plus the payload, followed by the payload. +func pktLine(s string) string { + return fmt.Sprintf("%04x%s", len(s)+4, s) +} + +const ( + flushPkt = "0000" + delimPkt = "0001" +) + +func TestIsLsRefs(t *testing.T) { + tests := []struct { + name string + body string + want bool + wantErr bool + }{ + { + name: "ls-refs command only", + body: pktLine("command=ls-refs\n"), + want: true, + }, + { + name: "ls-refs with capabilities, delimiter and args", + body: pktLine("command=ls-refs\n") + + pktLine("object-format=sha1\n") + + delimPkt + + pktLine("peel\n") + + pktLine("ref-prefix HEAD\n") + + flushPkt, + want: true, + }, + { + name: "ls-refs preceded by control packets", + body: flushPkt + delimPkt + pktLine("command=ls-refs\n"), + want: true, + }, + { + name: "fetch command is not ls-refs", + body: pktLine("command=fetch\n") + + pktLine("object-format=sha1\n") + + delimPkt + + pktLine("want 0123456789012345678901234567890123456789\n") + + flushPkt, + want: false, + }, + { + name: "v1-style want/done body has no command", + body: pktLine("want 0123456789012345678901234567890123456789\n") + + flushPkt + + pktLine("done\n"), + want: false, + }, + { + name: "empty body", + body: "", + want: false, + }, + { + name: "only a flush packet", + body: flushPkt, + want: false, + }, + { + name: "malformed packet length below minimum", + body: "0003", + wantErr: true, + }, + { + name: "invalid hex length", + body: "zzzzcommand=ls-refs\n", + wantErr: true, + }, + { + name: "truncated packet is treated as not ls-refs", + body: "0050command=ls-refs\n", // declares 0x50 bytes but the body is far shorter + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := isLsRefs(strings.NewReader(tt.body)) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tt.want, got) + }) + } +} + +func TestIsLsRefsResetsPosition(t *testing.T) { + body := pktLine("command=ls-refs\n") + flushPkt + r := strings.NewReader(body) + + got, err := isLsRefs(r) + require.NoError(t, err) + require.True(t, got) + + rest, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, body, string(rest), "body must be readable from the start after isLsRefs") +} + +func TestIsLsRefsSeeksToStartBeforeScanning(t *testing.T) { + body := pktLine("command=ls-refs\n") + r := strings.NewReader(body) + + _, err := r.Seek(4, io.SeekStart) + require.NoError(t, err) + + got, err := isLsRefs(r) + require.NoError(t, err) + require.True(t, got) +} + +// countingUploadPackUpstream returns a test server that records every +// git-upload-pack request body and responds with a fake pack. +func countingUploadPackUpstream(fetchCount *atomic.Int32, lastBody *[]byte) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/git-upload-pack") { + fetchCount.Add(1) + if b, err := io.ReadAll(r.Body); err == nil { + *lastBody = b + } + w.Header().Set("Content-Type", ContentTypeUploadPackResult) + _, _ = w.Write([]byte(fakeUploadPackBody)) + return + } + w.WriteHeader(http.StatusNotFound) + })) +} + +func TestHandlerUploadPackLsRefsBypassesCache(t *testing.T) { + var fetchCount atomic.Int32 + var lastBody []byte + upstreamSrv := countingUploadPackUpstream(&fetchCount, &lastBody) + defer upstreamSrv.Close() + + h, cleanup := newTestHandlerWithTransport(t, upstreamSrv) + defer cleanup() + + body := []byte(pktLine("command=ls-refs\n") + + pktLine("object-format=sha1\n") + + delimPkt + + pktLine("peel\n") + + pktLine("ref-prefix HEAD\n") + + flushPkt) + + for i := range 2 { + req := httptest.NewRequest(http.MethodPost, "/github.com/user/repo.git/git-upload-pack", bytes.NewReader(body)) + req.Header.Set("Content-Type", ContentTypeUploadPackRequest) + req.Header.Set("Git-Protocol", "version=2") + w := httptest.NewRecorder() + + h.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, "request %d", i) + require.Equal(t, ContentTypeUploadPackResult, w.Header().Get("Content-Type")) + require.Equal(t, fakeUploadPackBody, w.Body.String()) + } + + require.Equal(t, int32(2), fetchCount.Load(), "ls-refs must reach upstream on every request and never be cached") + require.Equal(t, body, lastBody, "the full ls-refs body must be forwarded to upstream") +} + +func TestHandlerUploadPackLsRefsGzip(t *testing.T) { + var fetchCount atomic.Int32 + var lastBody []byte + upstreamSrv := countingUploadPackUpstream(&fetchCount, &lastBody) + defer upstreamSrv.Close() + + h, cleanup := newTestHandlerWithTransport(t, upstreamSrv) + defer cleanup() + + plain := []byte(pktLine("command=ls-refs\n") + flushPkt) + + var gzBuf bytes.Buffer + gz := gzip.NewWriter(&gzBuf) + _, err := gz.Write(plain) + require.NoError(t, err) + require.NoError(t, gz.Close()) + + req := httptest.NewRequest(http.MethodPost, "/github.com/user/repo.git/git-upload-pack", bytes.NewReader(gzBuf.Bytes())) + req.Header.Set("Content-Type", ContentTypeUploadPackRequest) + req.Header.Set("Content-Encoding", "gzip") + req.Header.Set("Git-Protocol", "version=2") + w := httptest.NewRecorder() + + h.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, fakeUploadPackBody, w.Body.String()) + require.Equal(t, int32(1), fetchCount.Load()) + require.Equal(t, plain, lastBody, "the decompressed ls-refs body must be forwarded to upstream") +} + // newTestHandlerWithTransport creates a test handler that redirects all HTTPS // requests to the given test server via a custom transport. func newTestHandlerWithTransport(t *testing.T, targetServer *httptest.Server) (*Handler, func()) { diff --git a/protocol/git/types.go b/protocol/git/types.go index be8b424..a1c9251 100644 --- a/protocol/git/types.go +++ b/protocol/git/types.go @@ -23,6 +23,11 @@ const ( // DefaultMaxDecompressedBodySize is the maximum size of a decompressed git-upload-pack body (500MB). DefaultMaxDecompressedBodySize int64 = 500 * 1024 * 1024 + + // maxLsRefsScanSize bounds how much of an upload-pack request body is scanned + // to detect a protocol v2 "command=" pkt-line. The command appears at the very + // start of the body, so a small cap is sufficient. + maxLsRefsScanSize int64 = 64 * 1024 ) // ErrNotFound is returned when a cached pack is not found. diff --git a/protocol/npm/handler.go b/protocol/npm/handler.go index a070583..caa1dcc 100644 --- a/protocol/npm/handler.go +++ b/protocol/npm/handler.go @@ -40,6 +40,10 @@ type Handler struct { logger *slog.Logger downloader *download.Downloader + // publicBaseURL, when set, is the external base URL used to build served + // tarball links instead of the request scheme and Host header. + publicBaseURL string + // Lifecycle management for background goroutines wg sync.WaitGroup shutdownCtx context.Context @@ -79,6 +83,16 @@ func WithDownloader(dl *download.Downloader) HandlerOption { } } +// WithPublicBaseURL sets the external base URL (e.g. https://cache.example.com) +// used when rewriting tarball download links. When empty, the request scheme and +// Host header are used. Set this when TLS is terminated by an upstream load +// balancer, so the cache cannot infer https from the inbound request. +func WithPublicBaseURL(baseURL string) HandlerOption { + return func(h *Handler) { + h.publicBaseURL = strings.TrimSuffix(baseURL, "/") + } +} + // NewHandler creates a new NPM registry handler. func NewHandler(index *Index, store store.Store, opts ...HandlerOption) *Handler { shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) @@ -249,11 +263,14 @@ func (h *Handler) rewriteTarballURLs(r *http.Request, meta map[string]any) { } // Determine our base URL - scheme := "http" - if r.TLS != nil { - scheme = "https" + baseURL := h.publicBaseURL + if baseURL == "" { + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + baseURL = fmt.Sprintf("%s://%s", scheme, r.Host) } - baseURL := fmt.Sprintf("%s://%s", scheme, r.Host) for _, v := range versions { version, ok := v.(map[string]any) diff --git a/protocol/pypi/handler.go b/protocol/pypi/handler.go index 6623c2a..f539b69 100644 --- a/protocol/pypi/handler.go +++ b/protocol/pypi/handler.go @@ -39,6 +39,10 @@ type Handler struct { downloader *download.Downloader metadataTTL time.Duration + // publicBaseURL, when set, is the external base URL used to build served + // download links instead of the request scheme and Host header. + publicBaseURL string + // Lifecycle management for background goroutines wg sync.WaitGroup ctx context.Context @@ -76,6 +80,16 @@ func WithMetadataTTL(ttl time.Duration) HandlerOption { } } +// WithPublicBaseURL sets the external base URL (e.g. https://cache.example.com) +// used when rewriting file download links. When empty, the request scheme and +// Host header are used. Set this when TLS is terminated by an upstream load +// balancer, so the cache cannot infer https from the inbound request. +func WithPublicBaseURL(baseURL string) HandlerOption { + return func(h *Handler) { + h.publicBaseURL = strings.TrimSuffix(baseURL, "/") + } +} + // NewHandler creates a new PyPI Simple API handler. func NewHandler(index *Index, store store.Store, opts ...HandlerOption) *Handler { ctx, cancel := context.WithCancel(context.Background()) @@ -562,18 +576,28 @@ func (h *Handler) cacheFile(ctx context.Context, project, filename string, hash logger.Info("cached file", "filename", filename, "hash", hash.ShortString(), "size", size) } +// baseURL returns the external base URL used to build served download links. +// It prefers the configured public base URL and otherwise derives it from the +// request scheme and Host header. +func (h *Handler) baseURL(r *http.Request) string { + if h.publicBaseURL != "" { + return h.publicBaseURL + } + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + return fmt.Sprintf("%s://%s", scheme, r.Host) +} + // writeProjectResponse writes the project page response in HTML or JSON format. func (h *Handler) writeProjectResponse(w http.ResponseWriter, r *http.Request, cached *CachedProject, normalized string) { - // Build file list for response + // Rewrite file URLs to point to our proxy. We include the /pypi prefix + // since the server strips it before passing to this handler. + base := h.baseURL(r) var files []ProjectFile for _, f := range cached.Files { - // Rewrite URL to point to our proxy - // Note: We include /pypi prefix since the server strips it before passing to this handler - scheme := "http" - if r.TLS != nil { - scheme = "https" - } - proxyURL := fmt.Sprintf("%s://%s/pypi/packages/%s/%s", scheme, r.Host, normalized, f.Filename) + proxyURL := fmt.Sprintf("%s/pypi/packages/%s/%s", base, normalized, f.Filename) pf := ProjectFile{ Filename: f.Filename, diff --git a/protocol/pypi/handler_test.go b/protocol/pypi/handler_test.go index cc9bec1..c317dfe 100644 --- a/protocol/pypi/handler_test.go +++ b/protocol/pypi/handler_test.go @@ -27,7 +27,7 @@ func setupMetaDB(t *testing.T, tmpDir string) (*metadb.BoltDB, *Index, func()) { return db, idx, func() { _ = db.Close() } } -func newTestHandler(t *testing.T, upstreamServer *httptest.Server) (*Handler, func()) { +func newTestHandler(t *testing.T, upstreamServer *httptest.Server, extraOpts ...HandlerOption) (*Handler, func()) { t.Helper() // Use a manual temp dir instead of t.TempDir() to avoid race with async goroutines tmpDir, err := os.MkdirTemp("", "pypi-test-*") @@ -44,7 +44,7 @@ func newTestHandler(t *testing.T, upstreamServer *httptest.Server) (*Handler, fu } upstream := NewUpstream(opts...) - h := NewHandler(idx, cafs, WithUpstream(upstream)) + h := NewHandler(idx, cafs, append([]HandlerOption{WithUpstream(upstream)}, extraOpts...)...) return h, func() { h.Close() closeDB() @@ -121,6 +121,57 @@ func TestHandlerProject(t *testing.T) { }) } +func TestHandlerProjectPublicBaseURL(t *testing.T) { + upstreamHandler := func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/simple/requests/" { + w.Header().Set("Content-Type", "text/html") + _, _ = w.Write([]byte(` +
+requests-2.31.0-py3-none-any.whl