diff --git a/BUILD b/BUILD index 50753da04..08341aedb 100644 --- a/BUILD +++ b/BUILD @@ -133,6 +133,7 @@ cc_library( "PROXY_WASM_HOST_ENGINE_V8", ], deps = [ + ":base_lib", ":wasm_vm_headers", "//bazel:wee8_no_pointer_compression", ], @@ -220,6 +221,7 @@ cc_library( ], }), deps = [ + ":base_lib", ":wasm_vm_headers", "@com_github_bytecodealliance_wasmtime//:wasmtime_lib", ], diff --git a/include/proxy-wasm/bytecode_util.h b/include/proxy-wasm/bytecode_util.h index f708780e7..9394705d7 100644 --- a/include/proxy-wasm/bytecode_util.h +++ b/include/proxy-wasm/bytecode_util.h @@ -61,15 +61,25 @@ class BytecodeUtil { std::unordered_map &ret); /** - * getStrippedSource gets Wasm module without Custom Sections to save some memory in workers. + * getStrippedSource gets Wasm module without precompiled sections to save some memory in workers. * @param bytecode is the original bytecode. * @param ret is the reference to the stripped bytecode or a copy of the original bytecode. * @return indicates whether parsing succeeded or not. */ static bool getStrippedSource(std::string_view bytecode, std::string &ret); -private: - static bool parseVarint(const char *&pos, const char *end, uint32_t &ret); + /** + * writeModuleWithCustomSection returns a new wasm module comprised of the supplied module with + * the supplied custom section overwriting any existing custom sections containing precompiled + * precompiled content. + * @param bytecode is the bytecode to add the custom section to. + * @param section_name is the section name to add/overwrite. + * @param section_contents are the contents of the new section. + * @return a string containing the sum of the original bytecode with the new section. + */ + static std::optional writeModuleWithCustomSection(std::string_view bytecode, + std::string_view section_name, + std::string_view section_contents); }; } // namespace proxy_wasm diff --git a/include/proxy-wasm/null_vm.h b/include/proxy-wasm/null_vm.h index 703266df3..d4fc38cde 100644 --- a/include/proxy-wasm/null_vm.h +++ b/include/proxy-wasm/null_vm.h @@ -34,6 +34,9 @@ struct NullVm : public WasmVm { // WasmVm std::string_view getEngineName() override { return "null"; } Cloneable cloneable() override { return Cloneable::InstantiatedModule; }; + std::optional serialize(std::string_view original_bytecode) override { + return std::nullopt; + } std::unique_ptr clone() override; bool load(std::string_view plugin_name, std::string_view precompiled, const std::unordered_map &function_names) override; diff --git a/include/proxy-wasm/wasm_vm.h b/include/proxy-wasm/wasm_vm.h index 3eb1d4de6..a7a8d39fe 100644 --- a/include/proxy-wasm/wasm_vm.h +++ b/include/proxy-wasm/wasm_vm.h @@ -235,6 +235,14 @@ class WasmVm { virtual bool load(std::string_view bytecode, std::string_view precompiled, const std::unordered_map &function_names) = 0; + /** + * Serializes the loaded wasm module to a string. Returns true on success. The resulting string + * may be saved to storage and passed as a precompiled wasm module to another instance. + * @param original_bytecode the bytecode to append the serialized section to. + * @return a string containing the serialized wasm module, or nullopt if serialization failed. + */ + virtual std::optional serialize(std::string_view original_bytecode) = 0; + /** * Link the WASM code to the host-provided functions, e.g. the ABI. Prior to linking, the module * should be loaded and the ABI callbacks registered (see above). Linking should be done once diff --git a/src/bytecode_util.cc b/src/bytecode_util.cc index 70a373e01..d778d91e9 100644 --- a/src/bytecode_util.cc +++ b/src/bytecode_util.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "include/proxy-wasm/bytecode_util.h" +#include #if !defined(_MSC_VER) #include @@ -21,6 +22,102 @@ #include namespace proxy_wasm { +namespace { + +constexpr char kCustomSectionType = '\0'; + +bool parseVarint(const char *&pos, const char *end, uint32_t &ret) { + uint32_t shift = 0; + uint32_t total = 0; + uint32_t v; + char b; + while (pos < end) { + if (pos + 1 > end) { + // overread + return false; + } + b = *pos++; + v = (b & 0x7f); + if (shift == 28 && v > 3) { + // overflow + return false; + } + total += v << shift; + if ((b & 0x80) == 0) { + ret = total; + return true; + } + shift += 7; + if (shift > 28) { + // overflow + return false; + } + } + return false; +} + +std::vector encodeVarint(uint32_t value) { + std::vector ret; + while (value >= 0x80) { + ret.push_back(static_cast(value | 0x80)); + value >>= 7; + } + ret.push_back(static_cast(value)); + return ret; +} + +bool stripMatchingSections(std::string_view bytecode, + std::string_view section_name_substring_to_match, std::string &ret) { + // Check Wasm header. + if (!BytecodeUtil::checkWasmHeader(bytecode)) { + return false; + } + + // Skip the Wasm header. + const char *pos = bytecode.data() + 8; + const char *end = bytecode.data() + bytecode.size(); + while (pos < end) { + const auto *const section_start = pos; + if (pos + 1 > end) { + return false; + } + const auto section_type = *pos++; + uint32_t section_len = 0; + if (!parseVarint(pos, end, section_len) || pos + section_len > end) { + return false; + } + if (section_type == kCustomSectionType) { + const auto *const section_data_start = pos; + uint32_t section_name_len = 0; + if (!parseVarint(pos, end, section_name_len) || pos + section_name_len > end) { + return false; + } + auto section_name = std::string_view(pos, section_name_len); + if (section_name.find(section_name_substring_to_match) != std::string::npos) { + // If this is the first matching section, then save everything before + // it, otherwise skip it. + if (ret.empty()) { + const char *start = bytecode.data(); + ret.append(start, section_start); + } + } + pos = section_data_start + section_len; + } else { + pos += section_len; + // Save this section if we already saw a custom "precompiled_" section. + if (!ret.empty()) { + ret.append(section_start, pos); + } + } + } + if (ret.empty()) { + // Copy the original source code if it is empty. + ret = std::string(bytecode); + } + return true; +} + +} // namespace bool BytecodeUtil::checkWasmHeader(std::string_view bytecode) { // Wasm file header is 8 bytes (magic number + version). @@ -115,11 +212,11 @@ bool BytecodeUtil::getCustomSection(std::string_view bytecode, std::string_view if (!parseVarint(pos, end, section_len) || pos + section_len > end) { return false; } - if (section_type == 0) { + if (section_type == kCustomSectionType) { // Custom section. const char *section_end = pos + section_len; uint32_t section_name_len = 0; - if (!BytecodeUtil::parseVarint(pos, section_end, section_name_len) || + if (!parseVarint(pos, section_end, section_name_len) || pos + section_name_len > section_end) { return false; } @@ -195,83 +292,35 @@ bool BytecodeUtil::getFunctionNameIndex(std::string_view bytecode, } bool BytecodeUtil::getStrippedSource(std::string_view bytecode, std::string &ret) { - // Check Wasm header. - if (!checkWasmHeader(bytecode)) { - return false; - } - - // Skip the Wasm header. - const char *pos = bytecode.data() + 8; - const char *end = bytecode.data() + bytecode.size(); - while (pos < end) { - const auto *const section_start = pos; - if (pos + 1 > end) { - return false; - } - const auto section_type = *pos++; - uint32_t section_len = 0; - if (!parseVarint(pos, end, section_len) || pos + section_len > end) { - return false; - } - if (section_type == 0 /* custom section */) { - const auto *const section_data_start = pos; - uint32_t section_name_len = 0; - if (!parseVarint(pos, end, section_name_len) || pos + section_name_len > end) { - return false; - } - auto section_name = std::string_view(pos, section_name_len); - if (section_name.find("precompiled_") != std::string::npos) { - // If this is the first "precompiled_" section, then save everything - // before it, otherwise skip it. - if (ret.empty()) { - const char *start = bytecode.data(); - ret.append(start, section_start); - } - } - pos = section_data_start + section_len; - } else { - pos += section_len; - // Save this section if we already saw a custom "precompiled_" section. - if (!ret.empty()) { - ret.append(section_start, pos); - } - } - } - if (ret.empty()) { - // Copy the original source code if it is empty. - ret = std::string(bytecode); - } - return true; + return stripMatchingSections(bytecode, "precompiled_", ret); } -bool BytecodeUtil::parseVarint(const char *&pos, const char *end, uint32_t &ret) { - uint32_t shift = 0; - uint32_t total = 0; - uint32_t v; - char b; - while (pos < end) { - if (pos + 1 > end) { - // overread - return false; - } - b = *pos++; - v = (b & 0x7f); - if (shift == 28 && v > 3) { - // overflow - return false; - } - total += v << shift; - if ((b & 0x80) == 0) { - ret = total; - return true; - } - shift += 7; - if (shift > 28) { - // overflow - return false; - } +std::optional +BytecodeUtil::writeModuleWithCustomSection(std::string_view bytecode, std::string_view section_name, + std::string_view section_contents) { + if (!checkWasmHeader(bytecode)) { + return std::nullopt; } - return false; + std::vector section_name_len = encodeVarint(section_name.size()); + std::vector section_size = + encodeVarint(section_name_len.size() + section_name.size() + section_contents.size()); + int appended_contents_size = /*section_type*/ 1 + section_size.size() + section_name_len.size() + + section_name.size() + section_contents.size(); + std::string output; + // Copy wasm header. + if (!stripMatchingSections(bytecode, section_name, output)) { + // `bytecode` was an invalid wasm module. + return std::nullopt; + } + output.reserve(output.size() + appended_contents_size); + output.push_back(kCustomSectionType); + output.append( + std::string_view(reinterpret_cast(section_size.data()), section_size.size())); + output.append( + std::string_view(reinterpret_cast(section_name_len.data()), section_name_len.size())); + output.append(section_name); + output.append(section_contents); + return output; } } // namespace proxy_wasm diff --git a/src/v8/v8.cc b/src/v8/v8.cc index 3f372fd88..438462d1b 100644 --- a/src/v8/v8.cc +++ b/src/v8/v8.cc @@ -28,6 +28,7 @@ #include #include "include/proxy-wasm/limits.h" +#include "include/proxy-wasm/bytecode_util.h" #include "absl/strings/str_format.h" #include "include/v8-initialization.h" @@ -78,6 +79,7 @@ class V8 : public WasmVm { bool load(std::string_view bytecode, std::string_view precompiled, const std::unordered_map &function_names) override; + std::optional serialize(std::string_view original_bytecode) override; std::string_view getPrecompiledSectionName() override; bool link(std::string_view debug_name) override; @@ -313,6 +315,20 @@ bool V8::load(std::string_view bytecode, std::string_view precompiled, return true; } +std::optional V8::serialize(std::string_view original_bytecode) { + if (module_ == nullptr) { + return std::nullopt; + } + wasm::vec serialized = module_->serialize(); + if (serialized.invalid()) { + integration()->error("Failed to serialize wasm module."); + return std::nullopt; + } + return BytecodeUtil::writeModuleWithCustomSection( + original_bytecode, getPrecompiledSectionName(), + std::string_view(serialized.get(), serialized.size())); +} + std::unique_ptr V8::clone() { assert(shared_module_ != nullptr); diff --git a/src/wamr/wamr.cc b/src/wamr/wamr.cc index 0c6d401c6..0edf49083 100644 --- a/src/wamr/wamr.cc +++ b/src/wamr/wamr.cc @@ -60,6 +60,9 @@ class Wamr : public WasmVm { std::string_view getPrecompiledSectionName() override { return "wamr-aot"; } Cloneable cloneable() override { return Cloneable::CompiledBytecode; } + std::optional serialize(std::string_view original_bytecode) override { + return std::nullopt; + } std::unique_ptr clone() override; bool load(std::string_view bytecode, std::string_view precompiled, diff --git a/src/wasmedge/wasmedge.cc b/src/wasmedge/wasmedge.cc index 9815bcc49..3359e1b27 100644 --- a/src/wasmedge/wasmedge.cc +++ b/src/wasmedge/wasmedge.cc @@ -252,6 +252,9 @@ class WasmEdge : public WasmVm { std::string_view getPrecompiledSectionName() override { return ""; } Cloneable cloneable() override { return Cloneable::NotCloneable; } + std::optional serialize(std::string_view original_bytecode) override { + return std::nullopt; + } std::unique_ptr clone() override { return nullptr; } bool load(std::string_view bytecode, std::string_view precompiled, diff --git a/src/wasmtime/wasmtime.cc b/src/wasmtime/wasmtime.cc index e51a8b240..e90790c28 100644 --- a/src/wasmtime/wasmtime.cc +++ b/src/wasmtime/wasmtime.cc @@ -20,9 +20,11 @@ #include #include #include +#include #include "include/proxy-wasm/limits.h" #include "include/proxy-wasm/word.h" +#include "include/proxy-wasm/bytecode_util.h" #include "crates/c-api/include/wasmtime.hh" // IWYU pragma: keep @@ -96,10 +98,11 @@ class Wasmtime : public WasmVm { std::string_view getEngineName() override { return "wasmtime"; } Cloneable cloneable() override { return Cloneable::CompiledBytecode; } - std::string_view getPrecompiledSectionName() override { return ""; } + std::string_view getPrecompiledSectionName() override; bool load(std::string_view bytecode, std::string_view precompiled, const std::unordered_map &function_names) override; + std::optional serialize(std::string_view original_bytecode) override; bool link(std::string_view debug_name) override; std::unique_ptr clone() override; uint64_t getMemorySize() override; @@ -172,24 +175,53 @@ void Wasmtime::initStore() { /*memories=*/1); } -bool Wasmtime::load(std::string_view bytecode, std::string_view /*precompiled*/, +bool Wasmtime::load(std::string_view bytecode, std::string_view precompiled, const std::unordered_map & /*function_names*/) { + initStore(); if (!store_.has_value()) { return false; } - Result module = - Module::compile(*engine(), std::span((uint8_t *)bytecode.data(), bytecode.size())); + Result module(::wasmtime::Error("Unable to load Wasm module: zero-length.")); + if (!precompiled.empty()) { + module = Module::deserialize(*engine(), + std::span((uint8_t *)precompiled.data(), precompiled.size())); + if (module) { + module_.emplace(module.ok()); + return true; + } + } + if (bytecode.empty()) { + fail(FailState::UnableToInitializeCode, + "Failed to deserialize Wasm module: " + module.err().message()); + return false; + } + module = Module::compile(*engine(), std::span((uint8_t *)bytecode.data(), bytecode.size())); if (!module) { - fail(FailState::UnableToInitializeCode, "Failed to load Wasm code: " + module.err().message()); + fail(FailState::UnableToInitializeCode, + "Failed to load Wasm module: " + module.err().message()); return false; } module_.emplace(module.ok()); - return true; } +std::optional Wasmtime::serialize(std::string_view original_bytecode) { + if (!module_.has_value()) { + return std::nullopt; + } + Result> serialized = module_->serialize(); + if (!serialized) { + integration()->error("Failed to serialize wasm module: " + serialized.err().message()); + return std::nullopt; + } + return BytecodeUtil::writeModuleWithCustomSection( + original_bytecode, getPrecompiledSectionName(), + std::string_view(reinterpret_cast(serialized.ok_ref().data()), + serialized.ok_ref().size())); +} + std::unique_ptr Wasmtime::clone() { auto clone = std::make_unique(); if (clone == nullptr) { @@ -416,6 +448,9 @@ void Wasmtime::getModuleFunctionImpl(std::string_view function_name, void Wasmtime::warm() { initStore(); } +// Wasmtime sticks +std::string_view Wasmtime::getPrecompiledSectionName() { return "precompiled_wasmtime_bytecode"; } + } // namespace wasmtime std::unique_ptr createWasmtimeVm() { return std::make_unique(); } diff --git a/test/bytecode_util_test.cc b/test/bytecode_util_test.cc index 2e1a813cb..11edc2859 100644 --- a/test/bytecode_util_test.cc +++ b/test/bytecode_util_test.cc @@ -123,4 +123,32 @@ TEST(TestBytecodeUtil, getAbiVersion) { EXPECT_EQ(actual, proxy_wasm::AbiVersion::ProxyWasm_0_2_0); } +TEST(TestWriteModuleWithCustomSection, HasCustomSection) { + const std::string source = readTestWasmFile("abi_export.wasm"); + + std::optional result = BytecodeUtil::writeModuleWithCustomSection( + source, "my_custom_section", "my_custom_section_contents"); + ASSERT_NE(result, std::nullopt); + + std::string_view custom_section_contents; + ASSERT_TRUE( + BytecodeUtil::getCustomSection(*result, "my_custom_section", custom_section_contents)); + EXPECT_EQ(custom_section_contents, "my_custom_section_contents"); +} + +TEST(TestWriteModuleWithCustomSection, OverwritesCustomSectionWithSameName) { + const std::string source = readTestWasmFile("abi_export.wasm"); + + std::optional result = BytecodeUtil::writeModuleWithCustomSection( + source, "my_custom_section", "my_custom_section_contents"); + ASSERT_NE(result, std::nullopt); + result = BytecodeUtil::writeModuleWithCustomSection(*result, "my_custom_section", "new_contents"); + ASSERT_NE(result, std::nullopt); + + std::string_view custom_section_contents; + ASSERT_TRUE( + BytecodeUtil::getCustomSection(*result, "my_custom_section", custom_section_contents)); + EXPECT_EQ(custom_section_contents, "new_contents"); +} + } // namespace proxy_wasm diff --git a/test/runtime_test.cc b/test/runtime_test.cc index f6a9eb28a..002f50b2f 100644 --- a/test/runtime_test.cc +++ b/test/runtime_test.cc @@ -14,6 +14,7 @@ #include "gtest/gtest.h" +#include #include #include #include @@ -194,6 +195,37 @@ TEST_P(TestVm, Trap2) { } } +TEST_P(TestVm, SerializeAndDeserializeRoundTripWorks) { + if (engine_ != "v8" && engine_ != "wasmtime") { + return; + } + auto source = readTestWasmFile("clock.wasm"); + ASSERT_FALSE(source.empty()); + TestWasm wasm(std::move(vm_)); + + std::chrono::time_point unprecompiled_load_start = + std::chrono::system_clock::now(); + ASSERT_TRUE(wasm.load(source, false)); + std::chrono::time_point unprecompiled_load_end = + std::chrono::system_clock::now(); + + std::optional serialized = wasm.wasm_vm()->serialize(source); + ASSERT_NE(serialized, std::nullopt); + + // Still loads with allow_precompiled == false: + ASSERT_TRUE(TestWasm(makeVm(engine_)).load(*serialized, false)); + + // Loads faster now that it is precompiled: + std::chrono::time_point precompiled_load_start = + std::chrono::system_clock::now(); + ASSERT_TRUE(TestWasm(makeVm(engine_)).load(*serialized, true)); + std::chrono::time_point precompiled_load_end = + std::chrono::system_clock::now(); + + EXPECT_LT((precompiled_load_end - precompiled_load_start) * 2, + (unprecompiled_load_end - unprecompiled_load_start)); +} + class TestCounterContext : public TestContext { public: TestCounterContext(WasmBase *wasm) : TestContext(wasm) {}