diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/optimize/CorrelVariableNormalizerShuttle.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/optimize/CorrelVariableNormalizerShuttle.java new file mode 100644 index 0000000000000..56a299062bd4d --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/optimize/CorrelVariableNormalizerShuttle.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.optimize; + +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelShuttleImpl; +import org.apache.calcite.rel.core.CorrelationId; +import org.apache.calcite.rel.logical.LogicalCorrelate; +import org.apache.calcite.rel.logical.LogicalTableFunctionScan; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCorrelVariable; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.rex.RexSubQuery; + +import java.util.Optional; + +public final class CorrelVariableNormalizerShuttle extends RelShuttleImpl { + + private final RexBuilder rexBuilder; + private final RexShuttle rexCorrelNormalizer; + + private int correlationIdAdjustment = Integer.MIN_VALUE; + + public CorrelVariableNormalizerShuttle(RexBuilder rexBuilder) { + this.rexBuilder = rexBuilder; + rexCorrelNormalizer = new RexCorrelNormalizer(); + } + + @Override + public RelNode visit(LogicalCorrelate correlate) { + var adjustedId = adjustCorrelationId(correlate.getCorrelationId()); + if (adjustedId.isPresent()) { + var left = correlate.getLeft().accept(this); + var right = correlate.getRight().accept(this); + return correlate.copy( + correlate.getTraitSet(), + left, + right, + adjustedId.get(), + correlate.getRequiredColumns(), + correlate.getJoinType()); + } + + return super.visit(correlate); + } + + @Override + public RelNode visit(RelNode relNode) { + if (relNode instanceof LogicalTableFunctionScan && relNode.getInputs().isEmpty()) { + // Since we only visit the children with the rexshuttle below, + // we have to explicitly visit the table function scan since it is a sink (i.e. no + // children) but can have RexNodes (a TableScan cannot) + return relNode.accept(rexCorrelNormalizer); + } + + return super.visit(relNode); + } + + @Override + protected RelNode visitChild(RelNode parent, int i, RelNode child) { + if (i == 0) { + parent = parent.accept(rexCorrelNormalizer); + } + + return super.visitChild(parent, i, child); + } + + private Optional adjustCorrelationId(CorrelationId correlationId) { + if (!correlationId.getName().startsWith(CorrelationId.CORREL_PREFIX)) { + return Optional.empty(); + } + + if (correlationIdAdjustment < 0) { + correlationIdAdjustment = correlationId.getId() - 1; + } + + if (correlationIdAdjustment == 0) { + return Optional.empty(); + } + + return Optional.of(new CorrelationId(correlationId.getId() - correlationIdAdjustment)); + } + + private final class RexCorrelNormalizer extends RexShuttle { + + @Override + public RexNode visitCorrelVariable(RexCorrelVariable variable) { + var adjustedId = adjustCorrelationId(variable.id); + if (adjustedId.isPresent()) { + return rexBuilder.makeCorrel(variable.getType(), adjustedId.get()); + } else { + return super.visitCorrelVariable(variable); + } + } + + @Override + public RexNode visitSubQuery(RexSubQuery subQuery) { + var rewritten = subQuery.rel.accept(CorrelVariableNormalizerShuttle.this); + return subQuery.clone(rewritten); + } + } +} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/CommonSubGraphBasedOptimizer.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/CommonSubGraphBasedOptimizer.scala index 92c7f316ae03c..854939941657c 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/CommonSubGraphBasedOptimizer.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/CommonSubGraphBasedOptimizer.scala @@ -84,7 +84,14 @@ abstract class CommonSubGraphBasedOptimizer extends Optimizer { val clearQueryBlockAliasResolver = new ClearQueryBlockAliasResolver val resolvedAliasRoots = clearQueryBlockAliasResolver.resolve(resolvedHintRoots) - val sinkBlocks = doOptimize(resolvedAliasRoots) + // Normalize correlation variable ids per root so structurally equivalent + // subplans across sinks share digests (required for SubplanReuser to dedupe). + // A fresh shuttle per root avoids the first root's adjustment leaking into + // the next one's numbering. + val normalizedRoots = resolvedAliasRoots.map( + root => root.accept(new CorrelVariableNormalizerShuttle(root.getCluster.getRexBuilder))) + + val sinkBlocks = doOptimize(normalizedRoots) val optimizedPlan = sinkBlocks.map { block => val plan = block.getOptimizedPlan