diff --git a/LICENSE-binary b/LICENSE-binary index 1e30e9d8f28..774bc1539fc 100644 --- a/LICENSE-binary +++ b/LICENSE-binary @@ -308,6 +308,9 @@ com.squareup.retrofit2:retrofit com.squareup.okhttp3:okhttp org.apache.kafka:kafka-clients org.xerial:sqlite-jdbc +com.openai:openai-java +com.github.victools:jsonschema-generator +com.github.victools:jsonschema-module-jackson BSD ------------ diff --git a/docs/configuration/settings.md b/docs/configuration/settings.md index 22621050711..0d9f3e6fb9d 100644 --- a/docs/configuration/settings.md +++ b/docs/configuration/settings.md @@ -153,7 +153,8 @@ You can configure the Kyuubi properties in `$KYUUBI_HOME/conf/kyuubi-defaults.co | kyuubi.engine.data.agent.max.iterations | 100 | The maximum number of ReAct loop iterations for the Data Agent engine. | int | 1.12.0 | | kyuubi.engine.data.agent.memory | 1g | The heap memory for the Data Agent engine | string | 1.12.0 | | kyuubi.engine.data.agent.provider | ECHO | The provider for the Data Agent engine. Candidates: | string | 1.12.0 | -| kyuubi.engine.data.agent.query.timeout | PT5M | The query execution timeout for the Data Agent SQL tool. | duration | 1.12.0 | +| kyuubi.engine.data.agent.query.timeout | PT3M | The JDBC query execution timeout for the Data Agent SQL tools. Passed to Statement.setQueryTimeout so the server (Spark/Trino/...) can cooperatively cancel long-running queries and release cluster resources. Should be set lower than kyuubi.engine.data.agent.tool.call.timeout so server-side cancellation has time to react before the outer wall-clock cap fires. | duration | 1.12.0 | +| kyuubi.engine.data.agent.tool.call.timeout | PT5M | The maximum wall-clock execution time for any tool call in the Data Agent engine. Acts as the outer safety net enforced by the agent runtime via Future.cancel(), applied uniformly to every tool. For SQL tools the inner JDBC-level timeout is controlled separately by kyuubi.engine.data.agent.query.timeout, which should be set lower so server-side cancellation has time to react before this hard cap fires. | duration | 1.12.0 | | kyuubi.engine.deregister.exception.classes || A comma-separated list of exception classes. If there is any exception thrown, whose class matches the specified classes, the engine would deregister itself. | set | 1.2.0 | | kyuubi.engine.deregister.exception.messages || A comma-separated list of exception messages. If there is any exception thrown, whose message or stacktrace matches the specified message list, the engine would deregister itself. | set | 1.2.0 | | kyuubi.engine.deregister.exception.ttl | PT30M | Time to live(TTL) for exceptions pattern specified in kyuubi.engine.deregister.exception.classes and kyuubi.engine.deregister.exception.messages to deregister engines. Once the total error count hits the kyuubi.engine.deregister.job.max.failures within the TTL, an engine will deregister itself and wait for self-terminated. Otherwise, we suppose that the engine has recovered from temporary failures. | duration | 1.2.0 | diff --git a/externals/kyuubi-data-agent-engine/pom.xml b/externals/kyuubi-data-agent-engine/pom.xml index 47174958d6b..c34d049360c 100644 --- a/externals/kyuubi-data-agent-engine/pom.xml +++ b/externals/kyuubi-data-agent-engine/pom.xml @@ -50,6 +50,21 @@ ${project.version} + + com.openai + openai-java + + + + com.github.victools + jsonschema-generator + + + + com.github.victools + jsonschema-module-jackson + + org.apache.kyuubi @@ -65,6 +80,18 @@ test + + org.testcontainers + testcontainers-mysql + test + + + + com.mysql + mysql-connector-j + test + + junit junit diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/DataSourceFactory.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/DataSourceFactory.java new file mode 100644 index 00000000000..9f5eb154188 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/DataSourceFactory.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.datasource; + +import com.zaxxer.hikari.HikariConfig; +import com.zaxxer.hikari.HikariDataSource; +import javax.sql.DataSource; + +/** Factory for creating pooled DataSource instances from JDBC URLs. */ +public final class DataSourceFactory { + + private static final int DEFAULT_MAX_POOL_SIZE = 5; + + private DataSourceFactory() {} + + /** + * Create a pooled DataSource from a JDBC URL. Supports any JDBC driver available on the + * classpath. + * + * @param jdbcUrl the JDBC connection URL + * @return a HikariCP-backed DataSource + */ + public static DataSource create(String jdbcUrl) { + return create(jdbcUrl, null, null); + } + + /** + * Create a pooled DataSource from a JDBC URL with an explicit username. When the data-agent + * connects back to Kyuubi Server, the username determines the proxy user for the downstream + * engine (e.g. Spark). Without it, Kyuubi defaults to "anonymous" which typically fails Hadoop + * impersonation checks. + * + * @param jdbcUrl the JDBC connection URL + * @param user the username for the JDBC connection, may be null + * @return a HikariCP-backed DataSource + */ + public static DataSource create(String jdbcUrl, String user) { + return create(jdbcUrl, user, null); + } + + /** + * Create a pooled DataSource from a JDBC URL with explicit credentials. Prefer this overload when + * a password is required: passing the password through {@link HikariConfig#setPassword} keeps it + * out of the JDBC URL, which would otherwise leak the password into log lines, JMX pool metadata, + * exception messages, and connection strings printed by debug tooling. + * + * @param jdbcUrl the JDBC connection URL + * @param user the username for the JDBC connection, may be null + * @param password the password for the JDBC connection, may be null + * @return a HikariCP-backed DataSource + */ + public static DataSource create(String jdbcUrl, String user, String password) { + if (jdbcUrl == null || jdbcUrl.isEmpty()) { + throw new IllegalArgumentException("jdbcUrl must not be null or empty"); + } + HikariConfig config = new HikariConfig(); + config.setJdbcUrl(jdbcUrl); + if (user != null && !user.isEmpty()) { + config.setUsername(user); + } + if (password != null && !password.isEmpty()) { + config.setPassword(password); + } + config.setMaximumPoolSize(DEFAULT_MAX_POOL_SIZE); + config.setMinimumIdle(1); + config.setInitializationFailTimeout(-1); + config.setPoolName("kyuubi-data-agent"); + return new HikariDataSource(config); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/GenericDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/GenericDialect.java new file mode 100644 index 00000000000..3ea22ed54e3 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/GenericDialect.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.datasource; + +/** + * Fallback dialect for JDBC subprotocols that have no dedicated implementation. Carries the + * subprotocol name (e.g. "postgresql", "clickhouse") so prompts can still tell the LLM which SQL + * flavor it is talking to. {@link #quoteIdentifier(String)} is intentionally unsupported — callers + * that need quoting must check the dialect type first or pick a tool that does not depend on + * dialect-specific identifier quoting. + */ +public final class GenericDialect implements JdbcDialect { + + private final String name; + + public GenericDialect(String name) { + this.name = name; + } + + @Override + public String datasourceName() { + return name; + } + + @Override + public String quoteIdentifier(String identifier) { + throw new UnsupportedOperationException( + "quoteIdentifier is not supported for generic dialect: " + name); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialect.java new file mode 100644 index 00000000000..c3be1dad61a --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialect.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.datasource; + +/** + * SQL dialect abstraction for datasource-specific SQL generation. + * + *

Each dialect maps to a datasource name used for prompt resource lookup ({@code + * prompts/datasource-{name}.md}). + * + *

Relationship to kyuubi-jdbc-engine's JdbcDialect

+ * + *

This interface is intentionally decoupled from {@code + * org.apache.kyuubi.engine.jdbc.dialect.JdbcDialect} in the JDBC engine module. That dialect serves + * the Thrift protocol layer (mapping JDBC results into TRowSets for Kyuubi clients), while this one + * serves the Data Agent's tool system — providing identifier quoting, qualified name construction + * ({@link #qualify(TableRef)}), and (in the future) dialect-specific "recipe" SQL for preset + * analytical tools. The two evolve independently and share no dependency. + */ +public interface JdbcDialect { + + /** Datasource name for prompt resource lookup (e.g. "spark", "trino"). */ + String datasourceName(); + + /** + * Quote an identifier (table/column/database name) using the dialect-appropriate quote character. + * Escapes any embedded quote characters by doubling them. + */ + String quoteIdentifier(String identifier); + + /** + * Build a fully-qualified table name from a {@link TableRef}, quoting each segment with {@link + * #quoteIdentifier(String)} and joining with {@code .}. Null segments (catalog, schema) are + * skipped. + * + * @param ref the table reference + * @return the qualified name, e.g. {@code `mydb`.`users`} or {@code "hive"."sales"."orders"} + */ + default String qualify(TableRef ref) { + StringBuilder sb = new StringBuilder(); + if (ref.getCatalog() != null) sb.append(quoteIdentifier(ref.getCatalog())).append('.'); + if (ref.getSchema() != null) sb.append(quoteIdentifier(ref.getSchema())).append('.'); + sb.append(quoteIdentifier(ref.getTable())); + return sb.toString(); + } + + /** + * Infer the dialect from a JDBC URL. + * + *

If the URL prefix matches a built-in dialect (spark/trino/mysql/sqlite) the corresponding + * implementation is returned. Otherwise a {@link GenericDialect} carrying the extracted + * subprotocol name (e.g. "postgresql", "clickhouse", "oracle") is returned so prompts can still + * tell the LLM which SQL flavor it is talking to. Returns {@code null} only when the URL is null + * or not a parseable {@code jdbc::...} string. + * + * @param jdbcUrl the JDBC connection URL + * @return the matching dialect, or {@code null} if the URL is unparseable + */ + static JdbcDialect fromUrl(String jdbcUrl) { + String name = extractSubprotocol(jdbcUrl); + if (name == null) { + return null; + } + switch (name) { + case "hive2": + case "spark": + return SparkDialect.INSTANCE; + case "trino": + return TrinoDialect.INSTANCE; + case "mysql": + return MysqlDialect.INSTANCE; + case "sqlite": + return SqliteDialect.INSTANCE; + default: + return new GenericDialect(name); + } + } + + /** + * Extract the subprotocol name from a JDBC URL: {@code jdbc:postgresql://host} → {@code + * "postgresql"}. Returns null if the URL does not start with {@code jdbc:} or has no second + * colon. + */ + static String extractSubprotocol(String jdbcUrl) { + if (jdbcUrl == null) { + return null; + } + String lower = jdbcUrl.toLowerCase(); + if (!lower.startsWith("jdbc:")) { + return null; + } + int end = lower.indexOf(':', 5); + if (end <= 5) { + return null; + } + return lower.substring(5, end); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/MysqlDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/MysqlDialect.java new file mode 100644 index 00000000000..98747ffa30c --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/MysqlDialect.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.datasource; + +/** MySQL dialect. Uses backtick quoting for identifiers. */ +public final class MysqlDialect implements JdbcDialect { + + static final MysqlDialect INSTANCE = new MysqlDialect(); + + private MysqlDialect() {} + + @Override + public String datasourceName() { + return "mysql"; + } + + @Override + public String quoteIdentifier(String identifier) { + String escaped = identifier.replace("`", "``"); + return "`" + escaped + "`"; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SparkDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SparkDialect.java new file mode 100644 index 00000000000..3adb43fa398 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SparkDialect.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.datasource; + +/** Spark SQL dialect. Uses backtick quoting for identifiers. */ +public final class SparkDialect implements JdbcDialect { + + static final SparkDialect INSTANCE = new SparkDialect(); + + private SparkDialect() {} + + @Override + public String datasourceName() { + return "spark"; + } + + @Override + public String quoteIdentifier(String identifier) { + String escaped = identifier.replace("`", "``"); + return "`" + escaped + "`"; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SqliteDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SqliteDialect.java new file mode 100644 index 00000000000..a53255a9c67 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SqliteDialect.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.datasource; + +/** SQLite dialect. Uses double-quote quoting for identifiers. */ +public final class SqliteDialect implements JdbcDialect { + + static final SqliteDialect INSTANCE = new SqliteDialect(); + + private SqliteDialect() {} + + @Override + public String datasourceName() { + return "sqlite"; + } + + @Override + public String quoteIdentifier(String identifier) { + String escaped = identifier.replace("\"", "\"\""); + return "\"" + escaped + "\""; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/TableRef.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/TableRef.java new file mode 100644 index 00000000000..0dc69b72240 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/TableRef.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.datasource; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; +import java.util.Objects; + +/** + * Immutable reference to a table within a datasource, modeled as a three-level namespace: {@code + * catalog.schema.table}. + * + *

Design intent

+ * + *

Different datasources use different namespace hierarchies: Trino has {@code + * catalog.schema.table} (3 levels), MySQL has {@code database.table} (2 levels), and SQLite is + * essentially flat. This class normalises them into a single {@code (catalog?, schema?, table)} + * triple so that dialect "recipes" (preset analytical tools) can pass table references around + * without branching on the datasource type. + * + *

The mapping convention follows the JDBC standard: MySQL's "database" maps to the {@code + * schema} field (MySQL's JDBC driver reports databases via {@code DatabaseMetaData .getSchemas()}), + * while {@code catalog} is {@code null}. Trino and Spark use all three levels. + * + *

Where instances come from

+ * + * + * + *

How to turn a {@code TableRef} into SQL

+ * + *

Use {@link JdbcDialect#qualify(TableRef)} — it quotes each non-null segment with the + * dialect-appropriate character and joins them with {@code .}. Do not concatenate the + * fields manually; that bypasses identifier escaping and risks SQL injection. + * + *

Namespace mapping reference

+ * + * + * + * + * + * + * + * + * + * + * + *
DatasourcecatalogschemaExample
MySQL{@code null}database name{@code TableRef.of("mydb", "users")}
Trinoconnector nameschema name{@code TableRef.of("hive", "sales", "orders")}
Sparkcatalog namedatabase name{@code TableRef.of("spark_catalog", "default", "t")}
SQLite{@code null}{@code null} (or attached db){@code TableRef.of("t")}
+ */ +public final class TableRef { + + @JsonPropertyDescription( + "Catalog name (e.g. Trino connector or Spark catalog). Omit for MySQL/SQLite.") + private final String catalog; + + @JsonPropertyDescription( + "Schema or database name (e.g. MySQL database, Trino schema). Omit if not applicable.") + private final String schema; + + @JsonPropertyDescription("Table name (required).") + private final String table; + + @JsonCreator + private TableRef( + @JsonProperty("catalog") String catalog, + @JsonProperty("schema") String schema, + @JsonProperty(value = "table", required = true) String table) { + if (table == null || table.isEmpty()) { + throw new IllegalArgumentException("table must not be null or empty"); + } + this.catalog = emptyToNull(catalog); + this.schema = emptyToNull(schema); + this.table = table; + } + + /** Create a table reference with only a table name. */ + public static TableRef of(String table) { + return new TableRef(null, null, table); + } + + /** Create a table reference with schema (or database) and table. */ + public static TableRef of(String schema, String table) { + return new TableRef(null, schema, table); + } + + /** Create a table reference with catalog, schema, and table. */ + public static TableRef of(String catalog, String schema, String table) { + return new TableRef(catalog, schema, table); + } + + public String getCatalog() { + return catalog; + } + + public String getSchema() { + return schema; + } + + public String getTable() { + return table; + } + + private static String emptyToNull(String s) { + return (s == null || s.isEmpty()) ? null : s; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof TableRef)) return false; + TableRef that = (TableRef) o; + return Objects.equals(catalog, that.catalog) + && Objects.equals(schema, that.schema) + && Objects.equals(table, that.table); + } + + @Override + public int hashCode() { + return Objects.hash(catalog, schema, table); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder("TableRef{"); + if (catalog != null) sb.append("catalog='").append(catalog).append("', "); + if (schema != null) sb.append("schema='").append(schema).append("', "); + sb.append("table='").append(table).append("'}"); + return sb.toString(); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/TrinoDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/TrinoDialect.java new file mode 100644 index 00000000000..edacf2f87e2 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/TrinoDialect.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.datasource; + +/** Trino SQL dialect. Uses double-quote quoting for identifiers. */ +public final class TrinoDialect implements JdbcDialect { + + static final TrinoDialect INSTANCE = new TrinoDialect(); + + private TrinoDialect() {} + + @Override + public String datasourceName() { + return "trino"; + } + + @Override + public String quoteIdentifier(String identifier) { + String escaped = identifier.replace("\"", "\"\""); + return "\"" + escaped + "\""; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/prompt/SystemPromptBuilder.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/prompt/SystemPromptBuilder.java new file mode 100644 index 00000000000..a315010ecb8 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/prompt/SystemPromptBuilder.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.prompt; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.time.LocalDate; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +/** + * Builder for composing system prompts from Markdown resource sections. + * + *

Prompt resources live under {@code prompts/} on the classpath as {@code .md} files. The base + * template supports one placeholder: + * + *

+ * + *

A single datasource section ({@code prompts/datasource-{name}.md}) is set via {@link + * #datasource(String)}. Calling it again replaces the previous datasource section — an agent + * session talks to exactly one datasource. Free-form text sections are appended after the + * datasource section. + * + *

Usage: + * + *

{@code
+ * String prompt = SystemPromptBuilder.create()
+ *     .toolDescriptions(registry.describeTools())
+ *     .datasource("spark")
+ *     .section("Only query tables in the public schema.")
+ *     .build();
+ * }
+ */ +public final class SystemPromptBuilder { + + private static final String RESOURCE_PREFIX = "prompts/"; + + private String base; + private String toolDescriptions = ""; + private String datasourceSection; + private final List sections = new ArrayList<>(); + + private SystemPromptBuilder() { + this.base = loadResource("base"); + } + + public static SystemPromptBuilder create() { + return new SystemPromptBuilder(); + } + + /** Override the base prompt with custom text instead of the default {@code prompts/base.md}. */ + public SystemPromptBuilder base(String base) { + this.base = base; + return this; + } + + /** Set tool descriptions to substitute into the {@code {{tool_descriptions}}} placeholder. */ + public SystemPromptBuilder toolDescriptions(String toolDescriptions) { + if (toolDescriptions != null) { + this.toolDescriptions = toolDescriptions; + } + return this; + } + + /** + * Set the datasource-specific guidelines. Loads {@code prompts/datasource-{name}.md} from the + * classpath. If the resource does not exist, falls back to a generic "current dialect is X" hint. + * + *

Calling this method again replaces the previous datasource section. + */ + public SystemPromptBuilder datasource(String name) { + if (name != null) { + String lower = name.toLowerCase(); + String content = loadResourceOrNull("datasource-" + lower); + this.datasourceSection = (content != null) ? content : genericDialectSection(lower); + } + return this; + } + + private static String genericDialectSection(String name) { + return "## Current SQL dialect: " + + name + + "\n\nFollow " + + name + + " SQL syntax rules. When unsure about specific syntax, run schema exploration commands" + + " (e.g. SHOW TABLES, DESCRIBE) to verify before writing the query."; + } + + /** Append a free-form text section to the prompt. */ + public SystemPromptBuilder section(String text) { + if (text != null && !text.isEmpty()) { + sections.add(text); + } + return this; + } + + /** Build the final prompt by resolving placeholders and joining all sections. */ + public String build() { + String result = base.replace("{{tool_descriptions}}", toolDescriptions); + + StringBuilder sb = new StringBuilder(result); + sb.append("\n\nToday's date: ").append(LocalDate.now()).append("."); + if (datasourceSection != null) { + sb.append("\n\n").append(datasourceSection); + } + for (String section : sections) { + sb.append("\n\n").append(section); + } + return sb.toString(); + } + + static String loadResource(String name) { + String content = loadResourceOrNull(name); + if (content == null) { + throw new IllegalArgumentException( + "Prompt resource not found: " + RESOURCE_PREFIX + name + ".md"); + } + return content; + } + + private static String loadResourceOrNull(String name) { + String path = RESOURCE_PREFIX + name + ".md"; + try (InputStream is = SystemPromptBuilder.class.getClassLoader().getResourceAsStream(path)) { + if (is == null) { + return null; + } + try (BufferedReader reader = + new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8))) { + return reader.lines().collect(Collectors.joining("\n")); + } + } catch (IOException e) { + throw new RuntimeException("Failed to read prompt resource: " + path, e); + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/AgentTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/AgentTool.java new file mode 100644 index 00000000000..297a7c8d74a --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/AgentTool.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool; + +/** + * Base interface for agent tools. Separates tool metadata (name, description) and execution logic + * from the parameter schema (the args class with {@code @JsonPropertyDescription} annotations). + * + *

The args class defines parameter fields for JSON Schema generation and SDK deserialization. + * The tool implementation holds runtime dependencies (e.g. DataSource) and performs the actual + * work. + * + * @param the args class with fields annotated by {@code @JsonPropertyDescription} + */ +public interface AgentTool { + + /** Unique name for this tool, used by the LLM to select it. */ + String name(); + + /** Description shown to the LLM to help it decide when to use this tool. */ + String description(); + + /** Returns the args class for JSON Schema generation and deserialization. */ + Class argsType(); + + /** + * Returns the risk level of this tool, used to determine whether user approval is required. + * Defaults to {@link ToolRiskLevel#SAFE}. + */ + default ToolRiskLevel riskLevel() { + return ToolRiskLevel.SAFE; + } + + /** + * Execute the tool with the given deserialized arguments. + * + * @param args the deserialized arguments from the LLM's tool call + * @return the result string to feed back to the LLM + */ + String execute(T args); +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistry.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistry.java new file mode 100644 index 00000000000..a403c66b58d --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistry.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.openai.core.JsonValue; +import com.openai.models.FunctionDefinition; +import com.openai.models.FunctionParameters; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import com.openai.models.chat.completions.ChatCompletionFunctionTool; +import com.openai.models.chat.completions.ChatCompletionTool; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicLong; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Registry for agent tools. Provides tool lookup, schema generation for LLM requests, and + * deserialization + execution of tool calls. + * + *

Tool metadata (name, description) comes from {@link AgentTool}. Parameter schemas are + * generated from the args class via {@link ToolSchemaGenerator}. The SDK's {@code + * function.arguments(Class)} is used for deserializing LLM responses. + * + *

Concurrency model: single-writer at startup. {@link #register} is expected to be called + * only during engine initialization, before any LLM request is dispatched. After that the registry + * is read-only. {@link #addToolsTo} therefore does not synchronize on the {@code tools} map, which + * is safe under this single-writer-at-startup assumption — do NOT "fix" this by adding a lock + * around the read path. If we ever need true runtime registration, the right move is to swap {@code + * tools} for an immutable snapshot per refresh, not to grab a mutex on the hot path. + */ +public class ToolRegistry implements AutoCloseable { + + private static final Logger LOG = LoggerFactory.getLogger(ToolRegistry.class); + private static final ObjectMapper JSON = new ObjectMapper(); + + /** + * Hard ceiling on concurrent tool-call workers. Sized for the realistic working set: one LLM turn + * rarely dispatches more than a handful of parallel {@code tool_calls}, and a Kyuubi data-agent + * engine serves a small number of overlapping sessions. When the ceiling is reached we reject + * fast and tell the LLM to retry — better UX than burying the request in a queue for minutes. + * Tuned as a safety rail, not a user knob. + */ + private static final int MAX_POOL_SIZE = 8; + + private final Map> tools = new LinkedHashMap<>(); + private volatile Map cachedSpecs; + private final long toolCallTimeoutSeconds; + private final ExecutorService executor; + + /** + * @param toolCallTimeoutSeconds wall-clock cap on every {@link #executeTool} call, sourced from + * {@code kyuubi.engine.data.agent.tool.call.timeout}. When the timeout fires, the thread is + * interrupted and a descriptive error is returned to the LLM. + */ + public ToolRegistry(long toolCallTimeoutSeconds) { + this.toolCallTimeoutSeconds = toolCallTimeoutSeconds; + AtomicLong threadId = new AtomicLong(); + this.executor = + new ThreadPoolExecutor( + 0, + MAX_POOL_SIZE, + 60L, + TimeUnit.SECONDS, + new SynchronousQueue<>(), + r -> { + Thread t = new Thread(r, "tool-call-worker-" + threadId.incrementAndGet()); + t.setDaemon(true); + return t; + }); + } + + /** Shut down the worker pool. Idempotent. */ + @Override + public void close() { + executor.shutdownNow(); + } + + /** Register a tool. Keyed by {@link AgentTool#name()}. */ + public synchronized ToolRegistry register(AgentTool tool) { + tools.put(tool.name(), tool); + cachedSpecs = null; // invalidate cache + return this; + } + + public synchronized boolean isEmpty() { + return tools.isEmpty(); + } + + /** Returns the risk level of the named tool, or {@link ToolRiskLevel#SAFE} if not found. */ + public synchronized ToolRiskLevel getRiskLevel(String toolName) { + AgentTool tool = tools.get(toolName); + return tool != null ? tool.riskLevel() : ToolRiskLevel.SAFE; + } + + /** Add all tools to the ChatCompletion request builder. */ + public void addToolsTo(ChatCompletionCreateParams.Builder builder) { + Map specs = ensureSpecs(); + for (ChatCompletionTool spec : specs.values()) { + builder.addTool(spec); + } + } + + private synchronized Map ensureSpecs() { + if (cachedSpecs == null) { + Map specs = new LinkedHashMap<>(); + for (AgentTool tool : tools.values()) { + specs.put(tool.name(), buildChatCompletionTool(tool)); + } + cachedSpecs = specs; + } + return cachedSpecs; + } + + /** + * Execute a tool call: deserialize the JSON args, then delegate to the tool, with a wall-clock + * timeout sourced from {@code kyuubi.engine.data.agent.tool.call.timeout}. If the tool does not + * finish within the timeout, the worker thread is interrupted and a descriptive error is returned + * to the LLM so it can react (e.g. simplify the query, retry with LIMIT). + * + * @param toolName the function name from the LLM response + * @param argsJson the raw JSON arguments string + * @return the result string, or an error message + */ + @SuppressWarnings("unchecked") + public String executeTool(String toolName, String argsJson) { + AgentTool tool; + synchronized (this) { + tool = tools.get(toolName); + } + if (tool == null) { + return "Error: unknown tool '" + toolName + "'"; + } + return executeWithTimeout((AgentTool) tool, argsJson); + } + + private String executeWithTimeout(AgentTool tool, String argsJson) { + Callable task = + () -> { + T args = JSON.readValue(argsJson, tool.argsType()); + return tool.execute(args); + }; + Future future; + try { + future = executor.submit(task); + } catch (RejectedExecutionException e) { + LOG.warn("Tool call '{}' rejected — worker pool saturated at {}", tool.name(), MAX_POOL_SIZE); + return "Error: tool call '" + + tool.name() + + "' rejected — server is handling too many concurrent tool calls. " + + "Retry in a moment."; + } + try { + return future.get(toolCallTimeoutSeconds, TimeUnit.SECONDS); + } catch (TimeoutException e) { + future.cancel(true); + LOG.warn("Tool call '{}' timed out after {} seconds", tool.name(), toolCallTimeoutSeconds); + return "Error: tool call '" + + tool.name() + + "' timed out after " + + toolCallTimeoutSeconds + + " seconds. " + + "Try simplifying the query or adding filters to reduce execution time."; + } catch (ExecutionException e) { + Throwable cause = e.getCause() != null ? e.getCause() : e; + return "Error executing " + tool.name() + ": " + cause.getMessage(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return "Error: tool call '" + tool.name() + "' was interrupted."; + } + } + + private static ChatCompletionTool buildChatCompletionTool(AgentTool tool) { + Map schema = ToolSchemaGenerator.generateSchema(tool.argsType()); + + FunctionParameters.Builder paramsBuilder = FunctionParameters.builder(); + for (Map.Entry entry : schema.entrySet()) { + paramsBuilder.putAdditionalProperty(entry.getKey(), JsonValue.from(entry.getValue())); + } + + return ChatCompletionTool.ofFunction( + ChatCompletionFunctionTool.builder() + .function( + FunctionDefinition.builder() + .name(tool.name()) + .description(tool.description()) + .parameters(paramsBuilder.build()) + .build()) + .build()); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolSchemaGenerator.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolSchemaGenerator.java new file mode 100644 index 00000000000..ffa104ada3e --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolSchemaGenerator.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.victools.jsonschema.generator.OptionPreset; +import com.github.victools.jsonschema.generator.SchemaGenerator; +import com.github.victools.jsonschema.generator.SchemaGeneratorConfig; +import com.github.victools.jsonschema.generator.SchemaGeneratorConfigBuilder; +import com.github.victools.jsonschema.generator.SchemaVersion; +import com.github.victools.jsonschema.module.jackson.JacksonModule; +import com.github.victools.jsonschema.module.jackson.JacksonOption; +import java.util.Map; + +/** + * Generates JSON Schema from annotated Java classes using Jackson annotations. Used to build the + * {@code parameters} section of OpenAI function definitions. + * + *

Backed by victools + * jsonschema-generator with its Jackson module, which natively reads {@code @JsonProperty} and + * {@code @JsonPropertyDescription} annotations. + */ +public class ToolSchemaGenerator { + + private static final ObjectMapper JSON = new ObjectMapper(); + private static final SchemaGenerator GENERATOR; + + static { + JacksonModule jacksonModule = new JacksonModule(JacksonOption.RESPECT_JSONPROPERTY_REQUIRED); + SchemaGeneratorConfig config = + new SchemaGeneratorConfigBuilder(SchemaVersion.DRAFT_7, OptionPreset.PLAIN_JSON) + .with(jacksonModule) + .build(); + GENERATOR = new SchemaGenerator(config); + } + + /** Generate a JSON Schema (as a Map) from the given args class using Jackson annotations. */ + @SuppressWarnings("unchecked") + public static Map generateSchema(Class argsClass) { + JsonNode schemaNode = GENERATOR.generateSchema(argsClass); + Map schema = JSON.convertValue(schemaNode, Map.class); + // Remove $schema key — OpenAI function parameters don't expect it. + schema.remove("$schema"); + return schema; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryTool.java new file mode 100644 index 00000000000..06b12f2be72 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryTool.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool.sql; + +import javax.sql.DataSource; +import org.apache.kyuubi.engine.dataagent.tool.AgentTool; +import org.apache.kyuubi.engine.dataagent.tool.ToolRiskLevel; + +/** + * Mutating SQL tool for INSERT/UPDATE/DELETE/MERGE/DDL statements. Marked {@link + * ToolRiskLevel#DESTRUCTIVE} so the approval layer gates every call. There is no read-only check — + * passing a SELECT here will execute fine but the LLM will be told via the description to use + * {@link RunSelectQueryTool} for read paths. + */ +public class RunMutationQueryTool implements AgentTool { + + private final DataSource dataSource; + private final int queryTimeoutSeconds; + + /** + * @param dataSource pooled JDBC connection source + * @param queryTimeoutSeconds value for {@code Statement.setQueryTimeout}, sourced from {@code + * kyuubi.engine.data.agent.query.timeout}. This is the JDBC-level inner timeout that lets + * Spark/Trino cooperatively cancel a long-running query. The outer wall-clock cap on the + * whole tool call is enforced separately by the agent runtime via {@code + * kyuubi.engine.data.agent.tool.call.timeout}. + */ + public RunMutationQueryTool(DataSource dataSource, int queryTimeoutSeconds) { + this.dataSource = dataSource; + this.queryTimeoutSeconds = queryTimeoutSeconds; + } + + @Override + public String name() { + return "run_mutation_query"; + } + + @Override + public String description() { + return "Execute a SQL statement that MODIFIES data or schema. " + + "Use this for INSERT, UPDATE, DELETE, MERGE, CREATE, DROP, ALTER, TRUNCATE, GRANT, etc. " + + "REQUIRES USER APPROVAL before execution. " + + "For read-only queries (SELECT/SHOW/DESCRIBE/EXPLAIN), use run_select_query instead."; + } + + @Override + public ToolRiskLevel riskLevel() { + return ToolRiskLevel.DESTRUCTIVE; + } + + @Override + public Class argsType() { + return SqlQueryArgs.class; + } + + @Override + public String execute(SqlQueryArgs args) { + return SqlExecutor.execute(dataSource, args.sql, queryTimeoutSeconds); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryTool.java new file mode 100644 index 00000000000..0c57cc049ed --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryTool.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool.sql; + +import javax.sql.DataSource; +import org.apache.kyuubi.engine.dataagent.tool.AgentTool; +import org.apache.kyuubi.engine.dataagent.tool.ToolRiskLevel; + +/** + * Read-only SQL tool. Accepts only statements whose first significant token is a known read-only + * keyword (SELECT, WITH, SHOW, DESCRIBE, EXPLAIN, ...). Mutating statements are rejected with a + * clear error so the LLM knows to use {@link RunMutationQueryTool} instead. + */ +public class RunSelectQueryTool implements AgentTool { + + private final DataSource dataSource; + private final int queryTimeoutSeconds; + + /** + * @param dataSource pooled JDBC connection source + * @param queryTimeoutSeconds value for {@code Statement.setQueryTimeout}, sourced from {@code + * kyuubi.engine.data.agent.query.timeout}. This is the JDBC-level inner timeout that lets + * Spark/Trino cooperatively cancel a long-running query. The outer wall-clock cap on the + * whole tool call is enforced separately by the agent runtime via {@code + * kyuubi.engine.data.agent.tool.call.timeout}. + */ + public RunSelectQueryTool(DataSource dataSource, int queryTimeoutSeconds) { + this.dataSource = dataSource; + this.queryTimeoutSeconds = queryTimeoutSeconds; + } + + @Override + public String name() { + return "run_select_query"; + } + + @Override + public String description() { + return "Execute a READ-ONLY SQL query and return the results. " + + "Accepts SELECT, WITH (CTE), SHOW, DESCRIBE, EXPLAIN, USE, VALUES, TABLE, LIST. " + + "REJECTS any statement that modifies data or schema (INSERT, UPDATE, DELETE, MERGE, " + + "CREATE, DROP, ALTER, TRUNCATE, GRANT, ...). " + + "For mutating statements, use the run_mutation_query tool instead."; + } + + @Override + public ToolRiskLevel riskLevel() { + return ToolRiskLevel.SAFE; + } + + @Override + public Class argsType() { + return SqlQueryArgs.class; + } + + @Override + public String execute(SqlQueryArgs args) { + String sql = args.sql; + if (sql == null || sql.trim().isEmpty()) { + return "Error: 'sql' parameter is required."; + } + if (!SqlReadOnlyChecker.isReadOnly(sql)) { + return "Error: run_select_query only accepts read-only statements " + + "(SELECT, WITH, SHOW, DESCRIBE, EXPLAIN, USE, VALUES, TABLE, LIST). " + + "Use run_mutation_query for INSERT/UPDATE/DELETE/DDL statements."; + } + return SqlExecutor.execute(dataSource, sql, queryTimeoutSeconds); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlExecutor.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlExecutor.java new file mode 100644 index 00000000000..b36bb260874 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlExecutor.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool.sql; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.Statement; +import javax.sql.DataSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Shared JDBC execution + result formatting for SQL tools. Owned by {@link RunSelectQueryTool} and + * {@link RunMutationQueryTool}; not part of the public tool API. + */ +final class SqlExecutor { + + private static final Logger LOG = LoggerFactory.getLogger(SqlExecutor.class); + + private SqlExecutor() {} + + /** + * Execute the given SQL and return a markdown-formatted result string. Errors are caught and + * returned as a one-line "Error: ..." message so the LLM can self-correct. + * + *

No client-side row cap. The tool deliberately does not call {@link + * Statement#setMaxRows(int)}. Result-size discipline is delegated to the LLM via the system + * prompt, which requires every read query to include an explicit {@code LIMIT} (or equivalent + * pagination). Adding a hidden client-side cap here would silently truncate the LLM's view of the + * data and create magic numbers that are impossible to tune across deployments. + * + *

Two-layer timeout model. The {@code timeoutSeconds} arg here is the JDBC inner + * timeout, set via {@link Statement#setQueryTimeout(int)} and sourced from {@code + * kyuubi.engine.data.agent.query.timeout}. This is the cooperative path: the driver tells the + * server (Spark/Trino/...) to cancel the query, and cluster resources are released gracefully. + * The outer hard wall-clock cap on the whole tool call is enforced separately by the agent + * runtime via {@code kyuubi.engine.data.agent.tool.call.timeout} — that one fires even if the + * JDBC driver ignores {@code setQueryTimeout} (some legacy Hive/Spark Thrift drivers do), and is + * the ungraceful "kill the thread" backstop. + */ + static String execute(DataSource dataSource, String sql, int timeoutSeconds) { + if (sql == null || sql.trim().isEmpty()) { + return "Error: 'sql' parameter is required."; + } + + try (Connection conn = dataSource.getConnection(); + Statement stmt = conn.createStatement()) { + stmt.setQueryTimeout(timeoutSeconds); + boolean hasResultSet = stmt.execute(sql); + if (hasResultSet) { + try (ResultSet rs = stmt.getResultSet()) { + return formatResult(rs); + } + } else { + return "[Statement executed successfully. " + stmt.getUpdateCount() + " row(s) affected]"; + } + } catch (Exception e) { + LOG.warn("SQL execution error", e); + return "Error: SQL execution failed. " + extractRootCause(e); + } + } + + private static String formatResult(ResultSet rs) throws Exception { + ResultSetMetaData meta = rs.getMetaData(); + int colCount = meta.getColumnCount(); + StringBuilder sb = new StringBuilder(); + + sb.append("| "); + for (int i = 1; i <= colCount; i++) { + if (i > 1) sb.append(" | "); + sb.append(meta.getColumnLabel(i)); + } + sb.append(" |\n|"); + for (int i = 1; i <= colCount; i++) { + sb.append(" --- |"); + } + sb.append("\n"); + + int rowCount = 0; + while (rs.next()) { + sb.append("| "); + for (int i = 1; i <= colCount; i++) { + if (i > 1) sb.append(" | "); + String val = rs.getString(i); + if (val != null) { + sb.append(val.replace("|", "\\|")); + } else { + sb.append("NULL"); + } + } + sb.append(" |\n"); + rowCount++; + } + sb.append("\n[").append(rowCount).append(" row(s) returned]"); + return sb.toString(); + } + + /** + * Walk the exception cause chain to find the root cause message, then truncate to a single-line + * summary so the LLM can diagnose the problem without a full stack trace. + */ + private static String extractRootCause(Exception e) { + Throwable root = e; + while (root.getCause() != null) { + root = root.getCause(); + } + String msg = root.getMessage(); + if (msg == null) { + return root.getClass().getSimpleName(); + } + // Take only the first line — Spark errors often include the full query plan after a newline. + int newline = msg.indexOf('\n'); + if (newline > 0) { + msg = msg.substring(0, newline); + } + if (msg.length() > 500) { + msg = msg.substring(0, 500) + "..."; + } + return msg; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlQueryArgs.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlQueryArgs.java new file mode 100644 index 00000000000..cc2db71a31b --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlQueryArgs.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool.sql; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +/** Parameter schema shared by {@link RunSelectQueryTool} and {@link RunMutationQueryTool}. */ +public class SqlQueryArgs { + + @JsonProperty(required = true) + @JsonPropertyDescription( + "The SQL statement to execute. Read-only queries MUST include a LIMIT clause " + + "(or equivalent) — the tool does not enforce row limits.") + public String sql; +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlReadOnlyChecker.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlReadOnlyChecker.java new file mode 100644 index 00000000000..d2cac9ae518 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlReadOnlyChecker.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool.sql; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Locale; +import java.util.Set; + +/** + * Lightweight read-only check for the {@code run_select_query} tool. + * + *

Design note — this is a guardrail, not a security boundary. The real approval gate for + * writes is the separate {@code run_mutation_query} tool; an LLM with DML intent calls that + * directly, not via a crafted bypass of this check. This checker only catches the "LLM hallucinates + * and sends a mutation to the read-only tool" case, for which a first-token keyword whitelist is + * enough. Pulling a full SQL parser into the tool layer for a sanity check isn't warranted. + * + *

The whitelist is intentionally broad to cover common big-data read-only entry points. If a + * legitimate read-only keyword is missing, add it here rather than working around it elsewhere. + */ +final class SqlReadOnlyChecker { + + /** + * Read-only keywords accepted as the first token of a SQL statement. + * + *

    + *
  • {@code SELECT} — standard query + *
  • {@code WITH} — CTE-prefixed query + *
  • {@code VALUES} — standalone VALUES expression (Trino/Postgres) + *
  • {@code TABLE} — Spark/Postgres {@code TABLE x} shorthand for {@code SELECT * FROM x} + *
  • {@code FROM} — Hive {@code FROM t SELECT ...} variant + *
  • {@code SHOW} — SHOW TABLES / DATABASES / CREATE / PARTITIONS / FUNCTIONS / COLUMNS / + * VIEWS / TBLPROPERTIES / INDEXES / GRANT / ROLE / STATS / LOCKS / ... + *
  • {@code DESCRIBE} / {@code DESC} — table / function / formatted / extended + *
  • {@code EXPLAIN} — query plan inspection + *
  • {@code USE} — switch session catalog/database; not data-mutating + *
  • {@code LIST} — Spark {@code LIST FILE} / {@code LIST JAR} inspection + *
  • {@code HELP} — some engines expose interactive help + *
+ */ + private static final Set READ_ONLY_KEYWORDS = + Collections.unmodifiableSet( + new HashSet<>( + Arrays.asList( + "SELECT", + "WITH", + "VALUES", + "TABLE", + "FROM", + "SHOW", + "DESCRIBE", + "DESC", + "EXPLAIN", + "USE", + "LIST", + "HELP"))); + + private SqlReadOnlyChecker() {} + + /** + * Returns true if the SQL's first significant token is a known read-only keyword. Strips leading + * whitespace and {@code --} / {@code /* *}{@code /} comments before testing. + */ + static boolean isReadOnly(String sql) { + if (sql == null) { + return false; + } + String stripped = stripLeadingNoise(sql); + if (stripped.isEmpty()) { + return false; + } + int end = 0; + while (end < stripped.length() && Character.isLetter(stripped.charAt(end))) { + end++; + } + if (end == 0) { + return false; + } + String firstToken = stripped.substring(0, end).toUpperCase(Locale.ROOT); + return READ_ONLY_KEYWORDS.contains(firstToken); + } + + private static String stripLeadingNoise(String sql) { + int i = 0; + int n = sql.length(); + while (i < n) { + char c = sql.charAt(i); + if (Character.isWhitespace(c)) { + i++; + } else if (c == '-' && i + 1 < n && sql.charAt(i + 1) == '-') { + // line comment until newline + i += 2; + while (i < n && sql.charAt(i) != '\n') { + i++; + } + } else if (c == '/' && i + 1 < n && sql.charAt(i + 1) == '*') { + // block comment until */ + i += 2; + while (i + 1 < n && !(sql.charAt(i) == '*' && sql.charAt(i + 1) == '/')) { + i++; + } + if (i + 1 < n) { + i += 2; + } else { + i = n; + } + } else { + break; + } + } + return sql.substring(i); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/resources/prompts/base.md b/externals/kyuubi-data-agent-engine/src/main/resources/prompts/base.md new file mode 100644 index 00000000000..5a0b544cc15 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/resources/prompts/base.md @@ -0,0 +1,93 @@ +You are a data analysis agent. You query databases and explain data — nothing else. +You write and execute SQL to answer questions. You never fabricate data. +When uncertain about data meaning, ask the user rather than assuming. + +**Scope:** SELECT queries, schema exploration, and data interpretation. +You do not handle ETL pipelines, database administration, DDL migrations, or application code generation. + +## Available tools + +{{tool_descriptions}} + +### When NOT to use tools + +- If the question can be answered from your knowledge or conversation context (e.g. "What is a LEFT JOIN?"), answer directly. +- If you have already inspected a table's schema in this conversation, do not query it again — use the previous result. +- If the user pastes SQL for review or optimization, analyze the text directly unless you need to verify execution. + +## SQL workflow + +1. **Explore**: Use schema exploration SQL to understand tables and columns before writing queries. Read sample values — they reveal exact column contents (enum values, date formats, ID patterns) so you can write precise WHERE clauses without exploratory queries. +2. **Estimate scale**: Before querying a table for the first time, estimate its row count using metadata or a lightweight query. Classify the table and choose the right strategy: + - **S-class (< 1M rows)**: full `COUNT(*)`, `COUNT(DISTINCT col)` are safe. + - **M-class (1M–100M rows)**: use sampling or approximate functions for statistics. Always add filters before aggregation. + - **L-class (> 100M rows)**: metadata-only + `LIMIT` sampling. Never run unfiltered `COUNT(DISTINCT)` or full-table `JOIN` without partition/filter pruning. +3. **Write & execute**: Prefer flat JOINs over subqueries and CTEs — use CTEs or subqueries only when the logic genuinely requires two-level aggregation or self-reference. For simple aggregate + sort, use `ORDER BY ... LIMIT` directly. For complex analyses, break into smaller queries: validate assumptions first (row counts, distinct values, date ranges), then build the full query. +4. **Validate**: Check row counts, value ranges, and NULLs. If results look wrong, investigate before presenting. +5. **Present**: Lead with the conclusion, then explain reasoning. + +## Query risk control + +Classify query complexity before execution: + +- **L1 (single table + LIMIT)**: always safe, execute directly. +- **L2 (single table + aggregation + filter)**: safe with proper filters. +- **L3 (2–3 table JOIN + aggregation)**: check that the largest table has filter conditions; prefer filtering before joining. +- **L4 (4+ table JOIN or complex aggregation)**: run `EXPLAIN` first on large datasets; on M/L-class tables, add partition filters or reduce to L3. + +General rules: +- **Every read query MUST include an explicit `LIMIT`** (or equivalent: `TABLESAMPLE`, `WHERE` on a partition column reducing to ≤ a few thousand rows, `GROUP BY` aggregation that collapses to ≤ a few hundred rows). The tool does NOT cap rows for you — an unbounded `SELECT *` will pull the entire result set into the context window and break the conversation. Default to `LIMIT 100` for inspection, raise it consciously when you need more rows. +- Prefer aggregation over detail rows — `GROUP BY` with `COUNT`/`SUM` over `SELECT *`. +- Filter the largest table first, then JOIN to smaller tables. +- For complex queries on large datasets, run `EXPLAIN` to verify the plan uses partition pruning, predicate pushdown, or broadcast joins before executing. + +## Field attribution + +When multiple tables contain similar columns, always choose the column from the **entity's primary table** — the table whose purpose is that entity. + +1. Identify which entity the field describes (school? student? order?). +2. Pick the table named after that entity — it is the authoritative source. +3. Do not use a field just because it exists in a table you already selected. Confirm the table is the correct home for that attribute. + +## Relationship discovery + +When you need to find how tables connect (before writing JOINs), use these signals in order: + +1. **Naming convention**: columns named `{table}_id`, `{table}_sk`, or `{table}_key` likely reference the table of that name. For example, `ss_customer_sk` → joins to `customer.c_customer_sk`. Surrogate keys ending in `_sk` are common in star schemas. +2. **Type compatibility**: join columns must share compatible types (int↔bigint is OK, int↔string is a red flag). Check both sides with DESCRIBE before writing the JOIN. +3. **Value overlap verification**: when a join path is uncertain, run a quick overlap check: + - Compare value ranges: `SELECT MIN(col), MAX(col) FROM both_tables` + - Check coverage: `SELECT COUNT(DISTINCT a.fk) as matched FROM a JOIN b ON a.fk = b.pk` vs total distinct values + - Check for orphans: `SELECT COUNT(*) FROM a LEFT JOIN b ON a.fk = b.pk WHERE b.pk IS NULL` +4. **Cardinality**: determine if the relationship is 1:1, 1:N, or N:M by comparing distinct counts on both sides. This affects whether you need GROUP BY or deduplication after joining. + +Report discovered relationships to the user with confidence level (certain for explicit FKs, likely for naming matches, uncertain for inferred paths) so they can validate. + +## Error handling + +- If a query fails, analyze the error message, identify the root cause, fix the SQL, then retry. Never retry the same query verbatim. +- On permission errors or table/column-not-found errors, **stop immediately** — report to the user. +- **Maximum 3 retries** per query. After 3 failures, explain the problem and ask for guidance. + +## Security + +- Use `run_select_query` for all read-only work (SELECT, SHOW, DESCRIBE, EXPLAIN, WITH, USE, VALUES, TABLE, LIST). It will reject any mutating statement. +- Use `run_mutation_query` only when the user has explicitly asked you to modify data or schema (INSERT, UPDATE, DELETE, MERGE, CREATE, DROP, ALTER, TRUNCATE, GRANT, ...). The platform requires user approval for every call to this tool. +- Reject queries that use system functions, file I/O (`LOAD_FILE`, `INTO OUTFILE`, `COPY`, `pg_read_file`), or administrative commands — even if wrapped in a SELECT. +- Never expose database credentials or connection strings in your responses. + +## Result validation + +- **Empty results**: If 0 rows are returned, check WHERE conditions, table names, and JOIN keys. Run a simpler query first to confirm data exists. +- **Outliers**: Flag columns with excessive NULLs, negative amounts, future dates, or anomalous values to the user rather than silently ignoring them. +- A task is complete only when results are validated and clearly presented. Do not report "no data" without investigating the cause. + +## Output style + +- Respond in the same language the user used. +- Lead with the conclusion, then explain reasoning. +- Present tabular data in Markdown tables. Truncate result sets beyond 20 rows and state the total count. +- For multi-step analyses, number each step so the user can follow the logic. +- When results are ambiguous or incomplete, state limitations explicitly. +- Display NULL as `NULL`, not as empty strings or "N/A". +- Do not restate the user's question. \ No newline at end of file diff --git a/externals/kyuubi-data-agent-engine/src/main/resources/prompts/datasource-mysql.md b/externals/kyuubi-data-agent-engine/src/main/resources/prompts/datasource-mysql.md new file mode 100644 index 00000000000..10f8e4da84b --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/resources/prompts/datasource-mysql.md @@ -0,0 +1,22 @@ +## MySQL guidelines + +- You are connected to a MySQL-compatible database. +- Use backticks to quote identifiers that contain special characters. + +### Schema exploration + +- List databases: `SHOW DATABASES` +- List tables: `SHOW TABLES FROM database_name` +- Describe table: `DESCRIBE database_name.table_name` +- Show columns: `SHOW COLUMNS FROM database_name.table_name` +- Show create statement: `SHOW CREATE TABLE database_name.table_name` +- Show indexes: `SHOW INDEX FROM database_name.table_name` + +**Important:** Always use fully-qualified table names (`database_name.table_name`) in all SQL statements. +Do NOT use `USE database_name` — each tool call may run on a different connection, so session state is not preserved between calls. + +### Performance tips + +- For large datasets, avoid `SELECT *` without `LIMIT`. +- **Index usage**: include indexed columns in `WHERE` clauses. Use `EXPLAIN` to verify index usage. +- **EXPLAIN**: run `EXPLAIN sql` before executing complex queries on large tables. diff --git a/externals/kyuubi-data-agent-engine/src/main/resources/prompts/datasource-spark.md b/externals/kyuubi-data-agent-engine/src/main/resources/prompts/datasource-spark.md new file mode 100644 index 00000000000..fab8cd9c35d --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/resources/prompts/datasource-spark.md @@ -0,0 +1,36 @@ +## Spark SQL guidelines + +- You are connected to a Spark SQL engine via Kyuubi. +- Use backticks to quote identifiers that contain special characters. +- Spark supports ANSI SQL, HiveQL extensions, and Delta Lake syntax. + +### Schema exploration + +- List databases: `SHOW DATABASES` +- List tables: `SHOW TABLES IN database_name` +- Describe table: `DESCRIBE TABLE database_name.table_name` or `DESCRIBE TABLE EXTENDED database_name.table_name` +- Show columns: `SHOW COLUMNS IN database_name.table_name` +- Show create statement: `SHOW CREATE TABLE database_name.table_name` +- Show partitions: `SHOW PARTITIONS database_name.table_name` + +**Important:** Always use fully-qualified table names (`database_name.table_name`) in all SQL statements. +Do NOT use `USE database_name` — each tool call may run on a different connection, so session state like the current database is not preserved between calls. + +### Estimating table size + +- Use `DESCRIBE TABLE EXTENDED table_name` and check the `Statistics` row (e.g. `7783057 bytes`) to estimate size without scanning data. +- For partitioned tables, `SHOW PARTITIONS` reveals partition count — a rough proxy for data volume. +- If metadata is unavailable, use `SELECT COUNT(*) FROM table TABLESAMPLE(1 PERCENT) * 100` for a fast estimate. + +### Performance tips + +- Prefer built-in functions: `explode`, `collect_list`, window functions, etc. +- For large datasets, avoid `SELECT *` without `LIMIT`. +- **Partition pruning**: always include the partition column (often `*_date_sk`, `dt`, `ds`) in `WHERE` clauses. Check `EXPLAIN` output for `PartitionFilters` to verify pruning is effective. +- **Broadcast joins**: Spark auto-broadcasts small tables in JOINs. When joining a large fact table with dimension tables, put dimension filters in the `WHERE` clause so the optimizer can apply dynamic partition pruning. +- **Sampling**: use `TABLESAMPLE(N PERCENT)` for approximate statistics on large tables. Multiply results by `100/N` to estimate full-table values. +- **Approximate functions**: use `APPROX_COUNT_DISTINCT(col)` instead of `COUNT(DISTINCT col)` for cardinality estimation on large tables — it is orders of magnitude faster. +- **EXPLAIN**: run `EXPLAIN sql` before executing complex queries on large tables. Check for: + - `BroadcastHashJoin` (good) vs `SortMergeJoin` (expensive shuffle) + - `PartitionFilters` and `PushedFilters` (predicate pushdown working) + - `dynamicpruningexpression` (dynamic partition pruning active) \ No newline at end of file diff --git a/externals/kyuubi-data-agent-engine/src/main/resources/prompts/datasource-sqlite.md b/externals/kyuubi-data-agent-engine/src/main/resources/prompts/datasource-sqlite.md new file mode 100644 index 00000000000..14202a85bd8 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/resources/prompts/datasource-sqlite.md @@ -0,0 +1,37 @@ +## SQLite SQL compatibility + +SQLite differs significantly from MySQL/PostgreSQL. Follow these rules strictly. + +### Schema exploration + +- List tables: `SELECT name FROM sqlite_master WHERE type='table' ORDER BY name` +- Describe table: `PRAGMA table_info(table_name)` +- Show indexes: `PRAGMA index_list(table_name)` +- Sample data: `SELECT * FROM table_name LIMIT 3` + +### Type casting +- `CAST(x AS REAL)` — no FLOAT/DOUBLE. `CAST(x AS INTEGER)` — no INT/BIGINT. +- `CAST(x AS TEXT)` — no VARCHAR/CHAR(n). No BOOLEAN type — use 1/0. + +### Date and time +- No `NOW()` or `CURDATE()`. Use `DATE('now')`, `DATETIME('now')`. +- Arithmetic: `DATE(col, '+7 days')`, `DATE(col, 'start of year')` — no DATEADD/INTERVAL. +- Extract: `STRFTIME('%Y', col)` for year — no YEAR(), EXTRACT(), MONTH(). +- Difference: `JULIANDAY(d1) - JULIANDAY(d2)` — no DATEDIFF(). + +### String functions +- Concatenation: `||` — no CONCAT(). Substring: `SUBSTR()` — no SUBSTRING()/LEFT()/RIGHT(). +- Locate: `INSTR(haystack, needle)` — no LOCATE()/CHARINDEX()/POSITION(). +- `GROUP_CONCAT(x, ',')` — not `GROUP_CONCAT(x SEPARATOR ',')`. +- No LPAD/RPAD. + +### Math functions +- No CEIL/FLOOR/LOG/LN/EXP/POWER/SQRT. Modulo: `x % y`. +- `ROUND(x, n)`, `ABS(x)`, `MAX(a, b)`, `MIN(a, b)` are available. + +### Unsupported features +- No GREATEST/LEAST — use scalar MAX()/MIN(). +- No IF() — use IIF(cond, t, f) or CASE WHEN. +- No regex — use LIKE with %/_ or GLOB with */? +- Booleans: use 1/0 — no TRUE/FALSE literals. +- Quote identifiers with double quotes `"col"` — not brackets [col]. \ No newline at end of file diff --git a/externals/kyuubi-data-agent-engine/src/main/resources/prompts/datasource-trino.md b/externals/kyuubi-data-agent-engine/src/main/resources/prompts/datasource-trino.md new file mode 100644 index 00000000000..bde538c2609 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/resources/prompts/datasource-trino.md @@ -0,0 +1,25 @@ +## Trino SQL guidelines + +- You are connected to a Trino cluster. +- Use double quotes to quote identifiers that contain special characters. +- Trino supports ANSI SQL with extensions. + +### Schema exploration + +- List catalogs: `SHOW CATALOGS` +- List schemas: `SHOW SCHEMAS IN catalog_name` +- List tables: `SHOW TABLES IN catalog_name.schema_name` +- Describe table: `DESCRIBE catalog_name.schema_name.table_name` +- Show columns: `SHOW COLUMNS FROM catalog_name.schema_name.table_name` +- Show create statement: `SHOW CREATE TABLE catalog_name.schema_name.table_name` + +**Important:** Always use fully-qualified table names (`catalog_name.schema_name.table_name`) in all SQL statements. +Do NOT use `USE catalog.schema` — each tool call may run on a different connection, so session state is not preserved between calls. + +### Performance tips + +- For large datasets, avoid `SELECT *` without `LIMIT`. +- **Partition pruning**: always include partition columns in `WHERE` clauses. +- **Sampling**: use `SELECT ... FROM table TABLESAMPLE BERNOULLI(N)` for approximate statistics on large tables. +- **Approximate functions**: use `APPROX_DISTINCT(col)` instead of `COUNT(DISTINCT col)` for cardinality estimation on large tables. +- **EXPLAIN**: run `EXPLAIN sql` before executing complex queries on large tables to check the query plan. diff --git a/externals/kyuubi-data-agent-engine/src/main/scala/org/apache/kyuubi/engine/dataagent/operation/ExecuteStatement.scala b/externals/kyuubi-data-agent-engine/src/main/scala/org/apache/kyuubi/engine/dataagent/operation/ExecuteStatement.scala index 2ea1d8a41aa..04d0defa2dc 100644 --- a/externals/kyuubi-data-agent-engine/src/main/scala/org/apache/kyuubi/engine/dataagent/operation/ExecuteStatement.scala +++ b/externals/kyuubi-data-agent-engine/src/main/scala/org/apache/kyuubi/engine/dataagent/operation/ExecuteStatement.scala @@ -145,7 +145,7 @@ class ExecuteStatement( n.put("requestId", req.requestId()) n.put("id", req.toolCallId()) n.put("name", req.toolName()) - n.put("args", req.toolArgs().toString) + n.set("args", JSON.valueToTree(req.toolArgs())) n.put("riskLevel", req.riskLevel().name()) })) case EventType.AGENT_FINISH => diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/DataSourceFactoryAuthTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/DataSourceFactoryAuthTest.java new file mode 100644 index 00000000000..04873e46b76 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/DataSourceFactoryAuthTest.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.datasource; + +import static org.junit.Assert.*; + +import com.zaxxer.hikari.HikariDataSource; +import java.io.File; +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.Statement; +import javax.sql.DataSource; +import org.junit.After; +import org.junit.Test; + +/** Tests for DataSourceFactory. Uses real SQLite. */ +public class DataSourceFactoryAuthTest { + + private DataSource ds; + private File tmpFile; + + @After + public void tearDown() { + if (ds instanceof HikariDataSource) { + ((HikariDataSource) ds).close(); + } + if (tmpFile != null) { + tmpFile.delete(); + } + } + + @Test + public void testCreateWithUrl() throws Exception { + tmpFile = File.createTempFile("kyuubi-ds-test-", ".db"); + tmpFile.deleteOnExit(); + + ds = DataSourceFactory.create("jdbc:sqlite:" + tmpFile.getAbsolutePath()); + assertNotNull(ds); + + try (Connection conn = ds.getConnection(); + Statement stmt = conn.createStatement()) { + stmt.execute("CREATE TABLE t (id INTEGER)"); + stmt.execute("INSERT INTO t VALUES (1)"); + try (ResultSet rs = stmt.executeQuery("SELECT id FROM t")) { + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1)); + } + } + } + + @Test + public void testCreateWithUserOnly() throws Exception { + tmpFile = File.createTempFile("kyuubi-ds-test-", ".db"); + tmpFile.deleteOnExit(); + + ds = DataSourceFactory.create("jdbc:sqlite:" + tmpFile.getAbsolutePath(), "testuser"); + assertNotNull(ds); + assertTrue(ds instanceof HikariDataSource); + assertEquals("testuser", ((HikariDataSource) ds).getUsername()); + } + + @Test + public void testCreateWithUserAndPassword() throws Exception { + tmpFile = File.createTempFile("kyuubi-ds-test-", ".db"); + tmpFile.deleteOnExit(); + + ds = DataSourceFactory.create("jdbc:sqlite:" + tmpFile.getAbsolutePath(), "user", "pass123"); + assertNotNull(ds); + HikariDataSource hds = (HikariDataSource) ds; + assertEquals("user", hds.getUsername()); + assertEquals("pass123", hds.getPassword()); + } + + @Test + public void testCreateWithNullAndEmptyCredentials() throws Exception { + tmpFile = File.createTempFile("kyuubi-ds-test-", ".db"); + tmpFile.deleteOnExit(); + + // null user/password — should not set username/password on config + ds = DataSourceFactory.create("jdbc:sqlite:" + tmpFile.getAbsolutePath(), null, null); + assertNotNull(ds); + + // empty strings — treated same as null + DataSource ds2 = DataSourceFactory.create("jdbc:sqlite:" + tmpFile.getAbsolutePath(), "", ""); + assertNotNull(ds2); + ((HikariDataSource) ds2).close(); + } + + @Test + public void testMultipleDataSourcesIsolated() throws Exception { + File tmpFile1 = File.createTempFile("kyuubi-ds-test1-", ".db"); + File tmpFile2 = File.createTempFile("kyuubi-ds-test2-", ".db"); + tmpFile1.deleteOnExit(); + tmpFile2.deleteOnExit(); + + DataSource ds1 = DataSourceFactory.create("jdbc:sqlite:" + tmpFile1.getAbsolutePath()); + DataSource ds2 = DataSourceFactory.create("jdbc:sqlite:" + tmpFile2.getAbsolutePath()); + + try (Connection c1 = ds1.getConnection(); + Statement s1 = c1.createStatement()) { + s1.execute("CREATE TABLE t1 (v TEXT)"); + s1.execute("INSERT INTO t1 VALUES ('from-ds1')"); + } + try (Connection c2 = ds2.getConnection(); + Statement s2 = c2.createStatement()) { + s2.execute("CREATE TABLE t2 (v TEXT)"); + s2.execute("INSERT INTO t2 VALUES ('from-ds2')"); + } + + // Verify isolation + try (Connection c1 = ds1.getConnection(); + Statement s1 = c1.createStatement(); + ResultSet rs = s1.executeQuery("SELECT v FROM t1")) { + assertTrue(rs.next()); + assertEquals("from-ds1", rs.getString(1)); + } + + ((HikariDataSource) ds1).close(); + ((HikariDataSource) ds2).close(); + tmpFile1.delete(); + tmpFile2.delete(); + } + + @Test(expected = IllegalArgumentException.class) + public void testNullJdbcUrlThrows() { + DataSourceFactory.create(null); + } + + @Test(expected = IllegalArgumentException.class) + public void testEmptyJdbcUrlThrows() { + DataSourceFactory.create(""); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialectTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialectTest.java new file mode 100644 index 00000000000..e728ff871e7 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialectTest.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.datasource; + +import static org.junit.Assert.*; + +import org.junit.Test; + +public class JdbcDialectTest { + + @Test + public void testSparkViaHive2() { + JdbcDialect d = JdbcDialect.fromUrl("jdbc:hive2://localhost:10009/default"); + assertNotNull(d); + assertEquals("spark", d.datasourceName()); + } + + @Test + public void testSparkViaSpark() { + JdbcDialect d = JdbcDialect.fromUrl("jdbc:spark://localhost:10009/default"); + assertNotNull(d); + assertEquals("spark", d.datasourceName()); + } + + @Test + public void testCaseInsensitive() { + assertNotNull(JdbcDialect.fromUrl("JDBC:HIVE2://localhost:10009")); + assertNotNull(JdbcDialect.fromUrl("JDBC:SPARK://localhost:10009")); + } + + @Test + public void testQuoteIdentifierBacktick() { + JdbcDialect spark = JdbcDialect.fromUrl("jdbc:hive2://localhost:10009"); + assertEquals("`my_table`", spark.quoteIdentifier("my_table")); + assertEquals("` ``inject`` `", spark.quoteIdentifier(" `inject` ")); + } + + @Test + public void testTrino() { + JdbcDialect d = JdbcDialect.fromUrl("jdbc:trino://localhost:9090"); + assertNotNull(d); + assertEquals("trino", d.datasourceName()); + } + + @Test + public void testTrinoCaseInsensitive() { + assertNotNull(JdbcDialect.fromUrl("JDBC:TRINO://localhost:9090")); + } + + @Test + public void testTrinoQuoteIdentifier() { + JdbcDialect trino = JdbcDialect.fromUrl("jdbc:trino://localhost:9090"); + assertEquals("\"my_table\"", trino.quoteIdentifier("my_table")); + assertEquals("\" \"\"inject\"\" \"", trino.quoteIdentifier(" \"inject\" ")); + } + + @Test + public void testSqlite() { + JdbcDialect d = JdbcDialect.fromUrl("jdbc:sqlite:/tmp/test.db"); + assertNotNull(d); + assertEquals("sqlite", d.datasourceName()); + } + + @Test + public void testSqliteCaseInsensitive() { + assertNotNull(JdbcDialect.fromUrl("JDBC:SQLITE:test.db")); + } + + @Test + public void testSqliteQuoteIdentifier() { + JdbcDialect sqlite = JdbcDialect.fromUrl("jdbc:sqlite:test.db"); + assertEquals("\"my_table\"", sqlite.quoteIdentifier("my_table")); + assertEquals("\" \"\"inject\"\" \"", sqlite.quoteIdentifier(" \"inject\" ")); + } + + @Test + public void testMysql() { + JdbcDialect d = JdbcDialect.fromUrl("jdbc:mysql://localhost:3306"); + assertNotNull(d); + assertEquals("mysql", d.datasourceName()); + assertEquals("`my_table`", d.quoteIdentifier("my_table")); + } + + @Test + public void testUnknownReturnsGenericDialectWithName() { + JdbcDialect d = JdbcDialect.fromUrl("jdbc:postgresql://localhost:5432/db"); + assertNotNull(d); + assertTrue(d instanceof GenericDialect); + assertEquals("postgresql", d.datasourceName()); + } + + @Test + public void testGenericDialectQuoteIdentifierUnsupported() { + JdbcDialect d = JdbcDialect.fromUrl("jdbc:clickhouse://localhost:8123"); + assertEquals("clickhouse", d.datasourceName()); + try { + d.quoteIdentifier("col"); + fail("expected UnsupportedOperationException"); + } catch (UnsupportedOperationException expected) { + // ok + } + } + + // --- qualify tests --- + + @Test + public void testMysqlQualifySchemaAndTable() { + JdbcDialect d = JdbcDialect.fromUrl("jdbc:mysql://localhost:3306"); + assertEquals("`mydb`.`users`", d.qualify(TableRef.of("mydb", "users"))); + } + + @Test + public void testMysqlQualifyTableOnly() { + JdbcDialect d = JdbcDialect.fromUrl("jdbc:mysql://localhost:3306"); + assertEquals("`users`", d.qualify(TableRef.of("users"))); + } + + @Test + public void testTrinoQualifyFull() { + JdbcDialect d = JdbcDialect.fromUrl("jdbc:trino://localhost:9090"); + assertEquals( + "\"hive\".\"sales\".\"orders\"", d.qualify(TableRef.of("hive", "sales", "orders"))); + } + + @Test + public void testTrinoQualifySchemaAndTable() { + JdbcDialect d = JdbcDialect.fromUrl("jdbc:trino://localhost:9090"); + assertEquals("\"sales\".\"orders\"", d.qualify(TableRef.of("sales", "orders"))); + } + + @Test + public void testSparkQualifyFull() { + JdbcDialect d = JdbcDialect.fromUrl("jdbc:hive2://localhost:10009"); + assertEquals( + "`spark_catalog`.`default`.`t`", d.qualify(TableRef.of("spark_catalog", "default", "t"))); + } + + @Test + public void testSqliteQualifyTableOnly() { + JdbcDialect d = JdbcDialect.fromUrl("jdbc:sqlite:test.db"); + assertEquals("\"t\"", d.qualify(TableRef.of("t"))); + } + + @Test(expected = UnsupportedOperationException.class) + public void testGenericQualifyThrows() { + JdbcDialect d = JdbcDialect.fromUrl("jdbc:clickhouse://localhost:8123"); + d.qualify(TableRef.of("db", "t")); + } + + @Test + public void testQualifyWithSpecialCharacters() { + JdbcDialect mysql = JdbcDialect.fromUrl("jdbc:mysql://localhost:3306"); + assertEquals("`my``db`.`user``s`", mysql.qualify(TableRef.of("my`db", "user`s"))); + + JdbcDialect trino = JdbcDialect.fromUrl("jdbc:trino://localhost:9090"); + assertEquals("\"my\"\"schema\".\"tab\"", trino.qualify(TableRef.of("my\"schema", "tab"))); + } + + @Test + public void testNullReturnsNull() { + assertNull(JdbcDialect.fromUrl(null)); + } + + @Test + public void testNonJdbcReturnsNull() { + assertNull(JdbcDialect.fromUrl("postgresql://localhost")); + } + + @Test + public void testMalformedReturnsNull() { + assertNull(JdbcDialect.fromUrl("jdbc:")); + assertNull(JdbcDialect.fromUrl("jdbc::")); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/TableRefTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/TableRefTest.java new file mode 100644 index 00000000000..c575ad1f386 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/TableRefTest.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.datasource; + +import static org.junit.Assert.*; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.Test; + +public class TableRefTest { + + private static final ObjectMapper JSON = new ObjectMapper(); + + // --- Factory methods --- + + @Test + public void testOfTableOnly() { + TableRef ref = TableRef.of("users"); + assertNull(ref.getCatalog()); + assertNull(ref.getSchema()); + assertEquals("users", ref.getTable()); + } + + @Test + public void testOfSchemaAndTable() { + TableRef ref = TableRef.of("mydb", "users"); + assertNull(ref.getCatalog()); + assertEquals("mydb", ref.getSchema()); + assertEquals("users", ref.getTable()); + } + + @Test + public void testOfCatalogSchemaTable() { + TableRef ref = TableRef.of("hive", "sales", "orders"); + assertEquals("hive", ref.getCatalog()); + assertEquals("sales", ref.getSchema()); + assertEquals("orders", ref.getTable()); + } + + @Test + public void testNullCatalogAndSchemaAllowed() { + TableRef ref = TableRef.of(null, null, "t"); + assertNull(ref.getCatalog()); + assertNull(ref.getSchema()); + assertEquals("t", ref.getTable()); + } + + // --- Validation --- + + @Test(expected = IllegalArgumentException.class) + public void testNullTableThrows() { + TableRef.of(null); + } + + @Test(expected = IllegalArgumentException.class) + public void testEmptyTableThrows() { + TableRef.of(""); + } + + @Test(expected = IllegalArgumentException.class) + public void testNullTableInThreeArgThrows() { + TableRef.of("cat", "sch", null); + } + + @Test + public void testEmptyCatalogNormalizedToNull() { + TableRef ref = TableRef.of("", "mydb", "users"); + assertNull(ref.getCatalog()); + assertEquals("mydb", ref.getSchema()); + } + + @Test + public void testEmptySchemaNormalizedToNull() { + TableRef ref = TableRef.of("", "users"); + assertNull(ref.getSchema()); + } + + @Test + public void testEmptyCatalogAndSchemaEqualsNull() { + assertEquals(TableRef.of(null, null, "t"), TableRef.of("", "", "t")); + } + + // --- equals / hashCode --- + + @Test + public void testEqualsSameFields() { + TableRef a = TableRef.of("hive", "sales", "orders"); + TableRef b = TableRef.of("hive", "sales", "orders"); + assertEquals(a, b); + assertEquals(a.hashCode(), b.hashCode()); + } + + @Test + public void testEqualsDifferentFields() { + assertNotEquals(TableRef.of("a", "b", "c"), TableRef.of("x", "b", "c")); + assertNotEquals(TableRef.of("a", "b", "c"), TableRef.of("a", "x", "c")); + assertNotEquals(TableRef.of("a", "b", "c"), TableRef.of("a", "b", "x")); + } + + @Test + public void testEqualsNullFields() { + assertEquals(TableRef.of("t"), TableRef.of("t")); + assertNotEquals(TableRef.of("t"), TableRef.of("s", "t")); + } + + // --- JSON deserialization --- + + @Test + public void testDeserializeTableOnly() throws Exception { + TableRef ref = JSON.readValue("{\"table\":\"users\"}", TableRef.class); + assertNull(ref.getCatalog()); + assertNull(ref.getSchema()); + assertEquals("users", ref.getTable()); + } + + @Test + public void testDeserializeSchemaAndTable() throws Exception { + TableRef ref = JSON.readValue("{\"schema\":\"mydb\",\"table\":\"users\"}", TableRef.class); + assertNull(ref.getCatalog()); + assertEquals("mydb", ref.getSchema()); + assertEquals("users", ref.getTable()); + } + + @Test + public void testDeserializeFull() throws Exception { + TableRef ref = + JSON.readValue( + "{\"catalog\":\"hive\",\"schema\":\"sales\",\"table\":\"orders\"}", TableRef.class); + assertEquals("hive", ref.getCatalog()); + assertEquals("sales", ref.getSchema()); + assertEquals("orders", ref.getTable()); + } + + @Test(expected = Exception.class) + public void testDeserializeMissingTableThrows() throws Exception { + JSON.readValue("{\"schema\":\"mydb\"}", TableRef.class); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DataSourceFactoryTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DataSourceFactoryTest.java new file mode 100644 index 00000000000..90fbee79238 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DataSourceFactoryTest.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.mysql; + +import static org.junit.Assert.*; + +import com.zaxxer.hikari.HikariDataSource; +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.Statement; +import javax.sql.DataSource; +import org.apache.kyuubi.engine.dataagent.datasource.DataSourceFactory; +import org.junit.Test; + +/** Integration tests for {@link DataSourceFactory} against a real MySQL instance. */ +public class DataSourceFactoryTest extends WithMySQLContainer { + + @Test + public void testCreateWithUserPassword() throws Exception { + DataSource ds = + DataSourceFactory.create(mysql.getJdbcUrl(), mysql.getUsername(), mysql.getPassword()); + try (Connection conn = ds.getConnection(); + Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT 1")) { + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1)); + } finally { + if (ds instanceof HikariDataSource) { + ((HikariDataSource) ds).close(); + } + } + } + + @Test + public void testCreateWithWrongPassword() { + DataSource ds = DataSourceFactory.create(mysql.getJdbcUrl(), "root", "wrong_password"); + try (Connection conn = ds.getConnection()) { + fail("Expected exception for wrong password"); + } catch (Exception e) { + // HikariCP wraps MySQL auth error — just verify we get an exception. + // The exception chain should contain the MySQL access-denied message. + String fullMsg = getFullExceptionMessage(e); + assertTrue( + "Expected auth error, got: " + fullMsg, + fullMsg.contains("Access denied") || fullMsg.contains("password")); + } finally { + if (ds instanceof HikariDataSource) { + ((HikariDataSource) ds).close(); + } + } + } + + private static String getFullExceptionMessage(Throwable t) { + StringBuilder sb = new StringBuilder(); + while (t != null) { + if (t.getMessage() != null) { + sb.append(t.getMessage()).append(" | "); + } + t = t.getCause(); + } + return sb.toString(); + } + + @Test + public void testConnectionPoolReuse() throws Exception { + DataSource ds = + DataSourceFactory.create(mysql.getJdbcUrl(), mysql.getUsername(), mysql.getPassword()); + try { + // Acquire and release 5 connections sequentially; pool should handle this fine. + for (int i = 0; i < 5; i++) { + try (Connection conn = ds.getConnection(); + Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT " + i)) { + assertTrue(rs.next()); + assertEquals(i, rs.getInt(1)); + } + } + } finally { + if (ds instanceof HikariDataSource) { + ((HikariDataSource) ds).close(); + } + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DialectTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DialectTest.java new file mode 100644 index 00000000000..4e713b9cca3 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DialectTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.mysql; + +import static org.junit.Assert.*; + +import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect; +import org.apache.kyuubi.engine.dataagent.datasource.MysqlDialect; +import org.apache.kyuubi.engine.dataagent.prompt.SystemPromptBuilder; +import org.apache.kyuubi.engine.dataagent.tool.sql.RunSelectQueryTool; +import org.apache.kyuubi.engine.dataagent.tool.sql.SqlQueryArgs; +import org.junit.BeforeClass; +import org.junit.Test; + +/** Integration tests for {@link MysqlDialect} end-to-end with a real MySQL instance. */ +public class DialectTest extends WithMySQLContainer { + + private static RunSelectQueryTool selectTool; + + @BeforeClass + public static void setUp() { + selectTool = new RunSelectQueryTool(dataSource, 30); + } + + @Test + public void testDialectFromUrl() { + JdbcDialect dialect = JdbcDialect.fromUrl(mysql.getJdbcUrl()); + assertNotNull(dialect); + assertTrue(dialect instanceof MysqlDialect); + assertEquals("mysql", dialect.datasourceName()); + } + + @Test + public void testBacktickQuotingWithReservedWord() { + JdbcDialect dialect = JdbcDialect.fromUrl(mysql.getJdbcUrl()); + + // Create a table with a column named after a MySQL reserved word + String quotedTable = dialect.quoteIdentifier("order"); + String quotedCol = dialect.quoteIdentifier("select"); + + exec("DROP TABLE IF EXISTS " + quotedTable); + exec( + "CREATE TABLE " + + quotedTable + + " (" + + "id INT PRIMARY KEY, " + + quotedCol + + " VARCHAR(255)" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4"); + exec("INSERT INTO " + quotedTable + " VALUES (1, 'value1')"); + + SqlQueryArgs args = new SqlQueryArgs(); + args.sql = "SELECT " + quotedCol + " FROM " + quotedTable + " WHERE id = 1"; + String result = selectTool.execute(args); + assertFalse(result.startsWith("Error:")); + assertTrue(result.contains("value1")); + + exec("DROP TABLE " + quotedTable); + } + + @Test + public void testBacktickEscaping() { + JdbcDialect dialect = JdbcDialect.fromUrl(mysql.getJdbcUrl()); + // Identifier containing a backtick should be escaped by doubling + String quoted = dialect.quoteIdentifier("col`name"); + assertEquals("`col``name`", quoted); + } + + @Test + public void testPromptBuilderWithMySQLDatasource() { + JdbcDialect dialect = JdbcDialect.fromUrl(mysql.getJdbcUrl()); + String prompt = SystemPromptBuilder.create().datasource(dialect.datasourceName()).build(); + assertTrue("Prompt should mention mysql dialect", prompt.toLowerCase().contains("mysql")); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/RunMutationQueryTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/RunMutationQueryTest.java new file mode 100644 index 00000000000..2401d53d1cf --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/RunMutationQueryTest.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.mysql; + +import static org.junit.Assert.*; + +import org.junit.Before; +import org.junit.Test; + +/** + * Integration tests for run_mutation_query tool against a real MySQL instance. All calls go through + * {@link org.apache.kyuubi.engine.dataagent.tool.ToolRegistry}, the same path as the LLM runtime. + */ +public class RunMutationQueryTest extends WithMySQLContainer { + + @Before + public void setUp() { + exec("DROP TABLE IF EXISTS mutation_test"); + exec( + "CREATE TABLE mutation_test (" + + "id BIGINT PRIMARY KEY, " + + "name VARCHAR(255), " + + "age INT" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4"); + exec("INSERT INTO mutation_test VALUES (1, 'Alice', 30), (2, 'Bob', 25), (3, 'Carol', 28)"); + } + + @Test + public void testInsert() { + String result = mutate("INSERT INTO mutation_test VALUES (4, 'Dave', 35)"); + assertTrue(result.contains("1 row(s) affected")); + assertEquals("4", queryScalar("SELECT COUNT(*) FROM mutation_test")); + } + + @Test + public void testBatchInsert() { + String result = + mutate("INSERT INTO mutation_test VALUES (10, 'X', 1), (11, 'Y', 2), (12, 'Z', 3)"); + assertTrue(result.contains("3 row(s) affected")); + assertEquals("6", queryScalar("SELECT COUNT(*) FROM mutation_test")); + } + + @Test + public void testUpdate() { + String result = mutate("UPDATE mutation_test SET age = 99 WHERE age < 30"); + assertTrue(result.contains("2 row(s) affected")); + assertEquals("2", queryScalar("SELECT COUNT(*) FROM mutation_test WHERE age = 99")); + } + + @Test + public void testDelete() { + String result = mutate("DELETE FROM mutation_test WHERE id = 1"); + assertTrue(result.contains("1 row(s) affected")); + assertEquals("2", queryScalar("SELECT COUNT(*) FROM mutation_test")); + } + + @Test + public void testCreateAndDropTable() { + String createResult = + mutate("CREATE TABLE tmp_ddl_test (id INT PRIMARY KEY, v TEXT) ENGINE=InnoDB"); + assertTrue(createResult.contains("executed successfully")); + + String dropResult = mutate("DROP TABLE tmp_ddl_test"); + assertTrue(dropResult.contains("executed successfully")); + } + + @Test + public void testCreateIndex() { + String result = mutate("CREATE INDEX idx_name ON mutation_test (name)"); + assertTrue(result.contains("executed successfully")); + } + + @Test + public void testDuplicateKeyError() { + String result = mutate("INSERT INTO mutation_test VALUES (1, 'Duplicate', 40)"); + assertTrue(result.startsWith("Error:")); + assertTrue(result.contains("Duplicate entry")); + } + + @Test + public void testSyntaxError() { + String result = mutate("INSRET INTO mutation_test VALUES (99, 'typo', 1)"); + assertTrue(result.startsWith("Error:")); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/RunSelectQueryTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/RunSelectQueryTest.java new file mode 100644 index 00000000000..d74b399649d --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/RunSelectQueryTest.java @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.mysql; + +import static org.junit.Assert.*; + +import org.junit.BeforeClass; +import org.junit.Test; + +/** + * Integration tests for run_select_query tool against a real MySQL instance. All calls go through + * {@link org.apache.kyuubi.engine.dataagent.tool.ToolRegistry}, the same path as the LLM runtime. + */ +public class RunSelectQueryTest extends WithMySQLContainer { + + @BeforeClass + public static void setUp() { + exec( + "CREATE TABLE IF NOT EXISTS select_test (" + + "id BIGINT PRIMARY KEY, " + + "name VARCHAR(255), " + + "price DECIMAL(27,9), " + + "birth_date DATE, " + + "created_at DATETIME, " + + "updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, " + + "active BOOLEAN, " + + "score DOUBLE, " + + "status ENUM('ACTIVE','INACTIVE','PENDING')" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4"); + + exec( + "INSERT INTO select_test (id, name, price, birth_date, created_at, active, score, status) VALUES " + + "(1, 'Alice', 99.123456789, '1990-01-15', '2024-06-01 10:30:00', true, 88.5, 'ACTIVE'), " + + "(2, 'Bob', 0.000000001, '1985-12-31', '2024-06-02 14:00:00', false, 72.3, 'INACTIVE'), " + + "(3, NULL, NULL, NULL, NULL, NULL, NULL, 'PENDING')"); + } + + @Test + public void testSimpleSelect() { + String result = select("SELECT id, name FROM select_test ORDER BY id LIMIT 10"); + assertFalse(result.startsWith("Error:")); + assertTrue(result.contains("Alice")); + assertTrue(result.contains("Bob")); + assertTrue(result.contains("[3 row(s) returned]")); + } + + @Test + public void testColumnAliasRenderedInHeader() { + // Verifies getColumnLabel (not getColumnName) is used so AS aliases show in the markdown + // header. MySQL is one of the drivers that distinguishes the two — getColumnName returns + // "id"/"name" here while getColumnLabel returns the AS alias. SQLite does not distinguish, + // so this guard lives on the MySQL integration path. + String result = select("SELECT id AS user_id, name AS user_name FROM select_test LIMIT 1"); + assertFalse(result.startsWith("Error:")); + assertTrue( + "header should contain alias 'user_id', got:\n" + result, result.contains("user_id")); + assertTrue( + "header should contain alias 'user_name', got:\n" + result, result.contains("user_name")); + assertFalse( + "header should NOT leak the base column name 'id', got:\n" + result, + result.contains("| id |") || result.contains("| id ")); + } + + @Test + public void testMySQLTypes() { + String result = select("SELECT * FROM select_test WHERE id = 1"); + assertFalse(result.startsWith("Error:")); + assertTrue("DECIMAL value", result.contains("99.123456789")); + assertTrue("DATE value", result.contains("1990-01-15")); + assertTrue("DATETIME value", result.contains("2024-06-01")); + assertTrue("DOUBLE value", result.contains("88.5")); + assertTrue("ENUM value", result.contains("ACTIVE")); + } + + @Test + public void testShowDatabases() { + String result = select("SHOW DATABASES"); + assertFalse(result.startsWith("Error:")); + assertTrue(result.contains("test_db")); + } + + @Test + public void testShowTables() { + String result = select("SHOW TABLES"); + assertFalse(result.startsWith("Error:")); + assertTrue(result.contains("select_test")); + } + + @Test + public void testDescribeTable() { + String result = select("DESCRIBE select_test"); + assertFalse(result.startsWith("Error:")); + assertTrue(result.contains("id")); + assertTrue(result.contains("name")); + assertTrue(result.contains("bigint")); + } + + @Test + public void testShowCreateTable() { + String result = select("SHOW CREATE TABLE select_test"); + assertFalse(result.startsWith("Error:")); + assertTrue(result.contains("`select_test`") || result.contains("select_test")); + } + + @Test + public void testCteQuery() { + String result = + select( + "WITH active_users AS (SELECT id, name FROM select_test WHERE active = true) " + + "SELECT * FROM active_users"); + assertFalse(result.startsWith("Error:")); + assertTrue(result.contains("Alice")); + assertFalse(result.contains("Bob")); + } + + @Test + public void testChineseData() { + exec( + "CREATE TABLE IF NOT EXISTS chinese_test (id INT PRIMARY KEY, name VARCHAR(255)) " + + "ENGINE=InnoDB DEFAULT CHARSET=utf8mb4"); + exec("INSERT IGNORE INTO chinese_test VALUES (1, '张三'), (2, '李四')"); + + String result = select("SELECT * FROM chinese_test ORDER BY id"); + assertFalse(result.startsWith("Error:")); + assertTrue(result.contains("张三")); + assertTrue(result.contains("李四")); + } + + @Test + public void testNullRenderedAsNULL() { + String result = select("SELECT id, name, price FROM select_test WHERE id = 3"); + assertFalse(result.startsWith("Error:")); + assertTrue(result.contains("NULL")); + } + + @Test + public void testRejectsMutationStatement() { + String result = select("INSERT INTO select_test (id, name) VALUES (999, 'hacker')"); + assertTrue(result.startsWith("Error:")); + assertTrue(result.contains("read-only")); + assertTrue(result.contains("run_mutation_query")); + } + + @Test + public void testNonexistentTableReturnsUsefulError() { + String result = select("SELECT * FROM no_such_table LIMIT 1"); + assertTrue(result.startsWith("Error:")); + assertTrue(result.contains("no_such_table")); + } + + @Test + public void testPipeEscapedInMarkdownOutput() { + exec( + "CREATE TABLE IF NOT EXISTS pipe_test (id INT PRIMARY KEY, val TEXT) " + + "ENGINE=InnoDB DEFAULT CHARSET=utf8mb4"); + exec("INSERT IGNORE INTO pipe_test VALUES (1, 'a|b|c')"); + + String result = select("SELECT val FROM pipe_test WHERE id = 1"); + assertFalse(result.startsWith("Error:")); + assertTrue("Pipe should be escaped for markdown table", result.contains("a\\|b\\|c")); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/ToolExecutionTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/ToolExecutionTest.java new file mode 100644 index 00000000000..b887244a41b --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/ToolExecutionTest.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.mysql; + +import static org.junit.Assert.*; + +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; +import org.apache.kyuubi.engine.dataagent.tool.sql.RunSelectQueryTool; +import org.junit.BeforeClass; +import org.junit.Test; + +/** + * Integration tests for ToolRegistry-level behavior and cross-tool workflows against a real MySQL + * instance. Individual tool behavior is tested in {@link RunSelectQueryTest} and {@link + * RunMutationQueryTest}; this class covers registry dispatch, timeout, error handling, and + * multi-step agent workflows. + */ +public class ToolExecutionTest extends WithMySQLContainer { + + @BeforeClass + public static void setUp() { + exec( + "CREATE TABLE IF NOT EXISTS workflow_test (" + + "id BIGINT PRIMARY KEY, " + + "name VARCHAR(255), " + + "department VARCHAR(100)" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4"); + exec( + "INSERT INTO workflow_test VALUES " + + "(1, 'Alice', 'Engineering'), " + + "(2, 'Bob', 'Marketing')"); + } + + // ---- Agent workflow: explore → analyze → mutate → verify ---- + + @Test + public void testAgentSchemaExplorationWorkflow() { + // Step 1: agent explores available tables + String tables = select("SHOW TABLES"); + assertFalse(tables.startsWith("Error:")); + assertTrue(tables.contains("workflow_test")); + + // Step 2: agent inspects table structure + String schema = select("DESCRIBE workflow_test"); + assertFalse(schema.startsWith("Error:")); + assertTrue(schema.contains("department")); + + // Step 3: agent queries data + String data = select("SELECT * FROM workflow_test ORDER BY id LIMIT 10"); + assertFalse(data.startsWith("Error:")); + assertTrue(data.contains("Alice")); + assertTrue(data.contains("[2 row(s) returned]")); + } + + @Test + public void testAgentMutateThenVerify() { + // Agent inserts a row via mutation tool + String insertResult = + mutate("INSERT INTO workflow_test VALUES (100, 'NewHire', 'Engineering')"); + assertTrue(insertResult.contains("1 row(s) affected")); + + // Agent reads it back via select tool to confirm + String readBack = select("SELECT name FROM workflow_test WHERE id = 100"); + assertFalse(readBack.startsWith("Error:")); + assertTrue(readBack.contains("NewHire")); + + // Cleanup + mutate("DELETE FROM workflow_test WHERE id = 100"); + } + + // ---- Registry-level error handling ---- + + @Test + public void testUnknownTool() { + String result = registry.executeTool("nonexistent_tool", "{\"sql\": \"SELECT 1\"}"); + assertTrue(result.contains("Error:")); + assertTrue(result.contains("unknown tool")); + } + + @Test + public void testMalformedJson() { + String result = registry.executeTool("run_select_query", "not valid json"); + assertTrue(result.startsWith("Error")); + } + + // ---- Timeout enforcement ---- + + @Test + public void testQueryTimeoutKillsSlowQuery() { + // Create a separate registry with a short timeout to test timeout behavior + ToolRegistry shortTimeoutRegistry = new ToolRegistry(3); + shortTimeoutRegistry.register(new RunSelectQueryTool(dataSource, 2)); + + long start = System.currentTimeMillis(); + String result = + shortTimeoutRegistry.executeTool("run_select_query", "{\"sql\": \"SELECT SLEEP(60)\"}"); + long elapsed = System.currentTimeMillis() - start; + + // Must not block for 60 seconds + assertTrue("Should return within 10s, took " + elapsed + "ms", elapsed < 10_000); + // Either JDBC timeout or ToolRegistry timeout fires + assertTrue( + "Expected timeout or interrupted sleep, got: " + result, + result.contains("timed out") || result.startsWith("Error:") || result.contains("1")); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/WithMySQLContainer.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/WithMySQLContainer.java new file mode 100644 index 00000000000..91c66de03ab --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/WithMySQLContainer.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.mysql; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.zaxxer.hikari.HikariConfig; +import com.zaxxer.hikari.HikariDataSource; +import java.sql.Connection; +import java.sql.Statement; +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; +import org.apache.kyuubi.engine.dataagent.tool.sql.RunMutationQueryTool; +import org.apache.kyuubi.engine.dataagent.tool.sql.RunSelectQueryTool; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.testcontainers.mysql.MySQLContainer; +import org.testcontainers.utility.DockerImageName; + +/** + * Shared base class for MySQL integration tests. Starts a MySQL 8.0.32 container once per test + * class and provides a pooled DataSource. + */ +public abstract class WithMySQLContainer { + + private static final String MYSQL_IMAGE = + System.getProperty("kyuubi.test.mysql.image", "mysql:8.0.32"); + + private static final long TOOL_CALL_TIMEOUT_SECONDS = 30; + private static final int QUERY_TIMEOUT_SECONDS = 10; + private static final ObjectMapper JSON = new ObjectMapper(); + + protected static MySQLContainer mysql; + protected static HikariDataSource dataSource; + protected static ToolRegistry registry; + + @BeforeClass + public static void startContainer() { + DockerImageName imageName = + DockerImageName.parse(MYSQL_IMAGE).asCompatibleSubstituteFor("mysql"); + mysql = + new MySQLContainer(imageName) + .withUsername("root") + .withPassword("kyuubi") + .withDatabaseName("test_db") + .withEnv("MYSQL_ROOT_PASSWORD", "kyuubi"); + mysql.start(); + + HikariConfig config = new HikariConfig(); + config.setJdbcUrl(mysql.getJdbcUrl()); + config.setUsername(mysql.getUsername()); + config.setPassword(mysql.getPassword()); + config.setMaximumPoolSize(5); + config.setMinimumIdle(1); + config.setPoolName("kyuubi-mysql-it"); + dataSource = new HikariDataSource(config); + + registry = new ToolRegistry(TOOL_CALL_TIMEOUT_SECONDS); + registry.register(new RunSelectQueryTool(dataSource, QUERY_TIMEOUT_SECONDS)); + registry.register(new RunMutationQueryTool(dataSource, QUERY_TIMEOUT_SECONDS)); + } + + @AfterClass + public static void stopContainer() { + if (registry != null) { + registry.close(); + } + if (dataSource != null) { + dataSource.close(); + } + if (mysql != null) { + mysql.stop(); + } + } + + /** Execute a SQL statement that does not return a result set. */ + protected static void exec(String sql) { + try (Connection conn = dataSource.getConnection(); + Statement stmt = conn.createStatement()) { + stmt.execute(sql); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + /** + * Invoke a tool through the ToolRegistry, same as the LLM runtime does. Builds the JSON args + * envelope automatically. + */ + protected static String callTool(String toolName, String sql) { + ObjectNode args = JSON.createObjectNode(); + args.put("sql", sql); + try { + return registry.executeTool(toolName, JSON.writeValueAsString(args)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + /** Shorthand for {@code callTool("run_select_query", sql)}. */ + protected static String select(String sql) { + return callTool("run_select_query", sql); + } + + /** Shorthand for {@code callTool("run_mutation_query", sql)}. */ + protected static String mutate(String sql) { + return callTool("run_mutation_query", sql); + } + + /** Execute a SQL statement and return the result of the first column of the first row. */ + protected static String queryScalar(String sql) { + try (Connection conn = dataSource.getConnection(); + Statement stmt = conn.createStatement(); + java.sql.ResultSet rs = stmt.executeQuery(sql)) { + if (rs.next()) { + return rs.getString(1); + } + return null; + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/prompt/SystemPromptBuilderTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/prompt/SystemPromptBuilderTest.java new file mode 100644 index 00000000000..f3cfa03dc19 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/prompt/SystemPromptBuilderTest.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.prompt; + +import static org.junit.Assert.*; + +import java.time.LocalDate; +import org.junit.Test; + +public class SystemPromptBuilderTest { + + @Test + public void testDefaultBuildContainsBaseAndDate() { + String prompt = SystemPromptBuilder.create().build(); + assertTrue(prompt.contains("data analysis agent")); + assertTrue(prompt.contains(LocalDate.now().toString())); + } + + @Test + public void testPlaceholdersRemovedByDefault() { + String prompt = SystemPromptBuilder.create().build(); + assertFalse(prompt.contains("{{tool_descriptions}}")); + } + + @Test + public void testToolDescriptionsSubstituted() { + String tools = "- `sql_query`: Execute a SQL statement against the database."; + String prompt = SystemPromptBuilder.create().toolDescriptions(tools).build(); + assertTrue(prompt.contains("sql_query")); + assertTrue(prompt.contains("Execute a SQL statement")); + assertFalse(prompt.contains("{{tool_descriptions}}")); + } + + @Test + public void testDatasourceSqlite() { + String prompt = SystemPromptBuilder.create().datasource("sqlite").build(); + assertTrue(prompt.contains("SQLite SQL compatibility")); + assertTrue(prompt.contains("JULIANDAY")); + } + + @Test + public void testDatasourceSpark() { + String prompt = SystemPromptBuilder.create().datasource("spark").build(); + assertTrue(prompt.contains("data analysis agent")); + assertTrue(prompt.contains("Spark SQL")); + } + + @Test + public void testWithCustomBase() { + String prompt = SystemPromptBuilder.create().base("You are a helper.").build(); + assertTrue(prompt.startsWith("You are a helper.")); + assertTrue(prompt.contains(LocalDate.now().toString())); + } + + @Test + public void testWithSection() { + String prompt = SystemPromptBuilder.create().section("Only query public schema.").build(); + assertTrue(prompt.contains("Only query public schema.")); + } + + @Test + public void testFullComposition() { + String prompt = + SystemPromptBuilder.create() + .toolDescriptions("- `sql_query`: Execute SQL.") + .datasource("spark") + .section("Limit all queries to 1000 rows.") + .build(); + assertTrue(prompt.contains("sql_query")); + assertTrue(prompt.contains("Spark SQL")); + assertTrue(prompt.contains("Limit all queries to 1000 rows.")); + assertTrue(prompt.contains(LocalDate.now().toString())); + } + + @Test + public void testDatasourceReplacesNotAppends() { + String prompt = SystemPromptBuilder.create().datasource("sqlite").datasource("spark").build(); + assertTrue("Last datasource wins", prompt.contains("Spark SQL")); + assertFalse("Previous datasource replaced", prompt.contains("SQLite SQL compatibility")); + } + + @Test + public void testUnknownDatasourceUsesGenericDialectSection() { + String prompt = SystemPromptBuilder.create().datasource("postgresql").build(); + assertTrue(prompt.contains("Current SQL dialect: postgresql")); + assertTrue(prompt.contains("postgresql SQL syntax")); + } + + @Test + public void testDatasourceClickhouseUsesGenericDialectSection() { + String prompt = SystemPromptBuilder.create().datasource("clickhouse").build(); + assertTrue(prompt.contains("Current SQL dialect: clickhouse")); + } + + @Test + public void testDatasourceMysql() { + String prompt = SystemPromptBuilder.create().datasource("mysql").build(); + assertTrue(prompt.contains("MySQL")); + } + + @Test(expected = IllegalArgumentException.class) + public void testLoadResourceThrowsOnMissing() { + SystemPromptBuilder.loadResource("nonexistent-resource"); + } + + @Test + public void testNullsIgnored() { + String plain = SystemPromptBuilder.create().build(); + assertEquals(plain, SystemPromptBuilder.create().datasource(null).build()); + assertEquals(plain, SystemPromptBuilder.create().toolDescriptions(null).build()); + assertEquals(plain, SystemPromptBuilder.create().section(null).build()); + assertEquals(plain, SystemPromptBuilder.create().section("").build()); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistryThreadSafetyTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistryThreadSafetyTest.java new file mode 100644 index 00000000000..c3ceb1a4dc0 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistryThreadSafetyTest.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool; + +import static org.junit.Assert.*; + +import com.openai.models.ChatModel; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.Test; + +/** Thread safety tests for ToolRegistry. Uses real tool implementations. */ +public class ToolRegistryThreadSafetyTest { + + @Test + public void testConcurrentRegisterAndAccess() throws Exception { + ToolRegistry registry = new ToolRegistry(30); + int numThreads = 8; + CyclicBarrier barrier = new CyclicBarrier(numThreads); + AtomicInteger errors = new AtomicInteger(0); + ExecutorService pool = Executors.newFixedThreadPool(numThreads); + + // Half threads register, half threads read + for (int i = 0; i < numThreads; i++) { + final int idx = i; + pool.submit( + () -> { + try { + barrier.await(); + if (idx % 2 == 0) { + // Register a tool + registry.register( + new AgentTool() { + @Override + public String name() { + return "tool_" + idx; + } + + @Override + public String description() { + return "test tool " + idx; + } + + @Override + public Class argsType() { + return DummyArgs.class; + } + + @Override + public String execute(DummyArgs args) { + return "result_" + idx; + } + }); + } else { + // Read — may see partial state, but should not throw + try { + registry.isEmpty(); + ChatCompletionCreateParams.Builder builder = + ChatCompletionCreateParams.builder() + .model(ChatModel.GPT_4O) + .addUserMessage("test"); + registry.addToolsTo(builder); + } catch (Exception e) { + errors.incrementAndGet(); + } + } + } catch (Exception e) { + errors.incrementAndGet(); + } + }); + } + + pool.shutdown(); + assertTrue(pool.awaitTermination(10, TimeUnit.SECONDS)); + assertEquals("Should have no errors from concurrent access", 0, errors.get()); + assertFalse(registry.isEmpty()); + } + + @Test + public void testConcurrentExecuteWhileRegistering() throws Exception { + ToolRegistry registry = new ToolRegistry(30); + + // Pre-register a tool + registry.register( + new AgentTool() { + @Override + public String name() { + return "existing_tool"; + } + + @Override + public String description() { + return "existing"; + } + + @Override + public Class argsType() { + return DummyArgs.class; + } + + @Override + public String execute(DummyArgs args) { + return "existing_result"; + } + }); + + int numThreads = 8; + CyclicBarrier barrier = new CyclicBarrier(numThreads); + AtomicInteger errors = new AtomicInteger(0); + ExecutorService pool = Executors.newFixedThreadPool(numThreads); + + for (int i = 0; i < numThreads; i++) { + final int idx = i; + pool.submit( + () -> { + try { + barrier.await(); + for (int j = 0; j < 100; j++) { + if (idx % 2 == 0) { + String result = registry.executeTool("existing_tool", "{}"); + if (!result.equals("existing_result")) { + errors.incrementAndGet(); + } + } else { + registry.register( + new AgentTool() { + @Override + public String name() { + return "dynamic_" + idx + "_" + Thread.currentThread().getId(); + } + + @Override + public String description() { + return "dynamic"; + } + + @Override + public Class argsType() { + return DummyArgs.class; + } + + @Override + public String execute(DummyArgs args) { + return "dynamic"; + } + }); + } + } + } catch (Exception e) { + errors.incrementAndGet(); + } + }); + } + + pool.shutdown(); + assertTrue(pool.awaitTermination(10, TimeUnit.SECONDS)); + assertEquals(0, errors.get()); + } + + /** Minimal args class for testing. */ + public static class DummyArgs {} +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolSchemaGeneratorTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolSchemaGeneratorTest.java new file mode 100644 index 00000000000..a7619940a60 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolSchemaGeneratorTest.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool; + +import static org.junit.Assert.*; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; +import java.util.List; +import java.util.Map; +import org.apache.kyuubi.engine.dataagent.tool.sql.SqlQueryArgs; +import org.junit.Test; + +public class ToolSchemaGeneratorTest { + + // --- Real args classes --- + + @Test + public void testSqlQueryArgsSchema() { + Map schema = ToolSchemaGenerator.generateSchema(SqlQueryArgs.class); + assertEquals("object", schema.get("type")); + + @SuppressWarnings("unchecked") + Map props = (Map) schema.get("properties"); + assertNotNull(props); + assertTrue(props.containsKey("sql")); + + @SuppressWarnings("unchecked") + Map sqlProp = (Map) props.get("sql"); + assertEquals("string", sqlProp.get("type")); + assertNotNull("sql should have a description", sqlProp.get("description")); + + @SuppressWarnings("unchecked") + List required = (List) schema.get("required"); + assertNotNull(required); + assertTrue(required.contains("sql")); + } + + // --- Synthetic test classes to verify type mapping --- + + public static class AllTypesArgs { + @JsonProperty(required = true) + @JsonPropertyDescription("a string field") + public String stringField; + + @JsonPropertyDescription("an int field") + public int intField; + + @JsonPropertyDescription("an Integer field") + public Integer integerField; + + @JsonPropertyDescription("a long field") + public long longField; + + @JsonPropertyDescription("a double field") + public double doubleField; + + @JsonPropertyDescription("a float field") + public float floatField; + + @JsonPropertyDescription("a boolean field") + public boolean booleanField; + + @JsonPropertyDescription("a Boolean field") + public Boolean booleanWrapperField; + } + + @Test + public void testAllPrimitiveTypeMappingsAndAnnotations() { + Map schema = ToolSchemaGenerator.generateSchema(AllTypesArgs.class); + + @SuppressWarnings("unchecked") + Map props = (Map) schema.get("properties"); + + // Type mappings + assertEquals("string", getType(props, "stringField")); + assertEquals("integer", getType(props, "intField")); + assertEquals("integer", getType(props, "integerField")); + assertEquals("integer", getType(props, "longField")); + assertEquals("number", getType(props, "doubleField")); + assertEquals("number", getType(props, "floatField")); + assertEquals("boolean", getType(props, "booleanField")); + assertEquals("boolean", getType(props, "booleanWrapperField")); + + // Descriptions preserved + assertEquals("a string field", getDescription(props, "stringField")); + assertEquals("an int field", getDescription(props, "intField")); + assertEquals("a boolean field", getDescription(props, "booleanField")); + + // Required fields + @SuppressWarnings("unchecked") + List required = (List) schema.get("required"); + assertNotNull(required); + assertTrue("stringField should be required", required.contains("stringField")); + assertFalse("intField should not be required", required.contains("intField")); + } + + // --- Edge cases --- + + public static class EmptyArgs {} + + @Test + public void testEmptyArgsClass() { + Map schema = ToolSchemaGenerator.generateSchema(EmptyArgs.class); + assertEquals("object", schema.get("type")); + } + + public static class NoRequiredArgs { + @JsonPropertyDescription("optional field") + public String name; + } + + @Test + public void testNoRequiredFields() { + Map schema = ToolSchemaGenerator.generateSchema(NoRequiredArgs.class); + @SuppressWarnings("unchecked") + List required = (List) schema.get("required"); + assertTrue(required == null || required.isEmpty()); + } + + @Test + public void testSchemaDoesNotContainDollarSchema() { + Map schema = ToolSchemaGenerator.generateSchema(AllTypesArgs.class); + assertFalse("$schema key should be stripped", schema.containsKey("$schema")); + } + + public static class ArrayArgs { + @JsonPropertyDescription("tags") + public String[] tags; + } + + @Test + public void testArrayTypeSupported() { + Map schema = ToolSchemaGenerator.generateSchema(ArrayArgs.class); + @SuppressWarnings("unchecked") + Map props = (Map) schema.get("properties"); + assertEquals("array", getType(props, "tags")); + } + + // --- Helpers --- + + @SuppressWarnings("unchecked") + private static String getType(Map props, String fieldName) { + Map prop = (Map) props.get(fieldName); + assertNotNull("Property " + fieldName + " should exist", prop); + return (String) prop.get("type"); + } + + @SuppressWarnings("unchecked") + private static String getDescription(Map props, String fieldName) { + Map prop = (Map) props.get(fieldName); + assertNotNull("Property " + fieldName + " should exist", prop); + return (String) prop.get("description"); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolTest.java new file mode 100644 index 00000000000..8e0a5cd01b0 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolTest.java @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool; + +import static org.junit.Assert.*; + +import com.openai.models.ChatModel; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import com.openai.models.chat.completions.ChatCompletionTool; +import java.io.File; +import java.sql.Connection; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import org.apache.kyuubi.engine.dataagent.tool.sql.RunMutationQueryTool; +import org.apache.kyuubi.engine.dataagent.tool.sql.RunSelectQueryTool; +import org.junit.After; +import org.junit.Test; +import org.sqlite.SQLiteDataSource; + +public class ToolTest { + + private final List tempFiles = new ArrayList<>(); + + @After + public void cleanup() { + tempFiles.forEach(File::delete); + } + + // --- ToolRegistry --- + + @Test + public void testRegistryBuildsChatCompletionToolSpecs() { + SQLiteDataSource ds = createDataSource(); + ToolRegistry registry = new ToolRegistry(30); + registry.register(new RunSelectQueryTool(ds, 30)); + registry.register(new RunMutationQueryTool(ds, 30)); + assertFalse(registry.isEmpty()); + + ChatCompletionCreateParams.Builder builder = + ChatCompletionCreateParams.builder().model(ChatModel.GPT_4O).addUserMessage("test"); + registry.addToolsTo(builder); + ChatCompletionCreateParams params = builder.build(); + + List tools = params.tools().orElse(Collections.emptyList()); + assertEquals(2, tools.size()); + + List names = new ArrayList<>(); + tools.forEach(t -> names.add(t.asFunction().function().name())); + assertTrue("Missing run_select_query", names.contains("run_select_query")); + assertTrue("Missing run_mutation_query", names.contains("run_mutation_query")); + } + + @Test + public void testRegistryExecuteToolDeserializesAndDelegates() { + SQLiteDataSource ds = createDataSource(); + setupTestTable(ds); + ToolRegistry registry = new ToolRegistry(30); + registry.register(new RunSelectQueryTool(ds, 30)); + + String result = + registry.executeTool("run_select_query", "{\"sql\": \"SELECT COUNT(*) FROM users\"}"); + assertTrue("Expected count of 3, got: " + result, result.contains("3")); + } + + @Test + public void testRegistryReturnsErrorForUnknownTool() { + ToolRegistry registry = new ToolRegistry(30); + String result = registry.executeTool("nonexistent", "{}"); + assertTrue(result.startsWith("Error: unknown tool")); + } + + // --- ToolRegistry timeout enforcement --- + + @Test + public void testRegistryTimeoutKillsSlowToolCall() { + ToolRegistry registry = new ToolRegistry(2); + registry.register( + new AgentTool() { + @Override + public String name() { + return "slow_tool"; + } + + @Override + public String description() { + return "a tool that sleeps forever"; + } + + @Override + public Class argsType() { + return ToolRegistryThreadSafetyTest.DummyArgs.class; + } + + @Override + public String execute(ToolRegistryThreadSafetyTest.DummyArgs args) { + try { + Thread.sleep(60_000); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + return "should not reach here"; + } + }); + long start = System.currentTimeMillis(); + String result = registry.executeTool("slow_tool", "{}"); + long elapsed = System.currentTimeMillis() - start; + assertTrue("Should contain timeout error", result.contains("timed out")); + assertTrue("Should contain tool name", result.contains("slow_tool")); + assertTrue("Should finish within ~3s (timeout=2s)", elapsed < 5000); + } + + @Test + public void testRegistryToolExceptionReturnsError() { + ToolRegistry registry = new ToolRegistry(30); + registry.register( + new AgentTool() { + @Override + public String name() { + return "boom"; + } + + @Override + public String description() { + return "always fails"; + } + + @Override + public Class argsType() { + return ToolRegistryThreadSafetyTest.DummyArgs.class; + } + + @Override + public String execute(ToolRegistryThreadSafetyTest.DummyArgs args) { + throw new RuntimeException("intentional failure"); + } + }); + String result = registry.executeTool("boom", "{}"); + assertTrue(result.startsWith("Error executing boom")); + assertTrue(result.contains("intentional failure")); + } + + @Test + public void testRegistryGetRiskLevelForUnknownToolReturnsSafe() { + ToolRegistry registry = new ToolRegistry(30); + assertEquals(ToolRiskLevel.SAFE, registry.getRiskLevel("nonexistent")); + } + + @Test + public void testRegistryGetRiskLevelForRegisteredTool() { + SQLiteDataSource ds = createDataSource(); + ToolRegistry registry = new ToolRegistry(30); + registry.register(new RunSelectQueryTool(ds, 30)); + registry.register(new RunMutationQueryTool(ds, 30)); + assertEquals(ToolRiskLevel.SAFE, registry.getRiskLevel("run_select_query")); + assertEquals(ToolRiskLevel.DESTRUCTIVE, registry.getRiskLevel("run_mutation_query")); + } + + @Test + public void testRegistryInvalidJsonReturnsError() { + SQLiteDataSource ds = createDataSource(); + ToolRegistry registry = new ToolRegistry(30); + registry.register(new RunSelectQueryTool(ds, 30)); + String result = registry.executeTool("run_select_query", "not valid json"); + assertTrue(result.startsWith("Error executing")); + } + + // --- DataSource isolation --- + + @Test + public void testDifferentDataSourcesAreIsolated() { + SQLiteDataSource ds1 = createDataSource(); + SQLiteDataSource ds2 = createDataSource(); + setupTable(ds1, "CREATE TABLE t1 (x TEXT)", "INSERT INTO t1 VALUES ('from-ds1')"); + setupTable(ds2, "CREATE TABLE t2 (y TEXT)", "INSERT INTO t2 VALUES ('from-ds2')"); + + ToolRegistry reg1 = new ToolRegistry(30); + reg1.register(new RunSelectQueryTool(ds1, 30)); + + ToolRegistry reg2 = new ToolRegistry(30); + reg2.register(new RunSelectQueryTool(ds2, 30)); + + assertTrue( + reg1.executeTool("run_select_query", "{\"sql\": \"SELECT x FROM t1\"}") + .contains("from-ds1")); + assertTrue( + reg2.executeTool("run_select_query", "{\"sql\": \"SELECT y FROM t2\"}") + .contains("from-ds2")); + + // ds1 does not have t2 + assertTrue( + reg1.executeTool("run_select_query", "{\"sql\": \"SELECT * FROM t2\"}").contains("Error:")); + // ds2 does not have t1 + assertTrue( + reg2.executeTool("run_select_query", "{\"sql\": \"SELECT * FROM t1\"}").contains("Error:")); + } + + // --- Helpers --- + + private SQLiteDataSource createDataSource() { + try { + File tmpFile = File.createTempFile("kyuubi-tool-test-", ".db"); + tmpFile.deleteOnExit(); + tempFiles.add(tmpFile); + SQLiteDataSource ds = new SQLiteDataSource(); + ds.setUrl("jdbc:sqlite:" + tmpFile.getAbsolutePath()); + return ds; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private void setupTestTable(SQLiteDataSource ds) { + setupTable( + ds, + "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, age INTEGER)", + "INSERT INTO users VALUES (1, 'Alice', 25), (2, 'Bob', 30), (3, 'Charlie', 35)"); + } + + private void setupTable(SQLiteDataSource ds, String ddl, String dml) { + try (Connection conn = ds.getConnection(); + Statement stmt = conn.createStatement()) { + stmt.execute(ddl); + stmt.execute(dml); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryToolTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryToolTest.java new file mode 100644 index 00000000000..c46fc2501bc --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryToolTest.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool.sql; + +import static org.junit.Assert.*; + +import java.io.File; +import java.sql.Connection; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.List; +import org.apache.kyuubi.engine.dataagent.tool.ToolRiskLevel; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.sqlite.SQLiteDataSource; + +/** Tests for RunMutationQueryTool. Uses real SQLite — no mocks. */ +public class RunMutationQueryToolTest { + + private static final int TEST_TIMEOUT_SECONDS = 30; + + private SQLiteDataSource ds; + private RunMutationQueryTool tool; + private final List tempFiles = new ArrayList<>(); + + @Before + public void setUp() { + ds = createDataSource(); + setupTable(ds); + tool = new RunMutationQueryTool(ds, TEST_TIMEOUT_SECONDS); + } + + @After + public void tearDown() { + tempFiles.forEach(File::delete); + } + + @Test + public void testRiskLevelDestructive() { + assertEquals(ToolRiskLevel.DESTRUCTIVE, tool.riskLevel()); + } + + @Test + public void testInsert() { + SqlQueryArgs args = new SqlQueryArgs(); + args.sql = "INSERT INTO t VALUES (9999, 'hello')"; + String result = tool.execute(args); + assertTrue(result.contains("1 row(s) affected")); + } + + @Test + public void testUpdate() { + SqlQueryArgs args = new SqlQueryArgs(); + args.sql = "UPDATE t SET v = 'updated' WHERE id = 1"; + assertTrue(tool.execute(args).contains("1 row(s) affected")); + } + + @Test + public void testDelete() { + SqlQueryArgs args = new SqlQueryArgs(); + args.sql = "DELETE FROM t WHERE id = 1"; + assertTrue(tool.execute(args).contains("1 row(s) affected")); + } + + @Test + public void testCreateTable() { + SqlQueryArgs args = new SqlQueryArgs(); + args.sql = "CREATE TABLE new_t (id INTEGER PRIMARY KEY, v TEXT)"; + assertTrue(tool.execute(args).contains("executed successfully")); + } + + @Test + public void testAlsoAcceptsSelect() { + // Mutation tool does not enforce read-only check; SELECT works fine here. + SqlQueryArgs args = new SqlQueryArgs(); + args.sql = "SELECT v FROM t WHERE id = 1"; + String result = tool.execute(args); + assertFalse(result.startsWith("Error:")); + } + + @Test + public void testRejectsEmptyAndNullSql() { + SqlQueryArgs emptyArgs = new SqlQueryArgs(); + emptyArgs.sql = ""; + assertTrue(tool.execute(emptyArgs).startsWith("Error:")); + + SqlQueryArgs nullArgs = new SqlQueryArgs(); + nullArgs.sql = null; + assertTrue(tool.execute(nullArgs).startsWith("Error:")); + } + + @Test + public void testInvalidSqlReturnsError() { + SqlQueryArgs args = new SqlQueryArgs(); + args.sql = "INSERT INTO nonexistent_table VALUES (1)"; + assertTrue(tool.execute(args).startsWith("Error:")); + } + + // --- Helpers --- + + private SQLiteDataSource createDataSource() { + try { + File tmpFile = File.createTempFile("kyuubi-mutation-test-", ".db"); + tmpFile.deleteOnExit(); + tempFiles.add(tmpFile); + SQLiteDataSource dataSource = new SQLiteDataSource(); + dataSource.setUrl("jdbc:sqlite:" + tmpFile.getAbsolutePath()); + return dataSource; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private void setupTable(SQLiteDataSource dataSource) { + try (Connection conn = dataSource.getConnection(); + Statement stmt = conn.createStatement()) { + stmt.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, v TEXT)"); + stmt.execute("INSERT INTO t VALUES (1, 'one'), (2, 'two'), (3, 'three')"); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryToolTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryToolTest.java new file mode 100644 index 00000000000..3c6579cf9a7 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryToolTest.java @@ -0,0 +1,354 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool.sql; + +import static org.junit.Assert.*; + +import java.io.File; +import java.sql.Connection; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.List; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.sqlite.SQLiteDataSource; + +/** Tests for RunSelectQueryTool. Uses real SQLite — no mocks. */ +public class RunSelectQueryToolTest { + + private static final int TEST_TIMEOUT_SECONDS = 30; + + private SQLiteDataSource ds; + private RunSelectQueryTool tool; + private final List tempFiles = new ArrayList<>(); + + @Before + public void setUp() { + ds = createDataSource(); + setupLargeTable(ds); + tool = new RunSelectQueryTool(ds, TEST_TIMEOUT_SECONDS); + } + + @After + public void tearDown() { + tempFiles.forEach(File::delete); + } + + // --- Read-only enforcement --- + + @Test + public void testRejectsInsert() { + SqlQueryArgs args = new SqlQueryArgs(); + args.sql = "INSERT INTO large_table VALUES (9999, 'x')"; + String result = tool.execute(args); + assertTrue(result.startsWith("Error:")); + assertTrue(result.contains("read-only")); + assertTrue(result.contains("run_mutation_query")); + } + + @Test + public void testRejectsUpdate() { + SqlQueryArgs args = new SqlQueryArgs(); + args.sql = "UPDATE large_table SET value = 'x' WHERE id = 1"; + assertTrue(tool.execute(args).startsWith("Error:")); + } + + @Test + public void testRejectsDelete() { + SqlQueryArgs args = new SqlQueryArgs(); + args.sql = "DELETE FROM large_table WHERE id = 1"; + assertTrue(tool.execute(args).startsWith("Error:")); + } + + @Test + public void testRejectsCreateTable() { + SqlQueryArgs args = new SqlQueryArgs(); + args.sql = "CREATE TABLE x (id INT)"; + assertTrue(tool.execute(args).startsWith("Error:")); + } + + @Test + public void testAllowsSelect() { + SqlQueryArgs args = new SqlQueryArgs(); + args.sql = "SELECT id FROM large_table LIMIT 100"; + String result = tool.execute(args); + assertFalse(result.startsWith("Error:")); + assertTrue(result.contains("[100 row(s) returned]")); + } + + @Test + public void testAllowsCte() { + SqlQueryArgs args = new SqlQueryArgs(); + args.sql = "WITH cte AS (SELECT id, value FROM large_table LIMIT 5) SELECT * FROM cte"; + String result = tool.execute(args); + assertFalse(result.startsWith("Error:")); + assertTrue(result.contains("row(s)")); + } + + // --- LIMIT is the LLM's responsibility (no client-side cap) --- + + @Test + public void testRespectsLimitInSql() { + SqlQueryArgs args = new SqlQueryArgs(); + args.sql = "SELECT id FROM large_table LIMIT 5"; + assertTrue(tool.execute(args).contains("[5 row(s) returned]")); + } + + @Test + public void testNoClientSideCapWhenLimitOmitted() { + // 1500 rows in table; tool MUST return all of them when no LIMIT is given. + // Cap discipline is delegated to the LLM via the system prompt. + SqlQueryArgs args = new SqlQueryArgs(); + args.sql = "SELECT id FROM large_table"; + assertTrue(tool.execute(args).contains("[1500 row(s) returned]")); + } + + // --- Zero-row result --- + + @Test + public void testZeroRowsResult() { + SqlQueryArgs args = new SqlQueryArgs(); + args.sql = "SELECT id FROM large_table WHERE id < 0"; + String result = tool.execute(args); + assertFalse(result.startsWith("Error:")); + assertTrue(result.contains("[0 row(s) returned]")); + } + + // --- Comment handling end-to-end --- + + @Test + public void testSelectWithLeadingBlockComment() { + SqlQueryArgs args = new SqlQueryArgs(); + args.sql = "/* get count */ SELECT COUNT(*) FROM large_table"; + assertFalse(tool.execute(args).startsWith("Error:")); + } + + @Test + public void testRejectsMutationHiddenBehindComment() { + SqlQueryArgs args = new SqlQueryArgs(); + args.sql = "-- looks innocent\nDROP TABLE large_table"; + assertTrue(tool.execute(args).startsWith("Error:")); + assertTrue(tool.execute(args).contains("read-only")); + } + + // --- Edge cases --- + + @Test + public void testRejectsEmptyAndNullSql() { + SqlQueryArgs emptyArgs = new SqlQueryArgs(); + emptyArgs.sql = ""; + assertTrue(tool.execute(emptyArgs).startsWith("Error:")); + + SqlQueryArgs nullArgs = new SqlQueryArgs(); + nullArgs.sql = null; + assertTrue(tool.execute(nullArgs).startsWith("Error:")); + } + + @Test + public void testRejectsWhitespaceOnlySql() { + SqlQueryArgs args = new SqlQueryArgs(); + args.sql = " \t\n "; + assertTrue(tool.execute(args).startsWith("Error:")); + } + + @Test + public void testInvalidSqlReturnsError() { + SqlQueryArgs args = new SqlQueryArgs(); + args.sql = "SELECT * FROM nonexistent_table"; + assertTrue(tool.execute(args).startsWith("Error:")); + } + + // --- Output formatting --- + + @Test + public void testNullValuesRenderedAsNULL() { + try (Connection conn = ds.getConnection(); + Statement stmt = conn.createStatement()) { + stmt.execute("CREATE TABLE nullable_test (id INTEGER, name TEXT)"); + stmt.execute("INSERT INTO nullable_test VALUES (1, NULL)"); + stmt.execute("INSERT INTO nullable_test VALUES (NULL, 'Alice')"); + } catch (Exception e) { + throw new RuntimeException(e); + } + SqlQueryArgs args = new SqlQueryArgs(); + args.sql = "SELECT id, name FROM nullable_test ORDER BY ROWID"; + String result = tool.execute(args); + assertTrue(result.contains("NULL")); + assertTrue(result.contains("Alice")); + } + + @Test + public void testPipeCharacterEscapedInOutput() { + try (Connection conn = ds.getConnection(); + Statement stmt = conn.createStatement()) { + stmt.execute("CREATE TABLE pipe_test (val TEXT)"); + stmt.execute("INSERT INTO pipe_test VALUES ('a|b|c')"); + } catch (Exception e) { + throw new RuntimeException(e); + } + SqlQueryArgs args = new SqlQueryArgs(); + args.sql = "SELECT val FROM pipe_test"; + String result = tool.execute(args); + assertTrue("Pipe should be escaped for markdown table", result.contains("a\\|b\\|c")); + } + + // --- Error formatting --- + + @Test + public void testExtractRootCauseFromNestedExceptions() { + SqlQueryArgs args = new SqlQueryArgs(); + args.sql = "SELECT * FROM this_table_does_not_exist_at_all"; + String result = tool.execute(args); + assertTrue(result.startsWith("Error:")); + assertTrue(result.contains("this_table_does_not_exist_at_all")); + } + + @Test + public void testErrorMessageIsConcise() { + SqlQueryArgs args = new SqlQueryArgs(); + args.sql = "SELEC INVALID SYNTAX HERE !!!"; + String result = tool.execute(args); + assertTrue(result.startsWith("Error:")); + long newlines = result.chars().filter(c -> c == '\n').count(); + assertTrue("Error should be concise (<=2 newlines), got " + newlines, newlines <= 2); + } + + // --- Query timeout --- + + @Test + public void testCustomQueryTimeout() { + RunSelectQueryTool customTool = new RunSelectQueryTool(ds, 5); + SqlQueryArgs args = new SqlQueryArgs(); + args.sql = "SELECT COUNT(*) FROM large_table"; + assertFalse(customTool.execute(args).startsWith("Error:")); + } + + @Test + public void testQueryTimeoutReturnsError() throws Exception { + javax.sql.DataSource slowDs = + new javax.sql.DataSource() { + @Override + public Connection getConnection() throws java.sql.SQLException { + Connection real = ds.getConnection(); + return (Connection) + java.lang.reflect.Proxy.newProxyInstance( + getClass().getClassLoader(), + new Class[] {Connection.class}, + (proxy, method, args) -> { + if ("createStatement".equals(method.getName())) { + Statement realStmt = real.createStatement(); + return java.lang.reflect.Proxy.newProxyInstance( + getClass().getClassLoader(), + new Class[] {Statement.class}, + (p2, m2, a2) -> { + if ("execute".equals(m2.getName())) { + realStmt.close(); + real.close(); + throw new java.sql.SQLTimeoutException( + "Query timed out after 1 seconds"); + } + return m2.invoke(realStmt, a2); + }); + } + if ("close".equals(method.getName())) { + real.close(); + return null; + } + return method.invoke(real, args); + }); + } + + @Override + public Connection getConnection(String u, String p) throws java.sql.SQLException { + return getConnection(); + } + + @Override + public java.io.PrintWriter getLogWriter() { + return null; + } + + @Override + public void setLogWriter(java.io.PrintWriter out) {} + + @Override + public void setLoginTimeout(int seconds) {} + + @Override + public int getLoginTimeout() { + return 0; + } + + @Override + public java.util.logging.Logger getParentLogger() { + return null; + } + + @Override + public T unwrap(Class iface) { + return null; + } + + @Override + public boolean isWrapperFor(Class iface) { + return false; + } + }; + + RunSelectQueryTool timeoutTool = new RunSelectQueryTool(slowDs, 1); + SqlQueryArgs args = new SqlQueryArgs(); + args.sql = "SELECT * FROM large_table"; + String result = timeoutTool.execute(args); + assertTrue("Expected error on timeout", result.startsWith("Error:")); + assertTrue("Expected timeout message", result.contains("timed out")); + } + + // --- Helpers --- + + private SQLiteDataSource createDataSource() { + try { + File tmpFile = File.createTempFile("kyuubi-select-test-", ".db"); + tmpFile.deleteOnExit(); + tempFiles.add(tmpFile); + SQLiteDataSource dataSource = new SQLiteDataSource(); + dataSource.setUrl("jdbc:sqlite:" + tmpFile.getAbsolutePath()); + return dataSource; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private void setupLargeTable(SQLiteDataSource dataSource) { + try (Connection conn = dataSource.getConnection(); + Statement stmt = conn.createStatement()) { + stmt.execute("CREATE TABLE large_table (id INTEGER PRIMARY KEY, value TEXT)"); + StringBuilder sb = new StringBuilder(); + for (int i = 1; i <= 1500; i++) { + if (sb.length() > 0) sb.append(","); + sb.append("(").append(i).append(", 'row-").append(i).append("')"); + if (i % 500 == 0) { + stmt.execute("INSERT INTO large_table VALUES " + sb); + sb.setLength(0); + } + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlReadOnlyCheckerTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlReadOnlyCheckerTest.java new file mode 100644 index 00000000000..97ad50b9f45 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlReadOnlyCheckerTest.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool.sql; + +import static org.junit.Assert.*; + +import org.junit.Test; + +public class SqlReadOnlyCheckerTest { + + @Test + public void testAcceptsSelect() { + assertTrue(SqlReadOnlyChecker.isReadOnly("SELECT 1")); + assertTrue(SqlReadOnlyChecker.isReadOnly("select * from t")); + assertTrue(SqlReadOnlyChecker.isReadOnly(" SELECT 1 ")); + } + + @Test + public void testAcceptsCte() { + assertTrue(SqlReadOnlyChecker.isReadOnly("WITH cte AS (SELECT id FROM t) SELECT * FROM cte")); + assertTrue(SqlReadOnlyChecker.isReadOnly("with x as (select 1) select * from x")); + } + + @Test + public void testAcceptsShowDescribeExplain() { + assertTrue(SqlReadOnlyChecker.isReadOnly("SHOW TABLES")); + assertTrue(SqlReadOnlyChecker.isReadOnly("SHOW CREATE TABLE t")); + assertTrue(SqlReadOnlyChecker.isReadOnly("SHOW PARTITIONS t")); + assertTrue(SqlReadOnlyChecker.isReadOnly("DESCRIBE t")); + assertTrue(SqlReadOnlyChecker.isReadOnly("DESC FORMATTED t")); + assertTrue(SqlReadOnlyChecker.isReadOnly("EXPLAIN SELECT 1")); + } + + @Test + public void testAcceptsBigDataKeywords() { + assertTrue(SqlReadOnlyChecker.isReadOnly("USE my_db")); + assertTrue(SqlReadOnlyChecker.isReadOnly("VALUES (1, 2), (3, 4)")); + assertTrue(SqlReadOnlyChecker.isReadOnly("TABLE my_table")); + assertTrue(SqlReadOnlyChecker.isReadOnly("FROM t SELECT *")); + assertTrue(SqlReadOnlyChecker.isReadOnly("LIST FILE")); + assertTrue(SqlReadOnlyChecker.isReadOnly("LIST JAR")); + assertTrue(SqlReadOnlyChecker.isReadOnly("HELP")); + } + + @Test + public void testRejectsMutations() { + assertFalse(SqlReadOnlyChecker.isReadOnly("INSERT INTO t VALUES (1)")); + assertFalse(SqlReadOnlyChecker.isReadOnly("UPDATE t SET x = 1")); + assertFalse(SqlReadOnlyChecker.isReadOnly("DELETE FROM t")); + assertFalse(SqlReadOnlyChecker.isReadOnly("MERGE INTO t USING s ON ...")); + assertFalse(SqlReadOnlyChecker.isReadOnly("CREATE TABLE t (x INT)")); + assertFalse(SqlReadOnlyChecker.isReadOnly("DROP TABLE t")); + assertFalse(SqlReadOnlyChecker.isReadOnly("ALTER TABLE t ADD COLUMN y INT")); + assertFalse(SqlReadOnlyChecker.isReadOnly("TRUNCATE TABLE t")); + assertFalse(SqlReadOnlyChecker.isReadOnly("GRANT SELECT ON t TO user")); + assertFalse(SqlReadOnlyChecker.isReadOnly("ANALYZE TABLE t COMPUTE STATISTICS")); + assertFalse(SqlReadOnlyChecker.isReadOnly("REFRESH TABLE t")); + assertFalse(SqlReadOnlyChecker.isReadOnly("SET k = v")); + } + + @Test + public void testRejectsEmptyAndNull() { + assertFalse(SqlReadOnlyChecker.isReadOnly(null)); + assertFalse(SqlReadOnlyChecker.isReadOnly("")); + assertFalse(SqlReadOnlyChecker.isReadOnly(" ")); + assertFalse(SqlReadOnlyChecker.isReadOnly("123")); + } + + @Test + public void testStripsLineComments() { + assertTrue(SqlReadOnlyChecker.isReadOnly("-- a comment\nSELECT 1")); + assertTrue(SqlReadOnlyChecker.isReadOnly("-- comment 1\n-- comment 2\nSELECT 1")); + assertFalse(SqlReadOnlyChecker.isReadOnly("-- innocent\nDROP TABLE t")); + } + + @Test + public void testStripsBlockComments() { + assertTrue(SqlReadOnlyChecker.isReadOnly("/* hi */ SELECT 1")); + assertTrue(SqlReadOnlyChecker.isReadOnly("/* multi\nline */SELECT 1")); + assertFalse(SqlReadOnlyChecker.isReadOnly("/* sneaky */ DELETE FROM t")); + } + + @Test + public void testHandlesUnterminatedBlockComment() { + assertFalse(SqlReadOnlyChecker.isReadOnly("/* never ends")); + } + + @Test + public void testStripsMixedComments() { + assertTrue(SqlReadOnlyChecker.isReadOnly("/* block */ -- line\nSELECT 1")); + assertTrue(SqlReadOnlyChecker.isReadOnly("-- line\n/* block */\nSELECT 1")); + assertFalse(SqlReadOnlyChecker.isReadOnly("/* block */ -- line\nINSERT INTO t VALUES (1)")); + } + + @Test + public void testKeywordFollowedByParen() { + assertTrue(SqlReadOnlyChecker.isReadOnly("SELECT(1)")); + assertTrue(SqlReadOnlyChecker.isReadOnly("SHOW;")); + assertFalse(SqlReadOnlyChecker.isReadOnly("INSERT(1, 2)")); + } + + @Test + public void testLineCommentAtEndNoNewline() { + // "SELECT 1 -- comment" without trailing newline — still read-only + assertTrue(SqlReadOnlyChecker.isReadOnly("SELECT 1 -- comment")); + } + + @Test + public void testTabsAndCarriageReturns() { + assertTrue(SqlReadOnlyChecker.isReadOnly("\t\r\n SELECT 1")); + assertFalse(SqlReadOnlyChecker.isReadOnly("\t\r\n DROP TABLE t")); + } +} diff --git a/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala b/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala index 05e93212990..905aa46d176 100644 --- a/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala +++ b/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala @@ -3824,9 +3824,28 @@ object KyuubiConf { val ENGINE_DATA_AGENT_QUERY_TIMEOUT: ConfigEntry[Long] = buildConf("kyuubi.engine.data.agent.query.timeout") - .doc("The query execution timeout for the Data Agent SQL tool.") + .doc("The JDBC query execution timeout for the Data Agent SQL tools. " + + "Passed to Statement.setQueryTimeout so the server (Spark/Trino/...) " + + "can cooperatively cancel long-running queries and release cluster resources. " + + "Should be set lower than " + + "kyuubi.engine.data.agent.tool.call.timeout so server-side cancellation " + + "has time to react before the outer wall-clock cap fires.") .version("1.12.0") .timeConf + .checkValue(_ >= 1000, "must >= 1s") + .createWithDefaultString("PT3M") + + val ENGINE_DATA_AGENT_TOOL_CALL_TIMEOUT: ConfigEntry[Long] = + buildConf("kyuubi.engine.data.agent.tool.call.timeout") + .doc("The maximum wall-clock execution time for any tool call in the Data Agent " + + "engine. Acts as the outer safety net enforced by the agent runtime via " + + "Future.cancel(), applied uniformly to every tool. " + + "For SQL tools the inner JDBC-level timeout is controlled separately by " + + "kyuubi.engine.data.agent.query.timeout, which should be set lower " + + "so server-side cancellation has time to react before this hard cap fires.") + .version("1.12.0") + .timeConf + .checkValue(_ >= 1000, "must >= 1s") .createWithDefaultString("PT5M") val ENGINE_DATA_AGENT_JDBC_URL: OptionalConfigEntry[String] = diff --git a/pom.xml b/pom.xml index db9b46e9132..aa6c1a3865d 100644 --- a/pom.xml +++ b/pom.xml @@ -188,6 +188,8 @@ 4.11.0 4.2.7.Final 0.12.0 + 4.30.0 + 4.37.0 2.9.0 0.8.2 paimon-spark-${spark.binary.version} @@ -218,6 +220,7 @@ 3.46.1.3 2.2.0 2.2.1 + 2.0.3 0.44.1 1.7.0 @@ -580,6 +583,12 @@ + + org.testcontainers + testcontainers-mysql + ${testcontainers-java.version} + + com.dimafeng testcontainers-scala-scalatest_${scala.binary.version} @@ -1201,6 +1210,24 @@ ${openai.java.version} + + com.openai + openai-java + ${openai.sdk.version} + + + + com.github.victools + jsonschema-generator + ${victools.jsonschema.version} + + + + com.github.victools + jsonschema-module-jackson + ${victools.jsonschema.version} + + org.threeten threeten-extra