Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
import org.apache.flink.table.catalog.Column;
import org.apache.flink.table.catalog.ResolvedSchema;
import org.apache.flink.table.data.TimestampData;
import org.apache.flink.table.functions.FunctionKind;
import org.apache.flink.table.planner.catalog.CatalogSchemaModel;
import org.apache.flink.table.planner.catalog.CatalogSchemaTable;
import org.apache.flink.table.planner.functions.sql.ml.SqlMLTableFunction;
import org.apache.flink.table.planner.plan.FlinkCalciteCatalogReader;
import org.apache.flink.table.planner.plan.utils.FlinkRexUtil;
import org.apache.flink.table.planner.utils.ShortcutUtils;
Expand Down Expand Up @@ -65,6 +65,8 @@
import org.apache.calcite.sql.SqlUtil;
import org.apache.calcite.sql.SqlWindowTableFunction;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlOperandMetadata;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.sql.validate.DelegatingScope;
import org.apache.calcite.sql.validate.IdentifierNamespace;
Expand Down Expand Up @@ -92,7 +94,6 @@
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.apache.calcite.sql.type.SqlTypeName.DECIMAL;
import static org.apache.flink.table.expressions.resolver.lookups.FieldReferenceLookup.includeExpandedColumn;
Expand Down Expand Up @@ -343,7 +344,7 @@ protected void addToSelectList(
final Column column = resolvedSchema.getColumn(columnName).orElse(null);
if (qualified.suffix().size() == 1 && column != null) {
if (includeExpandedColumn(column, columnExpansionStrategies)
|| declaredDescriptorColumn(scope, column)) {
|| isDeclaredOnTimeColumn(scope, column)) {
super.addToSelectList(
list, aliases, fieldList, exp, scope, includeSystemVars);
}
Expand All @@ -360,71 +361,71 @@ protected void addToSelectList(
protected @PolyNull SqlNode performUnconditionalRewrites(
@PolyNull SqlNode node, boolean underFrom) {

// Special case for window TVFs like:
// TUMBLE(TABLE t, DESCRIPTOR(metadata_virtual), INTERVAL '1' MINUTE)) or
// SESSION(TABLE t PARTITION BY a, DESCRIPTOR(metadata_virtual), INTERVAL '1' MINUTE))
// Capture table arguments early:
// TUMBLE(TABLE t, DESCRIPTOR(metadata_virtual), INTERVAL '1' MINUTE) or
// SESSION(TABLE t PARTITION BY a, DESCRIPTOR(metadata_virtual), INTERVAL '1' MINUTE)
// MyPtf(in => TABLE t PARTITION BY a, on_time => DESCRIPTOR(metadata_virtual))
//
// "TABLE t" is translated into an implicit "SELECT * FROM t". This would ignore columns
// that are not expanded by default. However, the descriptor explicitly states the need
// for this column. Therefore, explicit table expressions (for window TVFs at most one)
// are captured before rewriting and replaced with a "marker" SqlSelect that contains the
// descriptor information. The "marker" SqlSelect is considered during column expansion.
// that are not expanded by default. However, the on_time descriptor explicitly states the
// need for time columns. Therefore, explicit table expressions are captured before
// rewriting and replaced with a "marker" SqlSelect that contains the descriptor
// information. The "marker" SqlSelect is considered during column expansion.
final List<SqlIdentifier> tableArgs = getTableOperands(node);

final SqlNode rewritten = super.performUnconditionalRewrites(node, underFrom);

if (!(node instanceof SqlBasicCall)) {
return rewritten;
}

final SqlBasicCall call = (SqlBasicCall) node;
final SqlOperator operator = call.getOperator();

// Special case for MODEL
if (node instanceof SqlExplicitModelCall) {
// Convert it so that model can be accessed in planner. SqlExplicitModelCall
// from parser can't access model.
SqlExplicitModelCall modelCall = (SqlExplicitModelCall) node;
SqlIdentifier modelIdentifier = modelCall.getModelIdentifier();
FlinkCalciteCatalogReader catalogReader =
final SqlExplicitModelCall modelCall = (SqlExplicitModelCall) node;
final SqlIdentifier modelIdentifier = modelCall.getModelIdentifier();
final FlinkCalciteCatalogReader catalogReader =
(FlinkCalciteCatalogReader) getCatalogReader();
CatalogSchemaModel model = catalogReader.getModel(modelIdentifier.names);
final CatalogSchemaModel model = catalogReader.getModel(modelIdentifier.names);
if (model != null) {
return new SqlModelCall(modelCall, model);
}
}

// TODO (FLINK-37819): add test for SqlMLTableFunction
if (operator instanceof SqlWindowTableFunction || operator instanceof SqlMLTableFunction) {
if (tableArgs.stream().allMatch(Objects::isNull)) {
return rewritten;
}

final List<SqlIdentifier> descriptors =
call.getOperandList().stream()
.flatMap(FlinkCalciteSqlValidator::extractDescriptors)
.collect(Collectors.toList());

// Mark rewritten "TABLE t" with on_time columns
if (tableArgs == null || tableArgs.stream().allMatch(Objects::isNull)) {
return rewritten;
}
final List<SqlIdentifier> onTimeColumns = extractOnTime(call);
if (onTimeColumns != null) {
for (int i = 0; i < call.operandCount(); i++) {
final SqlIdentifier tableArg = tableArgs.get(i);
if (tableArg != null) {
final SqlNode opReplacement = new ExplicitTableSqlSelect(tableArg, descriptors);
final SqlNode opReplacement =
new ExplicitTableSqlSelect(tableArg, onTimeColumns);
// for f(TABLE t PARTITION BY c, ...)
if (call.operand(i).getKind() == SqlKind.SET_SEMANTICS_TABLE) {
final SqlCall setSemanticsTable = call.operand(i);
setSemanticsTable.setOperand(0, opReplacement);
} else if (call.operand(i).getKind() == SqlKind.ARGUMENT_ASSIGNMENT) {
// for TUMBLE(DATA => TABLE t3, ...)
final SqlCall assignment = call.operand(i);
// for f(in => TABLE t PARTITION BY c, ...)
if (assignment.operand(0).getKind() == SqlKind.SET_SEMANTICS_TABLE) {
final SqlCall setSemanticsTable = assignment.operand(i);
final SqlCall setSemanticsTable = assignment.operand(0);
setSemanticsTable.setOperand(0, opReplacement);
} else {
// for f(in => TABLE t, ...)
assignment.setOperand(0, opReplacement);
}
} else {
// for TUMBLE(TABLE t3, ...)
// for f(TABLE t, ...)
call.setOperand(i, opReplacement);
}
}
// for TUMBLE([DATA =>] SELECT ..., ...)
// for f([in =>] SELECT ..., ...)
}
}

Expand All @@ -446,9 +447,9 @@ public SqlNode maybeCast(SqlNode node, RelDataType currentType, RelDataType desi
*/
static class ExplicitTableSqlSelect extends SqlSelect {

private final List<SqlIdentifier> descriptors;
private final List<SqlIdentifier> onTimeColumns;

public ExplicitTableSqlSelect(SqlIdentifier table, List<SqlIdentifier> descriptors) {
public ExplicitTableSqlSelect(SqlIdentifier table, List<SqlIdentifier> onTimeColumns) {
super(
SqlParserPos.ZERO,
null,
Expand All @@ -462,91 +463,150 @@ public ExplicitTableSqlSelect(SqlIdentifier table, List<SqlIdentifier> descripto
null,
null,
null);
this.descriptors = descriptors;
this.onTimeColumns = onTimeColumns;
}
}

/**
* Returns whether the given column has been declared in a {@link SqlKind#DESCRIPTOR} next to a
* {@link SqlKind#EXPLICIT_TABLE} within TVF operands.
*/
private static boolean declaredDescriptorColumn(SelectScope scope, Column column) {
private static boolean isDeclaredOnTimeColumn(SelectScope scope, Column column) {
if (!(scope.getNode() instanceof ExplicitTableSqlSelect)) {
return false;
}
final ExplicitTableSqlSelect select = (ExplicitTableSqlSelect) scope.getNode();
return select.descriptors.stream()
return select.onTimeColumns.stream()
.map(SqlIdentifier::getSimple)
.anyMatch(id -> id.equals(column.getName()));
}

/**
* Returns all {@link SqlKind#EXPLICIT_TABLE} and {@link SqlKind#SET_SEMANTICS_TABLE} operands
* within TVF operands. A list entry is {@code null} if the operand is not an {@link
* within PTF operands. A list entry is {@code null} if the operand is not an {@link
* SqlKind#EXPLICIT_TABLE} or {@link SqlKind#SET_SEMANTICS_TABLE}.
*/
private static List<SqlIdentifier> getTableOperands(SqlNode node) {
if (!(node instanceof SqlBasicCall)) {
return null;
}

final SqlBasicCall call = (SqlBasicCall) node;

if (!(call.getOperator() instanceof SqlFunction)) {
return null;
}

final SqlFunction function = (SqlFunction) call.getOperator();

if (!isTableFunction(function)) {
return null;
}

return call.getOperandList().stream()
.map(FlinkCalciteSqlValidator::extractTableOperand)
.map(FlinkCalciteSqlValidator::extractExplicitTables)
.collect(Collectors.toList());
}

private static @Nullable SqlIdentifier extractTableOperand(SqlNode op) {
/** Extracts "TABLE t" nodes before they get rewritten into "SELECT * FROM t". */
private static @Nullable SqlIdentifier extractExplicitTables(SqlNode op) {
if (op.getKind() == SqlKind.EXPLICIT_TABLE) {
final SqlBasicCall opCall = (SqlBasicCall) op;
if (opCall.operandCount() == 1 && opCall.operand(0) instanceof SqlIdentifier) {
// for TUMBLE(TABLE t3, ...)
// for f(TABLE t, ...)
return opCall.operand(0);
}
} else if (op.getKind() == SqlKind.SET_SEMANTICS_TABLE) {
// for SESSION windows
// for f(TABLE t PARTITION BY x)
final SqlBasicCall opCall = (SqlBasicCall) op;
final SqlCall setSemanticsTable = opCall.operand(0);
if (setSemanticsTable.operand(0) instanceof SqlIdentifier) {
return setSemanticsTable.operand(0);
}
return extractExplicitTables(opCall.operand(0));
} else if (op.getKind() == SqlKind.ARGUMENT_ASSIGNMENT) {
// for TUMBLE(DATA => TABLE t3, ...)
// for f(in => TABLE t, ...)
final SqlBasicCall opCall = (SqlBasicCall) op;
return extractTableOperand(opCall.operand(0));
return extractExplicitTables(opCall.operand(0));
}
return null;
}

private static Stream<SqlIdentifier> extractDescriptors(SqlNode op) {
/** Extracts the on_time argument of a PTF (or TIMECOL for window PTFs for legacy reasons). */
private static @Nullable List<SqlIdentifier> extractOnTime(SqlBasicCall call) {
// Extract from operand from PTF
final SqlNode onTimeOperand;
if (call.getOperator() instanceof SqlWindowTableFunction) {
onTimeOperand = extractOperandByArgName(call, "TIMECOL");
} else if (ShortcutUtils.isFunctionKind(call.getOperator(), FunctionKind.PROCESS_TABLE)) {
onTimeOperand = extractOperandByArgName(call, "on_time");
} else {
onTimeOperand = null;
}

// No operand found
if (onTimeOperand == null) {
return null;
}

return extractDescriptors(onTimeOperand);
}

private static List<SqlIdentifier> extractDescriptors(SqlNode op) {
if (op.getKind() == SqlKind.DESCRIPTOR) {
// for TUMBLE(..., DESCRIPTOR(col), ...)
final SqlBasicCall opCall = (SqlBasicCall) op;
return opCall.getOperandList().stream()
.filter(SqlIdentifier.class::isInstance)
.map(SqlIdentifier.class::cast);
} else if (op.getKind() == SqlKind.SET_SEMANTICS_TABLE) {
// for SESSION windows
final SqlBasicCall opCall = (SqlBasicCall) op;
return ((SqlNodeList) opCall.operand(1))
.stream()
.filter(SqlIdentifier.class::isInstance)
.map(SqlIdentifier.class::cast);
} else if (op.getKind() == SqlKind.ARGUMENT_ASSIGNMENT) {
// for TUMBLE(..., TIMECOL => DESCRIPTOR(col), ...)
final SqlBasicCall opCall = (SqlBasicCall) op;
return extractDescriptors(opCall.operand(0));
.map(SqlIdentifier.class::cast)
.collect(Collectors.toList());
}
return List.of();
}

/**
* Returns the operand for a given argument name from a BasicSqlCall. Supports both positional
* and named arguments. If at least one ARGUMENT_ASSIGNMENT is used, named lookup is performed.
* Otherwise, positional lookup using SqlOperandMetadata is used.
*
* @param call the SQL call to extract the operand from
* @param argumentName the name of the argument to retrieve
* @return the SqlNode for the operand, or null if not found or not supported
*/
private static @Nullable SqlNode extractOperandByArgName(
SqlBasicCall call, String argumentName) {
// Check if operator supports SqlOperandMetadata
final SqlOperator operator = call.getOperator();
final SqlOperandTypeChecker typeChecker = operator.getOperandTypeChecker();
if (!(typeChecker instanceof SqlOperandMetadata)) {
return null;
}

final SqlOperandMetadata operandMetadata = (SqlOperandMetadata) typeChecker;

// Detect if named arguments are used by checking for ARGUMENT_ASSIGNMENT
final List<SqlNode> operands = call.getOperandList();
final boolean hasNamedArguments =
operands.stream().anyMatch(op -> op.getKind() == SqlKind.ARGUMENT_ASSIGNMENT);

if (hasNamedArguments) {
// Named mode: search through ARGUMENT_ASSIGNMENT nodes
for (SqlNode operand : operands) {
if (operand.getKind() == SqlKind.ARGUMENT_ASSIGNMENT) {
final SqlBasicCall assignment = (SqlBasicCall) operand;
// operand(1) contains the parameter name as SqlIdentifier
final SqlIdentifier paramName = assignment.operand(1);
if (paramName.getSimple().equals(argumentName)) {
// operand(0) contains the actual value
return assignment.operand(0);
}
}
}
return null;
} else {
// Positional mode: use SqlOperandMetadata to map name to position
final List<String> paramNames = operandMetadata.paramNames();
final int index = paramNames.indexOf(argumentName);
if (index == -1 || index >= call.operandCount()) {
return null;
}
return call.operand(index);
}
return Stream.empty();
}

private static boolean isTableFunction(SqlFunction function) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.flink.table.expressions.CallExpression;
import org.apache.flink.table.expressions.ResolvedExpression;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.functions.FunctionKind;
import org.apache.flink.table.planner.calcite.FlinkContext;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.delegation.PlannerBase;
Expand Down Expand Up @@ -156,6 +157,18 @@ public static DataTypeFactory unwrapDataTypeFactory(RelBuilder relBuilder) {
return ((BridgingSqlFunction) call.getOperator()).getDefinition();
}

public static @Nullable FunctionDefinition unwrapFunctionDefinition(SqlOperator operator) {
if (!(operator instanceof BridgingSqlFunction)) {
return null;
}
return ((BridgingSqlFunction) operator).getDefinition();
}

public static boolean isFunctionKind(SqlOperator operator, FunctionKind kind) {
final FunctionDefinition functionDefinition = unwrapFunctionDefinition(operator);
return functionDefinition != null && functionDefinition.getKind() == kind;
}

public static @Nullable BridgingSqlFunction unwrapBridgingSqlFunction(RexCall call) {
final SqlOperator operator = call.getOperator();
if (operator instanceof BridgingSqlFunction) {
Expand Down
Loading