diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fe45077c234f..668f977188ad 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,7 +51,7 @@ repos: rev: v2.4.1 hooks: - id: codespell - args: [ --ignore-words-list, "CreateOr,implementors,PADD,re-use,re-used,re-using,subtile,subtiles,tRe" ] + args: [ --ignore-words-list, "CreateOr,implementors,PADD,re-use,re-used,re-using,SME,+sme,subtile,subtiles,tRe" ] exclude: | (?x)( ^src/autoschedulers/common/cmdline\.h$ diff --git a/Makefile b/Makefile index 62c1e5f1a68f..e257bfb930bb 100644 --- a/Makefile +++ b/Makefile @@ -543,6 +543,7 @@ SOURCE_FILES = \ LoopCarry.cpp \ Lower.cpp \ LowerParallelTasks.cpp \ + LowerSMEStreamingTasks.cpp \ LowerWarpShuffles.cpp \ Memoization.cpp \ Module.cpp \ @@ -747,6 +748,7 @@ HEADER_FILES = \ LoopPartitioningDirective.h \ Lower.h \ LowerParallelTasks.h \ + LowerSMEStreamingTasks.h \ LowerWarpShuffles.h \ MainPage.h \ Memoization.h \ diff --git a/python_bindings/src/halide/halide_/PyEnums.cpp b/python_bindings/src/halide/halide_/PyEnums.cpp index af5519896b9e..f8abe550bb90 100644 --- a/python_bindings/src/halide/halide_/PyEnums.cpp +++ b/python_bindings/src/halide/halide_/PyEnums.cpp @@ -26,7 +26,8 @@ void define_enums(py::module &m) { .value("Vulkan", DeviceAPI::Vulkan) .value("OpenCL", DeviceAPI::OpenCL) .value("Metal", DeviceAPI::Metal) - .value("Hexagon", DeviceAPI::Hexagon); + .value("Hexagon", DeviceAPI::Hexagon) + .value("Host_SMEStreaming", DeviceAPI::Host_SMEStreaming); py::enum_(m, "LinkageType") .value("External", LinkageType::External) @@ -186,6 +187,12 @@ void define_enums(py::module &m) { .value("WebGPU", Target::Feature::WebGPU) .value("SVE", Target::Feature::SVE) .value("SVE2", Target::Feature::SVE2) + .value("SME2", Target::Feature::SME2) + .value("SME_SVL128", Target::Feature::SME_SVL128) + .value("SME_SVL256", Target::Feature::SME_SVL256) + .value("SME_SVL512", Target::Feature::SME_SVL512) + .value("SME_SVL1024", Target::Feature::SME_SVL1024) + .value("SME_SVL2048", Target::Feature::SME_SVL2048) .value("ARMDotProd", Target::Feature::ARMDotProd) .value("ARMFp16", Target::Feature::ARMFp16) .value("LLVMLargeCodeModel", Target::Feature::LLVMLargeCodeModel) diff --git a/python_bindings/src/halide/halide_/PyScheduleMethods.h b/python_bindings/src/halide/halide_/PyScheduleMethods.h index 75475dc73f9d..f540c3c087c7 100644 --- a/python_bindings/src/halide/halide_/PyScheduleMethods.h +++ b/python_bindings/src/halide/halide_/PyScheduleMethods.h @@ -101,6 +101,8 @@ HALIDE_NEVER_INLINE void add_schedule_methods(PythonClass &class_instance) { .def("hexagon", &T::hexagon, py::arg("x") = Var::outermost()) + .def("sme_streaming", &T::sme_streaming, py::arg("enable"), py::arg("x") = Var::outermost()) + .def("prefetch", (T & (T::*)(const Func &, const VarOrRVar &, const VarOrRVar &, Expr, PrefetchBoundStrategy)) & T::prefetch, py::arg("func"), py::arg("at"), py::arg("from"), py::arg("offset") = 1, py::arg("strategy") = PrefetchBoundStrategy::GuardWithIf) .def("prefetch", // [](T &t, const ImageParam &image, const VarOrRVar &at, const VarOrRVar &from, const Expr &offset, PrefetchBoundStrategy strategy) -> T & { diff --git a/python_bindings/src/halide/halide_/PyTarget.cpp b/python_bindings/src/halide/halide_/PyTarget.cpp index 073f95ee9be1..80412b35f8b5 100644 --- a/python_bindings/src/halide/halide_/PyTarget.cpp +++ b/python_bindings/src/halide/halide_/PyTarget.cpp @@ -15,7 +15,8 @@ std::string target_repr(const Target &t) { void define_target(py::module &m) { // Disambiguate some ambiguous methods - int (Target::*natural_vector_size_method)(const Type &t) const = &Target::natural_vector_size; + int (Target::*natural_vector_size1_method)(const Type &t) const = &Target::natural_vector_size; + int (Target::*natural_vector_size2_method)(const Type &t, bool is_sme_streaming) const = &Target::natural_vector_size; bool (Target::*supports_type1_method)(const Type &t) const = &Target::supports_type; bool (Target::*supports_type2_method)(const Type &t, DeviceAPI device) const = &Target::supports_type; @@ -52,10 +53,13 @@ void define_target(py::module &m) { .def("supports_type", supports_type1_method, py::arg("type")) .def("supports_type", supports_type2_method, py::arg("type"), py::arg("device")) .def("supports_device_api", &Target::supports_device_api, py::arg("device")) - .def("natural_vector_size", natural_vector_size_method, py::arg("type")) + .def("natural_vector_size", natural_vector_size1_method, py::arg("type")) + .def("natural_vector_size", natural_vector_size2_method, py::arg("type"), py::arg("is_sme_streaming")) + .def("sme_streaming_vector_bits", &Target::sme_streaming_vector_bits) .def("has_large_buffers", &Target::has_large_buffers) .def("maximum_buffer_size", &Target::maximum_buffer_size) .def("supported", &Target::supported) + .def_static("sme_svl_feature_from_bits", &Target::sme_svl_feature_from_bits, py::arg("bits")) .def_static("validate_target_string", &Target::validate_target_string, py::arg("name")); ; diff --git a/src/AddImageChecks.cpp b/src/AddImageChecks.cpp index b9ef94564049..248e23c90409 100644 --- a/src/AddImageChecks.cpp +++ b/src/AddImageChecks.cpp @@ -39,7 +39,8 @@ class FindBuffers : public IRGraphVisitor { op->max.accept(this); bool old = in_device_loop; if (op->device_api != DeviceAPI::None && - op->device_api != DeviceAPI::Host) { + op->device_api != DeviceAPI::Host && + op->device_api != DeviceAPI::Host_SMEStreaming) { in_device_loop = true; } op->body.accept(this); diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8e4e4f1e7afd..8f4e3df1743d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -154,6 +154,7 @@ target_sources( LoopPartitioningDirective.h Lower.h LowerParallelTasks.h + LowerSMEStreamingTasks.h LowerWarpShuffles.h MainPage.h Memoization.h @@ -332,6 +333,7 @@ target_sources( LoopCarry.cpp Lower.cpp LowerParallelTasks.cpp + LowerSMEStreamingTasks.cpp LowerWarpShuffles.cpp Memoization.cpp Module.cpp diff --git a/src/CodeGen_ARM.cpp b/src/CodeGen_ARM.cpp index a61d12222b5d..fc32c06aee68 100644 --- a/src/CodeGen_ARM.cpp +++ b/src/CodeGen_ARM.cpp @@ -89,6 +89,7 @@ Target complete_arm_target(Target t) { static const Target::Feature features_with_fp16[] = { Target::SVE, Target::SVE2, + Target::SME2, }; for (const auto &f : features_with_fp16) { @@ -97,6 +98,7 @@ Target complete_arm_target(Target t) { static const Target::Feature features_with_dotprod[] = { Target::SVE2, + Target::SME2, }; for (const auto &f : features_with_dotprod) { @@ -202,7 +204,10 @@ class CodeGen_ARM : public CodeGen_CPU { /** Determine feasible vscale (vector_bits/128 or 0) by checking vector lanes used in the function. * Raise user_warning in case of not feasible */ - int check_feasible_vscale(int vector_bits, const std::set &lanes_used, const std::string &simple_name); + int check_feasible_vscale(int vector_bits, + const std::set &lanes_used, + const std::string &streaming_or_none, + const std::string &simple_name); /** Nodes for which we want to emit specific ARM vector intrinsics */ // @{ @@ -297,7 +302,9 @@ class CodeGen_ARM : public CodeGen_CPU { int feasible_vscale = 0; IntrinsicsMap intrinsics_neon; IntrinsicsMap intrinsics_sve2; + IntrinsicsMap intrinsics_streaming; IntrinsicsMap *effective_intrinsics; + bool in_streaming = false; }; CodeGen_ARM::CodeGen_ARM(const Target &target) @@ -1042,6 +1049,9 @@ llvm::Function *CodeGen_ARM::define_intrin_wrapper(const std::string &inner_name // Always inline these wrappers. wrapper->addFnAttr(llvm::Attribute::AlwaysInline); + // Available regardless of SME streaming mode + wrapper->addFnAttr("aarch64_pstate_sm_compatible"); + builder->restoreIP(here); llvm::verifyFunction(*wrapper); @@ -1059,10 +1069,15 @@ void CodeGen_ARM::init_module() { } else if (target.has_feature(Target::SVE)) { user_warning << "Halide does not support SVE for now. Use SVE2 if your target device supports it.\n"; } + if (target.has_feature(Target::SME2)) { + user_assert(target.sme_streaming_vector_bits() != 0) + << "For SME2 support, exactly one Target::SME_SVL* feature must be set. For generator target strings, add \"sme_svl\".\n"; + } const bool has_neon = !target.has_feature(Target::NoNEON); const bool has_sve = target.has_feature(Target::SVE2); - if (!(has_neon || has_sve)) { + const bool has_sme = target.has_feature(Target::SME2); + if (!(has_neon || has_sve || has_sme)) { return; } @@ -1070,6 +1085,7 @@ void CodeGen_ARM::init_module() { NeonWidthX1, NeonWidthX2, SVE, + Streaming, }; std::vector flavors; @@ -1080,6 +1096,9 @@ void CodeGen_ARM::init_module() { if (has_sve) { flavors.push_back(SIMDFlavors::SVE); } + if (has_sme) { + flavors.push_back(SIMDFlavors::Streaming); + } for (const ArmIntrinsic &intrin : intrinsic_defs) { if ((intrin.flags & ArmIntrinsic::RequireFp16) && !target.has_feature(Target::ARMFp16)) { @@ -1104,7 +1123,11 @@ void CodeGen_ARM::init_module() { // scaled, and one of two opcodes may be selected by different // iterations of this loop. for (const auto flavor : flavors) { - const bool is_sve = flavor == SIMDFlavors::SVE; + // Assuming intrinsics in Streaming can be handled in the same way as SVE + // except for the vscale value. + // This could change when we add a SME specific intrin or a SVE intrin which is + // unavailable in streaming mode. + const bool is_sve_or_streaming = (flavor == SIMDFlavors::SVE || flavor == SIMDFlavors::Streaming); int vscale = 0; IntrinsicsMap *intrinsics_map = nullptr; @@ -1117,13 +1140,17 @@ void CodeGen_ARM::init_module() { vscale = target.vector_bits / 128; intrinsics_map = &intrinsics_sve2; break; + case SIMDFlavors::Streaming: + vscale = target.sme_streaming_vector_bits() / 128; + intrinsics_map = &intrinsics_streaming; + break; default: internal_error << "unreachable\n"; break; } // Skip intrinsics that are NEON or SVE only depending on whether compiling for SVE. - if (is_sve) { + if (is_sve_or_streaming) { if (intrin.flags & ArmIntrinsic::SveUnavailable) { continue; } @@ -1134,7 +1161,7 @@ void CodeGen_ARM::init_module() { } if ((target.bits == 64) && (intrin.flags & ArmIntrinsic::Neon64Unavailable) && - !is_sve) { + !is_sve_or_streaming) { continue; } // Already declared in the x1 pass. @@ -1147,9 +1174,9 @@ void CodeGen_ARM::init_module() { const bool is_vanilla_intrinsic = starts_with(intrin_name, "llvm."); if (!is_vanilla_intrinsic && (intrin.flags & ArmIntrinsic::NoPrefix) == 0) { const char *prefix = - target.bits == 32 ? "llvm.arm.neon." : - is_sve ? "llvm.aarch64.sve." : - "llvm.aarch64.neon."; + target.bits == 32 ? "llvm.arm.neon." : + is_sve_or_streaming ? "llvm.aarch64.sve." : + "llvm.aarch64.neon."; return concat_strings(prefix, intrin_name); } return intrin_name; @@ -1165,6 +1192,7 @@ void CodeGen_ARM::init_module() { width_factor = 2; break; case SIMDFlavors::SVE: + case SIMDFlavors::Streaming: width_factor = (intrin.flags & ArmIntrinsic::HalfWidth) ? 2 : 1; width_factor *= vscale; break; @@ -1196,7 +1224,7 @@ void CodeGen_ARM::init_module() { if (starts_with(full_name, "llvm.") && (intrin.flags & ArmIntrinsic::NoMangle) == 0) { // Append LLVM name mangling for either the return type or the arguments, or both. vector types; - if (intrin.flags & ArmIntrinsic::MangleArgs && !is_sve) { + if (intrin.flags & ArmIntrinsic::MangleArgs && !is_sve_or_streaming) { types = arg_types; } else if (intrin.flags & ArmIntrinsic::MangleRetArgs) { types = {ret_type}; @@ -1205,8 +1233,8 @@ void CodeGen_ARM::init_module() { types = {ret_type}; } for (const Type &t : types) { - std::string llvm_vector_prefix = is_sve ? ".nxv" : ".v"; - int mangle_lanes = t.lanes() / (is_sve ? vscale : 1); + std::string llvm_vector_prefix = is_sve_or_streaming ? ".nxv" : ".v"; + int mangle_lanes = t.lanes() / (is_sve_or_streaming ? vscale : 1); mangled_name_builder << llvm_vector_prefix << mangle_lanes; if (t.is_int() || t.is_uint()) { mangled_name_builder << "i"; @@ -1220,7 +1248,7 @@ void CodeGen_ARM::init_module() { llvm::Function *intrin_impl = define_intrin_wrapper( intrin.name, ret_type, mangled_name, arg_types, - intrin.flags, is_sve, vscale); + intrin.flags, is_sve_or_streaming, vscale); function_does_not_access_memory(intrin_impl); intrin_impl->addFnAttr(llvm::Attribute::NoUnwind); @@ -1240,6 +1268,9 @@ void CodeGen_ARM::compile_func(const LoweredFunc &f, const string &extern_name) { LoweredFunc func = f; + llvm::Function *llvm_func = module->getFunction(extern_name); + internal_assert(llvm_func); + bool is_streaming_task = (f.attributes & LoweredFunc::Attribute::SME_STREAMING_TASK) && target.has_feature(Target::SME2); if (target.os != Target::IOS && target.os != Target::OSX) { // Substitute in strided loads to get vld2/3/4 emission. We don't do it @@ -1254,7 +1285,8 @@ void CodeGen_ARM::compile_func(const LoweredFunc &f, // Inspect vector lanes used in this function to determine feasible vscale. // TODO: Target::SVE not supported https://github.com/halide/Halide/issues/8872 feasible_vscale = 0; - if (target.features_any_of({Target::SVE2})) { + in_streaming = false; + if (target.features_any_of({Target::SVE2, Target::SME2})) { std::set lanes_used; mutate_with(func.body, [&](auto *self, const Expr &e) { @@ -1262,7 +1294,23 @@ void CodeGen_ARM::compile_func(const LoweredFunc &f, return self->mutate_base(e); }); - feasible_vscale = check_feasible_vscale(target.vector_bits, lanes_used, simple_name); + if (is_streaming_task) { + feasible_vscale = check_feasible_vscale(target.sme_streaming_vector_bits(), // SVL + lanes_used, "streaming_", simple_name); + } + in_streaming = (feasible_vscale > 0) && is_streaming_task; + + if (!in_streaming && target.has_feature(Target::SVE2)) { + feasible_vscale = check_feasible_vscale(target.vector_bits, // VL + lanes_used, "", simple_name); + } + } + + if (in_streaming) { + llvm_func->addFnAttr("aarch64_pstate_sm_body"); + llvm_func->addFnAttr(llvm::Attribute::NoInline); + } else if (f.attributes & LoweredFunc::Attribute::SME_NONSTREAMING_TASK) { + llvm_func->addFnAttr(llvm::Attribute::NoInline); } if (feasible_vscale > 0) { @@ -1274,20 +1322,31 @@ void CodeGen_ARM::compile_func(const LoweredFunc &f, } // Select intrinsics map for neon or sve2, depending on vscale - effective_intrinsics = feasible_vscale > 0 ? &intrinsics_sve2 : &intrinsics_neon; + effective_intrinsics = in_streaming ? &intrinsics_streaming : feasible_vscale > 0 ? &intrinsics_sve2 : + &intrinsics_neon; - CodeGen_CPU::set_effective_vscale(feasible_vscale); + set_effective_vscale(feasible_vscale); // Make sure run-time vscale is equal to compile-time vscale. // Avoiding the assert on inner functions is both an efficiency and a correctness issue // as the assertion code may not compile in all contexts. - if (f.linkage != LinkageType::Internal) { + if (f.linkage != LinkageType::Internal && !target.has_feature(Target::NoAsserts)) { int effective_vscale = target_vscale(); - if (effective_vscale != 0 && !target.has_feature(Target::NoAsserts)) { + if (effective_vscale != 0) { + internal_assert(!in_streaming) << "Streaming mode in non-internal linkage func is unexpected\n"; Expr runtime_vscale = Call::make(Int(32), Call::get_runtime_vscale, {}, Call::PureIntrinsic); Expr compiletime_vscale = Expr(effective_vscale); - Expr error = Call::make(Int(32), "halide_error_vscale_invalid", - {simple_name, runtime_vscale, compiletime_vscale}, Call::Extern); + std::vector args{simple_name, runtime_vscale, compiletime_vscale}; + Expr error = Call::make(Int(32), "halide_error_vscale_invalid", args, Call::Extern); + func.body = Block::make(AssertStmt::make(runtime_vscale == compiletime_vscale, error), func.body); + } + if (target.has_feature(Target::SME2)) { + // We check regardless of streaming mode enabled or not + // because streaming task is basically internal linkage. + Expr runtime_vscale = Call::make(Int(32), Call::get_runtime_streaming_vscale, {}, Call::PureIntrinsic); + Expr compiletime_vscale = Expr(target.sme_streaming_vector_bits() / 128); + std::vector args{simple_name, runtime_vscale, compiletime_vscale}; + Expr error = Call::make(Int(32), "halide_error_streaming_vscale_invalid", args, Call::Extern); func.body = Block::make(AssertStmt::make(runtime_vscale == compiletime_vscale, error), func.body); } } @@ -1295,7 +1354,10 @@ void CodeGen_ARM::compile_func(const LoweredFunc &f, CodeGen_CPU::compile_func(func, simple_name, extern_name); } -int CodeGen_ARM::check_feasible_vscale(int vector_bits, const std::set &lanes_used, const std::string &simple_name) { +int CodeGen_ARM::check_feasible_vscale(int vector_bits, + const std::set &lanes_used, + const std::string &streaming_or_none, + const std::string &simple_name) { internal_assert(vector_bits != 0 && (vector_bits % 128) == 0); int vscale = vector_bits / 128; bool feasible = true; @@ -1318,8 +1380,9 @@ int CodeGen_ARM::check_feasible_vscale(int vector_bits, const std::set &lan if (!feasible) { user_warning << "In " << simple_name << ", Vectorization factor is not suitable of scalable vector with " + << streaming_or_none << "vector_bits=" << vector_bits - << ". Disabling SVE\n"; + << ". Disabling " << streaming_or_none << "SVE\n"; return 0; } @@ -1815,6 +1878,13 @@ void CodeGen_ARM::visit(const Store *op) { } } else if (op->index.type().is_vector()) { // Scatter + if (in_streaming) { + user_warning << "Scatter store is not vectorized in streaming mode." + << " It will result in slow performance due to scalarization.\n"; + CodeGen_CPU::visit(op); + return; + } + Type elt = op->value.type().element_of(); // Rewrite float16 case into reinterpret and Store in uint16, as it is unsupported in LLVM @@ -1991,6 +2061,12 @@ void CodeGen_ARM::visit(const Load *op) { } } else if (op->index.type().is_vector()) { // General Gather Load + if (in_streaming) { + user_warning << "Gather load is not vectorized in streaming mode." + << " It will result in slow performance due to scalarization.\n"; + CodeGen_CPU::visit(op); + return; + } // Rewrite float16 case into load in uint16 and reinterpret, as it is unsupported in LLVM if (is_float16_and_has_feature(op->type)) { @@ -2619,6 +2695,38 @@ void CodeGen_ARM::visit(const Call *op) { } } + if (op->is_intrinsic(Call::get_runtime_streaming_vscale)) { + // This intrin function must be defined independently. + // Otherwise, vscale_range(n, n) attribute is added and llvm compiler optimize away the runtime call, + // which makes runtime assertion of vscale useless. + llvm::Function *fn = module->getFunction(op->name); + if (!fn) { + FunctionType *func_t = FunctionType::get(i32_t, {}, false); + fn = llvm::Function::Create(func_t, llvm::Function::InternalLinkage, op->name, module.get()); + + llvm::BasicBlock *block = llvm::BasicBlock::Create(module->getContext(), "entry", fn); + IRBuilderBase::InsertPoint here = builder->saveIP(); + builder->SetInsertPoint(block); + + // Body + FunctionType *intrin_fn_type = FunctionType::get(i64_t, {}, false); + FunctionCallee intrin_fn = module->getOrInsertFunction("llvm.aarch64.sme.cntsd", intrin_fn_type); // codespell:ignore sme + CallInst *intrin_call = builder->CreateCall(intrin_fn, {}); + Value *i32_cntsd = builder->CreateIntCast(intrin_call, i32_t, false); + // Divide by 2, as cnts"d" returns the number of lanes for 64bit type, while vscale=1 means 128bit. + Value *ret = builder->CreateLShr(i32_cntsd, ConstantInt::get(i32_t, 1)); + builder->CreateRet(ret); + + // To avoid vscale_range(n,n) added in CodeGen_Internal + fn->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(*context, 1, 16)); + fn->addFnAttr(llvm::Attribute::NoInline); + internal_assert(!verifyFunction(*fn, &llvm::errs())); + builder->restoreIP(here); + } + value = builder->CreateCall(fn, {}); + return; + } + CodeGen_CPU::visit(op); } @@ -2715,7 +2823,8 @@ bool CodeGen_ARM::codegen_dot_product_vector_reduce(const VectorReduce *op, cons if (op->op != p.reduce_op || factor % p.factor != 0) { continue; } - if (!target.has_feature(p.required_feature)) { + if (!target.has_feature(p.required_feature) && + !(in_streaming && p.required_feature == Target::SVE2)) { continue; } if (expr_match(p.pattern, op->value, matches)) { @@ -3017,6 +3126,9 @@ string CodeGen_ARM::mattrs() const { // TODO: https://github.com/halide/Halide/issues/8872 // attrs.emplace_back("+sve"); } + if (target.has_feature(Target::SME2)) { + attrs.emplace_back("+sme2"); + } if (target.os == Target::IOS || target.os == Target::OSX) { attrs.emplace_back("+reserve-x18"); } diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index 9263aef95f75..78b72e8aa006 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -199,7 +199,7 @@ CodeGen_LLVM::CodeGen_LLVM(const Target &t) void CodeGen_LLVM::set_context(llvm::LLVMContext &context) { this->context = &context; - effective_vscale = target_vscale(); + set_effective_vscale(target_vscale()); } std::unique_ptr CodeGen_LLVM::new_for_target(const Target &target, llvm::LLVMContext &context) { diff --git a/src/DeviceAPI.h b/src/DeviceAPI.h index 12476a23b724..f9b8b6b8932a 100644 --- a/src/DeviceAPI.h +++ b/src/DeviceAPI.h @@ -24,6 +24,7 @@ enum class DeviceAPI { D3D12Compute, Vulkan, WebGPU, + Host_SMEStreaming, }; /** An array containing all the device apis. Useful for iterating @@ -38,7 +39,8 @@ const DeviceAPI all_device_apis[] = {DeviceAPI::None, DeviceAPI::HexagonDma, DeviceAPI::D3D12Compute, DeviceAPI::Vulkan, - DeviceAPI::WebGPU}; + DeviceAPI::WebGPU, + DeviceAPI::Host_SMEStreaming}; } // namespace Halide diff --git a/src/DeviceInterface.cpp b/src/DeviceInterface.cpp index 27f6b549ee7d..28525117a370 100644 --- a/src/DeviceInterface.cpp +++ b/src/DeviceInterface.cpp @@ -169,7 +169,7 @@ DeviceAPI get_default_device_api_for_target(const Target &target) { namespace Internal { Expr make_device_interface_call(DeviceAPI device_api, MemoryType memory_type) { - if (device_api == DeviceAPI::Host) { + if (device_api == DeviceAPI::Host || device_api == DeviceAPI::Host_SMEStreaming) { return make_zero(type_of()); } diff --git a/src/Func.cpp b/src/Func.cpp index 623fe30ad4d5..c1e3313b8633 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -2053,6 +2053,11 @@ Stage &Stage::hexagon(const VarOrRVar &x) { return *this; } +Stage &Stage::sme_streaming(bool enable, const VarOrRVar &x) { + set_dim_device_api(x, enable ? DeviceAPI::Host_SMEStreaming : DeviceAPI::Host); + return *this; +} + Stage &Stage::prefetch(const Func &f, const VarOrRVar &at, const VarOrRVar &from, Expr offset, PrefetchBoundStrategy strategy) { definition.schedule().touched() = true; PrefetchDirective prefetch = {f.name(), at.name(), from.name(), std::move(offset), strategy, Parameter()}; @@ -2845,6 +2850,12 @@ Func &Func::hexagon(const VarOrRVar &x) { return *this; } +Func &Func::sme_streaming(bool enable, const VarOrRVar &x) { + invalidate_cache(); + Stage(func, func.definition(), 0).sme_streaming(enable, x); + return *this; +} + Func &Func::prefetch(const Func &f, const VarOrRVar &at, const VarOrRVar &from, Expr offset, PrefetchBoundStrategy strategy) { invalidate_cache(); Stage(func, func.definition(), 0).prefetch(f, at, from, std::move(offset), strategy); diff --git a/src/Func.h b/src/Func.h index 0bfb591871c7..0d937d26dfa2 100644 --- a/src/Func.h +++ b/src/Func.h @@ -463,6 +463,8 @@ class Stage { Stage &hexagon(const VarOrRVar &x = Var::outermost()); + Stage &sme_streaming(bool enable, const VarOrRVar &x = Var::outermost()); + Stage &prefetch(const Func &f, const VarOrRVar &at, const VarOrRVar &from, Expr offset = 1, PrefetchBoundStrategy strategy = PrefetchBoundStrategy::GuardWithIf); Stage &prefetch(const Parameter ¶m, const VarOrRVar &at, const VarOrRVar &from, Expr offset = 1, @@ -2020,6 +2022,12 @@ class Func { * Hexagon, that loop is executed on a Hexagon DSP. */ Func &hexagon(const VarOrRVar &x = Var::outermost()); + /** Schedule for aarch64 SME Streaming Mode. + * When a loop is marked with sme_streaming(true), that loop including its inner loops + * are executed in Streaming mode. Marking with sme_streaming(false) prevents the loop + * from being executed in Streaming mode. */ + Func &sme_streaming(bool enable, const VarOrRVar &x = Var::outermost()); + /** Prefetch data written to or read from a Func or an ImageParam by a * subsequent loop iteration, at an optionally specified iteration offset. You may specify * specification of different vars for the location of the prefetch() instruction diff --git a/src/IR.cpp b/src/IR.cpp index c5158728f367..9ba68bafd4d7 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -629,6 +629,8 @@ const char *const intrinsic_op_names[] = { "dynamic_shuffle", "extract_bits", "extract_mask_element", + "get_runtime_vscale", + "get_runtime_streaming_vscale", "get_user_context", "gpu_thread_barrier", "halving_add", @@ -704,7 +706,6 @@ const char *const intrinsic_op_names[] = { "widening_shift_left", "widening_shift_right", "widening_sub", - "get_runtime_vscale", }; static_assert(sizeof(intrinsic_op_names) / sizeof(intrinsic_op_names[0]) == Call::IntrinsicOpCount, diff --git a/src/IR.h b/src/IR.h index 16016fca819a..3d1c8131cb5d 100644 --- a/src/IR.h +++ b/src/IR.h @@ -660,6 +660,11 @@ struct Call : public ExprNode { // Extracts a single element from a mask vector extract_mask_element, + // Returns the runtime value of ARM SVE vscale (the vector length multiplier) + get_runtime_vscale, + + // Returns the runtime value of ARM SME streaming vscale (the vector length multiplier in streaming mode) + get_runtime_streaming_vscale, get_user_context, gpu_thread_barrier, halving_add, @@ -818,9 +823,6 @@ struct Call : public ExprNode { widening_shift_right, widening_sub, - // Returns the runtime value of ARM SVE vscale (the vector length multiplier) - get_runtime_vscale, - IntrinsicOpCount // Sentinel: keep last. }; diff --git a/src/IRPrinter.cpp b/src/IRPrinter.cpp index 331e5de658b2..6851f4c39cfb 100644 --- a/src/IRPrinter.cpp +++ b/src/IRPrinter.cpp @@ -135,6 +135,9 @@ ostream &operator<<(ostream &out, const DeviceAPI &api) { case DeviceAPI::WebGPU: out << ""; break; + case DeviceAPI::Host_SMEStreaming: + out << ""; + break; } return out; } diff --git a/src/InjectHostDevBufferCopies.cpp b/src/InjectHostDevBufferCopies.cpp index 6e8b21686657..baf2f53d5a4e 100644 --- a/src/InjectHostDevBufferCopies.cpp +++ b/src/InjectHostDevBufferCopies.cpp @@ -116,7 +116,12 @@ class FindBufferUsage : public IRVisitor { << "A GPU API should have been selected by this stage in lowering\n"; DeviceAPI old = current_device_api; if (op->device_api != DeviceAPI::None) { - current_device_api = op->device_api; + // In the context of device buffer, we treat Host_SMEStreaming as Host. + if (op->device_api == DeviceAPI::Host_SMEStreaming) { + current_device_api = DeviceAPI::Host; + } else { + current_device_api = op->device_api; + } } IRVisitor::visit(op); current_device_api = old; @@ -662,6 +667,7 @@ class InjectBufferCopies : public IRMutator { Stmt visit(const For *op) override { if (op->device_api != DeviceAPI::Host && + op->device_api != DeviceAPI::Host_SMEStreaming && op->device_api != DeviceAPI::None) { // Don't enter device loops return op; diff --git a/src/Lower.cpp b/src/Lower.cpp index c248684ea5d9..a0be742aa2e6 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -44,6 +44,7 @@ #include "LICM.h" #include "LoopCarry.h" #include "LowerParallelTasks.h" +#include "LowerSMEStreamingTasks.h" #include "LowerWarpShuffles.h" #include "Memoization.h" #include "OffloadGPULoops.h" @@ -523,6 +524,21 @@ void lower_impl(const vector &output_funcs, debug(2) << "Lowering after generating parallel tasks and closures:\n" << s << "\n\n"; + debug(1) << "Lowering SME Streaming Tasks...\n"; + closure_implementations.clear(); + s = lower_sme_streaming_tasks(s, closure_implementations, pipeline_name, t); + for (size_t i = initial_lowered_function_count; i < result_module.functions().size(); i++) { + // Note that lower_parallel_tasks() appends to the end of closure_implementations + result_module.functions()[i].body = + lower_sme_streaming_tasks(result_module.functions()[i].body, closure_implementations, + result_module.functions()[i].name, t); + } + for (auto &lowered_func : closure_implementations) { + result_module.append(lowered_func); + } + debug(2) << "Lowering after generating SME streaming tasks and closures:\n" + << s << "\n\n"; + vector public_args = args; for (const auto &out : outputs) { for (const Parameter &buf : out.output_buffers()) { diff --git a/src/LowerSMEStreamingTasks.cpp b/src/LowerSMEStreamingTasks.cpp new file mode 100644 index 000000000000..5a741056e9ba --- /dev/null +++ b/src/LowerSMEStreamingTasks.cpp @@ -0,0 +1,144 @@ +#include "LowerSMEStreamingTasks.h" + +#include "Argument.h" +#include "Closure.h" +#include "IREquality.h" +#include "IRMutator.h" +#include "IRPrinter.h" +#include "InjectHostDevBufferCopies.h" + +namespace Halide { +namespace Internal { + +namespace { +constexpr int DBG = 2; + +LoweredArgument make_scalar_arg(const std::string &name, const Type &type) { + return LoweredArgument(name, Argument::Kind::InputScalar, type, 0, ArgumentEstimates()); +} + +template +LoweredArgument make_scalar_arg(const std::string &name) { + return make_scalar_arg(name, type_of()); +} + +struct LowerSMEStreamingTasks : public IRMutator { + using IRMutator::visit; + + Stmt visit(const For *loop) override { + if (loop->device_api != DeviceAPI::None) { + internal_assert(loop->device_api == DeviceAPI::Host || loop->device_api == DeviceAPI::Host_SMEStreaming); + // After this mutation, it doesn't need to be marked as SMEStreaming anymore + Stmt body; + if (equal(loop->min, loop->max)) { + body = LetStmt::make(loop->name, loop->min, loop->body); + } else { + body = For::make(loop->name, loop->min, loop->max, loop->for_type, loop->partition_policy, + DeviceAPI::None, loop->body); + } + + const bool next_is_streaming = (loop->device_api == DeviceAPI::Host_SMEStreaming); + // We extract a separate task only when transiting to/from SMEStreaming + // 1. Any(except for Host_SMEStreaming) -> Host_SMEStreaming + // 2. Host_SMEStreaming -> Host + if (in_streaming != next_is_streaming && target.has_feature(Target::SME2)) { + debug(DBG) << "Switching to " << to_streaming_str(next_is_streaming) + << " from " << to_streaming_str(in_streaming) + << " in loop " << loop->name << "\n"; + + ScopedValue streaming_state(in_streaming, next_is_streaming); + + auto s = do_as_streaming_task(body, loop->name); + return s; + } else { + return mutate(body); + } + } + + return IRMutator::visit(loop); + } + + // Create a separate function that executes the body as a streaming (or non-streaming) task. + // Inject a Call op to call the extracted task function. + // The extracted task function is added to closure_implementations, which will be added to Module. + Stmt do_as_streaming_task(Stmt &body, const std::string &name) { + auto task_name = unique_name(concat_strings(name, ".", to_streaming_str(in_streaming), ".task")); + + Closure closure; + debug(DBG) << "Closure include for " << task_name << "\n" + << body << "\n"; + closure.include(body); + + // The same name can appear as a var and a buffer. Remove the var name in this case. + for (auto const &b : closure.buffers) { + closure.vars.erase(b.first); + } + + const std::string closure_name = unique_name("streaming_closure"); + const std::string closure_arg_name = unique_name("closure_arg"); + Expr closure_struct_allocation = closure.pack_into_struct(); + Expr closure_struct = Variable::make(Handle(), closure_name); + Expr closure_struct_arg = Cast::make(type_of(), closure_struct); + auto closure_arg = make_scalar_arg(closure_arg_name); + Expr closure_arg_var = Variable::make(closure_struct_allocation.type(), closure_arg_name); + + // Mutate body recursively, where further transition may happen + body = mutate(body); + Stmt wrapped_body = closure.unpack_from_struct(closure_arg_var, body); + + const std::string new_function_name = c_print_name(task_name, false); + auto attributes = in_streaming ? LoweredFunc::Attribute::SME_STREAMING_TASK : + LoweredFunc::Attribute::SME_NONSTREAMING_TASK; + LoweredFunc closure_func{new_function_name, + std::vector{std::move(closure_arg)}, + std::move(wrapped_body), + LinkageType::Internal, + NameMangling::C, + attributes}; + closure_implementations.emplace_back(std::move(closure_func)); + + Stmt stmt = call_extern_and_assert(new_function_name, {std::move(closure_struct_arg)}); + stmt = LetStmt::make(closure_name, closure_struct_allocation, stmt); + return stmt; + } + + std::string to_streaming_str(bool is_streaming) const { + return is_streaming ? "streaming" : "nonstreaming"; + } + + LowerSMEStreamingTasks(const Target &t) : target(t) { + } + + Target target; + bool in_streaming = false; + std::vector closure_implementations; +}; + +} // namespace + +Stmt lower_sme_streaming_tasks(const Stmt &s, std::vector &closure_implementations, + const std::string &name, const Target &t) { + + LowerSMEStreamingTasks lowering_mutator(t); + Stmt result = lowering_mutator(s); + + // Main body will be dumped as part of standard lowering debugging, but closures will not be. + debug(2) << [&] { + std::stringstream ss; + for (const auto &lf : lowering_mutator.closure_implementations) { + ss << "lower_sme_streaming_tasks generated closure lowered function " << lf.name << ":\n" + << lf.body << "\n\n"; + } + return ss.str(); + }(); + + // Append to the end rather than replacing the list entirely. + closure_implementations.insert(closure_implementations.end(), + lowering_mutator.closure_implementations.begin(), + lowering_mutator.closure_implementations.end()); + + return result; +} + +} // namespace Internal +} // namespace Halide diff --git a/src/LowerSMEStreamingTasks.h b/src/LowerSMEStreamingTasks.h new file mode 100644 index 000000000000..a5568523a9b6 --- /dev/null +++ b/src/LowerSMEStreamingTasks.h @@ -0,0 +1,36 @@ +#ifndef HALIDE_LOWER_SME_STREAMING_TASKS_H +#define HALIDE_LOWER_SME_STREAMING_TASKS_H + +/** \file + * Defines a lowering pass to pull loops marked with SMEStreaming device API to a separate function. + * In aarch64 SME, an execution can switch to 'streaming mode' by smstart/smstop instruction. + * In LLVM, a function with special attributes is compiled so it transits to/from streaming mode. + * In Halide, to follow this mechanism, a For loop with sme_streaming enabled is extracted + * as a separated function called streaming task to have the different attribute in CodeGen. + * We also handle the case where a streaming task has a non-streaming loop in it. + */ + +#include +#include + +namespace Halide { + +struct Target; + +namespace Internal { + +struct Stmt; +struct LoweredFunc; + +/** + * Create a separate function that executes the body as a streaming (or non-streaming) task. + * Inject a Call op to call the extracted task function. + * The extracted task functions are appended to closure_implementations. + */ +Stmt lower_sme_streaming_tasks(const Stmt &s, std::vector &closure_implementations, + const std::string &name, const Target &t); + +} // namespace Internal +} // namespace Halide + +#endif // HALIDE_LOWER_SME_STREAMING_TASKS_H diff --git a/src/Module.cpp b/src/Module.cpp index 372669ec6e57..24500a28d93d 100644 --- a/src/Module.cpp +++ b/src/Module.cpp @@ -352,16 +352,18 @@ LoweredFunc::LoweredFunc(const std::string &name, const std::vector &args, Stmt body, LinkageType linkage, - NameMangling name_mangling) - : name(name), args(args), body(std::move(body)), linkage(linkage), name_mangling(name_mangling) { + NameMangling name_mangling, + uint64_t attributes) + : name(name), args(args), body(std::move(body)), linkage(linkage), name_mangling(name_mangling), attributes(attributes) { } LoweredFunc::LoweredFunc(const std::string &name, const std::vector &args, Stmt body, LinkageType linkage, - NameMangling name_mangling) - : name(name), body(std::move(body)), linkage(linkage), name_mangling(name_mangling) { + NameMangling name_mangling, + uint64_t attributes) + : name(name), body(std::move(body)), linkage(linkage), name_mangling(name_mangling), attributes(attributes) { for (const Argument &i : args) { this->args.emplace_back(i); } diff --git a/src/Module.h b/src/Module.h index f470ced389bf..e92f9aa6d9c7 100644 --- a/src/Module.h +++ b/src/Module.h @@ -116,16 +116,26 @@ struct LoweredFunc { * the Target. */ NameMangling name_mangling; + /** The attributes in bit flags purposed for additional information used in lowering and codegen. */ + enum Attribute : uint64_t { + NO_ATTRIBUTE = 0, + SME_STREAMING_TASK = 1 << 0, + SME_NONSTREAMING_TASK = 1 << 1, + }; + uint64_t attributes; + LoweredFunc(const std::string &name, const std::vector &args, Stmt body, LinkageType linkage, - NameMangling mangling = NameMangling::Default); + NameMangling mangling = NameMangling::Default, + uint64_t attributes = 0); LoweredFunc(const std::string &name, const std::vector &args, Stmt body, LinkageType linkage, - NameMangling mangling = NameMangling::Default); + NameMangling mangling = NameMangling::Default, + uint64_t attributes = 0); }; } // namespace Internal diff --git a/src/Profiling.cpp b/src/Profiling.cpp index 2ce8c7c5a194..e686321a0c92 100644 --- a/src/Profiling.cpp +++ b/src/Profiling.cpp @@ -492,7 +492,8 @@ class InjectProfiling : public IRMutator { body = substitute(names.profiler_instance, Variable::make(Handle(), names.hvx_profiler_instance), body); body = LetStmt::make(names.hvx_profiler_instance, get_state, body); } else if (op->device_api == DeviceAPI::None || - op->device_api == DeviceAPI::Host) { + op->device_api == DeviceAPI::Host || + op->device_api == DeviceAPI::Host_SMEStreaming) { body = mutate(body); } else { body = op->body; diff --git a/src/Target.cpp b/src/Target.cpp index 6ba2ae2044a4..fbe0c8b57d9a 100644 --- a/src/Target.cpp +++ b/src/Target.cpp @@ -48,6 +48,18 @@ #ifndef HWCAP2_SVE2 #define HWCAP2_SVE2 0 #endif +#ifndef HWCAP2_SME2 +#define HWCAP2_SME2 0 +#endif +#endif + +/* Detect SME target attribute support */ +#if defined(__aarch64__) && !defined(__arm__) && \ + ((defined(__GNUC__) && !defined(__clang__) && (__GNUC__ >= 14)) || \ + (defined(__clang__) && (__clang_major__ >= 17))) +#define HAS_ATTR_TARGET_SME 1 +#else +#define HAS_ATTR_TARGET_SME 0 #endif namespace Halide { @@ -63,7 +75,20 @@ __attribute__((target("+sve"))) int get_sve_vector_length() { __asm__("cntb %x0, all, mul #8" : "=r"(result)); return result; } + +#if HAS_ATTR_TARGET_SME +__attribute__((target("+sme"))) int get_sme_streaming_vector_length() { // codespell:ignore sme + register int result asm("w0"); + __asm__("rdsvl %x0, #8" : "=r"(result)); + return result; +} +#else +int get_sme_streaming_vector_length() { + user_error << "Trying to get streaming_vector_length where SME is supposed to be unsupported\n"; + return 0; +} #endif +#endif // defined(__aarch64__) struct cpuid_result { int eax, ebx, ecx, edx; @@ -259,6 +284,7 @@ Target calculate_host_target() { Target::Arch arch = Target::ARM; #if !defined(__arm__) bool has_scalable_vector = false; + bool has_streaming_scalable_vector = false; #endif #ifdef __APPLE__ @@ -273,7 +299,13 @@ Target calculate_host_target() { if (sysctl_is_set("hw.optional.arm.FEAT_FP16")) { initial_features.push_back(Target::ARMFp16); } +#if HAS_ATTR_TARGET_SME + if (sysctl_is_set("hw.optional.arm.FEAT_SME2")) { + initial_features.push_back(Target::SME2); + has_streaming_scalable_vector = true; + } #endif +#endif // __APPLE__ #ifdef __linux__ unsigned long hwcaps = getauxval(AT_HWCAP); @@ -299,6 +331,14 @@ Target calculate_host_target() { has_scalable_vector = true; #endif } + +#if HAS_ATTR_TARGET_SME + if (hwcaps2 & HWCAP2_SME2) { + initial_features.push_back(Target::SME2); + has_streaming_scalable_vector = true; + } +#endif + #endif #ifdef _MSC_VER @@ -306,6 +346,7 @@ Target calculate_host_target() { // https://learn.microsoft.com/en-us/windows/win32/api/processthreadsapi/nf-processthreadsapi-isprocessorfeaturepresent #define PF_ARM_SVE_INSTRUCTIONS_AVAILABLE (46) #define PF_ARM_SVE2_INSTRUCTIONS_AVAILABLE (47) +#define PF_ARM_SME2_INSTRUCTIONS_AVAILABLE (71) // This is the strategy used by Google's cpuinfo library for // detecting fp16 arithmetic support on Windows. @@ -331,12 +372,25 @@ Target calculate_host_target() { #endif } +#if HAS_ATTR_TARGET_SME + if (IsProcessorFeaturePresent(PF_ARM_SME2_INSTRUCTIONS_AVAILABLE)) { + initial_features.push_back(Target::SME2); + has_streaming_scalable_vector = true; + } +#endif #endif #if !defined(__arm__) if (has_scalable_vector) { vector_bits = get_sve_vector_length(); } + if (has_streaming_scalable_vector) { + const int streaming_vector_bits = get_sme_streaming_vector_length(); + Target::Feature sme_svl = Target::sme_svl_feature_from_bits(streaming_vector_bits); + user_assert(sme_svl != Target::FeatureEnd) + << "Detected unsupported SME streaming vector length " << streaming_vector_bits << " bits.\n"; + initial_features.push_back(sme_svl); + } #endif #else @@ -790,6 +844,12 @@ const std::map feature_name_map = { {"webgpu", Target::WebGPU}, {"sve", Target::SVE}, {"sve2", Target::SVE2}, + {"sme2", Target::SME2}, + {"sme_svl128", Target::SME_SVL128}, + {"sme_svl256", Target::SME_SVL256}, + {"sme_svl512", Target::SME_SVL512}, + {"sme_svl1024", Target::SME_SVL1024}, + {"sme_svl2048", Target::SME_SVL2048}, {"arm_dot_prod", Target::ARMDotProd}, {"arm_fp16", Target::ARMFp16}, {"llvm_large_code_model", Target::LLVMLargeCodeModel}, @@ -1069,6 +1129,12 @@ void Target::validate_features() const { NoNEON, POWER_ARCH_2_07, RVV, + SME2, + SME_SVL128, + SME_SVL256, + SME_SVL512, + SME_SVL1024, + SME_SVL2048, SVE, SVE2, VSX, @@ -1130,11 +1196,31 @@ void Target::validate_features() const { POWER_ARCH_2_07, RVV, SSE41, + SME_SVL128, + SME_SVL256, + SME_SVL512, + SME_SVL1024, + SME_SVL2048, + SME2, SVE, SVE2, VSX, }); } + + const int num_sme_svl_features = + (int)has_feature(SME_SVL128) + + (int)has_feature(SME_SVL256) + + (int)has_feature(SME_SVL512) + + (int)has_feature(SME_SVL1024) + + (int)has_feature(SME_SVL2048); + + user_assert(num_sme_svl_features <= 1) + << "Target may have at most one SME_SVL feature.\n"; + user_assert(!has_feature(SME2) || num_sme_svl_features == 1) + << "Target feature sme2 requires exactly one SME_SVL feature.\n"; + user_assert(has_feature(SME2) || num_sme_svl_features == 0) + << "Target features SME_SVL128, SME_SVL256, SME_SVL512, SME_SVL1024, and SME_SVL2048 require target feature sme2.\n"; } Target::Target(const std::string &target) { @@ -1178,6 +1264,23 @@ Target::Feature Target::feature_from_name(const std::string &name) { return Target::FeatureEnd; } +Target::Feature Target::sme_svl_feature_from_bits(int bits) { + switch (bits) { + case 128: + return Target::SME_SVL128; + case 256: + return Target::SME_SVL256; + case 512: + return Target::SME_SVL512; + case 1024: + return Target::SME_SVL1024; + case 2048: + return Target::SME_SVL2048; + default: + return Target::FeatureEnd; + } +} + std::string Target::to_string() const { string result; for (const auto &arch_entry : arch_name_map) { @@ -1539,6 +1642,35 @@ Target::Feature target_feature_for_device_api(DeviceAPI api) { } int Target::natural_vector_size(const Halide::Type &t) const { + return natural_vector_size(t, false); +} + +int Target::sme_streaming_vector_bits() const { + int result = 0; + auto set_result = [&result](int bits) { + user_assert(result == 0) + << "Target may have at most one SME_SVL feature.\n"; + result = bits; + }; + if (has_feature(Target::SME_SVL128)) { + set_result(128); + } + if (has_feature(Target::SME_SVL256)) { + set_result(256); + } + if (has_feature(Target::SME_SVL512)) { + set_result(512); + } + if (has_feature(Target::SME_SVL1024)) { + set_result(1024); + } + if (has_feature(Target::SME_SVL2048)) { + set_result(2048); + } + return result; +} + +int Target::natural_vector_size(const Halide::Type &t, bool is_sme_streaming) const { user_assert(!has_unknowns()) << "natural_vector_size cannot be used on a Target with Unknown values.\n"; @@ -1546,9 +1678,13 @@ int Target::natural_vector_size(const Halide::Type &t) const { const int data_size = t.bytes(); if (arch == Target::ARM) { - if (vector_bits != 0 && - (has_feature(Halide::Target::SVE2) || - (t.is_float() && has_feature(Halide::Target::SVE)))) { + if (is_sme_streaming && + sme_streaming_vector_bits() != 0 && + has_feature(Halide::Target::SME2)) { + return sme_streaming_vector_bits() / (data_size * 8); + } else if (vector_bits != 0 && + (has_feature(Halide::Target::SVE2) || + (t.is_float() && has_feature(Halide::Target::SVE)))) { return vector_bits / (data_size * 8); } else { return 16 / data_size; @@ -1891,11 +2027,21 @@ void target_test() { } } + // Tests for vector_bits internal_assert(Target().vector_bits == 0) << "Default Target vector_bits not 0.\n"; internal_assert(Target("arm-64-linux-sve2-vector_bits_512").vector_bits == 512) << "Vector bits not parsed correctly.\n"; - Target with_vector_bits(Target::Linux, Target::ARM, 64, Target::ProcessorGeneric, {Target::SVE}, 512); + Target with_vector_bits(Target::Linux, Target::ARM, 64, Target::ProcessorGeneric, {Target::SVE2}, 512); internal_assert(with_vector_bits.vector_bits == 512) << "Vector bits not populated in constructor.\n"; internal_assert(Target(with_vector_bits.to_string()).vector_bits == 512) << "Vector bits not round tripped properly.\n"; + internal_assert(with_vector_bits.natural_vector_size(Int(32)) == 16) << "Wrong natural_vector_size.\n"; + + // Tests for SME streaming vector length + internal_assert(Target().sme_streaming_vector_bits() == 0) << "Default Target SME SVL not 0.\n"; + internal_assert(Target::sme_svl_feature_from_bits(1024) == Target::SME_SVL1024) << "SME SVL feature lookup failed.\n"; + Target with_sme_svl(Target::Linux, Target::ARM, 64, Target::ProcessorGeneric, {Target::SVE2, Target::SME2, Target::SME_SVL1024}, 512); + internal_assert(with_sme_svl.sme_streaming_vector_bits() == 1024) << "SME SVL not populated in constructor.\n"; + internal_assert(with_sme_svl.natural_vector_size(Int(32), true) == 32) << "Wrong natural_vector_size with SME streaming.\n"; + internal_assert(with_sme_svl.natural_vector_size(Int(32), false) == 16) << "Wrong natural_vector_size without SME streaming.\n"; std::cout << "Target test passed\n"; } diff --git a/src/Target.h b/src/Target.h index 9b53ede39c62..e7417fb1493d 100644 --- a/src/Target.h +++ b/src/Target.h @@ -154,6 +154,12 @@ struct Target { WebGPU = halide_target_feature_webgpu, SVE = halide_target_feature_sve, SVE2 = halide_target_feature_sve2, + SME2 = halide_target_feature_sme2, + SME_SVL128 = halide_target_feature_sme_svl128, + SME_SVL256 = halide_target_feature_sme_svl256, + SME_SVL512 = halide_target_feature_sme_svl512, + SME_SVL1024 = halide_target_feature_sme_svl1024, + SME_SVL2048 = halide_target_feature_sme_svl2048, ARMDotProd = halide_target_feature_arm_dot_prod, ARMFp16 = halide_target_feature_arm_fp16, LLVMLargeCodeModel = halide_llvm_large_code_model, @@ -281,7 +287,8 @@ struct Target { arch == other.arch && bits == other.bits && processor_tune == other.processor_tune && - features == other.features; + features == other.features && + vector_bits == other.vector_bits; } bool operator!=(const Target &other) const { @@ -316,11 +323,19 @@ struct Target { * for that data type when compiling for this Target. */ int natural_vector_size(const Halide::Type &t) const; + /** Given a data type, return an estimate of the "natural" vector size + * for that data type in streaming mode in aarch64 SME. */ + int natural_vector_size(const Halide::Type &t, bool is_sme_streaming) const; + + /** Return the fixed SME streaming vector length in bits selected by this target, + * or 0 if no SME_SVL feature is set. */ + int sme_streaming_vector_bits() const; + /** Given a data type, return an estimate of the "natural" vector size * for that data type when compiling for this Target. */ template - int natural_vector_size() const { - return natural_vector_size(type_of()); + int natural_vector_size(bool is_sme_streaming = false) const { + return natural_vector_size(type_of(), is_sme_streaming); } /** Return true iff 64 bits and has_feature(LargeBuffers). */ @@ -372,6 +387,10 @@ struct Target { * If the string is not a known feature name, return FeatureEnd. */ static Target::Feature feature_from_name(const std::string &name); + /** Return the SME_SVL feature corresponding to an SME streaming vector + * length in bits, or FeatureEnd if no exact SME_SVL feature exists. */ + static Target::Feature sme_svl_feature_from_bits(int bits); + private: /** A bitmask that stores the active features. */ std::bitset features; diff --git a/src/runtime/HalideRuntime.h b/src/runtime/HalideRuntime.h index 46c5a89cfbb4..4881f631fe36 100644 --- a/src/runtime/HalideRuntime.h +++ b/src/runtime/HalideRuntime.h @@ -1279,6 +1279,10 @@ enum halide_error_code_t { /** Profiling failed for a pipeline invocation. */ halide_error_code_cannot_profile_pipeline = -48, + + /** "vscale" value of Streaming Scalable Vector detected in runtime does not + * match the streaming vscale value used in compilation. */ + halide_error_code_streaming_vscale_invalid = -49, }; /** Halide calls the functions below on various error conditions. The @@ -1355,6 +1359,7 @@ extern int halide_error_storage_bound_too_small(void *user_context, const char * extern int halide_error_device_crop_failed(void *user_context); extern int halide_error_split_factor_not_positive(void *user_context, const char *func_name, const char *orig, const char *outer, const char *inner, const char *factor_str, int factor); extern int halide_error_vscale_invalid(void *user_context, const char *func_name, int runtime_vscale, int compiletime_vscale); +extern int halide_error_streaming_vscale_invalid(void *user_context, const char *func_name, int runtime_vscale, int compiletime_vscale); // @} /** Optional features a compilation Target can have. @@ -1447,6 +1452,12 @@ typedef enum halide_target_feature_t { halide_target_feature_webgpu, ///< Enable the WebGPU runtime. halide_target_feature_sve, ///< Enable ARM Scalable Vector Extensions halide_target_feature_sve2, ///< Enable ARM Scalable Vector Extensions v2 + halide_target_feature_sme2, ///< Enable ARM Scalable Matrix Extensions v2 + halide_target_feature_sme_svl128, ///< Assume ARM SME streaming vector length is 128 bits. + halide_target_feature_sme_svl256, ///< Assume ARM SME streaming vector length is 256 bits. + halide_target_feature_sme_svl512, ///< Assume ARM SME streaming vector length is 512 bits. + halide_target_feature_sme_svl1024, ///< Assume ARM SME streaming vector length is 1024 bits. + halide_target_feature_sme_svl2048, ///< Assume ARM SME streaming vector length is 2048 bits. halide_target_feature_egl, ///< Force use of EGL support. halide_target_feature_arm_dot_prod, ///< Enable ARMv8.2-a dotprod extension (i.e. udot and sdot instructions) halide_target_feature_arm_fp16, ///< Enable ARMv8.2-a half-precision floating point data processing diff --git a/src/runtime/aarch64_cpu_features.cpp b/src/runtime/aarch64_cpu_features.cpp index cc591341e42d..e5b7790c71c6 100644 --- a/src/runtime/aarch64_cpu_features.cpp +++ b/src/runtime/aarch64_cpu_features.cpp @@ -18,6 +18,7 @@ extern "C" unsigned long getauxval(unsigned long type); #define HWCAP_ASIMDDP (1 << 20) #define HWCAP_SVE (1 << 22) #define HWCAP2_SVE2 (1 << 1) +#define HWCAP2_SME2 (1ULL << 37) namespace { @@ -40,6 +41,10 @@ void set_platform_features(CpuFeatures *features) { if (hwcaps2 & HWCAP2_SVE2) { halide_set_available_cpu_feature(features, halide_target_feature_sve2); } + + if (hwcaps2 & HWCAP2_SME2) { + halide_set_available_cpu_feature(features, halide_target_feature_sme2); + } } } // namespace @@ -64,6 +69,10 @@ void set_platform_features(CpuFeatures *features) { if (sysctl_is_set("hw.optional.arm.FEAT_FP16")) { halide_set_available_cpu_feature(features, halide_target_feature_arm_fp16); } + + if (sysctl_is_set("hw.optional.arm.FEAT_SME2")) { + halide_set_available_cpu_feature(features, halide_target_feature_sme2); + } } } // namespace @@ -82,6 +91,7 @@ extern "C" BOOL IsProcessorFeaturePresent(DWORD feature); // https://learn.microsoft.com/en-us/windows/win32/api/processthreadsapi/nf-processthreadsapi-isprocessorfeaturepresent #define PF_ARM_SVE_INSTRUCTIONS_AVAILABLE (46) #define PF_ARM_SVE2_INSTRUCTIONS_AVAILABLE (47) +#define PF_ARM_SME2_INSTRUCTIONS_AVAILABLE (71) namespace { @@ -104,6 +114,10 @@ void set_platform_features(CpuFeatures *features) { if (IsProcessorFeaturePresent(PF_ARM_SVE2_INSTRUCTIONS_AVAILABLE)) { halide_set_available_cpu_feature(features, halide_target_feature_sve2); } + + if (IsProcessorFeaturePresent(PF_ARM_SME2_INSTRUCTIONS_AVAILABLE)) { + halide_set_available_cpu_feature(features, halide_target_feature_sme2); + } } } // namespace @@ -126,6 +140,12 @@ extern "C" WEAK int halide_get_cpu_features(CpuFeatures *features) { halide_set_known_cpu_feature(features, halide_target_feature_no_neon); halide_set_known_cpu_feature(features, halide_target_feature_sve); halide_set_known_cpu_feature(features, halide_target_feature_sve2); + halide_set_known_cpu_feature(features, halide_target_feature_sme2); + halide_set_known_cpu_feature(features, halide_target_feature_sme_svl128); + halide_set_known_cpu_feature(features, halide_target_feature_sme_svl256); + halide_set_known_cpu_feature(features, halide_target_feature_sme_svl512); + halide_set_known_cpu_feature(features, halide_target_feature_sme_svl1024); + halide_set_known_cpu_feature(features, halide_target_feature_sme_svl2048); // All ARM architectures support "No Neon". halide_set_available_cpu_feature(features, halide_target_feature_no_neon); diff --git a/src/runtime/errors.cpp b/src/runtime/errors.cpp index acb640c44b52..30b9df99d6a8 100644 --- a/src/runtime/errors.cpp +++ b/src/runtime/errors.cpp @@ -308,4 +308,12 @@ WEAK int halide_error_vscale_invalid(void *user_context, const char *func_name, return halide_error_code_vscale_invalid; } +WEAK int halide_error_streaming_vscale_invalid(void *user_context, const char *func_name, int runtime_vscale, int compiletime_vscale) { + error(user_context) + << "The function " << func_name + << " is compiled with the assumption that streaming vscale of Scalable Vector is " << compiletime_vscale + << ". However, the detected runtime streaming vscale is " << runtime_vscale << "."; + return halide_error_code_streaming_vscale_invalid; +} + } // extern "C" diff --git a/src/runtime/runtime_api.cpp b/src/runtime/runtime_api.cpp index 734af982bf91..92f164f454b5 100644 --- a/src/runtime/runtime_api.cpp +++ b/src/runtime/runtime_api.cpp @@ -89,6 +89,7 @@ extern "C" __attribute__((used)) void *halide_runtime_api_functions[] = { (void *)&halide_error_unaligned_host_ptr, (void *)&halide_error_storage_bound_too_small, (void *)&halide_error_device_crop_failed, + (void *)&halide_error_streaming_vscale_invalid, (void *)&halide_error_vscale_invalid, (void *)&halide_float16_bits_to_double, (void *)&halide_float16_bits_to_float, diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 4918b8474b51..7b435135fcfa 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -295,6 +295,7 @@ tests(GROUPS correctness sliding_over_guard_with_if.cpp sliding_reduction.cpp sliding_window.cpp + sme_streaming.cpp sort_exprs.cpp specialize.cpp specialize_to_gpu.cpp diff --git a/test/correctness/fallback_vscale_sve.cpp b/test/correctness/fallback_vscale_sve.cpp index e8110e910339..18e5ac122a30 100644 --- a/test/correctness/fallback_vscale_sve.cpp +++ b/test/correctness/fallback_vscale_sve.cpp @@ -1,11 +1,24 @@ #include "Halide.h" +#include "parse_llvm_ir.h" #include #include +#include using namespace Halide; +bool starts_with(std::string_view str, std::string_view prefix) { + return str.size() >= prefix.size() && + str.compare(0, prefix.size(), prefix) == 0; +} + +bool ends_with(std::string_view str, std::string_view suffix) { + return str.size() >= suffix.size() && + str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0; +} + bool compile_and_check_vscale(Func &f, const std::string &name, + const std::string &suffix, const Target &t, int exp_vscale, const std::string &exp_intrin) { @@ -14,24 +27,19 @@ bool compile_and_check_vscale(Func &f, auto llvm_file_name = name + ".ll"; f.compile_to_llvm_assembly(llvm_file_name, f.infer_arguments(), t); - Internal::assert_file_exists(llvm_file_name); - std::ifstream llvm_file; - llvm_file.open(llvm_file_name); - std::string line; - // Pattern to extract "n" and "m" in "vscale_range(n,m)" - std::regex vscale_regex(R"(vscale_range\(\s*([0-9]+)\s*,\s*([0-9]+)\s*\))"); - int act_vscale = 0; bool intrin_found = false; + std::regex vscale_regex(R"(vscale_range\(\s*([0-9]+)\s*,\s*([0-9]+)\s*\))"); - while (getline(llvm_file, line)) { - // Check vscale_range - std::smatch match; - if (std::regex_search(line, match, vscale_regex) && match[1] == match[2]) { - act_vscale = std::stoi(match[1]); - } - // Check intrin - if (line.find(exp_intrin) != std::string::npos) { + for (auto &[func_name, attrs_line] : parse_llvm_ir_attributes_from_file(llvm_file_name)) { + if (starts_with(func_name, name) && ends_with(func_name, suffix)) { + // Check vscale_range + std::smatch match; + if (std::regex_search(attrs_line, match, vscale_regex) && + match[1] == match[2]) { + act_vscale = std::stoi(match[1]); + } + } else if (func_name.find(exp_intrin) != std::string::npos) { intrin_found = true; } } @@ -50,20 +58,57 @@ bool compile_and_check_vscale(Func &f, Var x("x"), y("y"); bool test_vscale(int vectorization_factor, int vector_bits, int exp_vscale) { - Func f("f"); + std::stringstream name_ss; + name_ss << "test_vscale_v" << vectorization_factor + << "_vb_" << vector_bits; + const std::string name = name_ss.str(); + + Func f(name); f(x, y) = absd(x, y); f.compute_root().vectorize(x, vectorization_factor); Target t("arm-64-linux-sve2-no_asserts-no_runtime-no_bounds_query"); t.vector_bits = vector_bits; - std::stringstream name; - name << "test_vscale_v" << vectorization_factor << "_vector_bits_" << vector_bits; + // sve or neon + std::string intrin = exp_vscale > 0 ? "llvm.aarch64.sve.sabd" : "llvm.aarch64.neon.sabd"; + + return compile_and_check_vscale(f, name, "", t, exp_vscale, intrin); +} + +bool test_streaming_vscale(int vectorization_factor, int vector_bits, int streaming_vector_bits, int exp_vscale) { + std::stringstream name_ss; + name_ss << "test_vscale_v" << vectorization_factor + << "_vb_" << vector_bits + << "_svb_" << streaming_vector_bits; + const std::string name = name_ss.str(); + + Func f(name); + f(x, y) = absd(x, y); + f.compute_root() + .sme_streaming(true) // This extracts streaming task + .vectorize(x, vectorization_factor); + + Target t("arm-64-linux-no_asserts-no_runtime-no_bounds_query"); + if (vector_bits != 0) { + t = t.with_feature(Target::SVE2); + t.vector_bits = vector_bits; + } + if (streaming_vector_bits != 0) { + t = t.with_feature(Target::SME2); + Target::Feature sme_svl = Target::sme_svl_feature_from_bits(streaming_vector_bits); + if (sme_svl == Target::FeatureEnd) { + printf("[%s] Unsupported streaming_vector_bits %d\n", name.c_str(), streaming_vector_bits); + return false; + } + t.set_feature(sme_svl); + } // sve or neon std::string intrin = exp_vscale > 0 ? "llvm.aarch64.sve.sabd" : "llvm.aarch64.neon.sabd"; - return compile_and_check_vscale(f, name.str(), t, exp_vscale, intrin); + // Check func for streaming task + return compile_and_check_vscale(f, name, "_streaming_task", t, exp_vscale, intrin); } int main(int argc, char **argv) { @@ -75,6 +120,21 @@ int main(int argc, char **argv) { ok &= test_vscale(8, 512, 4); // Regular case: with vscale=4 ok &= test_vscale(4, 512, 0); // Fallback due to + // Regular case: with streaming_vscale=4 + ok &= test_streaming_vscale(8, 128, 512, 4); + + // Fallback to non-streaming SVE due to + // with vscale=1 + ok &= test_streaming_vscale(4, 128, 512, 1); + + // Fallback to non-streaming SVE due to + // And then fallback to NEON due to + ok &= test_streaming_vscale(2, 256, 512, 0); + + // Fallback to non-streaming SVE due to + // But the target does not have non-streaming SVE2 feature + ok &= test_streaming_vscale(4, 0, 512, 0); + if (!ok) { return 1; } diff --git a/test/correctness/parse_llvm_ir.h b/test/correctness/parse_llvm_ir.h new file mode 100644 index 000000000000..4cbcc9639c2d --- /dev/null +++ b/test/correctness/parse_llvm_ir.h @@ -0,0 +1,68 @@ +#pragma once +#include "Halide.h" +#include +#include +#include +#include + +// Returns map (key: function name, val: line of attributes) +std::unordered_map parse_llvm_ir_attributes(const std::string &llvm_ir_str) { + std::unordered_map result; + + // attribute id (#N) -> list of functions waiting for it + std::unordered_map> pending; + std::istringstream iss(llvm_ir_str); + std::string line; + + // define|declare ... @func(...) ... #N + std::regex func_regex(R"(^\s*(define|declare)\b.*@([A-Za-z_.$][\w.$]*)\b.*#(\d+).*)"); + + // attributes #N = + std::regex attr_def_regex(R"(^\s*attributes\s+#\d+\s*=.*)"); + + while (std::getline(iss, line)) { + std::smatch m; + + // function definition + if (std::regex_match(line, m, func_regex)) { + std::string func_name = m[2].str(); + std::string attr_id = "#" + m[3].str(); + pending[attr_id].push_back(func_name); + continue; + } + + // attribute definition + if (std::regex_match(line, m, attr_def_regex)) { + std::string &attr_line = line; + + std::smatch id_match; + if (std::regex_search(attr_line, id_match, std::regex(R"(#\d+)"))) { + std::string attr_id = id_match[0].str(); + + auto it = pending.find(attr_id); + if (it != pending.end()) { + for (const auto &func : it->second) { + result[func] = attr_line; + } + pending.erase(it); + } + } + } + } + + return result; +} + +std::unordered_map parse_llvm_ir_attributes_from_file(const std::string &llvm_file_name) { + Halide::Internal::assert_file_exists(llvm_file_name); + std::ifstream llvm_ir; + llvm_ir.open(llvm_file_name, std::ios::in); + if (!llvm_ir.is_open()) { + std::cerr << "Error: cannot open file: " << llvm_file_name << "\n"; + return {}; + } + std::ostringstream buffer; + buffer << llvm_ir.rdbuf(); + llvm_ir.close(); + return parse_llvm_ir_attributes(buffer.str()); +} diff --git a/test/correctness/simd_op_check.h b/test/correctness/simd_op_check.h index b8ed0118c9b3..00b09e608910 100644 --- a/test/correctness/simd_op_check.h +++ b/test/correctness/simd_op_check.h @@ -149,6 +149,7 @@ class SimdOpCheckTest { Target::SSE41, Target::SVE, Target::SVE2, + Target::SME2, Target::VSX, }) { if (target.has_feature(f) != host_target.has_feature(f)) { @@ -324,6 +325,7 @@ class SimdOpCheckTest { f(x, y) = e; f.bound(x, 0, W).vectorize(x, vector_width); f.compute_root(); + apply_additional_schedule(f); // Include a scalar version Halide::Func f_scalar("scalar_" + name); @@ -342,12 +344,13 @@ class SimdOpCheckTest { // Do the reduction separately in f_scalar g.clone_in(f_scalar); - g.compute_at(f, x) - .update() - .split(x, xo, xi, vector_width) - .atomic(true) - .vectorize(g.rvars()[0]) - .vectorize(xi); + auto stage = g.compute_at(f, x) + .update() + .split(x, xo, xi, vector_width) + .atomic(true) + .vectorize(g.rvars()[0]) + .vectorize(xi); + apply_additional_schedule(stage); } compile_and_check(f, op, name, vector_width, arg_types, error_msg); @@ -486,6 +489,14 @@ class SimdOpCheckTest { return Halide::Tools::ThreadPool::num_processors_online(); } + virtual void apply_additional_schedule(Stage &stage) const { + return; + } + + virtual void apply_additional_schedule(Func &f) const { + return; + } + virtual bool test_all() { /* First add some tests based on the target */ add_tests(); diff --git a/test/correctness/simd_op_check_sve2.cpp b/test/correctness/simd_op_check_sve2.cpp index 4f58556682fe..39f8c839bc3e 100644 --- a/test/correctness/simd_op_check_sve2.cpp +++ b/test/correctness/simd_op_check_sve2.cpp @@ -31,20 +31,31 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { cout << "HL_TARGET is: " << target.to_string() << endl; cout << "HL_JIT_TARGET is: " << jit_target.to_string() << endl; - auto is_runtime_compatible = [](const Target &t1, const Target &t2) -> bool { + auto is_runtime_compatible = [](const Target &t1, const Target &t2, + const vector features = {Target::SVE2}) -> bool { bool yes = true; yes &= (t1.arch == t2.arch && t1.bits == t2.bits && t1.os == t2.os); yes &= (t1.vector_bits == t2.vector_bits); // A bunch of feature flags also need to match between the // compiled code and the host in order to run the code. - for (Target::Feature f : {Target::SVE2}) { + for (Target::Feature f : features) { yes &= (t1.has_feature(f) == t2.has_feature(f)); } return yes; }; - can_run_the_code = is_runtime_compatible(host, target) && is_runtime_compatible(jit_target, target); + if (target.has_feature(Target::SME2)) { + // In this case, we run tests for streaming mode in SME2 but not for SVE2. + // At the moment, host with native SME2 is unavailable, so check only JIT target. + can_run_the_code = is_runtime_compatible(jit_target, target, + {Target::SME2, Target::SME_SVL128, + Target::SME_SVL256, Target::SME_SVL512, + Target::SME_SVL1024, Target::SME_SVL2048}); + } else { + // Run tests for SVE2, so don't care SME2. + can_run_the_code = is_runtime_compatible(host, target) && is_runtime_compatible(jit_target, target); + } if (!can_run_the_code) { debug(0) << "[WARN] To perform verification of realization, " @@ -65,6 +76,17 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { check_arm_pairwise(); } + void apply_additional_schedule(Stage &stage) const override { + if (target.has_feature(Target::SME2)) { + stage.sme_streaming(true); + } + } + void apply_additional_schedule(Func &f) const override { + if (target.has_feature(Target::SME2)) { + f.sme_streaming(true); + } + } + private: void check_arm_integer() { @@ -98,20 +120,21 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { // to peephole match any with vector, so we just try 64-bits, 128 // bits, 192 bits, and 256 bits for everything. std::vector simd_bit_widths; - if (has_neon()) { + + if (has_sve_or_sme()) { + simd_bit_widths.push_back(native_vector_bits()); + } else if (has_neon()) { simd_bit_widths.push_back(64); simd_bit_widths.push_back(128); } - if (has_sve() && ((target.vector_bits > 128) || !has_neon())) { - simd_bit_widths.push_back(target.vector_bits); - } + for (auto &total_bits : simd_bit_widths) { const int vf = total_bits / bits; // Due to workaround for SVE LLVM issues, in case of vector of half length of natural_lanes, // there is some inconsistency in generated SVE instruction about the number of lanes. // So the verification of lanes is skipped for this specific case. - const int instr_lanes = (total_bits == 64 && has_sve()) ? + const int instr_lanes = (total_bits == 64 && has_sve_or_sme()) ? Instruction::ANY_LANES : Instruction::get_instr_lanes(bits, vf, target); const int widen_lanes = Instruction::get_instr_lanes(bits * 2, vf, target); @@ -124,14 +147,14 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { AddTestFunctor add_16_32(*this, bits, instr_lanes, vf, bits == 16 || bits == 32); AddTestFunctor add_32(*this, bits, instr_lanes, vf, bits == 32); - AddTestFunctor add_8_16_32_widen(*this, bits, widen_lanes, vf, bits != 64 && !has_sve()); + AddTestFunctor add_8_16_32_widen(*this, bits, widen_lanes, vf, bits != 64 && !has_sve_or_sme()); - AddTestFunctor add_16_32_64_narrow(*this, bits, narrow_lanes, vf * 2, bits != 8 && !has_sve()); - AddTestFunctor add_16_32_narrow(*this, bits, narrow_lanes, vf * 2, (bits == 16 || bits == 32) && !has_sve()); - AddTestFunctor add_16_narrow(*this, bits, narrow_lanes, vf * 2, bits == 16 && !has_sve()); + AddTestFunctor add_16_32_64_narrow(*this, bits, narrow_lanes, vf * 2, bits != 8 && !has_sve_or_sme()); + AddTestFunctor add_16_32_narrow(*this, bits, narrow_lanes, vf * 2, (bits == 16 || bits == 32) && !has_sve_or_sme()); + AddTestFunctor add_16_narrow(*this, bits, narrow_lanes, vf * 2, bits == 16 && !has_sve_or_sme()); // VABA I - Absolute Difference and Accumulate - if (!has_sve()) { + if (!has_sve_or_sme()) { // Relying on LLVM to detect accumulation add_8_16_32(sel_op("vaba.s", "saba"), i_1 + absd(i_2, i_3)); add_8_16_32(sel_op("vaba.u", "uaba"), u_1 + absd(u_2, u_3)); @@ -222,7 +245,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { // We skip this // VCNT I - Count Number of Set Bits - if (!has_sve()) { + if (!has_sve_or_sme()) { // In NEON, there is only cnt for bytes, and then horizontal adds. add_8_16_32({{sel_op("vcnt.", "cnt"), 8, total_bits == 64 ? 8 : 16}}, vf, popcount(i_1)); add_8_16_32({{sel_op("vcnt.", "cnt"), 8, total_bits == 64 ? 8 : 16}}, vf, popcount(u_1)); @@ -409,7 +432,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { add_8_16_32(sel_op("vqshl.u", "uqshl"), cast_u(min(widen_u(u_1) * 16, max_u))); // VQSHLU I - Saturating Shift Left Unsigned - if (!has_sve()) { + if (!has_sve_or_sme()) { add_8_16_32(sel_op("vqshlu.s", "sqshlu"), satcast_u(widen_i(i_1) * 16)); } @@ -467,7 +490,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { add_16_32_narrow(sel_op("vrshrn.i", "rshrn"), narrow_u((widen_u(u_1) + (1 << (bits / 4))) >> (bits / 4 + 1))); // VRSRA I - Rounding Shift Right and Accumulate - if (!has_sve()) { + if (!has_sve_or_sme()) { // Relying on LLVM to detect accumulation add_8_16_32(sel_op("vrsra.s", "srsra"), i_2 + cast_i((widen_i(i_1) + 1) >> 1)); add_8_16_32(sel_op("vrsra.u", "ursra"), i_2 + cast_u((widen_u(u_1) + 1) >> 1)); @@ -481,7 +504,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { add_all_vec(sel_op("vshl.i", "shl", "lsl"), i_1 * 16); add_all_vec(sel_op("vshl.i", "shl", "lsl"), u_1 * 16); - if (!has_sve()) { // No equivalent instruction in SVE. + if (!has_sve_or_sme()) { // No equivalent instruction in SVE. add_all_vec(sel_op("vshl.s", "sshl"), i_1 << shift); add_all_vec(sel_op("vshl.s", "sshl"), i_1 >> shift); add_all_vec(sel_op("vshl.u", "ushl"), u_1 << shift); @@ -507,7 +530,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { // I guess this could be used for (x*256) | (y & 255)? We don't do bitwise ops on integers, so skip it. // VSRA I - Shift Right and Accumulate - if (!has_sve()) { + if (!has_sve_or_sme()) { // Relying on LLVM to detect accumulation add_all_vec(sel_op("vsra.s", "ssra"), i_2 + i_1 / 16); add_all_vec(sel_op("vsra.u", "usra"), u_2 + u_1 / 16); @@ -563,14 +586,14 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { } std::vector simd_bit_widths; - if (has_sve()) { - simd_bit_widths.push_back(target.vector_bits); + if (has_sve_or_sme()) { + simd_bit_widths.push_back(native_vector_bits()); } else if (has_neon()) { simd_bit_widths.push_back(64); simd_bit_widths.push_back(128); } - if (bits != 64) { + if (bits != 64 && !has_sme()) { // Add scalar case to verify float16 native operation simd_bit_widths.push_back(bits); } @@ -591,7 +614,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { add(sel_op("vsub.f", "fsub"), f_1 - f_2); add(sel_op("vmul.f", "fmul"), f_1 * f_2); add("fdiv", sel_op("vdiv.f", "fdiv", "(fdiv|fdivr)"), f_1 / f_2_clamped); - auto fneg_lanes = has_sve() ? force_vectorized_lanes : instr_lanes; + auto fneg_lanes = has_sve_or_sme() ? force_vectorized_lanes : instr_lanes; add({{sel_op("vneg.f", "fneg"), bits, fneg_lanes}}, vf, -f_1); add({{sel_op("vsqrt.f", "fsqrt"), bits, force_vectorized_lanes}}, vf, sqrt(f_1_clamped)); @@ -624,8 +647,8 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { add_arm64("fmin", is_vector ? "fmin" : "fminnm", min(f_1, f_2)); if (bits != 64 && total_bits != 192) { // Halide relies on LLVM optimization for this pattern, and in some case it doesn't work - add_arm64("fmla", is_vector ? (has_sve() ? "(fmla|fmad)" : "fmla") : "fmadd", f_1 + f_2 * f_3); - add_arm64("fmls", is_vector ? (has_sve() ? "(fmls|fmsb)" : "fmls") : "fmsub", f_1 - f_2 * f_3); + add_arm64("fmla", is_vector ? (has_sve_or_sme() ? "(fmla|fmad)" : "fmla") : "fmadd", f_1 + f_2 * f_3); + add_arm64("fmls", is_vector ? (has_sve_or_sme() ? "(fmls|fmsb)" : "fmls") : "fmsub", f_1 - f_2 * f_3); } if (bits != 64) { add_arm64(vector{"frecpe", "frecps"}, fast_inverse(f_1_clamped)); @@ -637,7 +660,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { // and then lowered to Internal::halide_xxx() function. // In case the target has FP16 feature, native type conversion between fp16 and fp32 should be generated // instead of emulated equivalent code with other types. - if (is_vector && !has_sve()) { + if (is_vector && !has_sve_or_sme()) { add_arm64("exp", {{"fcvtl", 16, 4}, {"fcvtn", 16, 4}}, vf, exp(f_1_clamped)); add_arm64("log", {{"fcvtl", 16, 4}, {"fcvtn", 16, 4}}, vf, log(f_1_clamped)); add_arm64("pow", {{"fcvtl", 16, 4}, {"fcvtn", 16, 4}}, vf, pow(f_1_clamped, f_2_clamped)); @@ -655,7 +678,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { add_arm64("finite", is_vector ? sel_op("", "fcmge", "fcmeq") : "", is_inf(f_1)); } - if (bits == 16) { + if (bits == 16 && target.os != Target::IOS && target.os != Target::OSX) { // Actually, the following ops are not vectorized because SIMD instruction is unavailable. // The purpose of the test is just to confirm no error. // In case the target has FP16 feature, native type conversion between fp16 and fp32 should be generated @@ -683,7 +706,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { vector> test_params = { {Int(8), in_i8}, {Int(16), in_i16}, {Int(32), in_i32}, {Int(64), in_i64}, {UInt(8), in_u8}, {UInt(16), in_u16}, {UInt(32), in_u32}, {UInt(64), in_u64}, {Float(16), in_f16}, {Float(32), in_f32}, {Float(64), in_f64}}; - const int base_vec_bits = has_sve() ? target.vector_bits : 128; + const int base_vec_bits = native_vector_bits(); const int vscale = base_vec_bits / 128; for (const auto &[elt, in_im] : test_params) { @@ -702,7 +725,8 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { // which makes it prone to false-positive detection as we only search strings line-by-line. // LDn - Structured Load strided elements - if (Halide::Internal::get_llvm_version() >= 220) { + if (target.os != Target::IOS && target.os != Target::OSX && + Halide::Internal::get_llvm_version() >= 220) { for (int stride = 2; stride <= 4; ++stride) { for (int factor : {1, 2, 4}) { @@ -717,7 +741,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { Expr load_n = in_im(x * stride) + in_im(x * stride + stride - 1); const string ldn_str = "ld" + to_string(stride); - if (has_sve()) { + if (has_sve_or_sme()) { add_ldn({get_sve_ls_instr(ldn_str, bits)}, vector_lanes, load_n); } else { add_ldn(sel_op("v" + ldn_str + ".", ldn_str), load_n); @@ -735,7 +759,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { Expr load_n = in_im(x * stride) + in_im(x * stride + stride - 1); - if (has_sve()) { + if (has_sve_or_sme()) { add("tbl", load_n); } } @@ -754,10 +778,10 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { tmp1(x) = cast(elt, x); tmp1.compute_root(); tmp2(x, y) = select(x % 2 == 0, tmp1(x / 2), tmp1(x / 2 + 16)); - tmp2.compute_root().vectorize(x, total_lanes); + apply_additional_schedule(tmp2.compute_root().vectorize(x, total_lanes)); Expr store_2 = tmp2(0, 0) + tmp2(0, 127); - if (has_sve()) { + if (has_sve_or_sme()) { add_stn({get_sve_ls_instr("st2", bits)}, total_lanes, store_2); } else { add_stn(sel_op("vst2.", "st2"), store_2); @@ -780,10 +804,10 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { tmp1.compute_root(); Expr e = (tmp1(x / 2) * 2 + 7) / 4; tmp2(x, y) = select(x % 2 == 0, e * 3, e + 17); - tmp2.compute_root().vectorize(x, total_lanes); + apply_additional_schedule(tmp2.compute_root().vectorize(x, total_lanes)); Expr store_2 = tmp2(0, 0) + tmp2(0, 127); - if (has_sve()) { + if (has_sve_or_sme()) { add_stn({get_sve_ls_instr("st2", bits)}, total_lanes, store_2); } else { add_stn(sel_op("vst2.", "st2"), store_2); @@ -806,10 +830,10 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { tmp2(x, y) = select(x % 3 == 0, tmp1(x / 3), x % 3 == 1, tmp1(x / 3 + 16), tmp1(x / 3 + 32)); - tmp2.compute_root().vectorize(x, total_lanes); + apply_additional_schedule(tmp2.compute_root().vectorize(x, total_lanes)); Expr store_3 = tmp2(0, 0) + tmp2(0, 127); - if (has_sve()) { + if (has_sve_or_sme()) { if (Halide::Internal::get_llvm_version() >= 220) { add_stn({get_sve_ls_instr("st3", bits)}, total_lanes, store_3); } @@ -835,10 +859,10 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { x % 4 == 1, tmp1(x / 4 + 16), x % 4 == 2, tmp1(x / 4 + 32), tmp1(x / 4 + 48)); - tmp2.compute_root().vectorize(x, total_lanes); + apply_additional_schedule(tmp2.compute_root().vectorize(x, total_lanes)); Expr store_4 = tmp2(0, 0) + tmp2(0, 127); - if (has_sve()) { + if (has_sve_or_sme()) { if (Halide::Internal::get_llvm_version() >= 220) { add_stn({get_sve_ls_instr("st4", bits)}, total_lanes, store_4); } @@ -848,7 +872,8 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { } // SVE Gather/Scatter - if (has_sve()) { + // Not supported in streaming mode in SME2.0 + if (has_sve() && !has_sme()) { for (float factor : {0.5f, 1.f, 2.f}) { const int width = base_vec_bits * factor; const int total_lanes = width / bits; @@ -900,7 +925,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { {64, in_i64, in_u64, i64, i64, u64, u64}, }; - const int base_vec_bits = has_sve() ? target.vector_bits : 128; + const int base_vec_bits = native_vector_bits(); const int vscale = base_vec_bits / 128; for (const auto &[bits, in_i, in_u, widen_i, widenx4_i, widen_u, widenx4_u] : test_params) { @@ -913,7 +938,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { const int widen_lanes = Instruction::get_instr_lanes(bits, vf * 2, target); AddTestFunctor add_widen(*this, bits, widen_lanes, vf, bits != 64); - if (!has_sve()) { + if (!has_sve_or_sme()) { // VPADD I, F - Pairwise Add // VPMAX I, F - Pairwise Maximum // VPMIN I, F - Pairwise Minimum @@ -954,7 +979,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { } const bool is_arm_dot_prod_available = (!is_arm32() && target.has_feature(Target::ARMDotProd) && bits == 8) || - (has_sve() && (bits == 8 || bits == 16)); + (has_sve_or_sme() && (bits == 8 || bits == 16)); if ((bits == 8 || bits == 16) && !is_arm_dot_prod_available) { // udot/sdot is applied if available int f = 4; RDom r(0, f); @@ -1011,7 +1036,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { {64, in_f64}, }; - if (!has_sve()) { + if (!has_sve_or_sme()) { for (const auto &[bits, in_f] : test_params) { for (auto &total_bits : {64, 128}) { const int vf = total_bits / bits; @@ -1057,9 +1082,12 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { : opcode(opcode), operand(nullopt), bits(bits), pattern_lanes(lanes) { } + static bool is_sve_instr(const Target &target) { + return target.features_any_of({Target::SVE, Target::SVE2, Target::SME2}); + } + string generate_pattern(const Target &target) const { bool is_arm32 = target.bits == 32; - bool has_sve = target.has_feature(Target::SVE2); string opcode_pattern; string operand_pattern; @@ -1067,7 +1095,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { if (is_arm32) { opcode_pattern = get_opcode_neon32(); operand_pattern = get_reg_neon32(); - } else if (!has_sve) { + } else if (!is_sve_instr(target)) { opcode_pattern = opcode; operand_pattern = get_reg_neon64(); } else { @@ -1093,7 +1121,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { static int get_force_vectorized_instr_lanes(int bits, int vec_factor, const Target &target) { // For some cases, where scalar operation is forced to vectorize - if (target.has_feature(Target::SVE2)) { + if (is_sve_instr(target)) { if (vec_factor == 1) { return 1; } else { @@ -1123,18 +1151,13 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { } string get_reg_sve() const { - if (pattern_lanes == ANY_LANES) { - return R"((z\d\d?\.[bhsd])|(s\d\d?))"; + const char *bits_designator = get_bits_designator(bits.value()); + if (pattern_lanes == 1) { + return std::string(bits_designator) + R"(\d\d?)"; // e.g. "h15" + } else if (pattern_lanes == ANY_LANES) { + return R"(z\d\d?\.[bhsd])"; } else { - const char *bits_designator = get_bits_designator(bits.value()); - // TODO(need issue): This should only match the scalar register, and likely a NEON instruction opcode. - // Generating a full SVE vector instruction for a scalar operation is inefficient. However this is - // happening and fixing it involves changing intrinsic selection. Likely to use NEON intrinsics where - // applicable. For now, accept both a scalar operation and a vector one. - std::string scalar_reg_pattern = (pattern_lanes > 1) ? "" : std::string("|(") + bits_designator + R"(\d\d?))"; // e.g. "h15" - - return std::string(R"(((z\d\d?\.)") + bits_designator + ")|(" + - R"(v\d\d?\.)" + to_string(pattern_lanes.value()) + bits_designator + ")" + scalar_reg_pattern + ")"; + return std::string(R"(z\d\d?\.)") + bits_designator; // e.g. "z15.h" } } @@ -1147,7 +1170,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { if (pattern_lanes == 1) { return std::string(bits_designator) + R"(\d\d?)"; // e.g. "h15" } else if (pattern_lanes == ANY_LANES) { - return R"(v\d\d?\.[bhsd])"; + return R"(v\d\d?\.\d\d?[bhsd])"; } else { return R"(v\d\d?\.)" + to_string(pattern_lanes.value()) + bits_designator; // e.g. "v15.4h" } @@ -1392,9 +1415,9 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { } inline const string &sel_op(const string &neon32, const string &neon64, const string &sve) { - return is_arm32() ? neon32 : - target.has_feature(Target::SVE) || target.has_feature(Target::SVE2) ? sve : - neon64; + return is_arm32() ? neon32 : + target.features_any_of({Target::SVE, Target::SVE2, Target::SME2}) ? sve : + neon64; } inline bool is_arm32() const { @@ -1406,9 +1429,22 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { inline bool has_sve() const { return target.has_feature(Target::SVE2); }; + inline bool has_sme() const { + return target.has_feature(Target::SME2); + }; + inline bool has_sve_or_sme() const { + return has_sve() || has_sme(); + }; + + int native_vector_bits() const { + // In this test, if target has SME, we run test in streaming mode, + // so the target's SME_SVL feature is applied. + return target.natural_vector_size(Int(8), has_sme()) * 8; + } bool is_float16_supported() const { - return (target.bits == 64) && target.features_any_of({Target::ARMFp16, Target::SVE, Target::SVE2}); + return (target.bits == 64) && + target.features_any_of({Target::ARMFp16, Target::SVE, Target::SVE2, Target::SME2}); } bool can_run_the_code; @@ -1425,15 +1461,30 @@ int main(int argc, char **argv) { return 0; } - return SimdOpCheckTest::main( - argc, argv, - { - // IMPORTANT: - // When adding new targets here, make sure to also update - // can_run_code in simd_op_check.h to include any new features used. - - Target("arm-64-linux-sve2-no_neon-vector_bits_128"), - Target("arm-64-linux-sve2-no_neon-vector_bits_256"), - Target("arm-64-linux-sve2-no_neon-vector_bits_512"), - }); + // IMPORTANT: + // When adding new targets here, make sure to also update + // can_run_code() in this file to include any new features used. + + std::vector targets; + if (auto target_set_env = Internal::get_env_variable("HL_SIMDOPCHECK_SVE2_TARGET"); !target_set_env.empty()) { + // Only test with the target set by environmental variable + targets.emplace_back(target_set_env); + } else { + targets.emplace_back("arm-64-linux-sve2-vector_bits_128"); + targets.emplace_back("arm-64-linux-sve2-vector_bits_256"); + targets.emplace_back("arm-64-linux-sve2-vector_bits_512"); + + // For SME2, try to select a target which runs natively if possible. + auto host_target = get_host_target(); + int svb = host_target.has_feature(Target::SME2) ? host_target.sme_streaming_vector_bits() : 512; + Target::Feature sme_svl = Target::sme_svl_feature_from_bits(svb); + if (sme_svl == Target::FeatureEnd) { + std::cerr << "Unsupported SME SVL " << svb << "\n"; + return 1; + } + auto sme_target = Target(host_target.os, Target::ARM, 64, Target::ProcessorGeneric, {Target::SME2, sme_svl}); + targets.emplace_back(sme_target); + } + + return SimdOpCheckTest::main(argc, argv, targets); } diff --git a/test/correctness/sme_streaming.cpp b/test/correctness/sme_streaming.cpp new file mode 100644 index 000000000000..a648fe3c9684 --- /dev/null +++ b/test/correctness/sme_streaming.cpp @@ -0,0 +1,474 @@ +#include "Halide.h" +#include "parse_llvm_ir.h" +#include +#include + +using namespace Halide; +using namespace Halide::Internal; +using Attribute = Internal::LoweredFunc::Attribute; + +constexpr bool DEBUG = false; +bool can_run_code = false; + +using TaskCall = std::pair; + +inline std::ostream &operator<<(std::ostream &os, Attribute attr) { + switch (attr) { + case Attribute::NO_ATTRIBUTE: + return os << "Default"; + case Attribute::SME_STREAMING_TASK: + return os << "SME_STREAMING_TASK"; + case Attribute::SME_NONSTREAMING_TASK: + return os << "SME_NONSTREAMING_TASK"; + default: + return os << "Attribute::Unknown"; + } +} + +Attribute get_streaming_attr(uint64_t attrs) { + int num_streaming_attrs = 0; + Attribute ret = Attribute::NO_ATTRIBUTE; + for (const auto &attr : {Attribute::SME_STREAMING_TASK, Attribute::SME_NONSTREAMING_TASK}) { + if (attr & attrs) { + ret = attr; + num_streaming_attrs++; + } + } + assert(num_streaming_attrs <= 1); // Should not have both + return ret; +} + +bool check_calling_tasks(const std::string &name, + Module &m, + const std::map &func_attr_map, + const std::multiset &exp_calls) { + + std::multiset act_calls{}; + for (auto &caller : m.functions()) { + auto caller_attr = get_streaming_attr(caller.attributes); + + std::vector callees; + visit_with(caller.body, [&](auto *self, const Call *op) { + callees.emplace_back(op->name); + self->visit_base(op); + }); + + for (const auto &callee : callees) { + if (const auto itr = func_attr_map.find(callee); itr != func_attr_map.end()) { + auto callee_attr = itr->second; + if (callee_attr == Attribute::SME_STREAMING_TASK || + callee_attr == Attribute::SME_NONSTREAMING_TASK) { + act_calls.emplace(caller_attr, callee_attr); + if (DEBUG) { + std::cout << name << ": " << caller_attr << " -> " << callee_attr << "\n"; + } + } + } + } + } + + if (act_calls != exp_calls) { + std::cerr << name << " failed! Calling tasks does not match.\n"; + auto print_calls = [](const std::string &header, const std::multiset &calls) { + std::cout << header << ":\n"; + for (const auto &[caller, callee] : calls) { + std::cout << " " << caller << " -> " << callee << "\n"; + } + }; + + print_calls("expected", exp_calls); + print_calls("actual", act_calls); + return false; + } + + return true; +} + +bool check_llvm_attribute(const std::string &name, + Module &m, + const std::map &func_attr_map) { + // Verify that streaming task must have LLVM function attribute "aarch64_pstate_sm_body" + // and vice versa for non-streaming task + auto llvm_file_name = name + ".ll"; + m.compile({{OutputFileType::llvm_assembly, llvm_file_name}}); + std::unordered_map attributes = parse_llvm_ir_attributes_from_file(llvm_file_name); + + auto check_streaming_attribute = [&](const std::string f_name, bool expect_streaming) -> bool { + if (auto it = attributes.find(f_name); it != attributes.end()) { + bool has_streaming_attribute = (it->second.find("aarch64_pstate_sm_body") != std::string::npos); + if (has_streaming_attribute != expect_streaming) { + std::cerr << "Streaming attribute does not match in " << f_name << "\n"; + return false; + } + return true; + } else { + std::cerr << "Cannot find function in llvm-ir: " << f_name << "\n"; + return false; + } + }; + + for (const auto &[func_name, func_attr] : func_attr_map) { + if (func_attr == Attribute::SME_STREAMING_TASK || + func_attr == Attribute::SME_NONSTREAMING_TASK) { + bool should_be_streaming = (func_attr == Attribute::SME_STREAMING_TASK); + if (!check_streaming_attribute(func_name, should_be_streaming)) { + return false; + } + } + } + + return true; +} + +bool check_correctness(const std::string &name, Func &f) { + auto target = get_jit_target_from_environment(); + constexpr int WIDTH = 1000; + Buffer with_streaming = f.realize({WIDTH}, target); + Target target_without_sme = target.without_feature(Target::SME2) + .without_feature(Target::SME_SVL128) + .without_feature(Target::SME_SVL256) + .without_feature(Target::SME_SVL512) + .without_feature(Target::SME_SVL1024) + .without_feature(Target::SME_SVL2048); + Buffer without_streaming = f.realize({WIDTH}, target_without_sme); + for (int x = 0; x < WIDTH; x++) { + if (with_streaming(x) != without_streaming(x)) { + std::cerr << "im(" << x << ") = " << with_streaming(x) + << " instead of " << without_streaming(x) + << " in " << name << "\n"; + return false; + } + } + + return true; +} + +bool check(Func &f, const std::string &name, + const std::multiset &exp_calls) { + + Target target("arm-64-linux-sme2-sme_svl512-no_asserts-no_runtime-no_bounds_query"); + Module m = f.compile_to_module(f.infer_arguments(), "", target); + + if (DEBUG) { + f.print_loop_nest(); + m.compile({{OutputFileType::stmt, "/dev/stdout"}}); + } + + std::map func_attr_map; + for (auto &func : m.functions()) { + func_attr_map.emplace(func.name, get_streaming_attr(func.attributes)); + } + + if (!check_calling_tasks(name, m, func_attr_map, exp_calls)) { + return false; + }; + + if (!check_llvm_attribute(name, m, func_attr_map)) { + return false; + }; + + if (can_run_code) { + if (!check_correctness(name, f)) { + return false; + } + } + + return true; +} + +Var x("x"), xo("xo"), xi("xi"); + +bool test_1_stage_non_streaming() { + const std::string name("test_1_stage_non_streaming"); + Func f("f"); + + f(x) = x * 0.1f; + + f.compute_root(); + + std::multiset expected_calls{}; + return check(f, name, expected_calls); +} + +bool test_1_stage_streaming_outermost() { + const std::string name("test_1_stage_streaming_outermost"); + Func f("f"); + + f(x) = x * 0.1f; + + f.compute_root().sme_streaming(true); + + std::multiset expected_calls{ + {Attribute::NO_ATTRIBUTE, Attribute::SME_STREAMING_TASK}}; + return check(f, name, expected_calls); +} + +bool test_1_stage_streaming_inner() { + const std::string name("test_1_stage_streaming_inner"); + Func f("f"); + + f(x) = x * 0.1f; + + f.compute_root() + .split(x, xo, xi, 256) + .sme_streaming(true, xi); + + std::multiset expected_calls{ + {Attribute::NO_ATTRIBUTE, Attribute::SME_STREAMING_TASK}}; + return check(f, name, expected_calls); +} + +bool test_2_stages_both_streaming() { + const std::string name("test_2_stages_both_streaming"); + Func f("f"); + Func g("g"); + + f(x) = x * 0.1f; + g(x) = f(x) * f(x); + + g.compute_root().sme_streaming(true, x); + f.compute_root().sme_streaming(true, x); + + std::multiset expected_calls{ + {Attribute::NO_ATTRIBUTE, Attribute::SME_STREAMING_TASK}, + {Attribute::NO_ATTRIBUTE, Attribute::SME_STREAMING_TASK}}; + return check(g, name, expected_calls); +} + +bool test_2_stages_producer_streaming() { + const std::string name("test_2_stages_producer_streaming"); + Func f("f"); + Func g("g"); + + f(x) = x * 0.1f; + g(x) = f(x) * f(x); + + g.compute_root(); + f.compute_root().sme_streaming(true, x); + + std::multiset expected_calls{ + {Attribute::NO_ATTRIBUTE, Attribute::SME_STREAMING_TASK}}; + return check(g, name, expected_calls); +} + +bool test_2_stages_consumer_streaming() { + const std::string name("test_2_stages_consumer_streaming"); + Func f("f"); + Func g("g"); + + f(x) = x * 0.1f; + g(x) = f(x) * f(x); + + g.compute_root().sme_streaming(true, x); + f.compute_root(); + + std::multiset expected_calls{ + {Attribute::NO_ATTRIBUTE, Attribute::SME_STREAMING_TASK}}; + return check(g, name, expected_calls); +} + +bool test_2_stages_both_streaming_at() { + const std::string name("test_2_stages_both_streaming_at"); + Func f("f"); + Func g("g"); + + f(x) = x * 0.1f; + g(x) = f(x) * f(x); + + g.compute_root().sme_streaming(true, x).split(x, xo, xi, 256); + f.compute_at(g, xo); // Computed in streaming mode implicitly + + std::multiset expected_calls{ + {Attribute::NO_ATTRIBUTE, Attribute::SME_STREAMING_TASK}}; + return check(g, name, expected_calls); +} + +bool test_2_stages_producer_streaming_at() { + const std::string name("test_2_stages_producer_streaming_at"); + Func f("f"); + Func g("g"); + + f(x) = x * 0.1f; + g(x) = f(x) * f(x); + + g.compute_root().split(x, xo, xi, 256); + f.compute_at(g, xo).sme_streaming(true, x); + + std::multiset expected_calls{ + {Attribute::NO_ATTRIBUTE, Attribute::SME_STREAMING_TASK}}; + return check(g, name, expected_calls); +} + +bool test_2_stages_consumer_streaming_at() { + const std::string name("test_2_stages_consumer_streaming_at"); + Func f("f"); + Func g("g"); + + f(x) = x * 0.1f; + g(x) = f(x) * f(x); + + g.compute_root().sme_streaming(true, x).split(x, xo, xi, 256); + // explicitly set false, otherwise streaming is enabled + f.compute_at(g, xo).sme_streaming(false); + + std::multiset expected_calls{ + {Attribute::NO_ATTRIBUTE, Attribute::SME_STREAMING_TASK}, + {Attribute::SME_STREAMING_TASK, Attribute::SME_NONSTREAMING_TASK}, + }; + return check(g, name, expected_calls); +} + +bool test_2_stages_consumer_streaming_at_2() { + const std::string name("test_2_stages_consumer_streaming_at_2"); + Func f("f"); + Func g("g"); + Func h("h"); + + f(x) = x * 0.1f; + g(x) = f(x) * f(x); + h(x) = g(x) + g(x); + + // Nested twice + h.compute_root().sme_streaming(true, x).split(x, xo, xi, 256); + g.compute_at(h, xo).sme_streaming(false, x).split(x, xo, xi, 64); + f.compute_at(g, xo).sme_streaming(true, x); + + std::multiset expected_calls{ + {Attribute::NO_ATTRIBUTE, Attribute::SME_STREAMING_TASK}, + {Attribute::SME_STREAMING_TASK, Attribute::SME_NONSTREAMING_TASK}, + {Attribute::SME_NONSTREAMING_TASK, Attribute::SME_STREAMING_TASK}, + }; + return check(h, name, expected_calls); +} + +bool test_update_rdom() { + const std::string name("test_update_rdom"); + Func f("f"); + Func g("g"); + RDom r(0, 3); + + f(x) = sin(x); + g(x) = 0.f; + g(x) += f(x + r - 1); + + g.compute_root().sme_streaming(true, x); + g.update().sme_streaming(true, x); + f.compute_root().sme_streaming(true, x); + + std::multiset expected_calls{ + {Attribute::NO_ATTRIBUTE, Attribute::SME_STREAMING_TASK}, + {Attribute::NO_ATTRIBUTE, Attribute::SME_STREAMING_TASK}, + {Attribute::NO_ATTRIBUTE, Attribute::SME_STREAMING_TASK}, + }; + return check(g, name, expected_calls); +} + +bool test_update_rdom_2() { + const std::string name("test_update_rdom_2"); + Func f("f"); + Func g("g"); + RDom r(0, 3); + + f(x) = sin(x); + g(x) = 0.f; + g(x) += f(x + r - 1); + + g.compute_at(g.in(), x); + g.in().compute_root().sme_streaming(true, x); + g = g.in(); + f.compute_root(); + + std::multiset expected_calls{ + {Attribute::NO_ATTRIBUTE, Attribute::SME_STREAMING_TASK}, + }; + return check(g, name, expected_calls); +} + +bool test_update_rdom_rvar() { + const std::string name("test_update_rdom_rvar"); + Func f("f"); + Func g("g"); + RDom r(0, 256); + + f(x) = sin(x); + g(x) = 0.f; + g(x) += f(x + r); + + g.compute_root().update().sme_streaming(true, r); + + std::multiset expected_calls{ + {Attribute::NO_ATTRIBUTE, Attribute::SME_STREAMING_TASK}, + }; + return check(g, name, expected_calls); +} + +bool test_compute_with() { + const std::string name("test_compute_with"); + Func f("f"); + Func g("g"); + Func h("h"); + + f(x) = sin(x); + g(x) = x * 0.1f; + h(x) = f(x) + g(x); + + h.compute_root(); + // DeviceAPI of g and f must match to compute with + g.compute_root().compute_with(f, x).sme_streaming(true, x); + f.compute_root().sme_streaming(true, x); + + std::multiset expected_calls{ + {Attribute::NO_ATTRIBUTE, Attribute::SME_STREAMING_TASK}, + }; + return check(h, name, expected_calls); +} + +bool test_parallel() { + const std::string name("test_parallel"); + Func f("f"); + Var xso("xso"), xsi("xsi"); + + f(x) = x * 0.1f; + + // Streaming task is called for each thread spawned by parallel(), + // rather than one streaming task spawning threads in it. + f.compute_root() + .split(x, xo, xi, 256) + .parallel(xo) + .sme_streaming(true, xi); + + std::multiset expected_calls{ + {Attribute::NO_ATTRIBUTE, Attribute::SME_STREAMING_TASK}, + }; + return check(f, name, expected_calls); +} + +int main(int argc, char **argv) { + can_run_code = get_jit_target_from_environment().has_feature(Target::SME2); + if (!can_run_code) { + std::cout << "(skip) Cannot run correctness check of sme_streaming on this target\n"; + } + + bool ok = true; + ok &= test_1_stage_non_streaming(); + ok &= test_1_stage_streaming_outermost(); + ok &= test_1_stage_streaming_inner(); + ok &= test_2_stages_both_streaming(); + ok &= test_2_stages_producer_streaming(); + ok &= test_2_stages_consumer_streaming(); + ok &= test_2_stages_both_streaming_at(); + ok &= test_2_stages_producer_streaming_at(); + ok &= test_2_stages_consumer_streaming_at(); + ok &= test_2_stages_consumer_streaming_at_2(); + ok &= test_update_rdom(); + ok &= test_update_rdom_2(); + ok &= test_update_rdom_rvar(); + ok &= test_compute_with(); + ok &= test_parallel(); + + if (!ok) { + return 1; + } + printf("Success!\n"); + return 0; +} diff --git a/test/error/CMakeLists.txt b/test/error/CMakeLists.txt index fc9496af0244..20db7ae6f5c4 100644 --- a/test/error/CMakeLists.txt +++ b/test/error/CMakeLists.txt @@ -83,6 +83,7 @@ tests(GROUPS error memoize_redefine_eviction_key.cpp metal_threads_too_large.cpp metal_vector_too_large.cpp + mismatch_runtime_streaming_vscale.cpp mismatch_runtime_vscale.cpp missing_args.cpp no_default_device.cpp diff --git a/test/error/mismatch_runtime_streaming_vscale.cpp b/test/error/mismatch_runtime_streaming_vscale.cpp new file mode 100644 index 000000000000..cab0e43199a0 --- /dev/null +++ b/test/error/mismatch_runtime_streaming_vscale.cpp @@ -0,0 +1,34 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + auto target = get_jit_target_from_environment(); + if (!target.has_feature(Target::SME2)) { + printf("[SKIP] Streaming SVE is not supported on this target.\n"); + _halide_user_assert(0); + return 1; + } + + Func f("f"); + Var x("x"); + + f(x) = x; + + const int correct_vector_bits = target.sme_streaming_vector_bits(); + const int wrong_vector_bits = correct_vector_bits == 512 ? 256 : 512; + Target::Feature correct_sme_svl = Target::sme_svl_feature_from_bits(correct_vector_bits); + Target::Feature wrong_sme_svl = Target::sme_svl_feature_from_bits(wrong_vector_bits); + if (correct_sme_svl == Target::FeatureEnd || wrong_sme_svl == Target::FeatureEnd) { + printf("Unexpected behavior in getting sme_vl feature!\n"); + return 0; // Normal return is test failure + } + target = target.without_feature(correct_sme_svl).with_feature(wrong_sme_svl); + + // Compile with wrong vscale and run on host, which should end up with assertion failure. + Buffer out = f.realize({100}, target); + + printf("Success!\n"); + return 0; +}