getStateClass() {
@@ -436,7 +512,8 @@ public Executing getState() {
userCodeClassLoader,
failureCollection,
stateTransitionManagerFactory,
- rescaleOnFailedCheckpointCount);
+ rescaleOnFailedCheckpointCount,
+ activeCheckpointTriggerEnabled);
}
}
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/StateTransitionManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/StateTransitionManager.java
index 98229a9afd3e3..e057addd40f4a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/StateTransitionManager.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/StateTransitionManager.java
@@ -89,5 +89,24 @@ interface Context extends RescaleContext {
* @return the {@link JobID} of the job
*/
JobID getJobId();
+
+ /**
+ * Requests the context to actively trigger a checkpoint to expedite rescaling. Called by
+ * the {@link DefaultStateTransitionManager} from within phase lifecycle methods:
+ *
+ *
+ * - On entering {@link DefaultStateTransitionManager.Stabilizing} (to overlap
+ * checkpoint with the stabilization wait)
+ *
- On each {@link DefaultStateTransitionManager.Stabilizing#onChange} event (retry if
+ * a previous trigger was skipped)
+ *
- On entering {@link DefaultStateTransitionManager.Stabilized} (fallback if no
+ * checkpoint completed during stabilization)
+ *
+ *
+ * The implementation decides whether to actually trigger based on its own guard
+ * conditions (e.g., checkpointing enabled, no checkpoint in progress, config flag).
+ * Multiple calls are safe; guards prevent redundant triggers.
+ */
+ default void requestActiveCheckpointTrigger() {}
}
}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTestingUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTestingUtils.java
index 0a505ab07f2c6..5493e05667c68 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTestingUtils.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTestingUtils.java
@@ -63,6 +63,7 @@
import org.apache.flink.runtime.testutils.CommonTestUtils;
import org.apache.flink.util.InstantiationUtil;
import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.clock.Clock;
import org.apache.flink.util.clock.SystemClock;
import org.apache.flink.util.concurrent.Executors;
import org.apache.flink.util.concurrent.ManuallyTriggeredScheduledExecutor;
@@ -791,6 +792,8 @@ public static class CheckpointCoordinatorBuilder {
VertexFinishedStateChecker>
vertexFinishedStateCheckerFactory = VertexFinishedStateChecker::new;
+ private Clock clock = SystemClock.getInstance();
+
public CheckpointCoordinatorBuilder setCheckpointCoordinatorConfiguration(
CheckpointCoordinatorConfiguration checkpointCoordinatorConfiguration) {
this.checkpointCoordinatorConfiguration = checkpointCoordinatorConfiguration;
@@ -870,6 +873,11 @@ public CheckpointCoordinatorBuilder setVertexFinishedStateCheckerFactory(
return this;
}
+ public CheckpointCoordinatorBuilder setClock(Clock clock) {
+ this.clock = clock;
+ return this;
+ }
+
public CheckpointCoordinator build(ScheduledExecutorService executorService)
throws Exception {
return build(
@@ -899,7 +907,7 @@ public CheckpointCoordinator build(ExecutionGraph executionGraph) throws Excepti
timer,
failureManager,
checkpointPlanCalculator,
- SystemClock.getInstance(),
+ clock,
checkpointStatsTracker,
vertexFinishedStateCheckerFactory);
}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/DefaultStateTransitionManagerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/DefaultStateTransitionManagerTest.java
index cf6051a90b315..83c9002d9facf 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/DefaultStateTransitionManagerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/DefaultStateTransitionManagerTest.java
@@ -400,6 +400,59 @@ private static void assertFinalStateTransitionHappened(
assertThat(testInstance.getPhase()).isInstanceOf(Transitioning.class);
}
+ @Test
+ void testActiveCheckpointTriggerCalledOnEnteringStabilizing() {
+ final TestingStateTransitionManagerContext ctx =
+ TestingStateTransitionManagerContext.stableContext();
+ ctx.withSufficientResources();
+ final DefaultStateTransitionManager testInstance =
+ ctx.createTestInstanceThatPassedCooldownPhase();
+
+ assertThat(testInstance.getPhase()).isInstanceOf(Idling.class);
+ ctx.clearActiveCheckpointTriggerCount();
+
+ testInstance.onChange(true);
+
+ assertThat(testInstance.getPhase()).isInstanceOf(Stabilizing.class);
+ assertThat(ctx.getActiveCheckpointTriggerCount()).isGreaterThanOrEqualTo(1);
+ }
+
+ @Test
+ void testActiveCheckpointTriggerCalledOnChangeInStabilizing() {
+ final TestingStateTransitionManagerContext ctx =
+ TestingStateTransitionManagerContext.stableContext();
+ ctx.withSufficientResources();
+ final DefaultStateTransitionManager testInstance =
+ ctx.createTestInstanceThatPassedCooldownPhase();
+
+ testInstance.onChange(true);
+ assertThat(testInstance.getPhase()).isInstanceOf(Stabilizing.class);
+ ctx.clearActiveCheckpointTriggerCount();
+
+ testInstance.onChange(true);
+
+ assertThat(testInstance.getPhase()).isInstanceOf(Stabilizing.class);
+ assertThat(ctx.getActiveCheckpointTriggerCount()).isGreaterThanOrEqualTo(1);
+ }
+
+ @Test
+ void testActiveCheckpointTriggerCalledOnEnteringStabilized() {
+ final TestingStateTransitionManagerContext ctx =
+ TestingStateTransitionManagerContext.stableContext();
+ ctx.withSufficientResources();
+ final DefaultStateTransitionManager testInstance =
+ ctx.createTestInstanceThatPassedCooldownPhase();
+
+ testInstance.onChange(true);
+ assertThat(testInstance.getPhase()).isInstanceOf(Stabilizing.class);
+ ctx.clearActiveCheckpointTriggerCount();
+
+ ctx.passResourceStabilizationTimeout();
+
+ assertThat(testInstance.getPhase()).isInstanceOf(Stabilized.class);
+ assertThat(ctx.getActiveCheckpointTriggerCount()).isGreaterThanOrEqualTo(1);
+ }
+
private static void changeWithoutPhaseMove(
TestingStateTransitionManagerContext ctx,
DefaultStateTransitionManager testInstance,
@@ -460,6 +513,7 @@ private static class TestingStateTransitionManagerContext
// internal state used for assertions
private final AtomicBoolean transitionTriggered = new AtomicBoolean();
+ private int activeCheckpointTriggerCount = 0;
private final SortedMap>> scheduledTasks =
new TreeMap<>();
@@ -537,6 +591,11 @@ public void transitionToSubsequentState() {
transitionTriggered.set(true);
}
+ @Override
+ public void requestActiveCheckpointTrigger() {
+ activeCheckpointTriggerCount++;
+ }
+
@Override
public ScheduledFuture> scheduleOperation(Runnable callback, Duration delay) {
final Instant triggerTime =
@@ -703,5 +762,13 @@ public boolean stateTransitionWasTriggered() {
public void clearStateTransition() {
transitionTriggered.set(false);
}
+
+ public int getActiveCheckpointTriggerCount() {
+ return activeCheckpointTriggerCount;
+ }
+
+ public void clearActiveCheckpointTriggerCount() {
+ activeCheckpointTriggerCount = 0;
+ }
}
}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/ExecutingTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/ExecutingTest.java
index f593082f6c75b..cf216f0296e57 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/ExecutingTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/ExecutingTest.java
@@ -57,6 +57,7 @@
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguration;
import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
import org.apache.flink.runtime.operators.coordination.CoordinatorStoreImpl;
import org.apache.flink.runtime.scheduler.DefaultVertexParallelismInfo;
@@ -75,6 +76,8 @@
import org.apache.flink.testutils.TestingUtils;
import org.apache.flink.testutils.executor.TestExecutorExtension;
import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.clock.ManualClock;
+import org.apache.flink.util.concurrent.ManuallyTriggeredScheduledExecutor;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
@@ -160,7 +163,8 @@ void testNoDeploymentCallOnEnterWhenVertexRunning() throws Exception {
ClassLoader.getSystemClassLoader(),
new ArrayList<>(),
(context) -> TestingStateTransitionManager.withNoOp(),
- 1);
+ 1,
+ false);
assertThat(mockExecutionVertex.isDeployCalled()).isFalse();
}
}
@@ -186,7 +190,8 @@ void testIllegalStateExceptionOnNotRunningExecutionGraph() {
ClassLoader.getSystemClassLoader(),
new ArrayList<>(),
context -> TestingStateTransitionManager.withNoOp(),
- 1);
+ 1,
+ false);
}
})
.isInstanceOf(IllegalStateException.class);
@@ -556,6 +561,313 @@ public CheckpointCoordinator getCheckpointCoordinator() {
}
}
+ @Test
+ void testActiveCheckpointTriggerSkipsWhenDisabled() throws Exception {
+ try (MockExecutingContext ctx = new MockExecutingContext()) {
+ final AtomicBoolean coordinatorAccessed = new AtomicBoolean(false);
+ CheckpointCoordinator coordinator =
+ new CheckpointCoordinatorTestingUtils.CheckpointCoordinatorBuilder()
+ .build(EXECUTOR_EXTENSION.getExecutor());
+ StateTrackingMockExecutionGraph graph =
+ new StateTrackingMockExecutionGraph() {
+ @Nullable
+ @Override
+ public CheckpointCoordinator getCheckpointCoordinator() {
+ coordinatorAccessed.set(true);
+ return coordinator;
+ }
+ };
+ Executing exec =
+ new ExecutingStateBuilder()
+ .setExecutionGraph(graph)
+ .setActiveCheckpointTriggerEnabled(false)
+ .build(ctx);
+
+ exec.requestActiveCheckpointTrigger();
+ assertThat(coordinatorAccessed.get()).isFalse();
+ }
+ }
+
+ @Test
+ void testActiveCheckpointTriggerSkipsWhenNoCoordinator() throws Exception {
+ try (MockExecutingContext ctx = new MockExecutingContext()) {
+ MockExecutionJobVertex mejv = new MockExecutionJobVertex(MockExecutionVertex::new);
+ ExecutionGraph graph = new MockExecutionGraph(() -> Collections.singletonList(mejv));
+ Executing exec =
+ new ExecutingStateBuilder()
+ .setExecutionGraph(graph)
+ .setActiveCheckpointTriggerEnabled(true)
+ .build(ctx);
+ exec.requestActiveCheckpointTrigger();
+ }
+ }
+
+ @Test
+ void testActiveCheckpointTriggerSkipsWhenPeriodicCheckpointingNotConfigured() throws Exception {
+ try (MockExecutingContext ctx = new MockExecutingContext()) {
+ CheckpointCoordinator coordinator =
+ new CheckpointCoordinatorTestingUtils.CheckpointCoordinatorBuilder()
+ .setCheckpointCoordinatorConfiguration(
+ new CheckpointCoordinatorConfiguration
+ .CheckpointCoordinatorConfigurationBuilder()
+ .setCheckpointInterval(Long.MAX_VALUE)
+ .build())
+ .build(EXECUTOR_EXTENSION.getExecutor());
+ StateTrackingMockExecutionGraph graph =
+ new StateTrackingMockExecutionGraph() {
+ @Nullable
+ @Override
+ public CheckpointCoordinator getCheckpointCoordinator() {
+ return coordinator;
+ }
+ };
+ Executing exec =
+ new ExecutingStateBuilder()
+ .setExecutionGraph(graph)
+ .setActiveCheckpointTriggerEnabled(true)
+ .build(ctx);
+
+ assertThat(coordinator.isPeriodicCheckpointingConfigured()).isFalse();
+ exec.requestActiveCheckpointTrigger();
+ }
+ }
+
+ @Test
+ void testActiveCheckpointTriggerSkipsWhenParallelismUnchanged() throws Exception {
+ try (MockExecutingContext ctx = new MockExecutingContext()) {
+ MockExecutionJobVertex mejv = new MockExecutionJobVertex(MockExecutionVertex::new);
+
+ CheckpointCoordinator coordinator =
+ new CheckpointCoordinatorTestingUtils.CheckpointCoordinatorBuilder()
+ .build(EXECUTOR_EXTENSION.getExecutor());
+ StateTrackingMockExecutionGraph graph =
+ new StateTrackingMockExecutionGraph() {
+ @Nullable
+ @Override
+ public CheckpointCoordinator getCheckpointCoordinator() {
+ return coordinator;
+ }
+
+ @Override
+ public Iterable getVerticesTopologically() {
+ return Collections.singletonList(mejv);
+ }
+ };
+ ctx.setVertexParallelism(
+ new VertexParallelism(
+ graph.getAllVertices().values().stream()
+ .collect(
+ Collectors.toMap(
+ AccessExecutionJobVertex::getJobVertexId,
+ AccessExecutionJobVertex::getParallelism))));
+
+ Executing exec =
+ new ExecutingStateBuilder()
+ .setExecutionGraph(graph)
+ .setActiveCheckpointTriggerEnabled(true)
+ .build(ctx);
+ exec.requestActiveCheckpointTrigger();
+ assertThat(coordinator.getNumberOfPendingCheckpoints()).isEqualTo(0);
+ }
+ }
+
+ @Test
+ void testActiveCheckpointTriggerSkipsWhenCheckpointInProgress() throws Exception {
+ try (MockExecutingContext ctx = new MockExecutingContext()) {
+ MockExecutionJobVertex mejv = new MockExecutionJobVertex(MockExecutionVertex::new);
+
+ ManuallyTriggeredScheduledExecutor checkpointTimer =
+ new ManuallyTriggeredScheduledExecutor();
+ CheckpointCoordinator coordinator =
+ new CheckpointCoordinatorTestingUtils.CheckpointCoordinatorBuilder()
+ .setTimer(checkpointTimer)
+ .build(EXECUTOR_EXTENSION.getExecutor());
+ StateTrackingMockExecutionGraph graph =
+ new StateTrackingMockExecutionGraph() {
+ @Nullable
+ @Override
+ public CheckpointCoordinator getCheckpointCoordinator() {
+ return coordinator;
+ }
+
+ @Override
+ public Iterable getVerticesTopologically() {
+ return Collections.singletonList(mejv);
+ }
+ };
+
+ ctx.setVertexParallelism(
+ new VertexParallelism(
+ graph.getAllVertices().values().stream()
+ .collect(
+ Collectors.toMap(
+ AccessExecutionJobVertex::getJobVertexId,
+ v -> v.getParallelism() + 1))));
+
+ Executing exec =
+ new ExecutingStateBuilder()
+ .setExecutionGraph(graph)
+ .setActiveCheckpointTriggerEnabled(true)
+ .build(ctx);
+ coordinator.triggerCheckpoint(false);
+ checkpointTimer.triggerAll();
+
+ int pendingBefore = coordinator.getNumberOfPendingCheckpoints();
+ assertThat(pendingBefore).isGreaterThan(0);
+ exec.requestActiveCheckpointTrigger();
+ checkpointTimer.triggerAll();
+ assertThat(coordinator.getNumberOfPendingCheckpoints()).isEqualTo(pendingBefore);
+ }
+ }
+
+ @Test
+ void testActiveCheckpointTriggerFiresWhenAllGuardsPass() throws Exception {
+ try (MockExecutingContext ctx = new MockExecutingContext()) {
+ MockExecutionJobVertex mejv = new MockExecutionJobVertex(MockExecutionVertex::new);
+
+ ManuallyTriggeredScheduledExecutor checkpointTimer =
+ new ManuallyTriggeredScheduledExecutor();
+ CheckpointCoordinator coordinator =
+ new CheckpointCoordinatorTestingUtils.CheckpointCoordinatorBuilder()
+ .setTimer(checkpointTimer)
+ .build(EXECUTOR_EXTENSION.getExecutor());
+ StateTrackingMockExecutionGraph graph =
+ new StateTrackingMockExecutionGraph() {
+ @Nullable
+ @Override
+ public CheckpointCoordinator getCheckpointCoordinator() {
+ return coordinator;
+ }
+
+ @Override
+ public Iterable getVerticesTopologically() {
+ return Collections.singletonList(mejv);
+ }
+ };
+ ctx.setVertexParallelism(
+ new VertexParallelism(
+ graph.getAllVertices().values().stream()
+ .collect(
+ Collectors.toMap(
+ AccessExecutionJobVertex::getJobVertexId,
+ v -> v.getParallelism() + 1))));
+
+ Executing exec =
+ new ExecutingStateBuilder()
+ .setExecutionGraph(graph)
+ .setActiveCheckpointTriggerEnabled(true)
+ .build(ctx);
+
+ assertThat(coordinator.getNumberOfPendingCheckpoints()).isEqualTo(0);
+
+ exec.requestActiveCheckpointTrigger();
+ checkpointTimer.triggerAll();
+ assertThat(coordinator.getNumberOfPendingCheckpoints()).isGreaterThan(0);
+ }
+ }
+
+ @Test
+ void testActiveCheckpointTriggerRespectsMinPauseBetweenCheckpoints() throws Exception {
+ try (MockExecutingContext ctx = new MockExecutingContext()) {
+ MockExecutionJobVertex mejv = new MockExecutionJobVertex(MockExecutionVertex::new);
+
+ ManuallyTriggeredScheduledExecutor checkpointTimer =
+ new ManuallyTriggeredScheduledExecutor();
+ ManualClock clock = new ManualClock();
+ CheckpointCoordinatorConfiguration coordConfig =
+ new CheckpointCoordinatorConfiguration
+ .CheckpointCoordinatorConfigurationBuilder()
+ .setCheckpointInterval(10_000L)
+ .setMinPauseBetweenCheckpoints(10_000L)
+ .setMaxConcurrentCheckpoints(Integer.MAX_VALUE)
+ .build();
+
+ CheckpointCoordinator coordinator =
+ new CheckpointCoordinatorTestingUtils.CheckpointCoordinatorBuilder()
+ .setCheckpointCoordinatorConfiguration(coordConfig)
+ .setTimer(checkpointTimer)
+ .setClock(clock)
+ .build(EXECUTOR_EXTENSION.getExecutor());
+
+ StateTrackingMockExecutionGraph graph =
+ new StateTrackingMockExecutionGraph() {
+ @Nullable
+ @Override
+ public CheckpointCoordinator getCheckpointCoordinator() {
+ return coordinator;
+ }
+
+ @Override
+ public Iterable getVerticesTopologically() {
+ return Collections.singletonList(mejv);
+ }
+ };
+ ctx.setVertexParallelism(
+ new VertexParallelism(
+ graph.getAllVertices().values().stream()
+ .collect(
+ Collectors.toMap(
+ AccessExecutionJobVertex::getJobVertexId,
+ v -> v.getParallelism() + 1))));
+
+ Executing exec =
+ new ExecutingStateBuilder()
+ .setExecutionGraph(graph)
+ .setActiveCheckpointTriggerEnabled(true)
+ .build(ctx);
+ exec.requestActiveCheckpointTrigger();
+ checkpointTimer.triggerAll();
+ assertThat(coordinator.getNumberOfPendingCheckpoints()).isEqualTo(0);
+ clock.advanceTime(10_000L, java.util.concurrent.TimeUnit.MILLISECONDS);
+ exec.requestActiveCheckpointTrigger();
+ checkpointTimer.triggerAll();
+ assertThat(coordinator.getNumberOfPendingCheckpoints()).isEqualTo(1);
+ }
+ }
+
+ @Test
+ void testActiveCheckpointTriggerHandlesFailureGracefully() throws Exception {
+ try (MockExecutingContext ctx = new MockExecutingContext()) {
+ MockExecutionJobVertex mejv = new MockExecutionJobVertex(MockExecutionVertex::new);
+ ManuallyTriggeredScheduledExecutor checkpointTimer =
+ new ManuallyTriggeredScheduledExecutor();
+ CheckpointCoordinator coordinator =
+ new CheckpointCoordinatorTestingUtils.CheckpointCoordinatorBuilder()
+ .setTimer(checkpointTimer)
+ .build(EXECUTOR_EXTENSION.getExecutor());
+ coordinator.shutdown();
+
+ StateTrackingMockExecutionGraph graph =
+ new StateTrackingMockExecutionGraph() {
+ @Nullable
+ @Override
+ public CheckpointCoordinator getCheckpointCoordinator() {
+ return coordinator;
+ }
+
+ @Override
+ public Iterable getVerticesTopologically() {
+ return Collections.singletonList(mejv);
+ }
+ };
+ ctx.setVertexParallelism(
+ new VertexParallelism(
+ graph.getAllVertices().values().stream()
+ .collect(
+ Collectors.toMap(
+ AccessExecutionJobVertex::getJobVertexId,
+ v -> v.getParallelism() + 1))));
+
+ Executing exec =
+ new ExecutingStateBuilder()
+ .setExecutionGraph(graph)
+ .setActiveCheckpointTriggerEnabled(true)
+ .build(ctx);
+ exec.requestActiveCheckpointTrigger();
+ checkpointTimer.triggerAll();
+ }
+ }
+
@Test
void testJobInformationMethods() throws Exception {
try (MockExecutingContext ctx = new MockExecutingContext()) {
@@ -691,6 +1003,7 @@ private final class ExecutingStateBuilder {
private Function
stateTransitionManagerFactory = context -> TestingStateTransitionManager.withNoOp();
private int rescaleOnFailedCheckpointCount = 1;
+ private boolean activeCheckpointTriggerEnabled = false;
private ExecutingStateBuilder() throws JobException, JobExecutionException {
operatorCoordinatorHandler = new TestingOperatorCoordinatorHandler();
@@ -720,6 +1033,12 @@ public ExecutingStateBuilder setRescaleOnFailedCheckpointCount(
return this;
}
+ public ExecutingStateBuilder setActiveCheckpointTriggerEnabled(
+ boolean activeCheckpointTriggerEnabled) {
+ this.activeCheckpointTriggerEnabled = activeCheckpointTriggerEnabled;
+ return this;
+ }
+
private Executing build(MockExecutingContext ctx) {
executionGraph.transitionToRunning();
@@ -733,7 +1052,8 @@ private Executing build(MockExecutingContext ctx) {
ClassLoader.getSystemClassLoader(),
new ArrayList<>(),
stateTransitionManagerFactory::apply,
- rescaleOnFailedCheckpointCount);
+ rescaleOnFailedCheckpointCount,
+ activeCheckpointTriggerEnabled);
} finally {
Preconditions.checkState(
!ctx.hadStateTransition,
@@ -1029,6 +1349,12 @@ public boolean updateState(TaskExecutionStateTransition state) {
public Iterable getVerticesTopologically() {
return getVerticesTopologicallySupplier.get();
}
+
+ @Nullable
+ @Override
+ public CheckpointCoordinator getCheckpointCoordinator() {
+ return null;
+ }
}
private static class FinishingMockExecutionGraph extends StateTrackingMockExecutionGraph {
diff --git a/flink-tests/src/test/java/org/apache/flink/test/scheduling/RescaleOnCheckpointITCase.java b/flink-tests/src/test/java/org/apache/flink/test/scheduling/RescaleOnCheckpointITCase.java
index 6a24600f1ace1..dcc9ceefba776 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/scheduling/RescaleOnCheckpointITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/scheduling/RescaleOnCheckpointITCase.java
@@ -175,4 +175,93 @@ void testRescaleOnCheckpoint(
restClusterClient.cancel(jobGraph.getJobID()).join();
}
}
+
+ @Test
+ void testRescaleWithActiveCheckpointTrigger(
+ @InjectMiniCluster MiniCluster miniCluster,
+ @InjectClusterClient RestClusterClient> restClusterClient)
+ throws Exception {
+ final Configuration config = new Configuration();
+ config.set(JobManagerOptions.SCHEDULER_RESCALE_TRIGGER_ACTIVE_CHECKPOINT_ENABLED, true);
+
+ final StreamExecutionEnvironment env =
+ StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(BEFORE_RESCALE_PARALLELISM);
+ env.enableCheckpointing(Duration.ofHours(24).toMillis());
+ env.fromSequence(0, Integer.MAX_VALUE).sinkTo(new DiscardingSink<>());
+
+ final JobGraph jobGraph = env.getStreamGraph().getJobGraph();
+ final Iterator jobVertexIterator = jobGraph.getVertices().iterator();
+ assertThat(jobVertexIterator.hasNext()).isTrue();
+ final JobVertexID jobVertexId = jobVertexIterator.next().getID();
+
+ final JobResourceRequirements jobResourceRequirements =
+ JobResourceRequirements.newBuilder()
+ .setParallelismForJobVertex(jobVertexId, 1, AFTER_RESCALE_PARALLELISM)
+ .build();
+
+ restClusterClient.submitJob(jobGraph).join();
+
+ final JobID jobId = jobGraph.getJobID();
+ try {
+ LOG.info(
+ "Waiting for job {} to reach parallelism of {} for vertex {}.",
+ jobId,
+ BEFORE_RESCALE_PARALLELISM,
+ jobVertexId);
+ waitForRunningTasks(restClusterClient, jobId, BEFORE_RESCALE_PARALLELISM);
+
+ LOG.info(
+ "Updating job {} resource requirements: parallelism {} -> {}.",
+ jobId,
+ BEFORE_RESCALE_PARALLELISM,
+ AFTER_RESCALE_PARALLELISM);
+ restClusterClient.updateJobResourceRequirements(jobId, jobResourceRequirements).join();
+ LOG.info(
+ "Waiting for job {} to rescale to parallelism {} via active checkpoint trigger.",
+ jobId,
+ AFTER_RESCALE_PARALLELISM);
+ waitForRunningTasks(restClusterClient, jobId, AFTER_RESCALE_PARALLELISM);
+ final int expectedFreeSlotCount = NUMBER_OF_SLOTS - AFTER_RESCALE_PARALLELISM;
+ LOG.info(
+ "Waiting for {} slot(s) to become available after scale down.",
+ expectedFreeSlotCount);
+ waitForAvailableSlots(restClusterClient, expectedFreeSlotCount);
+ } finally {
+ restClusterClient.cancel(jobGraph.getJobID()).join();
+ }
+ }
+
+ @Test
+ void testNoRescaleWithoutCheckpointingConfigured(
+ @InjectMiniCluster MiniCluster miniCluster,
+ @InjectClusterClient RestClusterClient> restClusterClient)
+ throws Exception {
+ final Configuration config = new Configuration();
+ final StreamExecutionEnvironment env =
+ StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(BEFORE_RESCALE_PARALLELISM);
+ env.fromSequence(0, Integer.MAX_VALUE).sinkTo(new DiscardingSink<>());
+
+ final JobGraph jobGraph = env.getStreamGraph().getJobGraph();
+ final Iterator jobVertexIterator = jobGraph.getVertices().iterator();
+ assertThat(jobVertexIterator.hasNext()).isTrue();
+ final JobVertexID jobVertexId = jobVertexIterator.next().getID();
+
+ final JobResourceRequirements jobResourceRequirements =
+ JobResourceRequirements.newBuilder()
+ .setParallelismForJobVertex(jobVertexId, 1, AFTER_RESCALE_PARALLELISM)
+ .build();
+ restClusterClient.submitJob(jobGraph).join();
+ final JobID jobId = jobGraph.getJobID();
+ try {
+ waitForRunningTasks(restClusterClient, jobId, BEFORE_RESCALE_PARALLELISM);
+ restClusterClient.updateJobResourceRequirements(jobId, jobResourceRequirements).join();
+ Thread.sleep(REQUIREMENT_UPDATE_TO_CHECKPOINT_GAP.toMillis());
+ waitForRunningTasks(restClusterClient, jobId, BEFORE_RESCALE_PARALLELISM);
+ LOG.info("Verified: job {} did not rescale without checkpointing configured.", jobId);
+ } finally {
+ restClusterClient.cancel(jobGraph.getJobID()).join();
+ }
+ }
}