diff --git a/tmva/sofie/CMakeLists.txt b/tmva/sofie/CMakeLists.txt index 395af4a9ad104..7a3e74225d23c 100644 --- a/tmva/sofie/CMakeLists.txt +++ b/tmva/sofie/CMakeLists.txt @@ -72,6 +72,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTTMVASofie TMVA/ROperator_Not.hxx TMVA/ROperator_Clip.hxx TMVA/ROperator_Gelu.hxx + TMVA/ROperator_HardSigmoid.hxx TMVA/SOFIE_common.hxx TMVA/SOFIEHelpers.hxx diff --git a/tmva/sofie/inc/TMVA/ROperator_HardSigmoid.hxx b/tmva/sofie/inc/TMVA/ROperator_HardSigmoid.hxx new file mode 100644 index 0000000000000..ede5ff8650537 --- /dev/null +++ b/tmva/sofie/inc/TMVA/ROperator_HardSigmoid.hxx @@ -0,0 +1,86 @@ +#ifndef TMVA_SOFIE_ROPERATOR_HARDSIGMOID +#define TMVA_SOFIE_ROPERATOR_HARDSIGMOID + +#include "TMVA/SOFIE_common.hxx" +#include "TMVA/ROperator.hxx" +#include "TMVA/RModel.hxx" + +#include + +namespace TMVA { +namespace Experimental { +namespace SOFIE { + +template +class ROperator_HardSigmoid final : public ROperator { + +private: + float fAlpha = 0.2; + float fBeta = 0.5; + std::string fNX; + std::string fNY; + std::vector fShape; + std::string fType; + +public: + ROperator_HardSigmoid() {} + ROperator_HardSigmoid(float alpha, float beta, std::string nameX, std::string nameY) + : fAlpha(alpha), fBeta(beta), fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)) + { + if (std::is_same::value) { + fType = "float"; + } else { + throw std::runtime_error("TMVA SOFIE Encountered unsupported type parsing a HardSigmoid operator"); + } + + fInputTensorNames = {fNX}; + fOutputTensorNames = {fNY}; + } + + std::vector TypeInference(std::vector input) override { return input; } + + std::vector> ShapeInference(std::vector> input) override + { + auto ret = input; // suggest copy to compiler + return ret; + } + + void Initialize(RModel &model) override + { + if (model.CheckIfTensorAlreadyExist(fNX) == false) { + throw std::runtime_error("TMVA SOFIE HardSigmoid Op Input Tensor is not found in model"); + } + fShape = model.GetTensorShape(fNX); + model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShape); + } + + std::string Generate(std::string OpName) override + { + OpName = "op_" + OpName; + if (fShape.empty()) { + throw std::runtime_error("TMVA SOFIE Operator HardSigmoid called to Generate without being initialized first"); + } + std::stringstream out; + size_t length = ConvertShapeToLength(fShape); + + out << SP << "constexpr float " << OpName + << "_alpha = " << std::setprecision(std::numeric_limits::max_digits10) << fAlpha << ";\n"; + out << SP << "constexpr float " << OpName + << "_beta = " << std::setprecision(std::numeric_limits::max_digits10) << fBeta << ";\n"; + + out << "\n//------ HardSigmoid\n"; + out << SP << "for (int id = 0; id < " << length << " ; id++){\n"; + out << SP << SP << "tensor_" << fNY << "[id] = std::max(0.0f, std::min(1.0f, " << OpName << "_alpha * tensor_" + << fNX << "[id] + " << OpName << "_beta));\n"; + out << SP << "}\n"; + return out.str(); + } + + std::vector GetStdLibs() override { return {std::string("algorithm")}; } +}; + +} // namespace SOFIE +} // namespace Experimental +} // namespace TMVA + +#endif // TMVA_SOFIE_ROPERATOR_HARDSIGMOID diff --git a/tmva/sofie/test/TestCustomModelsFromONNX.cxx b/tmva/sofie/test/TestCustomModelsFromONNX.cxx index 75d341495b65f..80dd5c395eae8 100644 --- a/tmva/sofie/test/TestCustomModelsFromONNX.cxx +++ b/tmva/sofie/test/TestCustomModelsFromONNX.cxx @@ -93,6 +93,7 @@ constexpr auto modelDataSuffix = "_FromONNX.dat"; #include "input_models/references/Log.ref.hxx" #include "input_models/references/Elu.ref.hxx" #include "input_models/references/Gelu.ref.hxx" +#include "input_models/references/HardSigmoid.ref.hxx" #include "input_models/references/Equal.ref.hxx" #include "input_models/references/EluAlpha.ref.hxx" #include "input_models/references/LessOrEqual.ref.hxx" @@ -3195,3 +3196,23 @@ TEST(ONNX, Gelu) EXPECT_LE(std::abs(output[i] - correct[i]), TOLERANCE); } } + +TEST(ONNX, HardSigmoid) +{ + constexpr float TOLERANCE = DEFAULT_TOLERANCE; + + // Preparing the standard input + std::vector input{1.0, -2.0, 3.0, 0.5, -1.0, 2.0}; + + ASSERT_INCLUDE_AND_RUN(std::vector, "HardSigmoid", input); + + // Checking output size + EXPECT_EQ(output.size(), std::size(HardSigmoid_ExpectedOutput::outputs)); + + float *correct = HardSigmoid_ExpectedOutput::outputs; + + // Checking every output value, one by one + for (size_t i = 0; i < output.size(); ++i) { + EXPECT_LE(std::abs(output[i] - correct[i]), TOLERANCE); + } +} diff --git a/tmva/sofie/test/input_models/HardSigmoid.onnx b/tmva/sofie/test/input_models/HardSigmoid.onnx new file mode 100644 index 0000000000000..df4952c4992b9 Binary files /dev/null and b/tmva/sofie/test/input_models/HardSigmoid.onnx differ diff --git a/tmva/sofie/test/input_models/references/HardSigmoid.ref.hxx b/tmva/sofie/test/input_models/references/HardSigmoid.ref.hxx new file mode 100644 index 0000000000000..67b8db1e1ed75 --- /dev/null +++ b/tmva/sofie/test/input_models/references/HardSigmoid.ref.hxx @@ -0,0 +1,3 @@ +namespace HardSigmoid_ExpectedOutput { +float outputs[] = {0.700000f, 0.100000f, 1.000000f, 0.600000f, 0.300000f, 0.900000f}; +} // namespace HardSigmoid_ExpectedOutput diff --git a/tmva/sofie_parsers/CMakeLists.txt b/tmva/sofie_parsers/CMakeLists.txt index 4814b62c6ec51..957c43ec4b6b0 100644 --- a/tmva/sofie_parsers/CMakeLists.txt +++ b/tmva/sofie_parsers/CMakeLists.txt @@ -81,6 +81,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTTMVASofieParser src/ParseNot.cxx src/ParseClip.cxx src/ParseGelu.cxx + src/ParseHardSigmoid.cxx ${PROTO_SRCS} LIBRARIES PUBLIC protobuf::libprotobuf diff --git a/tmva/sofie_parsers/src/ParseHardSigmoid.cxx b/tmva/sofie_parsers/src/ParseHardSigmoid.cxx new file mode 100644 index 0000000000000..3d3222cf7c30b --- /dev/null +++ b/tmva/sofie_parsers/src/ParseHardSigmoid.cxx @@ -0,0 +1,52 @@ +#include "TMVA/RModelParser_ONNX.hxx" +#include "TMVA/ROperator_HardSigmoid.hxx" +#include "onnx_proto3.pb.h" + +namespace TMVA { +namespace Experimental { +namespace SOFIE { + +ParserFuncSignature ParseHardSigmoid = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { + ETensorType input_type; + + auto input_name = nodeproto.input(0); + if (parser.IsRegisteredTensorType(input_name)) { + input_type = parser.GetTensorType(input_name); + } else { + throw std::runtime_error("TMVA::SOFIE ONNX Parser HardSigmoid op has input tensor" + input_name + + " but its type is not yet registered"); + } + + std::unique_ptr op; + + float attr_alpha = 0.2; + float attr_beta = 0.5; + + for (int_t i = 0; i < nodeproto.attribute_size(); i++) { + std::string attribute_name = nodeproto.attribute(i).name(); + if (attribute_name == "alpha") + attr_alpha = nodeproto.attribute(i).f(); + if (attribute_name == "beta") + attr_beta = nodeproto.attribute(i).f(); + } + + std::string output_name = nodeproto.output(0); + switch (input_type) { + case ETensorType::FLOAT: + op.reset(new ROperator_HardSigmoid(attr_alpha, attr_beta, input_name, output_name)); + break; + default: + throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator HardSigmoid does not yet support input type " + + std::to_string(static_cast(input_type))); + } + + if (!parser.IsRegisteredTensorType(output_name)) { + parser.RegisterTensorType(output_name, input_type); + } + + return op; +}; + +} // namespace SOFIE +} // namespace Experimental +} // namespace TMVA diff --git a/tmva/sofie_parsers/src/RModelParser_ONNX.cxx b/tmva/sofie_parsers/src/RModelParser_ONNX.cxx index aa444c05d5f76..c76de40ae6dc5 100644 --- a/tmva/sofie_parsers/src/RModelParser_ONNX.cxx +++ b/tmva/sofie_parsers/src/RModelParser_ONNX.cxx @@ -86,6 +86,7 @@ extern ParserFuncSignature ParseGather; extern ParserFuncSignature ParseGatherND; extern ParserFuncSignature ParseErf; extern ParserFuncSignature ParseElu; +extern ParserFuncSignature ParseHardSigmoid; extern ParserFuncSignature ParseEyeLike; extern ParserFuncSignature ParseRange; extern ParserFuncSignature ParseTopK; @@ -324,6 +325,7 @@ RModelParser_ONNX::RModelParser_ONNX() noexcept : fOperatorsMapImpl(std::make_un RegisterOperator("GatherND", ParseGatherND); RegisterOperator("Erf", ParseErf); RegisterOperator("Elu", ParseElu); + RegisterOperator("HardSigmoid", ParseHardSigmoid); RegisterOperator("EyeLike", ParseEyeLike); RegisterOperator("Range", ParseRange); RegisterOperator("TopK", ParseTopK);