diff --git a/src/Core/PostgreSQLProtocol.cpp b/src/Core/PostgreSQLProtocol.cpp index 553d195605..1febcc0711 100644 --- a/src/Core/PostgreSQLProtocol.cpp +++ b/src/Core/PostgreSQLProtocol.cpp @@ -22,6 +22,13 @@ ColumnTypeSpec convertTypeIndexToPostgresColumnTypeSpec(TypeIndex type_index) case TypeIndex::Int64: return {ColumnType::INT8, 8}; + case TypeIndex::UInt64: + case TypeIndex::Int128: + case TypeIndex::UInt128: + case TypeIndex::Int256: + case TypeIndex::UInt256: + return {ColumnType::NUMERIC, -1}; + case TypeIndex::Float32: return {ColumnType::FLOAT4, 4}; case TypeIndex::Float64: @@ -32,16 +39,35 @@ ColumnTypeSpec convertTypeIndexToPostgresColumnTypeSpec(TypeIndex type_index) return {ColumnType::VARCHAR, -1}; case TypeIndex::Date: + case TypeIndex::Date32: return {ColumnType::DATE, 4}; + case TypeIndex::DateTime: + return {ColumnType::TIMESTAMP, 8}; + + case TypeIndex::DateTime64: + return {ColumnType::TIMESTAMPTZ, 8}; + case TypeIndex::Decimal32: case TypeIndex::Decimal64: case TypeIndex::Decimal128: + case TypeIndex::Decimal256: return {ColumnType::NUMERIC, -1}; case TypeIndex::UUID: return {ColumnType::UUID, 16}; + case TypeIndex::Enum8: + case TypeIndex::Enum16: + return {ColumnType::VARCHAR, -1}; + + case TypeIndex::Map: + return {ColumnType::JSONB, -1}; + + case TypeIndex::Array: + case TypeIndex::Tuple: + return {ColumnType::VARCHAR, -1}; + default: return {ColumnType::VARCHAR, -1}; } diff --git a/src/Core/PostgreSQLProtocol.h b/src/Core/PostgreSQLProtocol.h index d873c89b75..60e7191072 100644 --- a/src/Core/PostgreSQLProtocol.h +++ b/src/Core/PostgreSQLProtocol.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -46,6 +47,7 @@ enum class FrontMessageType : Int32 PARSE = 'P', BIND = 'B', DESCRIBE = 'D', + EXECUTE = 'E', SYNC = 'S', FLUSH = 'H', CLOSE = 'C', @@ -122,16 +124,21 @@ enum class MessageType : Int32 //// Column 'typelem' from 'pg_type' table. NB: not all types are compatible with PostgreSQL's ones enum class ColumnType : Int32 { + BOOL = 16, CHAR = 18, INT8 = 20, INT2 = 21, INT4 = 23, + TEXT = 25, FLOAT4 = 700, FLOAT8 = 701, VARCHAR = 1043, DATE = 1082, + TIMESTAMP = 1114, + TIMESTAMPTZ = 1184, NUMERIC = 1700, UUID = 2950, + JSONB = 3802, }; class ColumnTypeSpec @@ -793,6 +800,266 @@ class CommandComplete : BackendMessage } }; +// Extended query protocol front messages + +class Parse : public FrontMessage +{ +public: + String statement_name; + String query; + std::vector param_oids; + + void deserialize(ReadBuffer & in) override + { + Int32 sz; + readBinaryBigEndian(sz, in); + Int32 remaining = sz - 4; + + readNullTerminated(statement_name, in); + remaining -= static_cast(statement_name.size() + 1); + + readNullTerminated(query, in); + remaining -= static_cast(query.size() + 1); + + Int16 num_params; + readBinaryBigEndian(num_params, in); + remaining -= 2; + + param_oids.resize(num_params); + for (Int16 i = 0; i < num_params; ++i) + { + readBinaryBigEndian(param_oids[i], in); + remaining -= 4; + } + + if (remaining > 0) + in.ignore(remaining); + } + + MessageType getMessageType() const override { return MessageType::PARSE; } +}; + +class Bind : public FrontMessage +{ +public: + String portal_name; + String statement_name; + std::vector param_format_codes; + std::vector> param_values; + std::vector result_format_codes; + + void deserialize(ReadBuffer & in) override + { + Int32 sz; + readBinaryBigEndian(sz, in); + + readNullTerminated(portal_name, in); + readNullTerminated(statement_name, in); + + Int16 num_format_codes; + readBinaryBigEndian(num_format_codes, in); + param_format_codes.resize(num_format_codes); + for (Int16 i = 0; i < num_format_codes; ++i) + readBinaryBigEndian(param_format_codes[i], in); + + Int16 num_params; + readBinaryBigEndian(num_params, in); + param_values.resize(num_params); + for (Int16 i = 0; i < num_params; ++i) + { + Int32 len; + readBinaryBigEndian(len, in); + if (len == -1) + { + param_values[i] = std::nullopt; // NULL parameter + } + else + { + String val; + val.resize(len); + in.readStrict(val.data(), len); + param_values[i] = std::move(val); + } + } + + Int16 num_result_codes; + readBinaryBigEndian(num_result_codes, in); + result_format_codes.resize(num_result_codes); + for (Int16 i = 0; i < num_result_codes; ++i) + readBinaryBigEndian(result_format_codes[i], in); + } + + MessageType getMessageType() const override { return MessageType::BIND; } +}; + +class Describe : public FrontMessage +{ +public: + char describe_type = 0; // 'S' = statement, 'P' = portal + String name; + + void deserialize(ReadBuffer & in) override + { + Int32 sz; + readBinaryBigEndian(sz, in); + in.readStrict(describe_type); + readNullTerminated(name, in); + } + + MessageType getMessageType() const override { return MessageType::DESCRIBE; } +}; + +class Execute : public FrontMessage +{ +public: + String portal_name; + Int32 max_rows = 0; + + void deserialize(ReadBuffer & in) override + { + Int32 sz; + readBinaryBigEndian(sz, in); + readNullTerminated(portal_name, in); + readBinaryBigEndian(max_rows, in); + } + + MessageType getMessageType() const override { return MessageType::EXECUTE; } +}; + +class CloseMsg : public FrontMessage +{ +public: + char close_type = 0; // 'S' or 'P' + String name; + + void deserialize(ReadBuffer & in) override + { + Int32 sz; + readBinaryBigEndian(sz, in); + in.readStrict(close_type); + readNullTerminated(name, in); + } + + MessageType getMessageType() const override { return MessageType::CLOSE; } +}; + +class FlushMsg : public FrontMessage +{ +public: + void deserialize(ReadBuffer & in) override + { + Int32 sz; + readBinaryBigEndian(sz, in); + } + + MessageType getMessageType() const override { return MessageType::FLUSH; } +}; + +class SyncMsg : public FrontMessage +{ +public: + void deserialize(ReadBuffer & in) override + { + Int32 sz; + readBinaryBigEndian(sz, in); + } + + MessageType getMessageType() const override { return MessageType::SYNC; } +}; + +// Extended query protocol backend messages + +class ParseComplete : public BackendMessage +{ +public: + void serialize(WriteBuffer & out) const override + { + out.write('1'); + writeBinaryBigEndian(size(), out); + } + + Int32 size() const override { return 4; } + MessageType getMessageType() const override { return MessageType::PARSE_COMPLETE; } +}; + +class BindComplete : public BackendMessage +{ +public: + void serialize(WriteBuffer & out) const override + { + out.write('2'); + writeBinaryBigEndian(size(), out); + } + + Int32 size() const override { return 4; } + MessageType getMessageType() const override { return MessageType::BIND_COMPLETE; } +}; + +class CloseCompleteMsg : public BackendMessage +{ +public: + void serialize(WriteBuffer & out) const override + { + out.write('3'); + writeBinaryBigEndian(size(), out); + } + + Int32 size() const override { return 4; } + MessageType getMessageType() const override { return MessageType::CLOSE_COMPLETE; } +}; + +class NoDataMsg : public BackendMessage +{ +public: + void serialize(WriteBuffer & out) const override + { + out.write('n'); + writeBinaryBigEndian(size(), out); + } + + Int32 size() const override { return 4; } + MessageType getMessageType() const override { return MessageType::NODATA; } +}; + +class ParameterDescription : public BackendMessage +{ +private: + std::vector param_oids; + +public: + ParameterDescription() = default; + explicit ParameterDescription(std::vector oids) : param_oids(std::move(oids)) {} + + void serialize(WriteBuffer & out) const override + { + out.write('t'); + writeBinaryBigEndian(size(), out); + writeBinaryBigEndian(static_cast(param_oids.size()), out); + for (const auto & oid : param_oids) + writeBinaryBigEndian(oid, out); + } + + Int32 size() const override + { + return static_cast(4 + 2 + param_oids.size() * 4); + } + + MessageType getMessageType() const override { return MessageType::PARAMETER_DESCRIPTION; } +}; + +class PortalSuspended : public BackendMessage +{ +public: + void serialize(WriteBuffer & out) const override + { + out.write('s'); + writeBinaryBigEndian(size(), out); + } + + Int32 size() const override { return 4; } + MessageType getMessageType() const override { return MessageType::PORTAL_SUSPENDED; } +}; + } namespace PGAuthentication diff --git a/src/Server/PostgreSQLHandler.cpp b/src/Server/PostgreSQLHandler.cpp index 6d3e40c1f3..a97c208c56 100644 --- a/src/Server/PostgreSQLHandler.cpp +++ b/src/Server/PostgreSQLHandler.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -39,6 +40,7 @@ PostgreSQLHandler::PostgreSQLHandler( , ssl_enabled(ssl_enabled_) , connection_id(connection_id_) , authentication_manager(auth_methods_) + , stmt_manager(std::make_unique()) { changeIO(socket()); } @@ -67,7 +69,11 @@ void PostgreSQLHandler::run() while (tcp_server.isOpen()) { - message_transport->send(PostgreSQLProtocol::Messaging::ReadyForQuery(), true); + if (send_ready_for_query) + { + message_transport->send(PostgreSQLProtocol::Messaging::ReadyForQuery(), true); + send_ready_for_query = false; + } constexpr size_t connection_check_timeout = 1; // 1 second while (!in->poll(1000000 * connection_check_timeout)) @@ -77,28 +83,45 @@ void PostgreSQLHandler::run() if (!tcp_server.isOpen()) return; + + /// Skip messages until Sync after error in extended query + if (extended_query_error + && message_type != PostgreSQLProtocol::Messaging::FrontMessageType::SYNC + && message_type != PostgreSQLProtocol::Messaging::FrontMessageType::TERMINATE) + { + message_transport->dropMessage(); + continue; + } + switch (message_type) { case PostgreSQLProtocol::Messaging::FrontMessageType::QUERY: processQuery(); + send_ready_for_query = true; break; case PostgreSQLProtocol::Messaging::FrontMessageType::TERMINATE: LOG_DEBUG(log, "Client closed the connection"); return; case PostgreSQLProtocol::Messaging::FrontMessageType::PARSE: + processParse(); + break; case PostgreSQLProtocol::Messaging::FrontMessageType::BIND: + processBind(); + break; case PostgreSQLProtocol::Messaging::FrontMessageType::DESCRIBE: + processDescribe(); + break; case PostgreSQLProtocol::Messaging::FrontMessageType::SYNC: + processSync(); + break; case PostgreSQLProtocol::Messaging::FrontMessageType::FLUSH: + processFlush(); + break; + case PostgreSQLProtocol::Messaging::FrontMessageType::EXECUTE: + processExecute(); + break; case PostgreSQLProtocol::Messaging::FrontMessageType::CLOSE: - message_transport->send( - PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse( - PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse::ERROR, - "0A000", - "proton doesn't support extended query mechanism"), - true); - LOG_ERROR(log, "Client tried to access via extended query protocol"); - message_transport->dropMessage(); + processClose(); break; default: message_transport->send( @@ -109,6 +132,7 @@ void PostgreSQLHandler::run() true); LOG_ERROR(log, "Command is not supported. Command code {:d}", static_cast(message_type)); message_transport->dropMessage(); + send_ready_for_query = true; } } } @@ -285,6 +309,9 @@ void PostgreSQLHandler::processQuery() return; } + if (tryAnswerCatalogQuery(query->query)) + return; + const auto & settings = session->sessionContext()->getSettingsRef(); std::vector queries; auto parse_res = splitMultipartQuery(query->query, queries, @@ -333,4 +360,384 @@ bool PostgreSQLHandler::isEmptyQuery(const String & query) return regex.match(query); } +void PostgreSQLHandler::processParse() +{ + try + { + auto msg = message_transport->receive(); + stmt_manager->parseStatement(msg->statement_name, msg->query, msg->param_oids); + message_transport->send(PostgreSQLProtocol::Messaging::ParseComplete()); + } + catch (const Exception & e) + { + extended_query_error = true; + message_transport->send( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse::ERROR, + "42601", "Parse failed.\n" + e.displayText()), + true); + } +} + +void PostgreSQLHandler::processBind() +{ + try + { + auto msg = message_transport->receive(); + stmt_manager->bindPortal( + msg->portal_name, msg->statement_name, + msg->param_values, msg->result_format_codes); + message_transport->send(PostgreSQLProtocol::Messaging::BindComplete()); + } + catch (const Exception & e) + { + extended_query_error = true; + message_transport->send( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse::ERROR, + "42601", "Bind failed.\n" + e.displayText()), + true); + } +} + +void PostgreSQLHandler::processDescribe() +{ + try + { + auto msg = message_transport->receive(); + + if (msg->describe_type == 'S') + { + const auto * stmt = stmt_manager->getStatement(msg->name); + if (!stmt) + { + message_transport->send( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse::ERROR, + "26000", "Prepared statement does not exist"), + true); + return; + } + std::vector oids; + if (!stmt->param_oids.empty()) + oids = stmt->param_oids; + else + oids.assign(stmt->param_count, 0); // 0 = unspecified type + message_transport->send(PostgreSQLProtocol::Messaging::ParameterDescription(std::move(oids))); + message_transport->send(PostgreSQLProtocol::Messaging::NoDataMsg()); + } + else if (msg->describe_type == 'P') + { + message_transport->send(PostgreSQLProtocol::Messaging::NoDataMsg()); + } + } + catch (const Exception & e) + { + extended_query_error = true; + message_transport->send( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse::ERROR, + "26000", "Describe failed.\n" + e.displayText()), + true); + } +} + +void PostgreSQLHandler::processExecute() +{ + try + { + auto msg = message_transport->receive(); + const auto * portal = stmt_manager->getPortal(msg->portal_name); + + if (!portal) + { + message_transport->send( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse::ERROR, + "34000", "Portal does not exist"), + true); + return; + } + + const String & query = portal->bound_query; + + if (isEmptyQuery(query)) + { + message_transport->send(PostgreSQLProtocol::Messaging::EmptyQueryResponse()); + return; + } + + if (tryAnswerCatalogQuery(query)) + return; + + bool psycopg2_cond = query == "BEGIN" || query == "COMMIT" || query == "ROLLBACK"; + bool jdbc_cond = query.find("SET extra_float_digits") != String::npos + || query.find("SET application_name") != String::npos; + if (psycopg2_cond || jdbc_cond) + { + message_transport->send( + PostgreSQLProtocol::Messaging::CommandComplete( + PostgreSQLProtocol::Messaging::CommandComplete::classifyQuery(query), 0)); + return; + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dis(0, INT32_MAX); + secret_key = dis(gen); + + auto query_context = session->makeQueryContext(); + query_context->setCurrentQueryId(fmt::format("postgres:{:d}:{:d}", connection_id, secret_key)); + + CurrentThread::QueryScope query_scope{query_context}; + ReadBufferFromString read_buf(query); + executeQuery(read_buf, *out, false, query_context, {}); + + PostgreSQLProtocol::Messaging::CommandComplete::Command command = + PostgreSQLProtocol::Messaging::CommandComplete::classifyQuery(query); + message_transport->send(PostgreSQLProtocol::Messaging::CommandComplete(command, 0), true); + } + catch (const Exception & e) + { + extended_query_error = true; + message_transport->send( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse::ERROR, + "2F000", "Execute failed.\n" + e.displayText()), + true); + } } + +void PostgreSQLHandler::processSync() +{ + try + { + auto msg = message_transport->receive(); + extended_query_error = false; + message_transport->send(PostgreSQLProtocol::Messaging::ReadyForQuery(), true); + } + catch (const Exception & e) + { + extended_query_error = false; + message_transport->send( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse::ERROR, + "XX000", "Sync failed.\n" + e.displayText()), + true); + message_transport->send(PostgreSQLProtocol::Messaging::ReadyForQuery(), true); + } +} + +void PostgreSQLHandler::processClose() +{ + try + { + auto msg = message_transport->receive(); + if (msg->close_type == 'S') + stmt_manager->closeStatement(msg->name); + else if (msg->close_type == 'P') + stmt_manager->closePortal(msg->name); + + message_transport->send(PostgreSQLProtocol::Messaging::CloseCompleteMsg()); + } + catch (const Exception & e) + { + message_transport->send( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse::ERROR, + "XX000", "Close failed.\n" + e.displayText()), + true); + } +} + +void PostgreSQLHandler::processFlush() +{ + try + { + auto msg = message_transport->receive(); + message_transport->flush(); + } + catch (const Exception & e) + { + message_transport->send( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse::ERROR, + "XX000", "Flush failed.\n" + e.displayText()), + true); + } +} + +bool PostgreSQLHandler::tryAnswerCatalogQuery(const String & query) +{ + String upper_query = query; + std::transform(upper_query.begin(), upper_query.end(), upper_query.begin(), + [](unsigned char c) { return std::toupper(c); }); + + if (upper_query.find("SET ") == 0) + { + message_transport->send( + PostgreSQLProtocol::Messaging::CommandComplete( + PostgreSQLProtocol::Messaging::CommandComplete::Command::SELECT, 0)); + return true; + } + + if (upper_query.find("RESET ") == 0 || upper_query.find("DISCARD ") == 0) + { + message_transport->send( + PostgreSQLProtocol::Messaging::CommandComplete( + PostgreSQLProtocol::Messaging::CommandComplete::Command::SELECT, 0)); + return true; + } + + if (upper_query.find("SHOW ") == 0) + { + if (upper_query.find("SERVER_VERSION") != String::npos) + { + std::vector columns; + columns.emplace_back("server_version", TypeIndex::String); + message_transport->send(PostgreSQLProtocol::Messaging::RowDescription(columns)); + + std::vector> row; + row.push_back(std::make_shared("14.0")); + message_transport->send(PostgreSQLProtocol::Messaging::DataRow(row)); + + message_transport->send( + PostgreSQLProtocol::Messaging::CommandComplete( + PostgreSQLProtocol::Messaging::CommandComplete::Command::SELECT, 1)); + return true; + } + + if (upper_query.find("TRANSACTION_ISOLATION") != String::npos) + { + std::vector columns; + columns.emplace_back("transaction_isolation", TypeIndex::String); + message_transport->send(PostgreSQLProtocol::Messaging::RowDescription(columns)); + + std::vector> row; + row.push_back(std::make_shared("read committed")); + message_transport->send(PostgreSQLProtocol::Messaging::DataRow(row)); + + message_transport->send( + PostgreSQLProtocol::Messaging::CommandComplete( + PostgreSQLProtocol::Messaging::CommandComplete::Command::SELECT, 1)); + return true; + } + + if (upper_query.find("STANDARD_CONFORMING_STRINGS") != String::npos) + { + std::vector columns; + columns.emplace_back("standard_conforming_strings", TypeIndex::String); + message_transport->send(PostgreSQLProtocol::Messaging::RowDescription(columns)); + + std::vector> row; + row.push_back(std::make_shared("on")); + message_transport->send(PostgreSQLProtocol::Messaging::DataRow(row)); + + message_transport->send( + PostgreSQLProtocol::Messaging::CommandComplete( + PostgreSQLProtocol::Messaging::CommandComplete::Command::SELECT, 1)); + return true; + } + + message_transport->send( + PostgreSQLProtocol::Messaging::CommandComplete( + PostgreSQLProtocol::Messaging::CommandComplete::Command::SELECT, 0)); + return true; + } + + if (upper_query.find("DEALLOCATE") == 0) + { + auto pos = query.find_first_of(" \t", 10); + if (pos != String::npos) + { + String stmt_name = query.substr(pos + 1); + while (!stmt_name.empty() && (stmt_name.back() == ';' || stmt_name.back() == ' ')) + stmt_name.pop_back(); + String upper_name = stmt_name; + std::transform(upper_name.begin(), upper_name.end(), upper_name.begin(), + [](unsigned char c) { return std::toupper(c); }); + if (upper_name != "ALL") + stmt_manager->closeStatement(stmt_name); + else + stmt_manager->clearAll(); + } + message_transport->send( + PostgreSQLProtocol::Messaging::CommandComplete( + PostgreSQLProtocol::Messaging::CommandComplete::Command::SELECT, 0)); + return true; + } + + if (upper_query.find("VERSION()") != String::npos) + { + std::vector columns; + columns.emplace_back("version", TypeIndex::String); + message_transport->send(PostgreSQLProtocol::Messaging::RowDescription(columns)); + + std::vector> row; + row.push_back(std::make_shared( + "Proton (PostgreSQL compatible, ClickHouse based)")); + message_transport->send(PostgreSQLProtocol::Messaging::DataRow(row)); + + message_transport->send( + PostgreSQLProtocol::Messaging::CommandComplete( + PostgreSQLProtocol::Messaging::CommandComplete::Command::SELECT, 1)); + return true; + } + + if (upper_query.find("CURRENT_SCHEMA") != String::npos) + { + std::vector columns; + columns.emplace_back("current_schema", TypeIndex::String); + message_transport->send(PostgreSQLProtocol::Messaging::RowDescription(columns)); + + std::vector> row; + row.push_back(std::make_shared("default")); + message_transport->send(PostgreSQLProtocol::Messaging::DataRow(row)); + + message_transport->send( + PostgreSQLProtocol::Messaging::CommandComplete( + PostgreSQLProtocol::Messaging::CommandComplete::Command::SELECT, 1)); + return true; + } + + if (upper_query == "BEGIN" || upper_query == "COMMIT" || upper_query == "ROLLBACK" + || upper_query == "END" || upper_query.find("SAVEPOINT ") == 0 + || upper_query.find("RELEASE ") == 0) + { + message_transport->send( + PostgreSQLProtocol::Messaging::CommandComplete( + PostgreSQLProtocol::Messaging::CommandComplete::classifyQuery(query), 0)); + return true; + } + + if (upper_query.find("LISTEN ") == 0 || upper_query.find("UNLISTEN ") == 0) + { + message_transport->send( + PostgreSQLProtocol::Messaging::CommandComplete( + PostgreSQLProtocol::Messaging::CommandComplete::Command::SELECT, 0)); + return true; + } + + if (upper_query.find("PG_CATALOG") != String::npos || upper_query.find("PG_TYPE") != String::npos + || upper_query.find("PG_NAMESPACE") != String::npos || upper_query.find("PG_CLASS") != String::npos + || upper_query.find("PG_ATTRIBUTE") != String::npos || upper_query.find("PG_DESCRIPTION") != String::npos + || upper_query.find("PG_CONSTRAINT") != String::npos || upper_query.find("PG_INDEX") != String::npos + || upper_query.find("PG_DATABASE") != String::npos || upper_query.find("PG_ROLES") != String::npos + || upper_query.find("PG_SETTINGS") != String::npos || upper_query.find("PG_AM") != String::npos + || upper_query.find("PG_PROC") != String::npos || upper_query.find("PG_TABLES") != String::npos + || upper_query.find("PG_STAT") != String::npos || upper_query.find("PG_EXTENSION") != String::npos + || upper_query.find("PG_COLLATION") != String::npos || upper_query.find("PG_MATVIEWS") != String::npos + || upper_query.find("PG_SHDESCRIPTION") != String::npos + || upper_query.find("INFORMATION_SCHEMA") != String::npos) + { + message_transport->send( + PostgreSQLProtocol::Messaging::CommandComplete( + PostgreSQLProtocol::Messaging::CommandComplete::Command::SELECT, 0)); + return true; + } + + return false; +} + +} + diff --git a/src/Server/PostgreSQLHandler.h b/src/Server/PostgreSQLHandler.h index 7fca167b30..3e772ad557 100644 --- a/src/Server/PostgreSQLHandler.h +++ b/src/Server/PostgreSQLHandler.h @@ -5,6 +5,7 @@ #include #include #include "IServer.h" +#include "PostgreSQLPreparedStatement.h" #if USE_SSL # include @@ -75,7 +76,22 @@ class PostgreSQLHandler : public Poco::Net::TCPServerConnection void processQuery(); + void processParse(); + void processBind(); + void processDescribe(); + void processExecute(); + void processSync(); + void processClose(); + void processFlush(); + + bool tryAnswerCatalogQuery(const String & query); + static bool isEmptyQuery(const String & query); + + std::unique_ptr stmt_manager; + + bool extended_query_error = false; + bool send_ready_for_query = true; }; } diff --git a/src/Server/PostgreSQLPreparedStatement.h b/src/Server/PostgreSQLPreparedStatement.h new file mode 100644 index 0000000000..d5c605bcfe --- /dev/null +++ b/src/Server/PostgreSQLPreparedStatement.h @@ -0,0 +1,229 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} + +struct PreparedStatement +{ + String query; + std::vector param_oids; + size_t param_count = 0; +}; + +struct Portal +{ + String statement_name; + String bound_query; + std::vector result_formats; +}; + +/// Manages named and unnamed prepared statements and portals for a single +/// PostgreSQL connection. Each connection gets its own manager instance. +class PreparedStatementManager +{ +public: + PreparedStatementManager() = default; + + void parseStatement(const String & name, const String & query, const std::vector & param_oids) + { + PreparedStatement stmt; + stmt.query = query; + stmt.param_oids = param_oids; + stmt.param_count = countParameters(query); + + const size_t pc = stmt.param_count; + const String log_query = query; + + if (name.empty()) + unnamed_statement = std::move(stmt); + else + named_statements[name] = std::move(stmt); + + LOG_DEBUG(log, "Parsed statement '{}': {} params, query: {}", + name.empty() ? "" : name, pc, log_query); + } + + void bindPortal( + const String & portal_name, + const String & stmt_name, + const std::vector> & param_values, + const std::vector & result_formats) + { + const auto * stmt = getStatement(stmt_name); + if (!stmt) + throw Exception(ErrorCodes::LOGICAL_ERROR, + "Prepared statement '{}' does not exist", stmt_name); + + Portal portal; + portal.statement_name = stmt_name; + portal.bound_query = substituteParams(stmt->query, param_values); + portal.result_formats = result_formats; + + const String log_bound_query = portal.bound_query; + + if (portal_name.empty()) + unnamed_portal = std::move(portal); + else + named_portals[portal_name] = std::move(portal); + + LOG_DEBUG(log, "Bound portal '{}' from statement '{}': {}", + portal_name.empty() ? "" : portal_name, + stmt_name.empty() ? "" : stmt_name, + log_bound_query); + } + + const PreparedStatement * getStatement(const String & name) const + { + if (name.empty()) + return unnamed_statement ? &*unnamed_statement : nullptr; + + auto it = named_statements.find(name); + return it != named_statements.end() ? &it->second : nullptr; + } + + const Portal * getPortal(const String & name) const + { + if (name.empty()) + return unnamed_portal ? &*unnamed_portal : nullptr; + + auto it = named_portals.find(name); + return it != named_portals.end() ? &it->second : nullptr; + } + + void closeStatement(const String & name) + { + if (name.empty()) + unnamed_statement.reset(); + else + named_statements.erase(name); + } + + void closePortal(const String & name) + { + if (name.empty()) + unnamed_portal.reset(); + else + named_portals.erase(name); + } + + void clearAll() + { + named_statements.clear(); + named_portals.clear(); + unnamed_statement.reset(); + unnamed_portal.reset(); + } + +private: + LoggerPtr log = getLogger("PreparedStatementManager"); + + std::optional unnamed_statement; + std::optional unnamed_portal; + std::unordered_map named_statements; + std::unordered_map named_portals; + + static size_t countParameters(const String & query) + { + size_t max_param = 0; + for (size_t i = 0; i < query.size(); ++i) + { + if (query[i] == '$' && i + 1 < query.size() && std::isdigit(query[i + 1])) + { + size_t num = 0; + size_t j = i + 1; + while (j < query.size() && std::isdigit(query[j])) + { + num = num * 10 + (query[j] - '0'); + ++j; + } + if (num > max_param) + max_param = num; + } + } + return max_param; + } + + static String substituteParams(const String & query, const std::vector> & values) + { + String result; + result.reserve(query.size() * 2); + + bool in_string_literal = false; + + for (size_t i = 0; i < query.size(); ++i) + { + if (query[i] == '\'' && !in_string_literal) + { + in_string_literal = true; + result += query[i]; + continue; + } + if (query[i] == '\'' && in_string_literal) + { + if (i + 1 < query.size() && query[i + 1] == '\'') + { + result += "''"; + ++i; + continue; + } + in_string_literal = false; + result += query[i]; + continue; + } + + if (!in_string_literal && query[i] == '$' && i + 1 < query.size() && std::isdigit(query[i + 1])) + { + size_t num = 0; + size_t j = i + 1; + while (j < query.size() && std::isdigit(query[j])) + { + num = num * 10 + (query[j] - '0'); + ++j; + } + + if (num >= 1 && num <= values.size()) + { + const auto & val = values[num - 1]; + if (!val.has_value()) + { + result += "NULL"; + } + else + { + result += '\''; + for (char c : *val) + { + if (c == '\'') + result += "''"; + else if (c == '\\') + result += "\\\\"; + else + result += c; + } + result += '\''; + } + i = j - 1; // skip past the $N + continue; + } + } + result += query[i]; + } + return result; + } +}; + +} diff --git a/src/Storages/PruneShards.cpp b/src/Storages/PruneShards.cpp index 1c5ecff09e..3f5cd9f068 100644 --- a/src/Storages/PruneShards.cpp +++ b/src/Storages/PruneShards.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -257,6 +258,22 @@ QueryMode getQueryMode(ConstStoragePtr storage, const SelectQueryInfo & query_in if (require_back_fill_from_historical) { + /// For time-based seeks, try resolving via NativeLog first. + /// If the streaming store still has the data, skip the expensive + /// historical MergeTree scan entirely. + if (query_info.seek_to_info->isTimeBased()) + { + if (const auto * stream_storage = dynamic_cast(storage.get())) + { + if (auto resolved = stream_storage->tryResolveTimeSeekViaStreamingStore(query_info.seek_to_info)) + { + query_info.seek_to_info->seek_points = std::move(*resolved); + query_info.seek_to_info->type = SeekToType::SEQUENCE_NUMBER; + return QueryMode::Streaming; + } + } + } + /// By default, we will seek to earliest for backfill concat if (query_info.seek_to_info->getSeekTo().empty()) query_info.seek_to_info->seek_points = {cluster::Constants::EarliestSN}; diff --git a/src/Storages/Stream/StorageStream.cpp b/src/Storages/Stream/StorageStream.cpp index 2dd4f8404f..2ec30a3fd2 100644 --- a/src/Storages/Stream/StorageStream.cpp +++ b/src/Storages/Stream/StorageStream.cpp @@ -1682,4 +1682,52 @@ IStorage::SnapshotDataWithExpiration StorageStream::readSnapshot(const Names & r return {std::move(result), std::move(snapshot_expired)}; } + +std::optional> StorageStream::tryResolveTimeSeekViaStreamingStore(const SeekToInfoPtr & seek_to_info) const +{ + if (!seek_to_info || !seek_to_info->isTimeBased()) + return std::nullopt; + + auto local_shards = stream_shards; + if (local_shards.empty()) + return std::nullopt; + + try + { + for (const auto & shard : local_shards) + { + if (shard->isVirtualReplica() || shard->isInmemory()) + continue; + + auto seek_copy = std::make_shared(*seek_to_info); + seek_copy->replicateForShards(shards); + + auto resolved_sns = shard->sequencesForTimestamps(seek_copy->getSeekPoints()); + + /// Check all resolved SNs are still available in NativeLog + bool all_available = true; + for (UInt32 i = 0; i < shards && all_available; ++i) + { + auto range = local_shards[i]->sequenceRange(); + if (range.first < 0 || resolved_sns[i] < range.first) + all_available = false; + } + + if (all_available) + { + LOG_INFO(log, "Time-based seek resolved via streaming store, skipping historical backfill"); + return resolved_sns; + } + + LOG_DEBUG(log, "Time-based seek data partially compacted, falling back to historical backfill"); + return std::nullopt; + } + } + catch (...) + { + LOG_DEBUG(log, "Failed to resolve time-based seek via streaming store, falling back to historical backfill"); + } + + return std::nullopt; +} } diff --git a/src/Storages/Stream/StorageStream.h b/src/Storages/Stream/StorageStream.h index 08677d8373..552e42dffc 100644 --- a/src/Storages/Stream/StorageStream.h +++ b/src/Storages/Stream/StorageStream.h @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -269,6 +270,11 @@ class StorageStream final : public shared_ptr_helper, public Merg std::vector getLastSNs() const; + /// For time-based seek_to, probe NativeLog to check if the streaming store + /// still has the requested data. Returns resolved sequence numbers per shard + /// if available, std::nullopt if the data has been compacted away. + std::optional> tryResolveTimeSeekViaStreamingStore(const SeekToInfoPtr & seek_to_info) const; + bool supportsStreamingQuery() const override { return true; } friend class StreamSink;