Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmd/content-cache/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type ServeCmd struct {
Storage string `kong:"name='storage',default='./cache',env='CACHE_STORAGE',help='Storage directory path',group='Server'"`
TLSCertFile string `kong:"name='tls-cert',env='TLS_CERT_FILE',type='existingfile',help='Path to TLS certificate file (enables HTTPS)',group='Server'"`
TLSKeyFile string `kong:"name='tls-key',env='TLS_KEY_FILE',type='existingfile',help='Path to TLS private key file (enables HTTPS)',group='Server'"`
PublicBaseURL string `kong:"name='public-base-url',env='PUBLIC_BASE_URL',help='External base URL clients use to reach this cache (e.g. https://cache.example.com). Used to build served download links.',group='Server'"`

AuthToken string `kong:"name='auth-token',env='AUTH_TOKEN',help='Bearer token for inbound authentication (mutually exclusive with --oidc-policies)',group='Auth'"`
AuthTokenFile string `kong:"name='auth-token-file',env='AUTH_TOKEN_FILE',type='existingfile',help='Path to file containing auth token (for k8s secret mounts)',group='Auth'"`
Expand Down Expand Up @@ -263,6 +264,7 @@ func (cmd *ServeCmd) Run() error {
StoragePath: cmd.Storage,
TLSCertFile: cmd.TLSCertFile,
TLSKeyFile: cmd.TLSKeyFile,
PublicBaseURL: cmd.PublicBaseURL,
AuthToken: authToken,
OIDCValidator: oidcValidator,
Credentials: creds,
Expand Down
53 changes: 51 additions & 2 deletions download/downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,22 @@ package download

import (
"context"
"errors"
"sync/atomic"
"time"

contentcache "github.com/buildkite/content-cache"
"github.com/buildkite/content-cache/telemetry"
"golang.org/x/sync/singleflight"
)

const (
spoolRoleOrigin = "origin"
spoolRoleCoalesced = "coalesced"
)

type spoolRecorder func(ctx context.Context, role, outcome string, duration time.Duration, bytesSaved int64)

// Result holds the outcome of a download operation.
type Result struct {
Hash contentcache.Hash
Expand All @@ -25,12 +36,13 @@ type DownloadFunc func(ctx context.Context) (*Result, error)
// using singleflight. It uses DoChan so each caller can respect its own
// context deadline without cancelling the in-flight download for others.
type Downloader struct {
group singleflight.Group
group singleflight.Group
recordSpool spoolRecorder
}

// New creates a new Downloader.
func New() *Downloader {
return &Downloader{}
return &Downloader{recordSpool: telemetry.RecordSpoolRequest}
}

// Do deduplicates concurrent downloads for the same key.
Expand All @@ -40,23 +52,60 @@ func New() *Downloader {
// If the caller's context expires before the download completes, Do returns
// the context error but the in-flight download continues for other waiters.
func (d *Downloader) Do(ctx context.Context, key string, fn DownloadFunc) (*Result, bool, error) {
started := time.Now()
var executed atomic.Bool
ch := d.group.DoChan(key, func() (any, error) {
executed.Store(true)
// Use a detached context so that no single caller's cancellation
// stops the download for everyone else.
return fn(context.WithoutCancel(ctx))
})

select {
case res := <-ch:
role := spoolRoleCoalesced
if executed.Load() {
role = spoolRoleOrigin
}
outcome := spoolOutcome(res.Err)
var bytesSaved int64
if role == spoolRoleCoalesced && res.Err == nil {
bytesSaved = res.Val.(*Result).Size
}
d.record(ctx, role, outcome, time.Since(started), bytesSaved)
if res.Err != nil {
return nil, res.Shared, res.Err
}
return res.Val.(*Result), res.Shared, nil
case <-ctx.Done():
role := spoolRoleCoalesced
if executed.Load() {
role = spoolRoleOrigin
}
d.record(ctx, role, spoolOutcome(ctx.Err()), time.Since(started), 0)
return nil, false, ctx.Err()
}
}

func (d *Downloader) record(ctx context.Context, role, outcome string, duration time.Duration, bytesSaved int64) {
if d.recordSpool != nil {
d.recordSpool(ctx, role, outcome, duration, bytesSaved)
}
}

func spoolOutcome(err error) string {
switch {
case err == nil:
return "success"
case errors.Is(err, context.DeadlineExceeded):
return "timeout"
case errors.Is(err, context.Canceled):
return "canceled"
default:
return "error"
}
}

// Forget removes the key from the singleflight group, allowing a subsequent
// call to retry. Typically called after a download error.
func (d *Downloader) Forget(key string) {
Expand Down
30 changes: 30 additions & 0 deletions download/downloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ func TestDo_SingleCall(t *testing.T) {

func TestDo_ConcurrentDeduplication(t *testing.T) {
d := New()
type spoolEvent struct {
role string
outcome string
bytesSaved int64
}
var eventsMu sync.Mutex
var events []spoolEvent
d.recordSpool = func(_ context.Context, role, outcome string, _ time.Duration, bytesSaved int64) {
eventsMu.Lock()
defer eventsMu.Unlock()
events = append(events, spoolEvent{role: role, outcome: outcome, bytesSaved: bytesSaved})
}

var callCount atomic.Int32
expected := &Result{
Expand Down Expand Up @@ -63,6 +75,24 @@ func TestDo_ConcurrentDeduplication(t *testing.T) {
require.NoError(t, errs[i])
require.Equal(t, expected.Hash, results[i].Hash)
}

eventsMu.Lock()
defer eventsMu.Unlock()
var origins, coalesced int
var bytesSaved int64
for _, event := range events {
require.Equal(t, "success", event.outcome)
switch event.role {
case spoolRoleOrigin:
origins++
case spoolRoleCoalesced:
coalesced++
}
bytesSaved += event.bytesSaved
}
require.Equal(t, 1, origins)
require.Equal(t, 9, coalesced)
require.Equal(t, int64(36), bytesSaved)
}

func TestDo_CallerTimeout(t *testing.T) {
Expand Down
75 changes: 75 additions & 0 deletions protocol/git/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,35 @@ func (h *Handler) handleUploadPack(w http.ResponseWriter, r *http.Request, repo
}
}

// In the git v2 wire protocol `ls-refs` is passed to the git-upload-pack endpoint.
// This type of request cannot be cached, as it is backed by mutable refs upstream.
isLsRefs, lsRefsErr := isLsRefs(activeBodyFile)
if lsRefsErr != nil {
logger.Error("failed to check if git command includes ls-refs", "error", lsRefsErr)
} else if isLsRefs {
telemetry.SetCacheResult(r, telemetry.CacheBypass)

upstream := h.router.Match(repo)
rc, err := upstream.FetchUploadPack(ctx, repo, gitProtocol, activeBodyFile)
_ = activeBodyFile.Close()
if errors.Is(err, ErrNotFound) {
http.Error(w, "repository not found", http.StatusNotFound)
return
}
if err != nil {
logger.Error("upstream git-upload-pack ls-refs failed", "error", err)
http.Error(w, "upstream error", http.StatusBadGateway)
return
}
defer func() { _ = rc.Close() }()

w.Header().Set("Content-Type", ContentTypeUploadPackResult)
if _, err := io.Copy(w, rc); err != nil {
logger.Error("failed to stream git-upload-pack ls-refs response", "error", err)
}
return
}

// Check cache
cached, err := h.index.GetCachedPack(ctx, cacheKey)
if err == nil {
Expand Down Expand Up @@ -616,3 +645,49 @@ func parseHexLen(s string) (int, error) {
}
return int(n), nil
}

// isLsRefs determines if a particular request is a ls-refs command.
//
// The output of ls-refs commands change as the upstream updates,
// and so they cannot be cached in the same way a normal git-upload-pack request can be.
func isLsRefs(activeBody io.ReadSeeker) (bool, error) {
if _, err := activeBody.Seek(0, io.SeekStart); err != nil {
return false, err
}
defer func() { _, _ = activeBody.Seek(0, io.SeekStart) }()

body, err := io.ReadAll(io.LimitReader(activeBody, maxLsRefsScanSize))
if err != nil {
return false, err
}

pos := 0
for pos+4 <= len(body) {
pktLen, err := parseHexLen(string(body[pos : pos+4]))
if err != nil {
return false, err
}
if pktLen == 0 || pktLen == 1 {
// pktLen == 0 -> flush packet
// pktLen == 1 -> delimiter packet
// in both cases, they are just special control packets
pos += 4
continue
}
if pktLen < 4 {
return false, fmt.Errorf("git protocol malformed packet size")
}
if pos+pktLen > len(body) {
break
}
content := strings.TrimRight(string(body[pos+4:pos+pktLen]), "\n")
cmd, ok := strings.CutPrefix(content, "command=")
if !ok {
pos += pktLen
continue
}
return cmd == "ls-refs", nil
}

return false, nil
}
Loading