diff --git a/ratis-common/src/main/java/org/apache/ratis/util/CodeInjectionForTesting.java b/ratis-common/src/main/java/org/apache/ratis/util/CodeInjectionForTesting.java index 112f6bd250..290fa4287b 100644 --- a/ratis-common/src/main/java/org/apache/ratis/util/CodeInjectionForTesting.java +++ b/ratis-common/src/main/java/org/apache/ratis/util/CodeInjectionForTesting.java @@ -50,9 +50,9 @@ public interface Code { = new ConcurrentHashMap<>(); /** Put an injection point. */ - public static void put(String injectionPoint, Code code) { + public static Code put(String injectionPoint, Code code) { LOG.debug("put: {}, {}", injectionPoint, code); - INJECTION_POINTS.put(injectionPoint, code); + return INJECTION_POINTS.put(injectionPoint, code); } /** Execute the injected code, if there is any. */ diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcLogAppender.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcLogAppender.java index b4d78c207a..0cf07ee04b 100644 --- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcLogAppender.java +++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcLogAppender.java @@ -33,6 +33,7 @@ import org.apache.ratis.server.protocol.TermIndex; import org.apache.ratis.server.raftlog.RaftLog; import org.apache.ratis.server.util.ServerStringUtils; +import org.apache.ratis.thirdparty.io.grpc.Status; import org.apache.ratis.thirdparty.io.grpc.StatusRuntimeException; import org.apache.ratis.thirdparty.io.grpc.stub.CallStreamObserver; import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver; @@ -215,7 +216,7 @@ private void resetClient(AppendEntriesRequest request, Event event) { try (AutoCloseableLock writeLock = lock.writeLock(caller, LOG::trace)) { getClient().resetConnectBackoff(); if (appendLogRequestObserver != null) { - appendLogRequestObserver.stop(); + appendLogRequestObserver.stop(event); appendLogRequestObserver = null; } final int errorCount = replyState.process(event); @@ -266,16 +267,23 @@ private boolean installSnapshot() { @Override public void run() throws IOException { - for(; isRunning(); mayWait()) { - //HB period is expired OR we have messages OR follower is behind with commit index - if (shouldSendAppendEntries() || isFollowerCommitBehindLastCommitIndex()) { - final boolean installingSnapshot = installSnapshot(); - appendLog(installingSnapshot || haveTooManyPendingRequests()); + try { + for (; isRunning(); mayWait()) { + //HB period is expired OR we have messages OR follower is behind with commit index + if (shouldSendAppendEntries() || isFollowerCommitBehindLastCommitIndex()) { + final boolean installingSnapshot = installSnapshot(); + appendLog(installingSnapshot || haveTooManyPendingRequests()); + } + getLeaderState().checkHealth(getFollower()); + } + } finally { + try (AutoCloseableLock writeLock = lock.writeLock(caller, LOG::trace)) { + if (appendLogRequestObserver != null) { + appendLogRequestObserver.onCompleted(); + appendLogRequestObserver = null; + } } - getLeaderState().checkHealth(getFollower()); } - - Optional.ofNullable(appendLogRequestObserver).ifPresent(StreamObservers::onCompleted); } public long getWaitTimeMs() { @@ -366,16 +374,46 @@ void onNext(AppendEntriesRequestProto proto) while (!stream.isReady() && running) { sleep(waitForReady, isHeartBeat); } + if (!running) { + return; + } stream.onNext(proto); } - void stop() { + void stop(Event event) { running = false; + if (event == Event.COMPLETE) { + onCompleted(); + } else { + cancelStream("stop due to " + event); + } } void onCompleted() { - appendLog.onCompleted(); - Optional.ofNullable(heartbeat).ifPresent(StreamObserver::onCompleted); + try { + appendLog.onCompleted(); + } catch (Exception e) { + LOG.debug("Failed to complete appendLog stream", e); + } + try { + Optional.ofNullable(heartbeat).ifPresent(StreamObserver::onCompleted); + } catch (Exception e) { + LOG.debug("Failed to complete heartbeat stream", e); + } + } + + void cancelStream(String reason) { + try { + appendLog.onError(new StatusRuntimeException(Status.CANCELLED.withDescription(reason))); + } catch (Exception e) { + LOG.debug("Failed to cancel appendLog stream", e); + } + try { + Optional.ofNullable(heartbeat).ifPresent((hb) -> + hb.onError(new StatusRuntimeException(Status.CANCELLED.withDescription(reason)))); + } catch (Exception e) { + LOG.debug("Failed to cancel heartbeat stream", e); + } } } diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolService.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolService.java index a13e74b89d..aca9357d9c 100644 --- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolService.java +++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolService.java @@ -31,6 +31,8 @@ import org.apache.ratis.proto.RaftProtos.*; import org.apache.ratis.proto.grpc.RaftServerProtocolServiceGrpc.RaftServerProtocolServiceImplBase; import org.apache.ratis.util.BatchLogger; +import org.apache.ratis.util.CodeInjectionForTesting; +import org.apache.ratis.util.JavaUtils; import org.apache.ratis.util.MemoizedSupplier; import org.apache.ratis.util.ProtoUtils; import org.slf4j.Logger; @@ -46,6 +48,9 @@ class GrpcServerProtocolService extends RaftServerProtocolServiceImplBase { public static final Logger LOG = LoggerFactory.getLogger(GrpcServerProtocolService.class); + public static final String GRPC_SERVER_HANDLE_ERROR = + JavaUtils.getClassSimpleName(GrpcServerProtocolService.class) + ".handleError"; + private enum BatchLogKey implements BatchLogger.Key { COMPLETED_REQUEST, COMPLETED_REPLY @@ -114,7 +119,9 @@ StatusRuntimeException wrapException(Throwable e, REQUEST request) { private void handleError(Throwable e, REQUEST request) { GrpcUtil.warn(LOG, () -> getId() + ": Failed " + op + " request " + requestToString(request), e); if (isClosed.compareAndSet(false, true)) { + previousOnNext.set(null); responseObserver.onError(wrapException(e, request)); + CodeInjectionForTesting.execute(GRPC_SERVER_HANDLE_ERROR, getId(), null, previousOnNext.get()); } } @@ -172,18 +179,22 @@ public void onCompleted() { BatchLogger.print(BatchLogKey.COMPLETED_REQUEST, getName(), suffix -> LOG.info("{}: Completed {}, lastRequest: {} {}", getId(), op, getPreviousRequestString(), suffix)); + previousOnNext.set(null); requestFuture.get().thenAccept(reply -> { BatchLogger.print(BatchLogKey.COMPLETED_REPLY, getName(), suffix -> LOG.info("{}: Completed {}, lastReply: {} {}", getId(), op, ProtoUtils.shortDebugString(reply), suffix)); responseObserver.onCompleted(); }); + requestFuture.set(null); } } @Override public void onError(Throwable t) { GrpcUtil.warn(LOG, () -> getId() + ": "+ op + " onError, lastRequest: " + getPreviousRequestString(), t); if (isClosed.compareAndSet(false, true)) { + previousOnNext.set(null); + requestFuture.set(null); Status status = Status.fromThrowable(t); if (status != null && status.getCode() != Status.Code.CANCELLED) { responseObserver.onCompleted(); diff --git a/ratis-test/src/test/java/org/apache/ratis/grpc/TestLogAppenderWithGrpc.java b/ratis-test/src/test/java/org/apache/ratis/grpc/TestLogAppenderWithGrpc.java index 107cd7ba9a..729c681fee 100644 --- a/ratis-test/src/test/java/org/apache/ratis/grpc/TestLogAppenderWithGrpc.java +++ b/ratis-test/src/test/java/org/apache/ratis/grpc/TestLogAppenderWithGrpc.java @@ -25,12 +25,15 @@ import org.apache.ratis.conf.RaftProperties; import org.apache.ratis.grpc.metrics.GrpcServerMetrics; import org.apache.ratis.protocol.RaftClientReply; +import org.apache.ratis.protocol.RaftPeerId; import org.apache.ratis.server.RaftServer; import org.apache.ratis.server.RaftServerConfigKeys; import org.apache.ratis.server.leader.FollowerInfo; +import org.apache.ratis.server.leader.LogAppender; import org.apache.ratis.server.impl.RaftServerTestUtil; import org.apache.ratis.statemachine.impl.SimpleStateMachine4Testing; import org.apache.ratis.statemachine.StateMachine; +import org.apache.ratis.util.CodeInjectionForTesting.Code; import org.apache.ratis.util.JavaUtils; import org.apache.ratis.util.Slf4jUtils; import org.junit.jupiter.api.Assertions; @@ -39,10 +42,18 @@ import org.slf4j.event.Level; import java.io.IOException; +import java.lang.ref.WeakReference; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.List; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +import org.apache.ratis.server.impl.BlockRequestHandlingInjection; +import org.apache.ratis.util.CodeInjectionForTesting; import static org.apache.ratis.RaftTestUtil.waitForLeader; @@ -148,4 +159,207 @@ private void runTestRestartLogAppender(MiniRaftClusterWithGrpc cluster) throws E Assertions.assertTrue(newleaderMetrics.getRegistry().counter(counter).getCount() >= 1L); } } + + /** + * Verify that old LogAppender instances are properly cleaned up (gRPC streams terminated) + * after restartLogAppenders. Without the fix, gRPC holds references to unterminated + * stream response handlers, preventing old LogAppender instances from being collected. + */ + @ParameterizedTest + @MethodSource("data") + public void testLogAppenderStreamCleanupOnRestart(Boolean separateHeartbeat) throws Exception { + GrpcConfigKeys.Server.setHeartbeatChannel(getProperties(), separateHeartbeat); + runWithNewCluster(3, this::runTestLogAppenderStreamCleanupOnRestart); + } + + private void runTestLogAppenderStreamCleanupOnRestart(MiniRaftClusterWithGrpc cluster) throws Exception { + final RaftServer.Division leader = waitForLeader(cluster); + final RaftPeerId leaderId = leader.getId(); + + try (RaftClient client = cluster.createClient(leaderId)) { + for (int i = 0; i < 10; i++) { + final RaftClientReply reply = client.io().send(new RaftTestUtil.SimpleMessage("m" + i)); + Assertions.assertTrue(reply.isSuccess()); + } + } + + final List> weakRefs = + RaftServerTestUtil.getLogAppenders(leader) + .map(WeakReference::new) + .collect(Collectors.toList()); + Assertions.assertFalse(weakRefs.isEmpty()); + + RaftServerTestUtil.restartLogAppenders(leader); + + try (RaftClient client = cluster.createClient(leaderId)) { + for (int i = 0; i < 10; i++) { + client.io().send(new RaftTestUtil.SimpleMessage("after-" + i)); + } + } + + // Old appenders should be GC-able once their gRPC streams are terminated. + // If streams leaked, gRPC retains references to the response handlers + // (inner classes of GrpcLogAppender), preventing collection. + JavaUtils.attempt(() -> { + System.gc(); + for (WeakReference ref : weakRefs) { + Assertions.assertNull(ref.get(), + "Old LogAppender should be garbage collected after stream cleanup"); + } + }, 20, ONE_SECOND, "old-appender-gc", LOG); + } + + /** + * Verify that the follower's ServerRequestStreamObserver cleans up previousOnNext + * when handleError is triggered. This injects failures at the APPEND_ENTRIES point + * on a specific follower, causing process(request) to fail and handleError to be called. + * Without the fix, previousOnNext retains the last PendingServerRequest (including the + * full AppendEntriesRequestProto with log entry data) after handleError closes the stream. + */ + @ParameterizedTest + @MethodSource("data") + public void testFollowerHandleErrorCleanup(Boolean separateHeartbeat) throws Exception { + GrpcConfigKeys.Server.setHeartbeatChannel(getProperties(), separateHeartbeat); + runWithNewCluster(3, this::runTestFollowerHandleErrorCleanup); + } + + private void runTestFollowerHandleErrorCleanup(MiniRaftClusterWithGrpc cluster) throws Exception { + final RaftServer.Division leader = waitForLeader(cluster); + final RaftPeerId leaderId = leader.getId(); + + try (RaftClient client = cluster.createClient(leaderId)) { + for (int i = 0; i < 5; i++) { + Assertions.assertTrue(client.io().send( + new RaftTestUtil.SimpleMessage("init-" + i)).isSuccess()); + } + } + + final RaftPeerId followerId = cluster.getFollowers().get(0).getId(); + final String APPEND_ENTRIES = "RaftServerImpl.appendEntries"; + final String GRPC_SERVER_HANDLE_ERROR = "GrpcServerProtocolService.handleError"; + final AtomicBoolean shouldFail = new AtomicBoolean(false); + final AtomicInteger handleErrorCount = new AtomicInteger(0); + final AtomicInteger leakCount = new AtomicInteger(0); + + Code previousAECode = CodeInjectionForTesting.put(APPEND_ENTRIES, (localId, remoteId, args) -> { + if (shouldFail.get() && localId.toString().equals(followerId.toString())) { + throw new RuntimeException("Injected failure for handleError test"); + } + return false; + }); + CodeInjectionForTesting.put(GRPC_SERVER_HANDLE_ERROR, (localId, remoteId, args) -> { + handleErrorCount.incrementAndGet(); + if (args != null && args.length > 0 && args[0] != null) { + leakCount.incrementAndGet(); + } + return false; + }); + + try { + final int numCycles = 3; + for (int cycle = 0; cycle < numCycles; cycle++) { + LOG.info("=== HandleError cycle {} ===", cycle); + + shouldFail.set(true); + + try (RaftClient client = cluster.createClient(leaderId)) { + for (int i = 0; i < 5; i++) { + client.io().send(new RaftTestUtil.SimpleMessage("fail-" + cycle + "-" + i)); + } + } + + JavaUtils.attempt(() -> { + final long leaderCommit = leader.getRaftLog().getLastCommittedIndex(); + Assertions.assertTrue(leaderCommit > 0); + }, 10, ONE_SECOND, "leader-commit-cycle-" + cycle, LOG); + + shouldFail.set(false); + + try (RaftClient client = cluster.createClient(leaderId)) { + for (int i = 0; i < 5; i++) { + Assertions.assertTrue(client.io().send( + new RaftTestUtil.SimpleMessage("recover-" + cycle + "-" + i)).isSuccess()); + } + } + + final int c = cycle; + JavaUtils.attempt(() -> { + final RaftServer.Division f = cluster.getDivision(followerId); + Assertions.assertTrue(f.getInfo().getLastAppliedIndex() > 0, + "Follower " + followerId + " should recover after handleError"); + }, 10, ONE_SECOND, "follower-recover-" + c, LOG); + } + + try (RaftClient client = cluster.createClient(leaderId)) { + final RaftClientReply reply = client.io().send(new RaftTestUtil.SimpleMessage("final")); + Assertions.assertTrue(reply.isSuccess()); + client.io().watch(reply.getLogIndex(), RaftProtos.ReplicationLevel.ALL_COMMITTED); + } + + Assertions.assertTrue(handleErrorCount.get() > 0, + "handleError should have been triggered by the injected failures"); + Assertions.assertEquals(0, leakCount.get(), + "previousOnNext should be cleaned up in handleError to prevent memory leaks"); + } finally { + if (previousAECode != null) { + CodeInjectionForTesting.put(APPEND_ENTRIES, previousAECode); + } + CodeInjectionForTesting.remove(GRPC_SERVER_HANDLE_ERROR); + } + } + + /** + * Reproduce the StreamObserver leak by repeatedly killing and restarting a follower. + * Each kill triggers onError on the leader's response handler, calling + * resetClient -> stop(ERROR) -> cancelStream. Without the fix, the old gRPC streams + * are never terminated and accumulate on both leader and follower sides. + */ + @ParameterizedTest + @MethodSource("data") + public void testStreamObserverCleanupOnFollowerKillRestart(Boolean separateHeartbeat) throws Exception { + GrpcConfigKeys.Server.setHeartbeatChannel(getProperties(), separateHeartbeat); + runWithNewCluster(3, this::runTestStreamObserverCleanupOnFollowerKillRestart); + } + + private void runTestStreamObserverCleanupOnFollowerKillRestart(MiniRaftClusterWithGrpc cluster) throws Exception { + final RaftServer.Division leader = waitForLeader(cluster); + final RaftPeerId leaderId = leader.getId(); + final int numCycles = 5; + + try (RaftClient client = cluster.createClient(leaderId)) { + for (int i = 0; i < 5; i++) { + Assertions.assertTrue(client.io().send( + new RaftTestUtil.SimpleMessage("init-" + i)).isSuccess()); + } + } + + for (int cycle = 0; cycle < numCycles; cycle++) { + LOG.info("=== Kill/Restart cycle {} ===", cycle); + final RaftPeerId followerId = cluster.getFollowers().get(0).getId(); + + cluster.killServer(followerId); + + try (RaftClient client = cluster.createClient(leaderId)) { + for (int i = 0; i < 5; i++) { + Assertions.assertTrue(client.io().send( + new RaftTestUtil.SimpleMessage("cycle" + cycle + "-" + i)).isSuccess()); + } + } + + cluster.restartServer(followerId, false); + + JavaUtils.attempt(() -> { + final RaftServer.Division f = cluster.getDivision(followerId); + Assertions.assertTrue(f.getInfo().getLastAppliedIndex() > 0, + "Follower " + followerId + " should have applied entries"); + }, 10, ONE_SECOND, "follower-catchup-" + cycle, LOG); + } + + // Verify all entries committed across the cluster + try (RaftClient client = cluster.createClient(leaderId)) { + final RaftClientReply reply = client.io().send(new RaftTestUtil.SimpleMessage("final")); + Assertions.assertTrue(reply.isSuccess()); + client.io().watch(reply.getLogIndex(), RaftProtos.ReplicationLevel.ALL_COMMITTED); + } + } }