diff --git a/.gitignore b/.gitignore index aa056f2d5..4c7e9450a 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,10 @@ rules.ninja node_modules /test/wasmBenchmarker/ctests/wasm /test/wasmBenchmarker/emsdk +/.cache +/build/.cache +/build/.cmake +/build/_deps +/build/third_party +/build/walrus +compile_commands.json diff --git a/src/interpreter/ByteCode.cpp b/src/interpreter/ByteCode.cpp index 3794f75ba..5115a34a7 100644 --- a/src/interpreter/ByteCode.cpp +++ b/src/interpreter/ByteCode.cpp @@ -71,6 +71,21 @@ size_t ByteCode::getSize() const return ByteCode::pointerAlignedSize(sizeof(CallRef) + sizeof(ByteCodeStackOffset) * callRef->parameterOffsetsSize() + sizeof(ByteCodeStackOffset) * callRef->resultOffsetsSize()); } + case ReturnCallOpcode: { + const ReturnCall* returnCall = reinterpret_cast(this); + return ByteCode::pointerAlignedSize(sizeof(ReturnCall) + sizeof(ByteCodeStackOffset) * returnCall->parameterOffsetsSize() + + sizeof(ByteCodeStackOffset) * returnCall->resultOffsetsSize()); + } + case ReturnCallIndirectOpcode: { + const ReturnCallIndirect* returnCallIndirect = reinterpret_cast(this); + return ByteCode::pointerAlignedSize(sizeof(ReturnCallIndirect) + sizeof(ByteCodeStackOffset) * returnCallIndirect->parameterOffsetsSize() + + sizeof(ByteCodeStackOffset) * returnCallIndirect->resultOffsetsSize()); + } + case ReturnCallRefOpcode: { + const ReturnCallRef* returnCallRef = reinterpret_cast(this); + return ByteCode::pointerAlignedSize(sizeof(ReturnCallRef) + sizeof(ByteCodeStackOffset) * returnCallRef->parameterOffsetsSize() + + sizeof(ByteCodeStackOffset) * returnCallRef->resultOffsetsSize()); + } case EndOpcode: { const End* end = reinterpret_cast(this); return ByteCode::pointerAlignedSize(sizeof(End) + sizeof(ByteCodeStackOffset) * end->offsetsSize()); diff --git a/src/interpreter/ByteCode.h b/src/interpreter/ByteCode.h index 4ced54101..010dfabeb 100644 --- a/src/interpreter/ByteCode.h +++ b/src/interpreter/ByteCode.h @@ -47,6 +47,9 @@ class FunctionType; F(Call) \ F(CallIndirect) \ F(CallRef) \ + F(ReturnCall) \ + F(ReturnCallIndirect) \ + F(ReturnCallRef) \ F(Select) \ F(MemorySize) \ F(MemorySizeM64) \ @@ -2128,6 +2131,175 @@ class CallRef : public ByteCode { uint16_t m_resultOffsetsSize; }; +class ReturnCall : public ByteCode { +public: + ReturnCall(uint32_t index, uint16_t parameterOffsetsSize, uint16_t resultOffsetsSize, + FunctionType* functionType) + : ByteCode(Opcode::ReturnCallOpcode) + , m_index(index) + , m_parameterOffsetsSize(parameterOffsetsSize) + , m_resultOffsetsSize(resultOffsetsSize) + { + } + + uint32_t index() const { return m_index; } + ByteCodeStackOffset* stackOffsets() const + { + return reinterpret_cast(reinterpret_cast(this) + sizeof(ReturnCall)); + } + + uint16_t parameterOffsetsSize() const + { + return m_parameterOffsetsSize; + } + + uint16_t resultOffsetsSize() const + { + return m_resultOffsetsSize; + } + +#if !defined(NDEBUG) + void dump(size_t pos) + { + printf("return_call "); + printf("index: %" PRId32 " ", m_index); + size_t c = 0; + auto arr = stackOffsets(); + printf("paramOffsets: "); + for (size_t i = 0; i < m_parameterOffsetsSize; i++) { + printf("%" PRIu32 " ", (uint32_t)arr[c++]); + } + printf(" "); + + printf("resultOffsets: "); + for (size_t i = 0; i < m_resultOffsetsSize; i++) { + printf("%" PRIu32 " ", (uint32_t)arr[c++]); + } + } +#endif + +protected: + uint32_t m_index; + uint16_t m_parameterOffsetsSize; + uint16_t m_resultOffsetsSize; +}; + +class ReturnCallIndirect : public ByteCode { +public: + ReturnCallIndirect(ByteCodeStackOffset stackOffset, uint32_t tableIndex, FunctionType* functionType, + uint16_t parameterOffsetsSize, uint16_t resultOffsetsSize) + : ByteCode(Opcode::ReturnCallIndirectOpcode) + , m_calleeOffset(stackOffset) + , m_tableIndex(tableIndex) + , m_functionType(functionType) + , m_parameterOffsetsSize(parameterOffsetsSize) + , m_resultOffsetsSize(resultOffsetsSize) + { + } + + ByteCodeStackOffset calleeOffset() const { return m_calleeOffset; } + uint32_t tableIndex() const { return m_tableIndex; } + FunctionType* functionType() const { return m_functionType; } + ByteCodeStackOffset* stackOffsets() const + { + return reinterpret_cast(reinterpret_cast(this) + sizeof(ReturnCallIndirect)); + } + + uint16_t parameterOffsetsSize() const + { + return m_parameterOffsetsSize; + } + + uint16_t resultOffsetsSize() const + { + return m_resultOffsetsSize; + } + +#if !defined(NDEBUG) + void dump(size_t pos) + { + printf("return_call_indirect "); + printf("tableIndex: %" PRId32 " ", m_tableIndex); + DUMP_BYTECODE_OFFSET(calleeOffset); + + size_t c = 0; + auto arr = stackOffsets(); + printf("paramOffsets: "); + for (size_t i = 0; i < m_parameterOffsetsSize; i++) { + printf("%" PRIu32 " ", (uint32_t)arr[c++]); + } + printf(" "); + + printf("resultOffsets: "); + for (size_t i = 0; i < m_resultOffsetsSize; i++) { + printf("%" PRIu32 " ", (uint32_t)arr[c++]); + } + } +#endif + +protected: + ByteCodeStackOffset m_calleeOffset; + uint32_t m_tableIndex; + FunctionType* m_functionType; + uint16_t m_parameterOffsetsSize; + uint16_t m_resultOffsetsSize; +}; + +class ReturnCallRef : public ByteCode { +public: + ReturnCallRef(ByteCodeStackOffset stackOffset, FunctionType* functionType, + uint16_t parameterOffsetsSize, uint16_t resultOffsetsSize) + : ByteCode(Opcode::ReturnCallRefOpcode) + , m_calleeOffset(stackOffset) + , m_functionType(functionType) + , m_parameterOffsetsSize(parameterOffsetsSize) + , m_resultOffsetsSize(resultOffsetsSize) + { + } + + ByteCodeStackOffset calleeOffset() const { return m_calleeOffset; } + FunctionType* functionType() const { return m_functionType; } + ByteCodeStackOffset* stackOffsets() const + { + return reinterpret_cast(reinterpret_cast(this) + sizeof(ReturnCallRef)); + } + + uint16_t parameterOffsetsSize() const + { + return m_parameterOffsetsSize; + } + + uint16_t resultOffsetsSize() const + { + return m_resultOffsetsSize; + } + +#if !defined(NDEBUG) + void dump(size_t pos) + { + printf("return_call_ref "); + size_t c = 0; + auto arr = stackOffsets(); + printf("paramOffsets: "); + for (size_t i = 0; i < m_parameterOffsetsSize; i++) { + printf("%" PRIu32 " ", (uint32_t)arr[c++]); + } + printf(" "); + + printf("resultOffsets: "); + for (size_t i = 0; i < m_resultOffsetsSize; i++) { + printf("%" PRIu32 " ", (uint32_t)arr[c++]); + } + } +#endif + +protected: + ByteCodeStackOffset m_calleeOffset; + FunctionType* m_functionType; + uint16_t m_parameterOffsetsSize; + uint16_t m_resultOffsetsSize; +}; + #define DEFINE_LOAD_OP(className, opcodeType, opStr) \ class className : public ByteCodeOffset2 { \ public: \ diff --git a/src/interpreter/Interpreter.cpp b/src/interpreter/Interpreter.cpp index e9459938b..33c4efabc 100644 --- a/src/interpreter/Interpreter.cpp +++ b/src/interpreter/Interpreter.cpp @@ -17,10 +17,12 @@ #include "Walrus.h" +#include "interpreter/ByteCode.h" #include "interpreter/Interpreter.h" #include "runtime/Instance.h" #include "runtime/Function.h" #include "runtime/Memory.h" +#include "runtime/Store.h" #include "runtime/Table.h" #include "runtime/GCArray.h" #include "runtime/GCStruct.h" @@ -1711,6 +1713,72 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state, NEXT_INSTRUCTION(); } + DEFINE_OPCODE(ReturnCall) + : + { + ReturnCall* code = (ReturnCall*)programCounter; + auto target = instance->function(code->index()); + + auto paramSize = code->parameterOffsetsSize(); + auto offsets = code->stackOffsets(); + + auto store = instance->module()->store(); + store->setTCO(bp, offsets, paramSize, code->resultOffsetsSize(), target); + + return nullptr; + } + + DEFINE_OPCODE(ReturnCallIndirect) + : + { + ReturnCallIndirect* code = (ReturnCallIndirect*)programCounter; + Table* table = instance->table(code->tableIndex()); + + uint32_t idx = readValue(bp, code->calleeOffset()); + if (UNLIKELY(idx >= table->size())) { + Trap::throwException(state, "undefined element"); + } + auto target = reinterpret_cast(table->uncheckedGetElement(idx)); + if (UNLIKELY(Value::isNull(target))) { + Trap::throwException(state, "uninitialized element " + std::to_string(idx)); + } + const FunctionType* ft = target->functionType(); + if (UNLIKELY(!ft->equals(code->functionType()))) { + Trap::throwException(state, "indirect call type mismatch"); + } + + auto paramSize = code->parameterOffsetsSize(); + auto offsets = code->stackOffsets(); + + auto store = instance->module()->store(); + store->setTCO(bp, offsets, paramSize, code->resultOffsetsSize(), target); + + return nullptr; + } + + DEFINE_OPCODE(ReturnCallRef) + : + { + ReturnCallRef* code = (ReturnCallRef*)programCounter; + + auto target = readValue(bp, code->calleeOffset()); + if (UNLIKELY(Value::isNull(target))) { + Trap::throwException(state, "null function reference"); + } + const FunctionType* ft = target->functionType(); + if (UNLIKELY(!ft->equals(code->functionType()))) { + Trap::throwException(state, "call by reference type mismatch"); + } + + auto paramSize = code->parameterOffsetsSize(); + auto offsets = code->stackOffsets(); + + auto store = instance->module()->store(); + store->setTCO(bp, offsets, paramSize, code->resultOffsetsSize(), target); + + return nullptr; + } + DEFINE_OPCODE(Select) : { @@ -2976,11 +3044,22 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state, uint8_t* ptr = userExceptionData.data(); auto& param = tag->functionType()->param().types(); - for (size_t i = 0; i < param.size(); i++) { - auto sz = valueStackAllocatedSize(param[i]); - memcpy(ptr, bp + code->dataOffsets()[i], sz); - ptr += sz; + auto store = instance->module()->store(); + + if (!store->hasTCO()) { + for (size_t i = 0; i < param.size(); i++) { + auto sz = valueStackAllocatedSize(param[i]); + memcpy(ptr, bp + code->dataOffsets()[i], sz); + ptr += sz; + } + } else { + for (size_t i = 0; i < param.size(); i++) { + auto sz = valueStackAllocatedSize(param[i]); + memcpy(ptr, bp + store->tcoBuffer()[i], sz); + ptr += sz; + } } + Trap::throwException(state, tag, std::move(userExceptionData)); ASSERT_NOT_REACHED(); NEXT_INSTRUCTION(); @@ -3041,7 +3120,15 @@ NEVER_INLINE void Interpreter::callOperation( { Call* code = (Call*)programCounter; Function* target = instance->function(code->index()); + auto store = instance->module()->store(); target->interpreterCall(state, bp, code->stackOffsets(), code->parameterOffsetsSize(), code->resultOffsetsSize()); + + while (UNLIKELY(store->hasTCO())) { + auto resultOffsetCount = store->tcoResultOffsetCount(); + target = store->tcoFunctionTarget(); + target->interpreterCall(state, bp, code->stackOffsets(), code->parameterOffsetsSize(), resultOffsetCount); + } + programCounter += ByteCode::pointerAlignedSize(sizeof(Call) + sizeof(ByteCodeStackOffset) * code->parameterOffsetsSize() + sizeof(ByteCodeStackOffset) * code->resultOffsetsSize()); } @@ -3068,7 +3155,15 @@ NEVER_INLINE void Interpreter::callIndirectOperation( Trap::throwException(state, "indirect call type mismatch"); } + auto store = instance->module()->store(); target->interpreterCall(state, bp, code->stackOffsets(), code->parameterOffsetsSize(), code->resultOffsetsSize()); + + while (UNLIKELY(store->hasTCO())) { + auto resultOffsetCount = store->tcoResultOffsetCount(); + target = store->tcoFunctionTarget(); + target->interpreterCall(state, bp, code->stackOffsets(), code->parameterOffsetsSize(), resultOffsetCount); + } + programCounter += ByteCode::pointerAlignedSize(sizeof(CallIndirect) + sizeof(ByteCodeStackOffset) * code->parameterOffsetsSize() + sizeof(ByteCodeStackOffset) * code->resultOffsetsSize()); } @@ -3090,7 +3185,15 @@ NEVER_INLINE void Interpreter::callRefOperation( Trap::throwException(state, "call by reference type mismatch"); } + auto store = instance->module()->store(); target->interpreterCall(state, bp, code->stackOffsets(), code->parameterOffsetsSize(), code->resultOffsetsSize()); + + while (UNLIKELY(store->hasTCO())) { + auto resultOffsetCount = store->tcoResultOffsetCount(); + target = store->tcoFunctionTarget(); + target->interpreterCall(state, bp, code->stackOffsets(), code->parameterOffsetsSize(), resultOffsetCount); + } + programCounter += ByteCode::pointerAlignedSize(sizeof(CallRef) + sizeof(ByteCodeStackOffset) * code->parameterOffsetsSize() + sizeof(ByteCodeStackOffset) * code->resultOffsetsSize()); } diff --git a/src/interpreter/Interpreter.h b/src/interpreter/Interpreter.h index 35fc7b0e3..7905c8198 100644 --- a/src/interpreter/Interpreter.h +++ b/src/interpreter/Interpreter.h @@ -22,6 +22,7 @@ #include "runtime/Instance.h" #include "runtime/JITExec.h" #include "runtime/Module.h" +#include "runtime/Store.h" #include "runtime/Tag.h" #include "interpreter/ByteCode.h" @@ -50,11 +51,16 @@ class Interpreter { CHECK_STACK_LIMIT(newState); auto moduleFunction = function->moduleFunction(); + auto store = function->instance()->module()->store(); ALLOCA(uint8_t, functionStackBase, moduleFunction->requiredStackSize()); - // init parameter space - for (size_t i = 0; i < parameterOffsetCount; i++) { - ((size_t*)functionStackBase)[i] = *((size_t*)(bp + offsets[i])); + if (store->hasTCO()) { + VectorCopier::copy((size_t*)functionStackBase, store->tcoBuffer(), store->tcoBufferSize()); + store->clearTCO(); + } else { + for (size_t i = 0; i < parameterOffsetCount; i++) { + ((size_t*)functionStackBase)[i] = *((size_t*)(bp + offsets[i])); + } } size_t programCounter = reinterpret_cast(moduleFunction->byteCode()); @@ -72,6 +78,7 @@ class Interpreter { resultOffsets = interpret(newState, programCounter, functionStackBase, function->instance()); break; } catch (std::unique_ptr& e) { + store->clearTCO(); for (size_t i = e->m_programCounterInfo.size(); i > 0; i--) { if (e->m_programCounterInfo[i - 1].first == &newState) { programCounter = e->m_programCounterInfo[i - 1].second; @@ -106,6 +113,10 @@ class Interpreter { resultOffsets = interpret(newState, programCounter, functionStackBase, function->instance()); } + if (store->hasTCO()) { + return; + } + offsets += parameterOffsetCount; for (size_t i = 0; i < resultOffsetCount; i++) { *((size_t*)(bp + offsets[i])) = *((size_t*)(functionStackBase + resultOffsets[i])); diff --git a/src/parser/WASMParser.cpp b/src/parser/WASMParser.cpp index 28cafbe6b..d10595d34 100644 --- a/src/parser/WASMParser.cpp +++ b/src/parser/WASMParser.cpp @@ -648,6 +648,7 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { // i32.eqz and JumpIf can be unified in some cases static const size_t s_noI32Eqz = SIZE_MAX - sizeof(Walrus::I32Eqz); size_t m_lastI32EqzPos; + bool m_useJIT; Walrus::FunctionType* getFunctionType(Index index) { @@ -860,7 +861,7 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { } public: - WASMBinaryReader(Walrus::TypeStore& typeStore) + WASMBinaryReader(Walrus::TypeStore& typeStore, bool useJIT = false) : m_readerOffsetPointer(nullptr) , m_readerDataPointer(nullptr) , m_codeEndOffset(0) @@ -876,6 +877,7 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { , m_segmentMode(Walrus::SegmentMode::None) , m_preprocessData(*this) , m_lastI32EqzPos(s_noI32Eqz) + , m_useJIT(useJIT) { } @@ -1582,6 +1584,71 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { generateCallExpr(code, parameterCount, resultCount, functionType); } + virtual void OnReturnCallExpr(uint32_t index) override + { + m_preprocessData.seenBranch(); + if (UNLIKELY(m_useJIT)) { + OnCallExpr(index); + generateFunctionReturnCode(); + return; + } + auto functionType = m_result.m_functions[index]->functionType(); + auto callPos = m_currentByteCode.size(); + auto parameterCount = computeFunctionParameterOrResultOffsetCount(functionType->param()); + auto resultCount = computeFunctionParameterOrResultOffsetCount(functionType->result()); + pushByteCode(Walrus::ReturnCall(index, parameterCount, resultCount, functionType), WASMOpcode::ReturnCallOpcode); + + expandByteCode(Walrus::ByteCode::pointerAlignedSize(sizeof(Walrus::ByteCodeStackOffset) * (parameterCount + resultCount))); + ASSERT(m_currentByteCode.size() % sizeof(void*) == 0); + auto code = peekByteCode(callPos); + + generateCallExpr(code, parameterCount, resultCount, functionType); + stopToGenerateByteCodeWhileBlockEnd(); + } + + virtual void OnReturnCallIndirectExpr(Index sigIndex, Index tableIndex) override + { + m_preprocessData.seenBranch(); + if (UNLIKELY(m_useJIT)) { + OnCallIndirectExpr(sigIndex, tableIndex); + generateFunctionReturnCode(); + return; + } + auto functionType = getFunctionType(sigIndex); + auto callPos = m_currentByteCode.size(); + auto parameterCount = computeFunctionParameterOrResultOffsetCount(functionType->param()); + auto resultCount = computeFunctionParameterOrResultOffsetCount(functionType->result()); + pushByteCode(Walrus::ReturnCallIndirect(popVMStack(), tableIndex, functionType, parameterCount, resultCount), + WASMOpcode::ReturnCallIndirectOpcode); + expandByteCode(Walrus::ByteCode::pointerAlignedSize(sizeof(Walrus::ByteCodeStackOffset) * (parameterCount + resultCount))); + ASSERT(m_currentByteCode.size() % sizeof(void*) == 0); + + auto code = peekByteCode(callPos); + generateCallExpr(code, parameterCount, resultCount, functionType); + stopToGenerateByteCodeWhileBlockEnd(); + } + + virtual void OnReturnCallRefExpr(Type sig_type) override + { + m_preprocessData.seenBranch(); + if (UNLIKELY(m_useJIT)) { + OnCallRefExpr(sig_type); + generateFunctionReturnCode(); + return; + } + auto functionType = getFunctionType(sig_type.GetReferenceIndex()); + auto callPos = m_currentByteCode.size(); + auto parameterCount = computeFunctionParameterOrResultOffsetCount(functionType->param()); + auto resultCount = computeFunctionParameterOrResultOffsetCount(functionType->result()); + pushByteCode(Walrus::ReturnCallRef(popVMStack(), functionType, parameterCount, resultCount), WASMOpcode::ReturnCallRefOpcode); + expandByteCode(Walrus::ByteCode::pointerAlignedSize(sizeof(Walrus::ByteCodeStackOffset) * (parameterCount + resultCount))); + ASSERT(m_currentByteCode.size() % sizeof(void*) == 0); + + auto code = peekByteCode(callPos); + generateCallExpr(code, parameterCount, resultCount, functionType); + stopToGenerateByteCodeWhileBlockEnd(); + } + bool processConstValue(const Walrus::Value& value) { if (!m_inInitExpr) { @@ -3898,7 +3965,7 @@ void WASMParsingResult::clear() std::pair, std::string> WASMParser::parseBinary(Store* store, const std::string& filename, const uint8_t* data, size_t len, const uint32_t JITFlags, const uint32_t featureFlags) { - wabt::WASMBinaryReader delegate(store->getTypeStore()); + wabt::WASMBinaryReader delegate(store->getTypeStore(), JITFlags & JITFlagValue::useJIT); std::string error = ReadWasmBinary(filename, data, len, &delegate, featureFlags); if (error.length()) { diff --git a/src/runtime/ExecutionState.h b/src/runtime/ExecutionState.h index 07dbb1d0b..4b5f3bf3c 100644 --- a/src/runtime/ExecutionState.h +++ b/src/runtime/ExecutionState.h @@ -32,6 +32,7 @@ class ExecutionState { ExecutionState(ExecutionState& parent) : m_parent(&parent) + , m_currentFunction(nullptr) , m_stackLimit(parent.m_stackLimit) { } @@ -56,6 +57,8 @@ class ExecutionState { private: friend class ByteCodeTable; ExecutionState() + : m_parent(nullptr) + , m_currentFunction(nullptr) { m_stackLimit = (size_t)currentStackPointer(); diff --git a/src/runtime/Function.cpp b/src/runtime/Function.cpp index 92980c724..e078f8e98 100644 --- a/src/runtime/Function.cpp +++ b/src/runtime/Function.cpp @@ -17,9 +17,9 @@ #include "Walrus.h" #include "runtime/Function.h" +#include "runtime/Module.h" #include "runtime/Store.h" #include "interpreter/Interpreter.h" -#include "runtime/Module.h" #include "runtime/Tag.h" #include "runtime/Instance.h" #include "runtime/Value.h" @@ -92,6 +92,13 @@ void DefinedFunction::call(ExecutionState& state, Value* argv, Value* result) ASSERT(offsetIndex == parameterOffsetSize + resultOffsetSize); interpreterCall(state, valueBuffer, offsetBuffer, parameterOffsetSize, resultOffsetSize); + auto store = m_instance->module()->store(); + while (UNLIKELY(store->hasTCO())) { + auto resultOffsetCount = store->tcoResultOffsetCount(); + auto target = store->tcoFunctionTarget(); + target->interpreterCall(state, valueBuffer, offsetBuffer, parameterOffsetSize, resultOffsetCount); + } + size_t resultOffsetIndex = 0; for (size_t i = 0; i < resultTypeInfo.size(); i++) { result[i] = Value(resultTypeInfo[i], valueBuffer + offsetBuffer[resultOffsetIndex + parameterOffsetSize]); diff --git a/src/runtime/Store.cpp b/src/runtime/Store.cpp index ed4cbdfe8..8a3fd9b47 100644 --- a/src/runtime/Store.cpp +++ b/src/runtime/Store.cpp @@ -38,6 +38,8 @@ static const FunctionType g_defaultFunctionTypes[] = { Store::Store(Engine* engine) : m_engine(engine) + , m_tcoResultOffsetCount(0) + , m_tcoFunctionTarget(nullptr) { #ifdef ENABLE_GC GC_INIT(); diff --git a/src/runtime/Store.h b/src/runtime/Store.h index 9bd7ac801..c798d12a7 100644 --- a/src/runtime/Store.h +++ b/src/runtime/Store.h @@ -26,6 +26,7 @@ namespace Walrus { class Engine; +class Function; class Module; class Instance; class Extern; @@ -89,6 +90,29 @@ class Store { Waiter* getWaiter(void* address); + bool hasTCO() const { return m_tcoFunctionTarget != nullptr; } + + Function* tcoFunctionTarget() const { return m_tcoFunctionTarget; } + + void setTCO(const uint8_t* bp, const ByteCodeStackOffset* offsets, size_t paramCount, uint16_t resultOffsetCount, Function* target) + { + m_tcoBuffer.resize(paramCount); + for (size_t i = 0; i < paramCount; i++) { + m_tcoBuffer[i] = *((size_t*)(bp + offsets[i])); + } + m_tcoResultOffsetCount = resultOffsetCount; + m_tcoFunctionTarget = target; + } + + size_t* tcoBuffer() { return m_tcoBuffer.data(); } + size_t tcoBufferSize() const { return m_tcoBuffer.size(); } + uint16_t tcoResultOffsetCount() const { return m_tcoResultOffsetCount; } + + void clearTCO() + { + m_tcoFunctionTarget = nullptr; + } + private: Engine* m_engine; TypeStore m_typeStore; @@ -99,6 +123,10 @@ class Store { std::mutex m_waiterListLock; std::vector m_waiterList; + + std::vector m_tcoBuffer; + uint16_t m_tcoResultOffsetCount; + Function* m_tcoFunctionTarget; }; } // namespace Walrus diff --git a/test/wasi/write_to_this.txt b/test/wasi/write_to_this.txt index 980a0d5f1..e69de29bb 100644 --- a/test/wasi/write_to_this.txt +++ b/test/wasi/write_to_this.txt @@ -1 +0,0 @@ -Hello World! diff --git a/third_party/wabt/include/wabt/walrus/binary-reader-walrus.h b/third_party/wabt/include/wabt/walrus/binary-reader-walrus.h index 716334129..a8fdd1713 100644 --- a/third_party/wabt/include/wabt/walrus/binary-reader-walrus.h +++ b/third_party/wabt/include/wabt/walrus/binary-reader-walrus.h @@ -119,6 +119,9 @@ class WASMBinaryReaderDelegate { virtual void OnCallExpr(Index index) = 0; virtual void OnCallIndirectExpr(Index sigIndex, Index tableIndex) = 0; virtual void OnCallRefExpr(Type sig_type) = 0; + virtual void OnReturnCallExpr(Index index) = 0; + virtual void OnReturnCallIndirectExpr(Index sigIndex, Index tableIndex) = 0; + virtual void OnReturnCallRefExpr(Type sig_type) = 0; virtual void OnI32ConstExpr(uint32_t value) = 0; virtual void OnI64ConstExpr(uint64_t value) = 0; virtual void OnF32ConstExpr(uint32_t value) = 0; diff --git a/third_party/wabt/src/walrus/binary-reader-walrus.cc b/third_party/wabt/src/walrus/binary-reader-walrus.cc index 5f0a9651e..73dc30a5d 100644 --- a/third_party/wabt/src/walrus/binary-reader-walrus.cc +++ b/third_party/wabt/src/walrus/binary-reader-walrus.cc @@ -924,8 +924,8 @@ class BinaryReaderDelegateWalrus: public BinaryReaderDelegate { CHECK_RESULT(m_validator.OnReturnCall(GetLocation(), Var(func_index, GetLocation()))); SHOULD_GENERATE_BYTECODE; - m_externalDelegate->OnCallExpr(func_index); - m_externalDelegate->OnReturnExpr(); + + m_externalDelegate->OnReturnCallExpr(func_index); return Result::Ok; } Result OnReturnCallIndirectExpr(Index sig_index, Index table_index) override { @@ -939,8 +939,8 @@ class BinaryReaderDelegateWalrus: public BinaryReaderDelegate { CHECK_RESULT(m_validator.OnReturnCallIndirect(GetLocation(), Var(sig_index, GetLocation()), Var(table_index, GetLocation()))); SHOULD_GENERATE_BYTECODE; - m_externalDelegate->OnCallIndirectExpr(sig_index, table_index); - m_externalDelegate->OnReturnExpr(); + + m_externalDelegate->OnReturnCallIndirectExpr(sig_index, table_index); return Result::Ok; } Result OnReturnCallRefExpr(Type sig_type) override @@ -955,8 +955,8 @@ class BinaryReaderDelegateWalrus: public BinaryReaderDelegate { CHECK_RESULT(m_validator.OnReturnCallRef(GetLocation(), Var(sig_type, GetLocation()))); SHOULD_GENERATE_BYTECODE; - m_externalDelegate->OnCallRefExpr(sig_type); - m_externalDelegate->OnReturnExpr(); + + m_externalDelegate->OnReturnCallRefExpr(sig_type); return Result::Ok; } Result OnReturnExpr() override {