Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 81 additions & 48 deletions pkg/cache/validating_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ package cache
import (
"errors"
"fmt"
"reflect"
"sync"

lru "github.com/hashicorp/golang-lru/v2"
"golang.org/x/sync/singleflight"
Expand All @@ -18,25 +20,26 @@ import (
var ErrExpired = errors.New("cache entry expired")

// ValidatingCache is a node-local write-through cache backed by a
// capacity-bounded LRU map, with singleflight-deduplicated restore on cache
// miss and lazy liveness validation on cache hit.
// capacity-bounded LRU map, with singleflight-deduplicated Get operations and
// lazy liveness validation on cache hit.
//
// Type parameter K is the key type (must be comparable).
// Type parameter V is the cached value type.
//
// The no-resurrection invariant (preventing a concurrent restore from
// overwriting a deletion) is enforced via ContainsOrAdd: if a concurrent
// writer stored a value between load() returning and the cache being updated,
// the prior writer's value wins and the just-loaded value is discarded via
// onEvict.
// The entire Get operation — cache hit validation and miss load — runs under a
// singleflight group so at most one operation executes concurrently per key.
// Concurrent callers for the same key share the result, coalescing both
// liveness checks and storage round-trips into a single operation per key.
type ValidatingCache[K comparable, V any] struct {
lruCache *lru.Cache[K, V]
flight singleflight.Group
load func(key K) (V, error)
check func(key K, val V) error
// onEvict is kept here so we can call it when discarding a concurrently
// loaded value that lost the race to a prior writer.
onEvict func(key K, val V)
onEvict func(K, V)
// mu serializes Set against the conditional eviction in getHit.
// check() runs outside the lock to avoid holding it during I/O; the lock
// is only held for the short Peek+Remove sequence.
mu sync.Mutex
}

// New creates a ValidatingCache with the given capacity and callbacks.
Expand Down Expand Up @@ -88,63 +91,74 @@ func New[K comparable, V any](
func (c *ValidatingCache[K, V]) getHit(key K, val V) (V, bool) {
if err := c.check(key, val); err != nil {
if errors.Is(err, ErrExpired) {
// Remove fires the eviction callback automatically.
c.lruCache.Remove(key)
// check() ran outside the lock to avoid holding it during I/O.
// Re-verify under the lock that the entry hasn't been replaced by a
// concurrent Set before removing it; otherwise we would evict a
// freshly-written value that the caller intended to keep.
c.mu.Lock()
if current, ok := c.lruCache.Peek(key); ok && sameEntry(current, val) {
// Remove fires the eviction callback automatically.
c.lruCache.Remove(key)
}
c.mu.Unlock()
Comment thread
yrobla marked this conversation as resolved.
var zero V
return zero, false
}
Comment thread
yrobla marked this conversation as resolved.
}
return val, true
}

// Get returns the value for key, loading it on a cache miss. On a cache hit
// the entry's liveness is validated via the check function provided to New:
// ErrExpired evicts the entry and returns (zero, false); transient errors
// return the cached value unchanged. On a cache miss, load is called under a
// singleflight group so at most one restore runs concurrently per key.
// Get returns the value for key, loading it on a cache miss. The entire
// operation — cache hit validation and miss load — runs under a singleflight
// group so at most one operation executes concurrently per key. Concurrent
// callers for the same key share the result.
//
// On a cache hit the entry's liveness is validated via the check function
// provided to New: ErrExpired evicts the entry and falls through to load;
// transient errors return the cached value unchanged. On a cache miss, load
// is called to restore the value.
//
// The returned bool is false whenever the value is unavailable — either
// because load returned an error or because the key does not exist in the
// backing store. Callers cannot distinguish these two cases.
func (c *ValidatingCache[K, V]) Get(key K) (V, bool) {
if val, ok := c.lruCache.Get(key); ok {
return c.getHit(key, val)
}

// Cache miss: use singleflight to prevent concurrent restores for the same key.
type result struct{ v V }
// fmt.Sprint(key) is the singleflight key. For string keys this is
// exact. For other types, distinct values with identical string
// representations would be incorrectly coalesced — avoid non-string K
// types unless their fmt.Sprint output is guaranteed unique.
raw, err, _ := c.flight.Do(fmt.Sprint(key), func() (any, error) {
// Re-check the cache: a concurrent singleflight group may have stored
// the value between our miss check above and acquiring this group.
if existing, ok := c.lruCache.Get(key); ok {
return result{v: existing}, nil
// Cache hit path: validate liveness.
Comment thread
yrobla marked this conversation as resolved.
if val, ok := c.lruCache.Get(key); ok {
v, alive := c.getHit(key, val)
if alive {
return result{v: v}, nil
}
// Entry expired and evicted; fall through to load.
}
Comment thread
yrobla marked this conversation as resolved.

// Cache miss (or expired): load the value and store it.
v, loadErr := c.load(key)
if loadErr != nil {
return nil, loadErr
}

// Guard against a concurrent Set or Remove that occurred while load() was
// running. ContainsOrAdd stores only if absent; if another writer got
// in first, their value wins and we discard ours via onEvict.
ok, _ := c.lruCache.ContainsOrAdd(key, v)
if ok {
// Another writer stored a value first; discard our loaded value and
// return the winner's. ContainsOrAdd and Get are separate lock
// acquisitions, so the winner may itself have been evicted by LRU
// pressure between the two calls — fall back to our freshly loaded
// value in that case rather than returning a zero value.
winner, found := c.lruCache.Get(key)
// Discard our loaded value in favour of the winner (or clean up if
// the winner was itself evicted between ContainsOrAdd and Get).
if c.onEvict != nil {
c.onEvict(key, v)
// Guard against a concurrent Set that occurred while load() was running.
// ContainsOrAdd stores only if absent; if a concurrent Set got in first,
// their value wins and we return it instead.
if alreadySet, _ := c.lruCache.ContainsOrAdd(key, v); alreadySet {
if winner, ok := c.lruCache.Get(key); ok {
// Winner confirmed: v is definitively discarded — release its resources.
if c.onEvict != nil {
c.onEvict(key, v)
}
return result{v: winner}, nil
Comment thread
yrobla marked this conversation as resolved.
}
Comment thread
yrobla marked this conversation as resolved.
if !found {
// Winner was evicted before we could retrieve it; signal a cache
// miss so the caller retries rather than receiving a stale value.
return nil, nil
}
return result{v: winner}, nil
// The concurrent winner was itself evicted by LRU pressure between
// ContainsOrAdd and Get. Fall back to storing v — do NOT call onEvict
// since v has not been released and is still valid.
c.lruCache.Add(key, v)
}
Comment thread
yrobla marked this conversation as resolved.

return result{v: v}, nil
Comment thread
yrobla marked this conversation as resolved.
Comment thread
yrobla marked this conversation as resolved.
})
if err != nil {
Expand All @@ -159,10 +173,29 @@ func (c *ValidatingCache[K, V]) Get(key K) (V, bool) {
// cache is at capacity, the least-recently-used entry is evicted first and
// onEvict is called for it.
func (c *ValidatingCache[K, V]) Set(key K, value V) {
c.mu.Lock()
defer c.mu.Unlock()
c.lruCache.Add(key, value)
}

// Len returns the number of entries currently in the cache.
func (c *ValidatingCache[K, V]) Len() int {
return c.lruCache.Len()
}

// sameEntry reports whether a and b are the same cache entry.
// For pointer types it compares addresses (identity), so a concurrent Set that
// stores a distinct new value is never mistaken for the stale entry. For
// non-pointer types it falls back to reflect.DeepEqual, which is safe for all
// comparable and non-comparable types.
func sameEntry[V any](a, b V) bool {
ra := reflect.ValueOf(any(a))
if ra.IsValid() {
switch ra.Kind() { //nolint:exhaustive
case reflect.Ptr, reflect.UnsafePointer:
rb := reflect.ValueOf(any(b))
return rb.IsValid() && ra.Pointer() == rb.Pointer()
}
}
return reflect.DeepEqual(a, b)
Comment thread
yrobla marked this conversation as resolved.
}
Loading
Loading