diff --git a/.github/workflows/bedrock.yml b/.github/workflows/bedrock.yml index b942133d9..dd997748c 100644 --- a/.github/workflows/bedrock.yml +++ b/.github/workflows/bedrock.yml @@ -62,6 +62,8 @@ jobs: - name: Mark failure if debugging if: runner.debug == '1' run: exit 1 + - name: Install zstd + run: apt-get update -y && apt-get install -y libzstd-dev - name: Build Bedrock run: "./ci_build.sh" - name: Upload binaries @@ -94,6 +96,8 @@ jobs: with: repository: Expensify/Bedrock path: . + - name: Install zstd + run: apt-get update -y && apt-get install -y libzstd-dev - name: Download binaries uses: ./.github/actions/composite/download-binaries - name: Setup tmate session @@ -129,6 +133,8 @@ jobs: with: repository: Expensify/Bedrock path: . + - name: Install zstd + run: apt-get update -y && apt-get install -y libzstd-dev - name: Download binaries uses: ./.github/actions/composite/download-binaries - name: Setup tmate session diff --git a/BedrockPlugin.cpp b/BedrockPlugin.cpp index ca904c520..37a451357 100644 --- a/BedrockPlugin.cpp +++ b/BedrockPlugin.cpp @@ -81,6 +81,10 @@ void BedrockPlugin::upgradeDatabase(SQLite& db) { } +void BedrockPlugin::initializeFromDB(SQLite& db) +{ +} + bool BedrockPlugin::shouldLockCommitPageOnConflict(const string& conflictLocation) const { return true; diff --git a/BedrockPlugin.h b/BedrockPlugin.h index 11d7f2975..33c6ab412 100644 --- a/BedrockPlugin.h +++ b/BedrockPlugin.h @@ -38,6 +38,10 @@ class BedrockPlugin { // Called at some point during initiation to allow the plugin to verify/change the database schema. virtual void upgradeDatabase(SQLite& db); + // Called once after upgradeDatabase has completed, allowing plugins to read from the DB at startup. + // This runs on the sync thread before any commands are processed. + virtual void initializeFromDB(SQLite& db); + // The plugin can register any number of timers it wants. When any of them `ding`, then the `timerFired` // function will be called, and passed the timer that is dinging. set timers; diff --git a/BedrockServer.cpp b/BedrockServer.cpp index a06cd46d9..b3862bf4d 100644 --- a/BedrockServer.cpp +++ b/BedrockServer.cpp @@ -106,10 +106,16 @@ void BedrockServer::sync() // We use fewer FDs on test machines that have other resource restrictions in place. + SQLite::journalZstdDictionaryID = args.calc("-journalZstdDictionaryID"); SINFO("Setting dbPool size to: " << _dbPoolSize); _dbPool = make_shared(_dbPoolSize, args["-db"], args.calc("-cacheSize"), args.calc("-maxJournalSize"), journalTables, mmapSizeGB, args.isSet("-newDBsUseHctree"), args["-checkpointMode"]); SQLite& db = _dbPool->getBase(); + // Allow plugins to read from the DB at startup. + for (auto plugin : plugins) { + plugin.second->initializeFromDB(db); + } + // Initialize the command processor. BedrockCore core(db, *this); diff --git a/Makefile b/Makefile index cf3fb3d45..64bc3e2b8 100644 --- a/Makefile +++ b/Makefile @@ -22,7 +22,7 @@ INTERMEDIATEDIR = .build # We use the same library paths and required libraries for all binaries. LIBPATHS =-L$(PROJECT) -Lmbedtls/library -LIBRARIES =-Wl,--start-group -lbedrock -lstuff -Wl,--end-group -ldl -lpcre2-8 -lpthread -lmbedtls -lmbedx509 -lmbedcrypto -lz -lm +LIBRARIES =-Wl,--start-group -lbedrock -lstuff -Wl,--end-group -ldl -lpcre2-8 -lpthread -lmbedtls -lmbedx509 -lmbedcrypto -lz -lzstd -lm # These targets aren't actual files. .PHONY: all test clustertest clean testplugin diff --git a/ci_style.sh b/ci_style.sh index 5d448fd6b..67cd69b75 100755 --- a/ci_style.sh +++ b/ci_style.sh @@ -50,7 +50,7 @@ do fi # Counts occurrences of std:: that aren't in comments and have a leading space (except if it's inside pointer brackets, eg: ) - RETURN_VAL=$(sed -n '/^.*\/\/.*/!s/ std:://p; /^.* std::.*\/\//s/ std:://p; /^.*\<.*std::.*\>/s/std:://p;' "${FILENAME}" | wc -l) + RETURN_VAL=$(sed -n '/^.*\/\/.*/!s/ std:://p; /^.* std::.*\/\//s/ std:://p; /^.*\<.* std::.*\>/s/ std:://p;' "${FILENAME}" | wc -l) if [[ $RETURN_VAL -gt 0 ]] && [[ "$FAILED" != "true" ]]; then echo -e "${RED}${OUT} failed style checks, do not use std:: prefix.${RESET}" EXIT_VAL=$RETURN_VAL diff --git a/docker/Dockerfile b/docker/Dockerfile index ad8673311..02095e71c 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -23,7 +23,7 @@ RUN apt-key adv --keyserver keyserver.ubuntu.com --recv-keys BA9EF27F && \ # Update and install dependencies RUN apt-fast update && apt-fast install -y \ wget gnupg software-properties-common lsb-release ccache zlib1g-dev rsyslog cmake \ - libpcre2-dev libpcre3-dev libsodium-dev libgpgme11-dev libstdc++-13-dev make \ + libpcre2-dev libpcre3-dev libsodium-dev libgpgme11-dev libstdc++-13-dev libzstd-dev make \ linux-headers-generic git clang-18 lldb-18 lld-18 clangd-18 clang-tidy-18 \ clang-format-18 clang-tools-18 llvm-18-dev llvm-18-tools libomp-18-dev libc++-18-dev \ libc++abi-18-dev libclang-common-18-dev libclang-18-dev libclang-cpp18-dev libunwind-18-dev @@ -69,7 +69,7 @@ ARG DEBIAN_FRONTEND=noninteractive # Install necessary packages RUN apt-get update && apt-get install -y software-properties-common && \ - apt-get install -y build-essential libpcre++ zlib1g && \ + apt-get install -y build-essential libpcre++ zlib1g libzstd1 && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* && \ rm -rf /etc/apt/sources.list.d/* diff --git a/main.cpp b/main.cpp index 0eff7efc1..d53885273 100644 --- a/main.cpp +++ b/main.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -103,6 +104,9 @@ set loadPlugins(SData& args) BedrockPlugin::g_registeredPluginList.emplace(make_pair("MYSQL", [](BedrockServer& s){ return new BedrockPlugin_MySQL(s); })); + BedrockPlugin::g_registeredPluginList.emplace(make_pair("ZSTD", [](BedrockServer& s){ + return new BedrockPlugin_Zstd(s); + })); for (string pluginName : plugins) { // If it's one of our standard plugins, just move on to the next one. @@ -324,11 +328,12 @@ int main(int argc, char* argv[]) SETDEFAULT("-controlPort", "localhost:9999"); SETDEFAULT("-nodeName", SGetHostName()); SETDEFAULT("-cacheSize", SToStr(0)); - SETDEFAULT("-plugins", "db,jobs,cache,mysql"); + SETDEFAULT("-plugins", "db,jobs,cache,mysql,zstd"); SETDEFAULT("-priority", "100"); SETDEFAULT("-maxJournalSize", "1000000"); SETDEFAULT("-queryLog", "queryLog.csv"); SETDEFAULT("-enableMultiWrite", "true"); + SETDEFAULT("-journalZstdDictionaryID", "0"); // We default to PASSIVE checkpoint everywhere as that has been the value proven to work fine for many years. SETDEFAULT("-checkpointMode", "PASSIVE"); diff --git a/plugins/Compression.cpp b/plugins/Compression.cpp new file mode 100644 index 000000000..5ad4c99f9 --- /dev/null +++ b/plugins/Compression.cpp @@ -0,0 +1,313 @@ +#include "Compression.h" + +#include +#include +#include + +const string BedrockPlugin_Zstd::name("Zstd"); +map BedrockPlugin_Zstd::_dictionaries; + +const string& BedrockPlugin_Zstd::getName() const +{ + return name; +} + +BedrockPlugin_Zstd::BedrockPlugin_Zstd(BedrockServer& s) : + BedrockPlugin(s) +{ +} + +BedrockPlugin_Zstd::~BedrockPlugin_Zstd() +{ + for (auto& [id, dicts] : _dictionaries) { + if (dicts.compression) { + ZSTD_freeCDict(dicts.compression); + } + if (dicts.decompression) { + ZSTD_freeDDict(dicts.decompression); + } + } + _dictionaries.clear(); +} + +unique_ptr BedrockPlugin_Zstd::getCommand(SQLiteCommand&& baseCommand) +{ + return nullptr; +} + +void BedrockPlugin_Zstd::initializeFromDB(SQLite& db) +{ + loadDictionariesFromDB(db); +} + +void BedrockPlugin_Zstd::upgradeDatabase(SQLite& db) +{ + bool created; + SASSERT(db.verifyTable("zstdDictionaries", + "CREATE TABLE zstdDictionaries (" + "dictionaryID INTEGER PRIMARY KEY, " + "description TEXT, " + "dictionary BLOB" + ")", + created)); +} + +ZSTD_CDict* BedrockPlugin_Zstd::getCompressionDictionary(size_t id) +{ + auto it = _dictionaries.find(id); + if (it != _dictionaries.end()) { + return it->second.compression; + } + return nullptr; +} + +ZSTD_DDict* BedrockPlugin_Zstd::getDecompressionDictionary(size_t id) +{ + auto it = _dictionaries.find(id); + if (it != _dictionaries.end()) { + return it->second.decompression; + } + return nullptr; +} + +void BedrockPlugin_Zstd::loadDictionariesFromDB(SQLite& db) +{ + SQResult result; + if (!db.read("SELECT dictionaryID, dictionary FROM zstdDictionaries;", result)) { + SWARN("Failed to read zstdDictionaries table."); + return; + } + + for (size_t i = 0; i < result.size(); i++) { + size_t id = SToUInt64(result[i][0]); + const string dictData = result[i][1]; + + if (dictData.empty()) { + SWARN("Empty dictionary data for dictionaryID " << id << ", skipping."); + continue; + } + + int compressionLevel = 3; + ZSTD_CDict* cdict = ZSTD_createCDict(dictData.data(), dictData.size(), compressionLevel); + ZSTD_DDict* ddict = ZSTD_createDDict(dictData.data(), dictData.size()); + + if (!cdict || !ddict) { + SWARN("Failed to compile dictionary " << id); + if (cdict) { + ZSTD_freeCDict(cdict); + } + if (ddict) { + ZSTD_freeDDict(ddict); + } + continue; + } + + _dictionaries[id] = {cdict, ddict}; + SINFO("Loaded zstd dictionary " << id << " (" << dictData.size() << " bytes)"); + } + + SINFO("Loaded " << _dictionaries.size() << " zstd dictionaries from DB."); +} + +// SQLite UDF implementations + +static void sqliteCompress(sqlite3_context* ctx, int argc, sqlite3_value** argv) +{ + if (argc != 2) { + sqlite3_result_error(ctx, "compress() requires 2 arguments: data and dictionaryID", -1); + return; + } + + // If data is NULL, return NULL. + if (sqlite3_value_type(argv[0]) == SQLITE_NULL) { + sqlite3_result_null(ctx); + return; + } + + // Get dictionary ID. If NULL or 0, return data unchanged. + if (sqlite3_value_type(argv[1]) == SQLITE_NULL) { + sqlite3_result_value(ctx, argv[0]); + return; + } + size_t dictId = (size_t) sqlite3_value_int64(argv[1]); + if (dictId == 0) { + sqlite3_result_value(ctx, argv[0]); + return; + } + + // Look up the compiled compression dictionary. + ZSTD_CDict* cdict = BedrockPlugin_Zstd::getCompressionDictionary(dictId); + if (!cdict) { + string err = "compress(): no dictionary found for ID " + to_string(dictId); + sqlite3_result_error(ctx, err.c_str(), -1); + return; + } + + // Get source data. + const void* src = sqlite3_value_blob(argv[0]); + int srcLen = sqlite3_value_bytes(argv[0]); + if (srcLen == 0) { + sqlite3_result_value(ctx, argv[0]); + return; + } + + // Create compression context. + ZSTD_CCtx* cctx = ZSTD_createCCtx(); + if (!cctx) { + sqlite3_result_error(ctx, "compress(): failed to create compression context", -1); + return; + } + + // Enable dictionary ID in compressed output, disable checksums to save space. + ZSTD_CCtx_setParameter(cctx, ZSTD_c_dictIDFlag, 1); + ZSTD_CCtx_setParameter(cctx, ZSTD_c_checksumFlag, 0); + + // Allocate output buffer. + size_t dstCap = ZSTD_compressBound(srcLen); + void* dst = sqlite3_malloc64(dstCap); + if (!dst) { + ZSTD_freeCCtx(cctx); + sqlite3_result_error_nomem(ctx); + return; + } + + // Compress. + size_t compressedSize = ZSTD_compress_usingCDict(cctx, dst, dstCap, src, srcLen, cdict); + ZSTD_freeCCtx(cctx); + + if (ZSTD_isError(compressedSize)) { + sqlite3_free(dst); + string err = "compress(): " + string(ZSTD_getErrorName(compressedSize)); + sqlite3_result_error(ctx, err.c_str(), -1); + return; + } + + sqlite3_result_blob(ctx, dst, (int) compressedSize, sqlite3_free); +} + +static void sqliteDecompress(sqlite3_context* ctx, int argc, sqlite3_value** argv) +{ + if (argc != 1) { + sqlite3_result_error(ctx, "decompress() requires 1 argument", -1); + return; + } + + // If data is NULL, return NULL. + if (sqlite3_value_type(argv[0]) == SQLITE_NULL) { + sqlite3_result_null(ctx); + return; + } + + const void* src = sqlite3_value_blob(argv[0]); + int srcLen = sqlite3_value_bytes(argv[0]); + + // If not a zstd frame, return data unchanged (backward compatible with uncompressed data). + if (srcLen == 0 || !ZSTD_isFrame(src, srcLen)) { + sqlite3_result_value(ctx, argv[0]); + return; + } + + // Get dictionary ID from the compressed frame. + unsigned dictId = ZSTD_getDictID_fromFrame(src, srcLen); + ZSTD_DDict* ddict = nullptr; + if (dictId != 0) { + ddict = BedrockPlugin_Zstd::getDecompressionDictionary(dictId); + if (!ddict) { + string err = "decompress(): no dictionary found for ID " + to_string(dictId); + sqlite3_result_error(ctx, err.c_str(), -1); + return; + } + } + + // Get the decompressed size. + unsigned long long decompressedSize = ZSTD_getFrameContentSize(src, srcLen); + if (decompressedSize == ZSTD_CONTENTSIZE_UNKNOWN || decompressedSize == ZSTD_CONTENTSIZE_ERROR) { + sqlite3_result_error(ctx, "decompress(): unable to determine decompressed size", -1); + return; + } + + // Allocate output buffer. + void* dst = sqlite3_malloc64(decompressedSize); + if (!dst) { + sqlite3_result_error_nomem(ctx); + return; + } + + // Decompress. + ZSTD_DCtx* dctx = ZSTD_createDCtx(); + if (!dctx) { + sqlite3_free(dst); + sqlite3_result_error(ctx, "decompress(): failed to create decompression context", -1); + return; + } + + size_t actualSize = ZSTD_decompress_usingDDict(dctx, dst, decompressedSize, src, srcLen, ddict); + ZSTD_freeDCtx(dctx); + + if (ZSTD_isError(actualSize)) { + sqlite3_free(dst); + string err = "decompress(): " + string(ZSTD_getErrorName(actualSize)); + sqlite3_result_error(ctx, err.c_str(), -1); + return; + } + + sqlite3_result_blob(ctx, dst, (int) actualSize, sqlite3_free); +} + +void BedrockPlugin_Zstd::registerSQLite(sqlite3* db) +{ + sqlite3_create_function_v2(db, "compress", 2, SQLITE_UTF8 | SQLITE_DETERMINISTIC, + nullptr, ::sqliteCompress, nullptr, nullptr, nullptr); + sqlite3_create_function_v2(db, "decompress", 1, SQLITE_UTF8 | SQLITE_DETERMINISTIC, + nullptr, ::sqliteDecompress, nullptr, nullptr, nullptr); +} + +string BedrockPlugin_Zstd::decompress(const string& input) +{ + if (input.empty()) { + return input; + } + + // If not a zstd frame, return unchanged. + if (!ZSTD_isFrame(input.data(), input.size())) { + return input; + } + + // Get dictionary ID from the compressed frame. + unsigned dictId = ZSTD_getDictID_fromFrame(input.data(), input.size()); + ZSTD_DDict* ddict = nullptr; + if (dictId != 0) { + ddict = BedrockPlugin_Zstd::getDecompressionDictionary(dictId); + if (!ddict) { + SWARN("decompress(): no dictionary found for ID " << dictId); + return input; + } + } + + // Get the decompressed size. + unsigned long long decompressedSize = ZSTD_getFrameContentSize(input.data(), input.size()); + if (decompressedSize == ZSTD_CONTENTSIZE_UNKNOWN || decompressedSize == ZSTD_CONTENTSIZE_ERROR) { + SWARN("decompress(): unable to determine decompressed size"); + return input; + } + + // Decompress. + string output(decompressedSize, '\0'); + ZSTD_DCtx* dctx = ZSTD_createDCtx(); + if (!dctx) { + SWARN("decompress(): failed to create decompression context"); + return input; + } + + size_t actualSize = ZSTD_decompress_usingDDict(dctx, output.data(), decompressedSize, + input.data(), input.size(), ddict); + ZSTD_freeDCtx(dctx); + + if (ZSTD_isError(actualSize)) { + SWARN("decompress(): " << ZSTD_getErrorName(actualSize)); + return input; + } + + output.resize(actualSize); + return output; +} diff --git a/plugins/Compression.h b/plugins/Compression.h new file mode 100644 index 000000000..70ef1c5d7 --- /dev/null +++ b/plugins/Compression.h @@ -0,0 +1,47 @@ +#pragma once +#include +#define ZSTD_STATIC_LINKING_ONLY +#include +#include "../BedrockPlugin.h" + +// Forward-declare sqlite3 types to avoid forcing all consumers to include sqlite3 headers. +struct sqlite3; + +class BedrockPlugin_Zstd : public BedrockPlugin { +public: + BedrockPlugin_Zstd(BedrockServer& s); + ~BedrockPlugin_Zstd(); + + virtual const string& getName() const; + virtual void upgradeDatabase(SQLite& db); + virtual void initializeFromDB(SQLite& db); + virtual unique_ptr getCommand(SQLiteCommand&& baseCommand); + + // Returns the compiled compression dictionary for the given ID, or nullptr if not found. + static ZSTD_CDict* getCompressionDictionary(size_t id); + + // Returns the compiled decompression dictionary for the given ID, or nullptr if not found. + static ZSTD_DDict* getDecompressionDictionary(size_t id); + + // Loads all dictionaries from the zstdDictionaries table into compiled in-memory maps. + // Called once at startup from the sync thread, before any queries run. + static void loadDictionariesFromDB(SQLite& db); + + // Register the compress(data, dictID) and decompress(data) SQLite UDFs. + static void registerSQLite(sqlite3* db); + + // Non-SQL decompression for use in the synchronization path. + // Returns decompressed data if input is a zstd frame, otherwise returns input unchanged. + static string decompress(const string& input); + + static const string name; + +private: + struct ZDictionaries + { + ZSTD_CDict* compression = nullptr; + ZSTD_DDict* decompression = nullptr; + }; + + static map _dictionaries; +}; diff --git a/sqlitecluster/SQLite.cpp b/sqlitecluster/SQLite.cpp index 7e7767878..e928de09e 100644 --- a/sqlitecluster/SQLite.cpp +++ b/sqlitecluster/SQLite.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -13,6 +14,7 @@ // Tracing can only be enabled or disabled globally, not per object. atomic SQLite::enableTrace(false); +atomic SQLite::journalZstdDictionaryID(0); sqlite3* SQLite::getDBHandle() { @@ -54,7 +56,7 @@ const set& SQLite::getTablesUsed() const return _tablesUsed; } -SQLite::SharedData& SQLite::initializeSharedData(sqlite3* db, const string& filename, const vector& journalNames, bool hctree) +SQLite::SharedData& SQLite::initializeSharedData() { static struct SharedDataLookupMapType { @@ -70,30 +72,30 @@ SQLite::SharedData& SQLite::initializeSharedData(sqlite3* db, const string& file static mutex instantiationMutex; lock_guard lock(instantiationMutex); - auto sharedDataIterator = sharedDataLookupMap.m.find(filename); + auto sharedDataIterator = sharedDataLookupMap.m.find(_filename); if (sharedDataIterator == sharedDataLookupMap.m.end()) { SharedData* sharedData = new SharedData(); // This is never deleted. // Look up the existing wal setting for this DB. SQResult result; - SQuery(db, "PRAGMA journal_mode;", result); + SQuery(_db, "PRAGMA journal_mode;", result); bool isDBCurrentlyUsingWAL2 = result.size() && result[0][0] == "wal2"; // If the intended wal setting doesn't match the existing wal setting, change it. - if (!hctree && !isDBCurrentlyUsingWAL2) { - SASSERT(!SQuery(db, "PRAGMA journal_mode = delete;", result)); - SASSERT(!SQuery(db, "PRAGMA journal_mode = WAL2;", result)); + if (!_hctree && !isDBCurrentlyUsingWAL2) { + SASSERT(!SQuery(_db, "PRAGMA journal_mode = delete;", result)); + SASSERT(!SQuery(_db, "PRAGMA journal_mode = WAL2;", result)); } // Read the highest commit count from the database, and store it in commitCount. - string query = "SELECT MAX(maxIDs) FROM (" + _getJournalQuery(journalNames, {"SELECT MAX(id) as maxIDs FROM"}, true) + ")"; - SASSERT(!SQuery(db, query, result)); + string query = "SELECT MAX(maxIDs) FROM (" + _getJournalQuery(_journalNames, {"SELECT MAX(id) as maxIDs FROM"}, true) + ")"; + SASSERT(!SQuery(_db, query, result)); uint64_t commitCount = result.empty() ? 0 : SToUInt64(result[0][0]); sharedData->commitCount = commitCount; // And then read the hash for that transaction. - string lastCommittedHash, ignore; - getCommit(db, journalNames, commitCount, ignore, lastCommittedHash); + string lastCommittedHash; + getCommit(commitCount, nullptr, &lastCommittedHash); sharedData->lastCommittedHash.store(lastCommittedHash); // If we have a commit count, we should have a hash as well. @@ -102,7 +104,7 @@ SQLite::SharedData& SQLite::initializeSharedData(sqlite3* db, const string& file } // Insert our SharedData object into the global map. - sharedDataLookupMap.m.emplace(filename, sharedData); + sharedDataLookupMap.m.emplace(_filename, sharedData); return *sharedData; } else { // Otherwise, use the existing one. @@ -241,6 +243,9 @@ void SQLite::commonConstructorInitialization(bool hctree) // Register application-defined deburr function. SDeburr::registerSQLite(_db); + // Register compress/decompress functions for zstd compression. + BedrockPlugin_Zstd::registerSQLite(_db); + // We saw queries where the progress counter never exceeds 551,000, so we're setting it to a lower number // based on Richard Hipp's recommendation. sqlite3_progress_handler(_db, 100'000, _progressHandlerCallback, this); @@ -262,7 +267,8 @@ SQLite::SQLite(const string& filename, int cacheSize, int maxJournalSize, _hctree(validateDBFormat(_filename, hctree)), _db(initializeDB(_filename, mmapSizeGB, _hctree)), _journalNames(initializeJournal(_db, minJournalTables)), - _sharedData(initializeSharedData(_db, _filename, _journalNames, _hctree)), + // Note that it's significant that _sharedData is initialized after _hctree, _db, and _journalNames. Re-ordereing these initialzations will break this. + _sharedData(initializeSharedData()), _transactionTimer("transaction timer"), _cacheSize(cacheSize), _mmapSizeGB(mmapSizeGB), @@ -838,8 +844,9 @@ bool SQLite::prepare(uint64_t* transactionID, string* transactionhash) *transactionhash = _uncommittedHash; } - // Create our query. - string query = "INSERT INTO " + _journalName + " VALUES (" + SQ(commitCount + 1) + ", " + SQ(_uncommittedQuery) + ", " + SQ(_uncommittedHash) + " )"; + // Create our query. Wrap _uncommittedQuery with compress() for zstd compression. + // When journalZstdDictionaryID is 0 (the default), compress() returns data unchanged. + string query = "INSERT INTO " + _journalName + " VALUES (" + SQ(commitCount + 1) + ", compress(" + SQ(_uncommittedQuery) + ", " + SQ(journalZstdDictionaryID.load()) + "), " + SQ(_uncommittedHash) + " )"; // These are the values we're currently operating on, until we either commit or rollback. _sharedData.prepareTransactionInfo(commitCount + 1, _uncommittedQuery, _uncommittedHash, _dbCountAtStart); @@ -1093,33 +1100,24 @@ void SQLite::logLastTransactionTiming(const string& message, const string& comma }); } -bool SQLite::getCommit(uint64_t id, string& query, string& hash) +bool SQLite::getCommit(uint64_t id, string* query, string* hash) { - return getCommit(_db, _journalNames, id, query, hash); -} - -bool SQLite::getCommit(sqlite3* db, const vector& journalNames, uint64_t id, string& query, string& hash) -{ - // TODO: This can fail if called after `BEGIN TRANSACTION`, if the id we want to look up was committed by another - // thread. We may or may never need to handle this case. - // Look up the query and hash for the given commit - string internalQuery = _getJournalQuery(journalNames, {"SELECT query, hash FROM", "WHERE id = " + SQ(id)}); + // Look up the query and/or hash (whichever are supplied) for the given commit + string firstQueryPart = "SELECT "s + (query ? "decompress(query)" : "1") + ", " + (hash ? "hash" : "1") + " FROM"; + string internalQuery = _getJournalQuery(_journalNames, {firstQueryPart, "WHERE id = " + SQ(id)}); SQResult result; - SASSERT(!SQuery(db, internalQuery, result)); - if (!result.empty()) { - query = result[0][0]; - hash = result[0][1]; - } else { - query = ""; - hash = ""; + SASSERT(!SQuery(_db, internalQuery, result)); + if (result.empty()) { + return false; } - if (id) { - SASSERTWARN(!query.empty()); - SASSERTWARN(!hash.empty()); + if (query) { + *query = result[0][0]; + } + if (hash) { + *hash = result[0][1]; } - // If we found a hash, we assume this was a good commit, as we'll allow an empty commit. - return !hash.empty(); + return true; } string SQLite::getCommittedHash() @@ -1127,9 +1125,9 @@ string SQLite::getCommittedHash() return _sharedData.lastCommittedHash.load(); } -int SQLite::getCommits(uint64_t fromIndex, uint64_t toIndex, SQResult& result, uint64_t timeoutLimitUS) +int SQLite::getCompressedCommits(uint64_t fromIndex, uint64_t toIndex, SQResult& result, uint64_t timeoutLimitUS) { - // Look up all the queries within that range + // Look up all the queries within that range. Returns raw query data which may be compressed. SASSERTWARN(SWITHIN(1, fromIndex, toIndex)); string query = _getJournalQuery({"SELECT id, hash, query FROM", "WHERE id >= " + SQ(fromIndex) + (toIndex ? " AND id <= " + SQ(toIndex) : "")}); diff --git a/sqlitecluster/SQLite.h b/sqlitecluster/SQLite.h index e105900da..21ad24c29 100644 --- a/sqlitecluster/SQLite.h +++ b/sqlitecluster/SQLite.h @@ -265,15 +265,12 @@ class SQLite { return _insideTransaction; } - // Looks up the exact SQL of a paricular commit to the database, as well as gets the SHA1 hash of the database - // immediately following tha commit. - bool getCommit(uint64_t index, string& query, string& hash); + // Get the SQL and/or hash of a particular commit to the DB. + // Returns true on success, false on error or if the commit is not found. + bool getCommit(uint64_t index, string* query = nullptr, string* hash = nullptr); - // A static version of the above that can be used in initializers. - static bool getCommit(sqlite3* db, const vector& journalNames, uint64_t index, string& query, string& hash); - - // Looks up a range of commits. - int getCommits(uint64_t fromIndex, uint64_t toIndex, SQResult& result, uint64_t timeoutLimitUS = 0); + // Looks up a range of commits. Returns raw query data which may be compressed. + int getCompressedCommits(uint64_t fromIndex, uint64_t toIndex, SQResult& result, uint64_t timeoutLimitUS = 0); // Set a time limit for this transaction, in US from the current time. void setTimeout(uint64_t timeLimitUS); @@ -306,6 +303,9 @@ class SQLite { // Enable/disable SQL statement tracing. static atomic enableTrace; + // The zstd dictionary ID to use when compressing journal entries. 0 means no compression. + static atomic journalZstdDictionaryID; + // public read-only accessor for _dbCountAtStart. uint64_t getDBCountAtStart() const; @@ -413,13 +413,16 @@ class SQLite { // Initializers to support RAII-style allocation in constructors. static string initializeFilename(const string& filename); - static SharedData& initializeSharedData(sqlite3* db, const string& filename, const vector& journalNames, bool hctree); static bool validateDBFormat(const string& filename, bool hctree); static sqlite3* initializeDB(const string& filename, int64_t mmapSizeGB, bool hctree); static vector initializeJournal(sqlite3* db, int minJournalTables); void commonConstructorInitialization(bool hctree = false); static int getCheckpointModeFromString(const string& checkpointModeString); + // This is also an initializer to support RAII-style allocation in the constructor but is pulled out separately as + // it's not static and depends on several other members being initialized before it. + SharedData& initializeSharedData(); + // The filename of this DB, canonicalized to its full path on disk. const string _filename; diff --git a/sqlitecluster/SQLiteNode.cpp b/sqlitecluster/SQLiteNode.cpp index c53e2e007..cc4b13cfe 100644 --- a/sqlitecluster/SQLiteNode.cpp +++ b/sqlitecluster/SQLiteNode.cpp @@ -11,6 +11,7 @@ #include #include #include +#include // Convenience class for maintaining connections with a mesh of peers #define PDEBUG(_MSG_) SDEBUG("->{" << peer->name << "} " << _MSG_) @@ -1328,8 +1329,8 @@ void SQLiteNode::_onMESSAGE(SQLitePeer* peer, const SData& message) bool hashesMatch = true; peer->getCommit(peerCommitCount, peerCommitHash); if (!peerCommitHash.empty() && peerCommitCount <= getCommitCount()) { - string query, hash; - _db.getCommit(peerCommitCount, query, hash); + string hash; + _db.getCommit(peerCommitCount, nullptr, &hash); hashesMatch = (peerCommitHash == hash); } @@ -1527,7 +1528,7 @@ void SQLiteNode::_onMESSAGE(SQLitePeer* peer, const SData& message) if (message.isSet("hashMismatchValue") || message.isSet("hashMismatchNumber")) { SQResult result; uint64_t commitNum = SToUInt64(message["hashMismatchNumber"]); - _db.getCommits(commitNum, commitNum, result); + _db.getCompressedCommits(commitNum, commitNum, result); peer->forked = true; SALERT("Hash mismatch. Peer " << peer->name << " and I have forked at commit " << message["hashMismatchNumber"] @@ -2055,8 +2056,8 @@ void SQLiteNode::_queueSynchronize(const SQLiteNode* const node, SQLitePeer* pee } if (peerCommitCount) { // It has some data -- do we agree on what we share? - string myHash, ignore; - if (!db.getCommit(peerCommitCount, ignore, myHash)) { + string myHash; + if (!db.getCommit(peerCommitCount, nullptr, &myHash)) { PWARN("Error getting commit for peer's commit: " << peerCommitCount << ", my commit count is: " << db.getCommitCount()); STHROW("error getting hash"); } else if (myHash != peerHash) { @@ -2095,7 +2096,7 @@ void SQLiteNode::_queueSynchronize(const SQLiteNode* const node, SQLitePeer* pee } else { toIndex = min(toIndex, fromIndex + 100); // 100 transactions at a time } - int resultCode = db.getCommits(fromIndex, toIndex, result, timeoutAfterUS); + int resultCode = db.getCompressedCommits(fromIndex, toIndex, result, timeoutAfterUS); if (resultCode) { if (resultCode == SQLITE_INTERRUPT) { STHROW("synchronization query timeout"); @@ -2164,7 +2165,7 @@ void SQLiteNode::_recvSynchronize(SQLitePeer* peer, const SData& message) if (!_db.beginTransaction()) { STHROW("failed to begin transaction"); } - if (!_db.writeUnmodified(commit.content)) { + if (!_db.writeUnmodified(BedrockPlugin_Zstd::decompress(commit.content))) { STHROW("failed to write transaction"); } if (!_db.prepare()) { diff --git a/style.sh b/style.sh index a39529d58..8ea0f237e 100755 --- a/style.sh +++ b/style.sh @@ -12,6 +12,6 @@ for FILENAME in $(git diff --name-only origin/main...`git branch | grep \* | cut esac uncrustify -c /vagrant/Bedrock/.uncrustify.cfg --replace --no-backup -l CPP "$FILENAME"; # Removes occurrences of std:: that aren't in comments and have a leading space (except if it's inside pointer brackets, eg: ) - sed '/^.*\/\/.*/!s/ std::/ /; /^.* std::.*\/\//s/ std::/ /; /^.*\<.*std::.*\>/s/std:://;' "$FILENAME" > "$FILENAME.new"; + sed '/^.*\/\/.*/!s/ std::/ /; /^.* std::.*\/\//s/ std::/ /; /^.*\<.* std::.*\>/s/ std::/ /;' "$FILENAME" > "$FILENAME.new"; mv -f "$FILENAME.new" "$FILENAME" done diff --git a/test/clustertest/BedrockClusterTester.h b/test/clustertest/BedrockClusterTester.h index 97ecf4718..df1ae87a5 100644 --- a/test/clustertest/BedrockClusterTester.h +++ b/test/clustertest/BedrockClusterTester.h @@ -20,7 +20,7 @@ class ClusterTester { // Creates a cluster of the given size and brings up all the nodes. The nodes will have priority in the order of // their creation (i.e., node 0 is highest priority and will become leader. // You can also specify plugins to load if for some reason you need to override the default configuration. - ClusterTester(ClusterSize size, list queries = {}, map _args = {}, list uniquePorts = {}, string pluginsToLoad = "db,cache,jobs", const string& processPath = ""); + ClusterTester(ClusterSize size, list queries = {}, map _args = {}, list uniquePorts = {}, string pluginsToLoad = "db,cache,jobs,zstd", const string& processPath = ""); ClusterTester(const string& pluginString = "db,cache,jobs", const string& processPath = ""); ~ClusterTester(); diff --git a/test/clustertest/tests/CompressionTest.cpp b/test/clustertest/tests/CompressionTest.cpp new file mode 100644 index 000000000..5f3c3454e --- /dev/null +++ b/test/clustertest/tests/CompressionTest.cpp @@ -0,0 +1,298 @@ +#include + +#include +#include +#include + +struct CompressionTest : tpunit::TestFixture +{ + CompressionTest() + : tpunit::TestFixture("Compression", + BEFORE_CLASS(CompressionTest::setup), + AFTER_CLASS(CompressionTest::teardown), + TEST(CompressionTest::testCompressionDisabled), + TEST(CompressionTest::testCompressionEnabled), + TEST(CompressionTest::testAllNodesCompressed)) + { + } + + BedrockClusterTester* tester; + + // Helper to read the dictionary file from test/sample_data. + string readDictionaryFile() + { + string dictPath = "test/sample_data/journal.dict"; + ifstream file(dictPath, ios::binary); + if (!file.is_open()) { + // Try from the clustertest working directory. + dictPath = "../sample_data/journal.dict"; + file.open(dictPath, ios::binary); + } + if (!file.is_open()) { + STHROW("Could not open journal.dict"); + } + return string((istreambuf_iterator(file)), istreambuf_iterator()); + } + + void setup() + { + // Read the dictionary file and hex-encode it for insertion via SQL. + string dictData = readDictionaryFile(); + string hexDict; + for (unsigned char c : dictData) { + char buf[3]; + snprintf(buf, sizeof(buf), "%02x", c); + hexDict += buf; + } + + // Create the zstdDictionaries table and insert our test dictionary. + // The table is normally created by the Zstd plugin's upgradeDatabase, but we insert the data here + // so it's available before the cluster starts. + list queries = { + "CREATE TABLE IF NOT EXISTS zstdDictionaries (dictionaryID INTEGER PRIMARY KEY, description TEXT, dictionary BLOB);", + "INSERT INTO zstdDictionaries (dictionaryID, description, dictionary) VALUES (1, 'journal test dictionary', X'" + hexDict + "');", + }; + + // Start a 3-node cluster without -journalZstdDictionaryID (compression disabled). + tester = new BedrockClusterTester(ClusterSize::THREE_NODE_CLUSTER, queries); + } + + void teardown() + { + delete tester; + } + + // Generate a long query string (~10KB) of repeated INSERT statements. + // startID offsets the IDs so different tests don't overwrite each other's data. + string generateLongQuery(int startID = 0) + { + string query; + for (int i = 0; i < 200; i++) { + if (!query.empty()) { + query += ";"; + } + query += "INSERT INTO test VALUES(" + SQ(startID + i) + ", " + SQ("value_" + to_string(startID + i) + "_padding_data_to_make_this_longer_" + string(20, 'x')) + ")"; + } + query += ";"; + return query; + } + + // Run a query via the DB plugin and return the result as an SQResult. + SQResult queryServer(BedrockTester& node, const string& sql) + { + SData command("Query"); + command["Format"] = "json"; + command["Query"] = sql; + auto results = node.executeWaitMultipleData({command}); + SQResult result; + if (results[0].methodLine == "200 OK" && !results[0].content.empty()) { + result.deserialize(results[0].content); + } + return result; + } + + // Get the current commit count from a node via the Status command. + uint64_t getCommitCount(BedrockTester& node) + { + SData status("Status"); + string response = node.executeWaitVerifyContent(status); + STable json = SParseJSONObject(response); + return SToUInt64(json["commitCount"]); + } + + // Run a SELECT query against the journal tables via the bedrock server's DB plugin. + // All reads go through the server so that UDFs and dictionaries are available. + string queryJournal(BedrockTester& node, const string& selectExpr, uint64_t commitID) + { + // Build a UNION query across all journal tables. + // First, get the list of journal table names from the server. + SData tableCmd("Query"); + tableCmd["Format"] = "json"; + tableCmd["Query"] = "SELECT tbl_name FROM sqlite_master WHERE tbl_name LIKE 'journal%' ORDER BY tbl_name;"; + auto tableResults = node.executeWaitMultipleData({tableCmd}); + SQResult tables; + tables.deserialize(tableResults[0].content); + + // Build UNION query across all journal tables. + string sql; + for (size_t i = 0; i < tables.size(); i++) { + if (!sql.empty()) { + sql += " UNION "; + } + sql += "SELECT " + selectExpr + " FROM " + tables[i][0] + " WHERE id = " + SQ(commitID); + } + sql += ";"; + + SData command("Query"); + command["Format"] = "json"; + command["Query"] = sql; + auto results = node.executeWaitMultipleData({command}); + if (results[0].methodLine == "200 OK" && !results[0].content.empty()) { + SQResult result; + result.deserialize(results[0].content); + if (!result.empty()) { + return result[0][0]; + } + } + return ""; + } + + // Get the byte length of the raw query column (may be compressed binary). + // We use LENGTH(CAST(...AS BLOB)) because compressed data is binary and can't round-trip through JSON. + size_t readRawJournalEntrySize(BedrockTester& node, uint64_t commitID) + { + string result = queryJournal(node, "LENGTH(CAST(query AS BLOB))", commitID); + return result.empty() ? 0 : SToUInt64(result); + } + + // Read the query column with decompression applied (returns text, safe for JSON). + string readDecompressedJournalEntry(BedrockTester& node, uint64_t commitID) + { + return queryJournal(node, "decompress(query)", commitID); + } + + void testCompressionDisabled() + { + BedrockTester& leader = tester->getTester(0); + BedrockTester& follower = tester->getTester(1); + + // Stop node 2 before we write. It will miss this uncompressed commit and will need to + // sync it later when compression is enabled, verifying that sync works with mixed data. + tester->stopNode(2); + + // Record the commit count before our write. + uint64_t commitBefore = getCommitCount(leader); + + // Generate and execute a long query on the leader. Use startID=0 for test 1. + string longQuery = generateLongQuery(0); + SData command("Query"); + command["Query"] = longQuery; + leader.executeWaitVerifyContent(command, "200"); + + // The new commit should be commitBefore + 1. + uint64_t commitAfter = getCommitCount(leader); + ASSERT_EQUAL(commitBefore + 1, commitAfter); + + // Read the journal entry by ID on the leader and verify it matches. + // With compression disabled, raw size should equal the original and decompressed content should match. + size_t leaderRawSize = readRawJournalEntrySize(leader, commitAfter); + ASSERT_EQUAL(leaderRawSize, longQuery.size()); + string leaderDecompressed = readDecompressedJournalEntry(leader, commitAfter); + ASSERT_EQUAL(leaderDecompressed, longQuery); + + // Wait for the follower to replicate, then verify its journal entry matches too. + follower.waitForStatusTerm("commitCount", to_string(commitAfter)); + size_t followerRawSize = readRawJournalEntrySize(follower, commitAfter); + ASSERT_EQUAL(followerRawSize, longQuery.size()); + string followerDecompressed = readDecompressedJournalEntry(follower, commitAfter); + ASSERT_EQUAL(followerDecompressed, longQuery); + } + + void testCompressionEnabled() + { + BedrockTester& leader = tester->getTester(0); + BedrockTester& follower = tester->getTester(1); + + // Stop the leader and restart it with compression enabled. + tester->stopNode(0); + leader.updateArgs({{"-journalZstdDictionaryID", "1"}}); + tester->startNode(0); + ASSERT_TRUE(leader.waitForState("LEADING")); + + // Record the commit count before our write. + uint64_t commitBefore = getCommitCount(leader); + + // Generate and execute a long query on the leader. Use startID=1000 for test 2 to avoid collisions. + string longQuery = generateLongQuery(1000); + SData command("Query"); + command["Query"] = longQuery; + leader.executeWaitVerifyContent(command, "200"); + + // The new commit should be commitBefore + 1. + uint64_t commitAfter = getCommitCount(leader); + ASSERT_EQUAL(commitBefore + 1, commitAfter); + + // Read the raw journal entry size on the leader. It should be compressed and thus shorter than the original. + size_t leaderRawSize = readRawJournalEntrySize(leader, commitAfter); + ASSERT_GREATER_THAN(leaderRawSize, (size_t) 0); + ASSERT_LESS_THAN(leaderRawSize, longQuery.size()); + + // Read with decompress() and verify the content matches the original query. + string leaderDecompressed = readDecompressedJournalEntry(leader, commitAfter); + ASSERT_EQUAL(leaderDecompressed, longQuery); + + // Wait for the follower to replicate. + follower.waitForStatusTerm("commitCount", to_string(commitAfter)); + + // The follower does not have compression enabled, so its journal stores uncompressed data. + // Verify the follower's decompressed journal entry still matches (decompress is a no-op on uncompressed data). + string followerDecompressed = readDecompressedJournalEntry(follower, commitAfter); + ASSERT_EQUAL(followerDecompressed, longQuery); + } + + void testAllNodesCompressed() + { + // Stop all remaining nodes (node 0 is running with compression, node 1 without, node 2 was stopped in test 1). + tester->stopNode(0); + tester->stopNode(1); + + // Restart all three nodes with compression enabled. + // Use startNodeDontWait to avoid blocking — a node can't open its command port + // until it has quorum, which requires its peers to be running. + for (int i = 0; i < 3; i++) { + tester->getTester(i).updateArgs({{"-journalZstdDictionaryID", "1"}}); + tester->startNodeDontWait(i); + } + ASSERT_TRUE(tester->getTester(0).waitForState("LEADING")); + ASSERT_TRUE(tester->getTester(1).waitForState("FOLLOWING")); + ASSERT_TRUE(tester->getTester(2).waitForState("FOLLOWING")); + + BedrockTester& leader = tester->getTester(0); + BedrockTester& follower1 = tester->getTester(1); + BedrockTester& follower2 = tester->getTester(2); + + // Record the commit count before our write. + uint64_t commitBefore = getCommitCount(leader); + + // Insert a simple, verifiable row. Use ID 9999 to avoid collision with earlier tests. + SData command("Query"); + command["Query"] = "INSERT INTO test VALUES(9999, 'Verifying test 3');"; + leader.executeWaitVerifyContent(command, "200"); + + uint64_t commitAfter = getCommitCount(leader); + ASSERT_EQUAL(commitBefore + 1, commitAfter); + + // Wait for both followers to replicate. + follower1.waitForStatusTerm("commitCount", to_string(commitAfter)); + follower2.waitForStatusTerm("commitCount", to_string(commitAfter)); + + // Verify the journal entry is compressed on all three nodes. + string originalQuery = "INSERT INTO test VALUES(9999, 'Verifying test 3');"; + for (int i = 0; i < 3; i++) { + BedrockTester& node = tester->getTester(i); + size_t rawSize = readRawJournalEntrySize(node, commitAfter); + ASSERT_GREATER_THAN(rawSize, (size_t) 0); + ASSERT_LESS_THAN(rawSize, originalQuery.size()); + + string decompressed = readDecompressedJournalEntry(node, commitAfter); + ASSERT_EQUAL(decompressed, originalQuery); + } + + // Verify the actual DB content on all three nodes. + for (int i = 0; i < 3; i++) { + BedrockTester& node = tester->getTester(i); + SQResult result = queryServer(node, "SELECT id, value FROM test WHERE id = 9999;"); + ASSERT_EQUAL(result.size(), (size_t) 1); + ASSERT_EQUAL(result[0][0], "9999"); + ASSERT_EQUAL(result[0][1], "Verifying test 3"); + } + + // Sanity check: verify total row count in the test table across all tests. + // Test 1 inserted 200 rows (IDs 0-199), test 2 inserted 200 rows (IDs 1000-1199), test 3 inserted 1 row (ID 9999). + for (int i = 0; i < 3; i++) { + BedrockTester& node = tester->getTester(i); + SQResult result = queryServer(node, "SELECT COUNT(*) FROM test;"); + ASSERT_EQUAL(result[0][0], "401"); + } + } +} __CompressionTest; diff --git a/test/sample_data/journal.dict b/test/sample_data/journal.dict new file mode 100644 index 000000000..3e01f2c13 Binary files /dev/null and b/test/sample_data/journal.dict differ