Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ public ChannelFinder(
ChannelEndpointCache endpointCache,
@Nullable EndpointLifecycleManager lifecycleManager,
@Nullable String finderKey) {
this.rangeCache = new KeyRangeCache(Objects.requireNonNull(endpointCache), lifecycleManager);
this.rangeCache =
new KeyRangeCache(Objects.requireNonNull(endpointCache), lifecycleManager, finderKey);
this.lifecycleManager = lifecycleManager;
this.finderKey = finderKey;
}
Expand All @@ -91,6 +92,11 @@ void useDeterministicRandom() {
rangeCache.useDeterministicRandom();
}

@Nullable
String finderKey() {
return finderKey;
}

private static ExecutorService createCacheUpdatePool() {
ThreadPoolExecutor executor =
new ThreadPoolExecutor(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
/*
* Copyright 2026 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.spanner.spi.v1;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Ticker;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import java.time.Duration;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

/** Shared process-local latency scores for routed Spanner endpoints. */
final class EndpointLatencyRegistry {
private static final String GLOBAL_SCOPE = "__global__";

static final Duration DEFAULT_ERROR_PENALTY = Duration.ofSeconds(10);
static final Duration DEFAULT_RTT = Duration.ofMillis(10);
static final double DEFAULT_PENALTY_VALUE = 1_000_000.0;
@VisibleForTesting static final Duration TRACKER_EXPIRE_AFTER_ACCESS = Duration.ofMinutes(10);
@VisibleForTesting static final long MAX_TRACKERS = 100_000L;

private static volatile Cache<TrackerKey, LatencyTracker> TRACKERS =
newTrackerCache(Ticker.systemTicker());
private static final ConcurrentHashMap<String, AtomicInteger> INFLIGHT_REQUESTS =
Comment thread
rahul2393 marked this conversation as resolved.
Outdated
new ConcurrentHashMap<>();

private EndpointLatencyRegistry() {}

static boolean hasScore(
@javax.annotation.Nullable String databaseScope,
long operationUid,
String endpointLabelOrAddress) {
TrackerKey trackerKey = trackerKey(databaseScope, operationUid, endpointLabelOrAddress);
return trackerKey != null && TRACKERS.getIfPresent(trackerKey) != null;
}

static double getSelectionCost(
@javax.annotation.Nullable String databaseScope,
long operationUid,
String endpointLabelOrAddress) {
TrackerKey trackerKey = trackerKey(databaseScope, operationUid, endpointLabelOrAddress);
if (trackerKey == null) {
return Double.MAX_VALUE;
}
double activeRequests = getInflight(endpointLabelOrAddress);
LatencyTracker tracker = TRACKERS.getIfPresent(trackerKey);
if (tracker != null) {
return tracker.getScore() * (activeRequests + 1.0);
}
if (activeRequests > 0.0) {
return DEFAULT_PENALTY_VALUE + activeRequests;
}
return defaultRttMicros() * (activeRequests + 1.0);
Comment thread
rahul2393 marked this conversation as resolved.
Outdated
}

static void recordLatency(
@javax.annotation.Nullable String databaseScope,
long operationUid,
String endpointLabelOrAddress,
Duration latency) {
TrackerKey trackerKey = trackerKey(databaseScope, operationUid, endpointLabelOrAddress);
if (trackerKey == null || latency == null) {
return;
}
getOrCreateTracker(trackerKey).update(latency);
}

static void recordError(
@javax.annotation.Nullable String databaseScope,
long operationUid,
String endpointLabelOrAddress) {
recordError(databaseScope, operationUid, endpointLabelOrAddress, DEFAULT_ERROR_PENALTY);
}

static void recordError(
@javax.annotation.Nullable String databaseScope,
long operationUid,
String endpointLabelOrAddress,
Duration penalty) {
TrackerKey trackerKey = trackerKey(databaseScope, operationUid, endpointLabelOrAddress);
if (trackerKey == null || penalty == null) {
return;
}
getOrCreateTracker(trackerKey).recordError(penalty);
}

static void beginRequest(String endpointLabelOrAddress) {
String address = normalizeAddress(endpointLabelOrAddress);
if (address == null) {
return;
}
INFLIGHT_REQUESTS.computeIfAbsent(address, ignored -> new AtomicInteger()).incrementAndGet();
}

static void finishRequest(String endpointLabelOrAddress) {
String address = normalizeAddress(endpointLabelOrAddress);
if (address == null) {
return;
}
AtomicInteger counter = INFLIGHT_REQUESTS.get(address);
if (counter == null) {
return;
}
counter.updateAndGet(current -> current > 0 ? current - 1 : 0);
}

static int getInflight(String endpointLabelOrAddress) {
String address = normalizeAddress(endpointLabelOrAddress);
if (address == null) {
return 0;
}
AtomicInteger counter = INFLIGHT_REQUESTS.get(address);
return counter == null ? 0 : Math.max(0, counter.get());
}

@VisibleForTesting
static void clear() {
TRACKERS.invalidateAll();
INFLIGHT_REQUESTS.clear();
}

@VisibleForTesting
static void useTrackerTicker(Ticker ticker) {
TRACKERS = newTrackerCache(ticker);
}

@VisibleForTesting
static String normalizeAddress(String endpointLabelOrAddress) {
if (endpointLabelOrAddress == null || endpointLabelOrAddress.isEmpty()) {
return null;
}
return endpointLabelOrAddress;
}

@VisibleForTesting
static TrackerKey trackerKey(
@javax.annotation.Nullable String databaseScope,
long operationUid,
String endpointLabelOrAddress) {
String address = normalizeAddress(endpointLabelOrAddress);
if (operationUid <= 0 || address == null) {
return null;
}
return new TrackerKey(normalizeScope(databaseScope), operationUid, address);
}

private static long defaultRttMicros() {
return DEFAULT_RTT.toNanos() / 1_000L;
}

private static String normalizeScope(@javax.annotation.Nullable String databaseScope) {
return (databaseScope == null || databaseScope.isEmpty()) ? GLOBAL_SCOPE : databaseScope;
}

private static LatencyTracker getOrCreateTracker(TrackerKey trackerKey) {
try {
return TRACKERS.get(trackerKey, EwmaLatencyTracker::new);
} catch (ExecutionException e) {
throw new IllegalStateException("Failed to create latency tracker", e);
}
}

private static Cache<TrackerKey, LatencyTracker> newTrackerCache(Ticker ticker) {
return CacheBuilder.newBuilder()
.maximumSize(MAX_TRACKERS)
.expireAfterAccess(TRACKER_EXPIRE_AFTER_ACCESS.toNanos(), TimeUnit.NANOSECONDS)
.ticker(ticker)
.build();
}

@VisibleForTesting
static final class TrackerKey {
private final String databaseScope;
private final long operationUid;
private final String address;
Comment thread
rahul2393 marked this conversation as resolved.

private TrackerKey(String databaseScope, long operationUid, String address) {
this.databaseScope = databaseScope;
this.operationUid = operationUid;
this.address = address;
}

@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (!(other instanceof TrackerKey)) {
return false;
}
TrackerKey that = (TrackerKey) other;
return operationUid == that.operationUid
&& Objects.equals(databaseScope, that.databaseScope)
&& Objects.equals(address, that.address);
}

@Override
public int hashCode() {
return Objects.hash(databaseScope, operationUid, address);
}

@Override
public String toString() {
return databaseScope + ":" + operationUid + "@" + address;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -70,8 +71,9 @@ class EndpointLifecycleManager {
private static final long EVICTION_CHECK_INTERVAL_SECONDS = 300;

/**
* Maximum consecutive TRANSIENT_FAILURE probes before evicting an endpoint. Gives the channel
* time to recover from transient network issues before we tear it down and recreate.
* Maximum observed TRANSIENT_FAILURE probes before evicting an endpoint. The counter resets only
* after the channel reaches READY, so CONNECTING/IDLE oscillation does not hide a persistently
* unhealthy endpoint.
*/
private static final int MAX_TRANSIENT_FAILURE_COUNT = 3;

Expand Down Expand Up @@ -104,6 +106,7 @@ static final class EndpointState {

private final ChannelEndpointCache endpointCache;
private final Map<String, EndpointState> endpoints = new ConcurrentHashMap<>();
private final Set<String> evictedAddresses = ConcurrentHashMap.newKeySet();
Comment thread
rahul2393 marked this conversation as resolved.
Outdated
private final Set<String> transientFailureEvictedAddresses = ConcurrentHashMap.newKeySet();
private final Map<String, Long> finderGenerations = new ConcurrentHashMap<>();
private final Map<String, PendingActiveAddressUpdate> pendingActiveAddressUpdates =
Expand Down Expand Up @@ -215,6 +218,7 @@ private boolean ensureEndpointExists(String address) {
address,
addr -> {
logger.log(Level.FINE, "Creating endpoint state for address: {0}", addr);
evictedAddresses.remove(addr);
created[0] = true;
return new EndpointState(addr, clock.instant());
});
Expand Down Expand Up @@ -493,7 +497,8 @@ private void stopProbing(String address) {
* <p>All exceptions are caught to prevent {@link ScheduledExecutorService} from cancelling future
* runs of this task.
*/
private void probe(String address) {
@VisibleForTesting
void probe(String address) {
try {
if (isShutdown.get()) {
return;
Expand Down Expand Up @@ -530,25 +535,24 @@ private void probe(String address) {
logger.log(
Level.FINE, "Probe for {0}: channel IDLE, requesting connection (warmup)", address);
channel.getState(true);
state.consecutiveTransientFailures = 0;
break;

case CONNECTING:
state.consecutiveTransientFailures = 0;
break;

case TRANSIENT_FAILURE:
state.consecutiveTransientFailures++;
logger.log(
Level.FINE,
"Probe for {0}: channel in TRANSIENT_FAILURE ({1}/{2})",
"Probe for {0}: channel in TRANSIENT_FAILURE ({1}/{2} observed failures since last"
+ " READY)",
new Object[] {
address, state.consecutiveTransientFailures, MAX_TRANSIENT_FAILURE_COUNT
});
if (state.consecutiveTransientFailures >= MAX_TRANSIENT_FAILURE_COUNT) {
logger.log(
Level.FINE,
"Evicting endpoint {0}: {1} consecutive TRANSIENT_FAILURE probes",
"Evicting endpoint {0}: {1} TRANSIENT_FAILURE probes without reaching READY",
new Object[] {address, state.consecutiveTransientFailures});
evictEndpoint(address, EvictionReason.TRANSIENT_FAILURE);
}
Expand Down Expand Up @@ -608,6 +612,7 @@ private void evictEndpoint(String address, EvictionReason reason) {

stopProbing(address);
endpoints.remove(address);
evictedAddresses.add(address);
if (reason == EvictionReason.TRANSIENT_FAILURE) {
markTransientFailureEvicted(address);
} else {
Expand Down Expand Up @@ -636,6 +641,7 @@ void requestEndpointRecreation(String address) {

logger.log(Level.FINE, "Recreating previously evicted endpoint for address: {0}", address);
EndpointState state = new EndpointState(address, clock.instant());
evictedAddresses.remove(address);
if (endpoints.putIfAbsent(address, state) == null) {
// Schedule after putIfAbsent returns so the entry is visible to the scheduler thread.
scheduler.submit(() -> createAndStartProbing(address));
Expand Down Expand Up @@ -663,6 +669,32 @@ int managedEndpointCount() {
return endpoints.size();
}

Map<String, Long> snapshotEndpointStateCounts() {
Map<String, Long> counts = new HashMap<>();
snapshotEndpointStates().values().forEach(state -> counts.merge(state, 1L, Long::sum));
return counts;
}

Map<String, String> snapshotEndpointStates() {
Map<String, String> states = new HashMap<>();
for (String address : endpoints.keySet()) {
ChannelEndpoint endpoint = endpointCache.getIfPresent(address);
String stateName = "unknown";
if (endpoint != null) {
ConnectivityState state = endpoint.getChannel().getState(false);
stateName =
state == ConnectivityState.TRANSIENT_FAILURE
? "transient_failure"
: state.name().toLowerCase();
}
states.put(address, stateName);
}
for (String address : evictedAddresses) {
states.putIfAbsent(address, "evicted");
}
return states;
}

/** Shuts down the lifecycle manager and all probing. */
void shutdown() {
if (!isShutdown.compareAndSet(false, true)) {
Expand All @@ -684,6 +716,7 @@ void shutdown() {
}
}
endpoints.clear();
evictedAddresses.clear();
transientFailureEvictedAddresses.clear();
pendingActiveAddressUpdates.clear();
queuedFinderKeys.clear();
Expand Down
Loading
Loading