diff --git a/pkg/cache/validating_cache.go b/pkg/cache/validating_cache.go index b34dfee8b8..c3f6342f79 100644 --- a/pkg/cache/validating_cache.go +++ b/pkg/cache/validating_cache.go @@ -8,6 +8,8 @@ package cache import ( "errors" "fmt" + "reflect" + "sync" lru "github.com/hashicorp/golang-lru/v2" "golang.org/x/sync/singleflight" @@ -18,25 +20,26 @@ import ( var ErrExpired = errors.New("cache entry expired") // ValidatingCache is a node-local write-through cache backed by a -// capacity-bounded LRU map, with singleflight-deduplicated restore on cache -// miss and lazy liveness validation on cache hit. +// capacity-bounded LRU map, with singleflight-deduplicated Get operations and +// lazy liveness validation on cache hit. // // Type parameter K is the key type (must be comparable). // Type parameter V is the cached value type. // -// The no-resurrection invariant (preventing a concurrent restore from -// overwriting a deletion) is enforced via ContainsOrAdd: if a concurrent -// writer stored a value between load() returning and the cache being updated, -// the prior writer's value wins and the just-loaded value is discarded via -// onEvict. +// The entire Get operation — cache hit validation and miss load — runs under a +// singleflight group so at most one operation executes concurrently per key. +// Concurrent callers for the same key share the result, coalescing both +// liveness checks and storage round-trips into a single operation per key. type ValidatingCache[K comparable, V any] struct { lruCache *lru.Cache[K, V] flight singleflight.Group load func(key K) (V, error) check func(key K, val V) error - // onEvict is kept here so we can call it when discarding a concurrently - // loaded value that lost the race to a prior writer. - onEvict func(key K, val V) + onEvict func(K, V) + // mu serializes Set against the conditional eviction in getHit. + // check() runs outside the lock to avoid holding it during I/O; the lock + // is only held for the short Peek+Remove sequence. + mu sync.Mutex } // New creates a ValidatingCache with the given capacity and callbacks. @@ -88,8 +91,16 @@ func New[K comparable, V any]( func (c *ValidatingCache[K, V]) getHit(key K, val V) (V, bool) { if err := c.check(key, val); err != nil { if errors.Is(err, ErrExpired) { - // Remove fires the eviction callback automatically. - c.lruCache.Remove(key) + // check() ran outside the lock to avoid holding it during I/O. + // Re-verify under the lock that the entry hasn't been replaced by a + // concurrent Set before removing it; otherwise we would evict a + // freshly-written value that the caller intended to keep. + c.mu.Lock() + if current, ok := c.lruCache.Peek(key); ok && sameEntry(current, val) { + // Remove fires the eviction callback automatically. + c.lruCache.Remove(key) + } + c.mu.Unlock() var zero V return zero, false } @@ -97,54 +108,57 @@ func (c *ValidatingCache[K, V]) getHit(key K, val V) (V, bool) { return val, true } -// Get returns the value for key, loading it on a cache miss. On a cache hit -// the entry's liveness is validated via the check function provided to New: -// ErrExpired evicts the entry and returns (zero, false); transient errors -// return the cached value unchanged. On a cache miss, load is called under a -// singleflight group so at most one restore runs concurrently per key. +// Get returns the value for key, loading it on a cache miss. The entire +// operation — cache hit validation and miss load — runs under a singleflight +// group so at most one operation executes concurrently per key. Concurrent +// callers for the same key share the result. +// +// On a cache hit the entry's liveness is validated via the check function +// provided to New: ErrExpired evicts the entry and falls through to load; +// transient errors return the cached value unchanged. On a cache miss, load +// is called to restore the value. +// +// The returned bool is false whenever the value is unavailable — either +// because load returned an error or because the key does not exist in the +// backing store. Callers cannot distinguish these two cases. func (c *ValidatingCache[K, V]) Get(key K) (V, bool) { - if val, ok := c.lruCache.Get(key); ok { - return c.getHit(key, val) - } - - // Cache miss: use singleflight to prevent concurrent restores for the same key. type result struct{ v V } + // fmt.Sprint(key) is the singleflight key. For string keys this is + // exact. For other types, distinct values with identical string + // representations would be incorrectly coalesced — avoid non-string K + // types unless their fmt.Sprint output is guaranteed unique. raw, err, _ := c.flight.Do(fmt.Sprint(key), func() (any, error) { - // Re-check the cache: a concurrent singleflight group may have stored - // the value between our miss check above and acquiring this group. - if existing, ok := c.lruCache.Get(key); ok { - return result{v: existing}, nil + // Cache hit path: validate liveness. + if val, ok := c.lruCache.Get(key); ok { + v, alive := c.getHit(key, val) + if alive { + return result{v: v}, nil + } + // Entry expired and evicted; fall through to load. } + // Cache miss (or expired): load the value and store it. v, loadErr := c.load(key) if loadErr != nil { return nil, loadErr } - // Guard against a concurrent Set or Remove that occurred while load() was - // running. ContainsOrAdd stores only if absent; if another writer got - // in first, their value wins and we discard ours via onEvict. - ok, _ := c.lruCache.ContainsOrAdd(key, v) - if ok { - // Another writer stored a value first; discard our loaded value and - // return the winner's. ContainsOrAdd and Get are separate lock - // acquisitions, so the winner may itself have been evicted by LRU - // pressure between the two calls — fall back to our freshly loaded - // value in that case rather than returning a zero value. - winner, found := c.lruCache.Get(key) - // Discard our loaded value in favour of the winner (or clean up if - // the winner was itself evicted between ContainsOrAdd and Get). - if c.onEvict != nil { - c.onEvict(key, v) + // Guard against a concurrent Set that occurred while load() was running. + // ContainsOrAdd stores only if absent; if a concurrent Set got in first, + // their value wins and we return it instead. + if alreadySet, _ := c.lruCache.ContainsOrAdd(key, v); alreadySet { + if winner, ok := c.lruCache.Get(key); ok { + // Winner confirmed: v is definitively discarded — release its resources. + if c.onEvict != nil { + c.onEvict(key, v) + } + return result{v: winner}, nil } - if !found { - // Winner was evicted before we could retrieve it; signal a cache - // miss so the caller retries rather than receiving a stale value. - return nil, nil - } - return result{v: winner}, nil + // The concurrent winner was itself evicted by LRU pressure between + // ContainsOrAdd and Get. Fall back to storing v — do NOT call onEvict + // since v has not been released and is still valid. + c.lruCache.Add(key, v) } - return result{v: v}, nil }) if err != nil { @@ -159,6 +173,8 @@ func (c *ValidatingCache[K, V]) Get(key K) (V, bool) { // cache is at capacity, the least-recently-used entry is evicted first and // onEvict is called for it. func (c *ValidatingCache[K, V]) Set(key K, value V) { + c.mu.Lock() + defer c.mu.Unlock() c.lruCache.Add(key, value) } @@ -166,3 +182,20 @@ func (c *ValidatingCache[K, V]) Set(key K, value V) { func (c *ValidatingCache[K, V]) Len() int { return c.lruCache.Len() } + +// sameEntry reports whether a and b are the same cache entry. +// For pointer types it compares addresses (identity), so a concurrent Set that +// stores a distinct new value is never mistaken for the stale entry. For +// non-pointer types it falls back to reflect.DeepEqual, which is safe for all +// comparable and non-comparable types. +func sameEntry[V any](a, b V) bool { + ra := reflect.ValueOf(any(a)) + if ra.IsValid() { + switch ra.Kind() { //nolint:exhaustive + case reflect.Ptr, reflect.UnsafePointer: + rb := reflect.ValueOf(any(b)) + return rb.IsValid() && ra.Pointer() == rb.Pointer() + } + } + return reflect.DeepEqual(a, b) +} diff --git a/pkg/cache/validating_cache_test.go b/pkg/cache/validating_cache_test.go index 209d8769d3..0d78536e35 100644 --- a/pkg/cache/validating_cache_test.go +++ b/pkg/cache/validating_cache_test.go @@ -9,6 +9,7 @@ import ( "sync" "sync/atomic" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -44,6 +45,20 @@ func TestValidatingCache_New_PanicsOnNegativeCapacity(t *testing.T) { }) } +func TestValidatingCache_New_PanicsOnNilLoad(t *testing.T) { + t.Parallel() + assert.Panics(t, func() { + New[string, string](1, nil, alwaysAliveCheck, nil) + }) +} + +func TestValidatingCache_New_PanicsOnNilCheck(t *testing.T) { + t.Parallel() + assert.Panics(t, func() { + New(1, func(_ string) (string, error) { return "", nil }, nil, nil) + }) +} + // --------------------------------------------------------------------------- // Cache miss / restore // --------------------------------------------------------------------------- @@ -80,8 +95,8 @@ func TestValidatingCache_CacheMiss_StoresResult(t *testing.T) { nil, ) - c.Get("k") //nolint:errcheck - c.Get("k") //nolint:errcheck + c.Get("k") + c.Get("k") assert.Equal(t, 1, calls, "load should be called only once after caching") } @@ -112,7 +127,7 @@ func TestValidatingCache_CacheHit_AliveCheck_ReturnsCached(t *testing.T) { alwaysAliveCheck, nil, ) - c.Get("k") //nolint:errcheck // prime the cache + c.Get("k") // prime the cache // Second Get should return cached value without calling load again. v, ok := c.Get("k") @@ -133,11 +148,13 @@ func TestValidatingCache_CacheHit_Expired_EvictsAndCallsOnEvict(t *testing.T) { evictedVal = val }, ) - c.Get("k") //nolint:errcheck // prime the cache + c.Get("k") // prime the cache + // With singleflight wrapping the full Get, an expired hit evicts the entry + // and falls through to load within the same operation, returning the fresh value. v, ok := c.Get("k") - assert.False(t, ok) - assert.Empty(t, v) + require.True(t, ok) + assert.Equal(t, "v", v) assert.Equal(t, "k", evictedKey) assert.Equal(t, "v", evictedVal) } @@ -161,11 +178,11 @@ func TestValidatingCache_CacheHit_Expired_EntryRemovedFromCache(t *testing.T) { nil, ) - c.Get("k") //nolint:errcheck // prime the cache; check returns alive + c.Get("k") // prime the cache; check returns alive expired = true - c.Get("k") //nolint:errcheck // check returns ErrExpired → evict + c.Get("k") // check returns ErrExpired → evict expired = false - c.Get("k") //nolint:errcheck // cache miss again → load called + c.Get("k") // cache miss again → load called assert.Equal(t, 2, calls, "load should be called twice: initial + after eviction") } @@ -178,7 +195,7 @@ func TestValidatingCache_CacheHit_TransientCheckError_ReturnsCached(t *testing.T func(_ string, _ string) error { return errors.New("transient storage error") }, nil, ) - c.Get("k") //nolint:errcheck // prime the cache + c.Get("k") // prime the cache v, ok := c.Get("k") require.True(t, ok) @@ -213,7 +230,7 @@ func TestValidatingCache_Set_UpdatesExisting(t *testing.T) { alwaysAliveCheck, nil, ) - c.Get("k") //nolint:errcheck // prime with "loaded" + c.Get("k") // prime with "loaded" c.Set("k", "updated") v, ok := c.Get("k") @@ -242,9 +259,9 @@ func TestValidatingCache_LRU_EvictsLeastRecentlyUsed(t *testing.T) { }, ) - c.Get("a") //nolint:errcheck // a=MRU - c.Get("b") //nolint:errcheck // b=MRU, a=LRU - c.Get("c") //nolint:errcheck // c=MRU, b, a=LRU → evicts a + c.Get("a") // a=MRU + c.Get("b") // b=MRU, a=LRU + c.Get("c") // c=MRU, b, a=LRU → evicts a mu.Lock() defer mu.Unlock() @@ -273,10 +290,10 @@ func TestValidatingCache_LRU_GetRefreshesMRUPosition(t *testing.T) { }, ) - c.Get("a") //nolint:errcheck // a loaded (MRU) - c.Get("b") //nolint:errcheck // b loaded (MRU), a=LRU - c.Get("a") //nolint:errcheck // a accessed → a becomes MRU, b=LRU - c.Get("c") //nolint:errcheck // c loaded → evicts b (LRU), not a + c.Get("a") // a loaded (MRU) + c.Get("b") // b loaded (MRU), a=LRU + c.Get("a") // a accessed → a becomes MRU, b=LRU + c.Get("c") // c loaded → evicts b (LRU), not a mu.Lock() defer mu.Unlock() @@ -302,10 +319,10 @@ func TestValidatingCache_LRU_SetRefreshesMRUPosition(t *testing.T) { }, ) - c.Get("a") //nolint:errcheck // a=MRU - c.Get("b") //nolint:errcheck // b=MRU, a=LRU + c.Get("a") // a=MRU + c.Get("b") // b=MRU, a=LRU c.Set("a", "x") // Set refreshes a to MRU; b becomes LRU - c.Get("c") //nolint:errcheck // c loaded → evicts b + c.Get("c") // c loaded → evicts b mu.Lock() defer mu.Unlock() @@ -328,9 +345,9 @@ func TestValidatingCache_LRU_CapacityOne(t *testing.T) { }, ) - c.Get("a") //nolint:errcheck - c.Get("b") //nolint:errcheck // evicts a - c.Get("c") //nolint:errcheck // evicts b + c.Get("a") + c.Get("b") // evicts a + c.Get("c") // evicts b mu.Lock() defer mu.Unlock() @@ -350,7 +367,7 @@ func TestValidatingCache_LRU_LargeCapacityNoEviction(t *testing.T) { ) for i := range n { - c.Get(fmt.Sprintf("k%d", i)) //nolint:errcheck + c.Get(fmt.Sprintf("k%d", i)) } assert.Equal(t, n, c.Len(), "no entries should be evicted when under capacity") } @@ -365,9 +382,9 @@ func TestValidatingCache_LRU_Len(t *testing.T) { ) assert.Equal(t, 0, c.Len()) - c.Get("a") //nolint:errcheck + c.Get("a") assert.Equal(t, 1, c.Len()) - c.Get("b") //nolint:errcheck + c.Get("b") assert.Equal(t, 2, c.Len()) } @@ -375,19 +392,24 @@ func TestValidatingCache_LRU_Len(t *testing.T) { // Re-check inside singleflight (TOCTOU prevention) // --------------------------------------------------------------------------- -func TestValidatingCache_Singleflight_ReCheckReturnsPreStoredValue(t *testing.T) { +// TestValidatingCache_Singleflight_SetBeforeLoadReturns verifies that when +// Set is called for a key before the in-flight load completes, the Set value +// wins: ContainsOrAdd does not overwrite the writer's value, and the caller +// receives the Set value. +func TestValidatingCache_Singleflight_SetBeforeLoadReturns(t *testing.T) { t.Parallel() var loadCount atomic.Int32 - // The load function is gated: it waits until we signal that an external - // Set has been applied, mimicking a value written by another goroutine - // between the miss check and the singleflight group. - storeApplied := make(chan struct{}) + // loadReached is closed once load has definitely started, so the test can + // inject a concurrent Set before load returns. + loadReached := make(chan struct{}) + allowReturn := make(chan struct{}) c := newStringCache( func(_ string) (string, error) { - <-storeApplied // wait until external Set is applied + close(loadReached) // signal: load is now in-flight + <-allowReturn // block until test injects the concurrent Set loadCount.Add(1) return "from-load", nil }, @@ -400,70 +422,114 @@ func TestValidatingCache_Singleflight_ReCheckReturnsPreStoredValue(t *testing.T) result string ok bool ) - wg.Go(func() { + wg.Add(1) + go func() { + defer wg.Done() result, ok = c.Get("k") - }) + }() - // Set the value externally to simulate a concurrent writer, then release - // the load function. The re-check at the top of the singleflight function - // fires first and finds "external-value", so load is never called. + // Wait until load is definitely executing, then write via Set so that + // ContainsOrAdd inside the miss path finds the key already present. + <-loadReached c.Set("k", "external-value") - close(storeApplied) + close(allowReturn) // let load return "from-load" wg.Wait() require.True(t, ok) - assert.Equal(t, "external-value", result) - assert.Equal(t, int32(0), loadCount.Load(), "re-check should short-circuit before load is called") + assert.Equal(t, "external-value", result, "Set value should win over concurrent load") + assert.Equal(t, int32(1), loadCount.Load(), "load is called but its value is discarded") } -// TestValidatingCache_Singleflight_EvictsLoserWhenLoadRacesWriter covers the -// path where load() runs to completion but loses the ContainsOrAdd race to a -// concurrent Set. The loaded-but-discarded value must be passed to onEvict so -// any resources it holds (e.g. connections) can be cleaned up. -func TestValidatingCache_Singleflight_EvictsLoserWhenLoadRacesWriter(t *testing.T) { +// TestValidatingCache_Singleflight_DeduplicatesConcurrentLivenessChecks verifies +// that concurrent Gets on an expired entry coalesce into a single load call. +// +// Design: load blocks until all goroutines have signalled they are about to +// call Get. Because expired.Store(false) runs inside the singleflight callback +// (before it returns), goroutines that arrive late — after load() has already +// returned — find either: +// +// (a) the singleflight still in progress (they join it and share the result), or +// (b) a live entry in the cache (expired=false, check passes, no load needed). +// +// Either way loadCount == 1 is an invariant enforced by the implementation, not +// by timing luck. +func TestValidatingCache_Singleflight_DeduplicatesConcurrentLivenessChecks(t *testing.T) { t.Parallel() - // loadReached is closed when load() is about to return, giving us a hook to - // race a Set before ContainsOrAdd is called. - loadReached := make(chan struct{}) - // allowReturn lets the test control exactly when load() returns. - allowReturn := make(chan struct{}) + const goroutines = 10 + var ( + loadCount atomic.Int32 + allStarted sync.WaitGroup + wg sync.WaitGroup + results = make([]string, goroutines) + oks = make([]bool, goroutines) + ) + + var expired atomic.Bool - var evictedKey, evictedVal string c := newStringCache( func(_ string) (string, error) { - close(loadReached) // signal: load has run - <-allowReturn // wait until test injects the concurrent Set - return "from-load", nil + // Wait until every goroutine has signalled it is about to call Get. + // allStarted.Done() is called before Get(), so this unblocks once + // the goroutine scheduler has scheduled all callers — not necessarily + // once they've all entered flight.Do. That is fine: goroutines + // arriving after load() returns find a live entry (expired is cleared + // below) and return early via the cache-hit path. loadCount = 1 + // either way. + allStarted.Wait() + loadCount.Add(1) + expired.Store(false) // refresh: late arrivals see a live entry + return "reloaded", nil }, - alwaysAliveCheck, - func(key, val string) { - evictedKey = key - evictedVal = val + func(_ string, _ string) error { + if expired.Load() { + return ErrExpired + } + return nil }, + nil, ) - var wg sync.WaitGroup - var gotVal string - var gotOk bool - wg.Go(func() { - gotVal, gotOk = c.Get("k") - }) + // Prime the cache with a live entry. allStarted has count 0 here, so + // Wait() inside load returns immediately — no deadlock. + _, ok := c.Get("k") + require.True(t, ok) + assert.Equal(t, int32(1), loadCount.Load()) - // Wait until load() is running, then inject a concurrent Set so that - // ContainsOrAdd finds the key already present and discards the loaded value. - <-loadReached - c.Set("k", "from-set") - close(allowReturn) // let load() return "from-load" - wg.Wait() + // Reset state: add the goroutine count first, then mark expired so load + // will block waiting for goroutines to pile up. + loadCount.Store(0) + allStarted.Add(goroutines) + expired.Store(true) - // The concurrent Set wins: caller receives the Set value. - require.True(t, gotOk) - assert.Equal(t, "from-set", gotVal, "concurrent Set value should win") + for i := range goroutines { + wg.Add(1) + go func(i int) { + defer wg.Done() + allStarted.Done() // signal: about to call Get + results[i], oks[i] = c.Get("k") + }(i) + } + + // Use the test deadline as a safeguard so a future refactor that breaks + // the allStarted synchronisation causes a fast failure rather than a hang. + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + deadline, ok := t.Deadline() + if !ok { + deadline = time.Now().Add(10 * time.Second) + } + select { + case <-done: + case <-time.After(time.Until(deadline)): + t.Fatal("timed out waiting for goroutines — possible deadlock in load synchronisation") + } - // The loaded-but-discarded value must be passed to onEvict. - assert.Equal(t, "k", evictedKey, "onEvict must be called for the discarded loaded value") - assert.Equal(t, "from-load", evictedVal, "onEvict must receive the discarded loaded value") + assert.Equal(t, int32(1), loadCount.Load(), "concurrent expired-entry Gets should coalesce to a single load") + for i := range goroutines { + assert.True(t, oks[i], "all goroutines should get ok=true") + assert.Equal(t, "reloaded", results[i]) + } } // ---------------------------------------------------------------------------