diff --git a/src/Deinterleave.cpp b/src/Deinterleave.cpp index 6cf9428d9b14..0563398c518a 100644 --- a/src/Deinterleave.cpp +++ b/src/Deinterleave.cpp @@ -610,7 +610,20 @@ class Interleaver : public IRMutator { return expr; } + Scope allocation_scope; + Stmt visit(const Allocate *op) override { + ScopedBinding bind(allocation_scope, op->name, op->memory_type); + return IRMutator::visit(op); + } + Stmt visit(const Store *op) override { + // Don't mess with matrix multiply ops, which use natively-supported 2D + // loads and stores. + if (auto *alloc = allocation_scope.find(op->name); + alloc && (*alloc) == MemoryType::AMXTile) { + return op; + } + bool old_should_deinterleave = should_deinterleave; int old_num_lanes = num_lanes; @@ -657,6 +670,13 @@ class Interleaver : public IRMutator { return Stmt(); } + // Don't mess with matrix multiply ops, which use natively-supported 2D + // loads and stores. + if (auto *alloc = allocation_scope.find(store->name); + alloc && (*alloc) == MemoryType::AMXTile) { + return Stmt(); + } + const Ramp *r0 = store->index.as(); // It's not a store of a ramp index. diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index cd315db389b4..53c5b4957a5d 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -1,8 +1,11 @@ #include "ExtractTileOperations.h" -#include "IRMatch.h" +#include "FindIntrinsics.h" +#include "IREquality.h" #include "IRMutator.h" #include "IROperator.h" +#include "MultiRamp.h" +#include "Simplify.h" #include "Util.h" /** \file Support extraction of AMX instructions. */ @@ -24,7 +27,7 @@ * striding over 4 byte areas. │6 │ * Normally the row of the LHS matrix, │7 │ * 123... would multiply with the column │8 │ - * of the RHS matrix 123..., but with AMX └──┘ + * of the RHS matrix 123..., but with AMX │8 │ * this column is split up into a matrix of columns / 4 byte and rows * 4. * which then results in K/4 dot products per row. * @@ -38,524 +41,375 @@ using std::vector; namespace { -template -struct Tile { - bool result; - Expr base; - Expr stride[Dim]; - int extent[Dim]; -}; - -enum class AMXOpType { - Int8, - Bfloat16, +struct Matmul { + bool result = false; + Stmt stmt; + int I = 0; + int J = 0; + int K = 0; }; -/// returns the appropriate `Halide::Type` for the given operation type -Type amx_op_type_result_type(AMXOpType op_ty) { - switch (op_ty) { - case AMXOpType::Int8: - return Int(32, 256); - case AMXOpType::Bfloat16: - return Float(32, 256); - default: - internal_error << "Unexpected"; - return Type(); - } -} - -int amx_op_type_size(AMXOpType op_ty) { - switch (op_ty) { - case AMXOpType::Int8: - return 1; - case AMXOpType::Bfloat16: - return 2; - default: - internal_error << "Unexpected"; - return -1; - } -} - -const auto wild_i32 = Variable::make(Int(32), "*"); -const auto wild_i32x = Variable::make(Int(32, 0), "*"); - -Tile<1> get_1d_tile_index(const Expr &e) { - if (const auto *r1 = e.as()) { - - const auto stride_var = Variable::make(Int(32), "stride"); - const auto v1 = Variable::make(Int(32), "v1"); - const auto v2 = Variable::make(Int(32), "v2"); - const auto v3 = Variable::make(Int(32), "v3"); - - Expr patterns[] = { - ((v1 * stride_var) + v2) * v3, - v3 * ((v1 * stride_var) + v2), - (v2 + (v1 * stride_var)) * v3, - v3 * (v2 + (v1 * stride_var)), - }; - - std::map matches; - for (const auto &pattern : patterns) { - if (expr_match(pattern, r1->base, matches)) { - auto stride = std::move(matches["stride"]); - // stride must be a constant in order to not be confused with v1 - if (stride.as()) { - return {true, r1->base, {std::move(stride)}, {r1->lanes}}; - } - - // if stride wasn't a constant then v1 could possibly be the stride if constant - auto v1_expr = std::move(matches["v1"]); - if (v1_expr.as()) { - return {true, r1->base, {std::move(v1_expr)}, {r1->lanes}}; - } - } - } - } - - return {}; -} - -Tile<2> get_2d_tile_index(const Expr &e) { - // ramp(ramp(base, 1, 4), x4(stride), 4) - vector matches; - if (const auto *r1 = e.as()) { - if (const auto *r2 = r1->base.as()) { - auto ramp_2d_pattern = Ramp::make(Ramp::make(wild_i32, wild_i32, r2->lanes), Broadcast::make(wild_i32, r2->lanes), r1->lanes); - if (expr_match(ramp_2d_pattern, e, matches)) { - return {true, std::move(matches[0]), {std::move(matches[2]), std::move(matches[1])}, {r1->lanes, r2->lanes}}; - } - } - } - return {}; -} - -Tile<3> get_3d_tile_index(const Expr &e) { - vector matches; - - // there could be a sub node - const Sub *sub = e.as(); - const Add *add = nullptr; - - if (sub) { - add = sub->a.as(); - } else { - add = e.as(); - } - +Matmul convert_to_matmul(const Store *op, const string &new_name) { + // We expect the pattern: + // + // out[idx] = reduce_add(widen(lhs[multiramp]) * widen(rhs[multiramp])) + out[idx] + // + // Though if the multiramp has an outer dimension of stride zero it may have + // been hoisted outwards to just a broadcast of the widened value. + + auto fail = [&](const char *reason) -> Matmul { + user_error << "Matrix multiply not recognized. Store to AMX allocation must be a " + << "zero-initialization or a sum of a vector reduce op and a load from " + << "the same allocation. In the following store, " << reason << ".\n" + << Stmt(op); + return Matmul{}; + }; + + // Peel lets + std::vector> peeled_lets; + Expr value = op->value; + while (const Let *let = value.as()) { + peeled_lets.emplace_back(let->name, let->value); + value = let->body; + } + + // The RHS must be an add + const auto *add = value.as(); if (!add) { - return {}; - } - - const auto &first = add->a; - const auto &second = add->b; - - // ramp(x[x*r](base), x[x*r](stride), x) + x[x*y](ramp(idx, 1, r)) - - const auto *r1 = first.as(); - const auto *b2 = second.as(); - if (!r1 && !b2) { - // Try switching the order - r1 = second.as(); - b2 = first.as(); - } - if (!r1 || !b2) { - return {}; - } - - const auto *b1 = r1->base.as(); - const auto *r2 = b2->value.as(); - - if (!b1 || !r2) { - return {}; + return fail("the right-hand-side is not an add"); } - int x_tile = r1->lanes; - int r_tile = r2->lanes; - int y_tile = b1->lanes / r_tile; - if (y_tile != b2->lanes / x_tile) { - return {}; + // The add must be between a vector reduce and a load. The simplifier will + // have placed the vector reduce to the left, due to canonicalization of + // commutative ops. + Expr lhs = add->a; + const auto *reduce = lhs.as(); + if (!reduce || reduce->op != VectorReduce::Add) { + return fail("the right-hand-side is not a vector reduction plus a load"); } - auto pattern1 = Ramp::make(Broadcast::make(wild_i32, b1->lanes), Broadcast::make(wild_i32, b1->lanes), r1->lanes); - if (!expr_match(pattern1, first, matches)) { - return {}; + // The load must be to the same addresses as the store (i.e. this is a +=) + Expr rhs = add->b; + const auto *load = rhs.as(); + if (!load || load->name != op->name || !equal(load->index, op->index)) { + return fail("the right-hand-side load is not from the same address as the store"); } - Expr base = std::move(matches[0]); - Expr x_stride = std::move(matches[1]); - auto pattern2 = Broadcast::make(Ramp::make(wild_i32, wild_i32, r2->lanes), b2->lanes); - if (!expr_match(pattern2, second, matches)) { - return {}; + // There must be no predicate on the load or store + if (!is_const_one(load->predicate) || !is_const_one(op->predicate)) { + return fail("the load or store is predicated"); } - base += std::move(matches[0]); - Expr r_stride = std::move(matches[1]); - if (sub) { - Expr adj = sub->b; - const Broadcast *bcast = adj.as(); + // The vector reduce must be of a multiply. Unpack it and rebind lhs and rhs + // to mean the lhs and rhs of the mul. For integers we normalize various + // ways of doing the widening multiply by running find_intrinsics. For + // floats we do the opposite and canonicalize away from intrinsics, because + // FindIntrinsics does not currently lift float widening_muls. + if (reduce->type.is_int_or_uint()) { + Expr reduce_value = find_intrinsics(reduce->value); - if (!bcast) { - return {}; + const auto *cast = reduce_value.as(); + if (!cast) { + return fail("the vector reduction operand or result types are not supported"); } - if (bcast->lanes != b1->lanes * r1->lanes) { - return {}; + if (const auto *call = Call::as_intrinsic(cast->value, {Call::widening_mul})) { + // Simplify to convert bit-math back to multiply, div, mod + lhs = simplify(lower_intrinsics(call->args[0])); + rhs = simplify(lower_intrinsics(call->args[1])); + } else { + return fail("the vector reduction is not of a widening multiply"); } - base -= bcast->value; - } - - return {true, base, {x_stride, 0, r_stride}, {x_tile, y_tile, r_tile}}; -} + if (lhs.type().bits() != 8 || + rhs.type().bits() != 8) { + return fail("the vector reduction operand or result types are not supported"); + } -/** - * \brief Get the 3d rhs tile index configuration - * - * \param e index expression - * \param element_width the width of the elements, 1 for u8/i8, 2 for bf16 - * \return Tile<3> the tile configuration found - * - * The pattern which is getting matched looks roughly like - * `broadcast(ramp(0, 1, r), x*y) / broadcast(4, x*y*r) + optional(broadcast(base, x*y*r)) * broadcast(8, x*y*r) + - * broadcast(ramp(0, 1, r), x*y) % broadcast(4, x*y*r) + - * broadcast(ramp(broadcast(_, r), broadcast(4, r), x) , y)` - */ -Tile<3> get_3d_rhs_tile_index(const Expr &e, int element_width) { - const auto *sub = e.as(); - const Add *add_lhs = nullptr; - - // there's not always a sub pattern - // This depends on whether we have an ImageParam or a Buffer - if (!sub) { - add_lhs = e.as(); } else { - add_lhs = sub->a.as(); - } - - if (!add_lhs) { - return {}; - } - - // The right hand side of the add expression is used for retrieving the dimensions of the matrix. - // obtain the x, y, r dimensions - // this expr looks like below, the shape of `add_lhs->a` can be seen further down below - // broadcast(ramp(0, 1, r), x*y) % broadcast(4, x*y*r) + broadcast(ramp(broadcast(base, r), broadcast(4, r), x) , y) - const Add *dim_expr = add_lhs->b.as(); - - if (!dim_expr) { - return {}; - } - - // broadcast(ramp(broadcast(_, r), broadcast(4, r), x), y) - const Broadcast *base_stride_bc = dim_expr->b.as(); - - if (!base_stride_bc) { - return {}; - } - - int tile_y = base_stride_bc->lanes; - - // broadcast(ramp(0, 1, r), x*y) % broadcast(4, x*y*r) - const Mod *mod = dim_expr->a.as(); - - if (!mod) { - return {}; - } - - // broadcast(ramp(0, 1, r), x*y) - const Broadcast *bc_ramp = mod->a.as(); - - if (!bc_ramp) { - return {}; - } - - int tile_xy = bc_ramp->lanes; - int tile_x = tile_xy / tile_y; - - // ramp(0, 1, r) - const Ramp *r_ramp = bc_ramp->value.as(); - - if (!r_ramp) { - return {}; - } - - int tile_r = r_ramp->lanes; - - // get the base and stride - // ramp(broadcast(_, r), broadcast(4, r), x) - const Ramp *base_stride_ramp = base_stride_bc->value.as(); - - if (!base_stride_ramp) { - return {}; - } - - // broadcast(_, r) - const Broadcast *base_bc = base_stride_ramp->base.as(); - - if (!base_bc) { - return {}; - } - - Expr base = base_bc->value; - Expr stride; - - bool found_stride = false; - - // the following pattern will match the following shape - // broadcast(ramp(0, 1, k), x*y) / broadcast(4, x*y*k) * broadcast(_, x*y*k) - // where the stride is marked by _. - - // this stride pattern can occur if `tile_r` is the same size as `acc` - auto stride_pattern = Broadcast::make(Ramp::make(0, 1, tile_r), tile_x * tile_y) / Broadcast::make((4 / element_width), tile_x * tile_y * tile_r) * Broadcast::make(wild_i32, tile_x * tile_y * tile_r); - - std::vector results{}; - if (expr_match(stride_pattern, add_lhs->a, results)) { - found_stride = true; - stride = std::move(results[0]); - } - - // This pattern is similar to the above except with an additional offset to iterate over the tiles in the k dimension - // (broadcast(ramp(0, 1, k), m * n) / broadcast(4, m*n*k) + _) * broadcast(_, m*n*k) - // here the first _ marks the base and the second _ the stride. - if (!found_stride) { - stride_pattern = (Broadcast::make(Ramp::make(0, 1, tile_r), tile_x * tile_y) / Broadcast::make((4 / element_width), tile_x * tile_y * tile_r) + wild_i32) * Broadcast::make(wild_i32, tile_x * tile_y * tile_r); - if (expr_match(stride_pattern, add_lhs->a, results)) { - found_stride = true; - stride = std::move(results[1]); - base = std::move(results[0]) * stride + base; + // Lower a widening_mul intrinsic, as they can be used but aren't lifted to for bf16. + Expr reduce_value = simplify(lower_intrinsics(reduce->value)); + const auto *mul = reduce_value.as(); + if (!mul) { + return fail("the vector reduction is not of a widening multiply"); } + lhs = mul->a; + rhs = mul->b; } - if (!found_stride) { - return {}; - } - - return {true, base, {stride, 0, 0}, {tile_x, tile_y, tile_r}}; -} - -struct BaseStride { - bool result{false}; - Expr base{}; - Expr stride{}; -}; - -BaseStride get_rhs_tile_index(const Expr &index, int element_width, int tile_x, int tile_y, int tile_r) { - const auto rhs_tile2 = get_2d_tile_index(index); - - if (!rhs_tile2.result) { - const auto rhs_tile1 = get_1d_tile_index(index); - - if (!rhs_tile1.result) { - auto rhs_tile3 = get_3d_rhs_tile_index(index, element_width); - if (rhs_tile3.extent[0] != tile_x || rhs_tile3.extent[1] != tile_y || rhs_tile3.extent[2] != tile_r) { - return {}; - } - - return {true, rhs_tile3.base, rhs_tile3.stride[0] * element_width}; + // There may be a broadcast next (it can get hoisted outside of other ops) + auto debroadcast = [](Expr &e) -> int { + if (const Broadcast *b = e.as()) { + e = b->value; + return b->lanes; } else { - if (rhs_tile1.extent[0] != tile_y * tile_r) { - return {}; - } - - // times 4 because of the rhs layout, each vector used by AMX is 4 bytes in size. - // For the 4 gets divided by the element width which means each vector has 4 elements in u8/i8 and - // 2 elements for bf16. - return {true, rhs_tile1.base, rhs_tile1.stride[0] * (4 / element_width)}; + return 1; } - } else { - if (tile_y != rhs_tile2.extent[0] || tile_r != rhs_tile2.extent[1]) { - return {}; - } - - return {true, rhs_tile2.base, rhs_tile2.stride[0]}; - } -} - -struct Matmul { - bool result = false; - Stmt stmt; - int tile_x; - int tile_y; - int tile_r; -}; - -Matmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType op_type) { - // m[ramp(0, 1, S)] = VectorAdd(lhs[{XYR tile}] * xX(rhs[{YR tile}])) + m[ramp(0, 1, S)] - const auto wild_i8x = Variable::make(Int(8, 0), "*"); - const auto wild_u8x = Variable::make(UInt(8, 0), "*"); - const auto wild_bf16x = Variable::make(BFloat(16, 0), "*"); - const auto wild_f32x = Variable::make(Float(32, 0), "*"); - - vector matches; - if (op_type == AMXOpType::Int8) { - const auto pattern1 = wild_i32x + wild_i32x; - if (!expr_match(pattern1, op->value, matches)) { - return {}; + }; + int lhs_broadcast = debroadcast(lhs); + int rhs_broadcast = debroadcast(rhs); + + // Unpack the casts, if it was a direct multiply. This should only happen + // for floats (the integer branch above already extracted the cast inputs + // from the widening_mul intrinsic). + if (reduce->type.is_float()) { + const auto *lhs_cast = lhs.as(); + const auto *rhs_cast = rhs.as(); + if (!lhs_cast || !rhs_cast) { + return fail("the vector reduction is not of a widening multiply"); } - } else { // AMXOpType::Bfloat16 - const auto pattern1 = wild_f32x + wild_f32x; - if (!expr_match(pattern1, op->value, matches)) { - return {}; + lhs = lhs_cast->value; + rhs = rhs_cast->value; + if (!lhs.type().is_bfloat() || + rhs.type().element_of() != lhs.type().element_of()) { + return fail("the vector reduction operand or result types are not supported"); } } - const auto *reduce = matches[0].as(); - const auto *load = matches[1].as(); - if (!reduce || reduce->op != VectorReduce::Add) { - return {}; + // Underneath all of this must be a load + // TODO: What if we want to multiply by the same matrix multiple times? It might be a let binding. + const auto *lhs_load = lhs.as(); + const auto *rhs_load = rhs.as(); + if (!lhs_load || !rhs_load) { + return fail("the matrix multiply operands are not loads"); } - if (!load || load->name != op->name || !equal(load->index, op->index)) { - return {}; + // The loads must be unpredicated + if (!is_const_one(lhs_load->predicate) || !is_const_one(rhs_load->predicate)) { + return fail("the matrix multiply operands are predicated loads"); } - if (op_type == AMXOpType::Int8) { - auto pattern2 = cast(Int(32, 0), cast(Int(32, 0), wild_i8x) * wild_i32x); - auto pattern2_unsigned = cast(Int(32, 0), cast(Int(32, 0), wild_u8x) * wild_i32x); + // Now we analyze the load indices as multiramps + MultiRamp lhs_mr, rhs_mr; + Scope empty_scope; + if (!is_multiramp(lhs_load->index, empty_scope, &lhs_mr) || + !is_multiramp(rhs_load->index, empty_scope, &rhs_mr)) { + return fail("the matrix multiply loads indices are not affine"); + } - if (!(expr_match(pattern2, reduce->value, matches) || expr_match(pattern2_unsigned, reduce->value, matches))) { - return {}; + // Add back on any broadcasts as a stride-0 outer dim. + auto add_broadcast = [](MultiRamp &mr, int extent) { + if (extent > 1) { + mr.strides.push_back(make_zero(mr.base.type())); + mr.lanes.push_back(extent); } - } else { - auto pattern2 = cast(Float(32, 0), cast(Float(32, 0), wild_bf16x) * wild_f32x); - - if (!expr_match(pattern2, reduce->value, matches)) { - return {}; + }; + add_broadcast(lhs_mr, lhs_broadcast); + add_broadcast(rhs_mr, rhs_broadcast); + + // In a matrix multiply with row-major inputs and outputs, the algorithm + // looks like: + // + // C(j, i) += A(k, i) * B(j, k) + // + // (Recall that for matrices where the rows are stored densely in memory, Halide + // is indexed col-major) + // The canonical loop nest order, from innermost out, is k, j, i. So you'd + // expect the following vector shape (again, from innermost out): + // [K, J, I] + // And the following strides for A and B respectively: + // [1, 0, ?] [?, 1, 0] + // Where the question marks are the strides in memory that separate rows the LHS and RHS. + + // AMX however splits the storage of K into 32-bit chunks and reorders that + // innermost for both A and B. This changes the algorithm to: + // C(j, i) += A(ki, ko, i) * B(ki, j, ko) + // the shape to: + // [Ki, Ko, J, I] + // and the strides to: + // [1, Ki, 0, ?] [1, ?, Ki, 0] + + // So next we need to: + // 1) Deduce what Ki, Ko, are. + // 2) Deduce which is the LHS and which is the RHS + // 3) Deduce what I, J are. + // 4) extract those two question marks, and validate the + // rest is as-expected. + + // The reduction's K dimension will be split into inner and outer elements. + int element_width = lhs_load->type.bytes(); + int K = reduce->value.type().lanes() / reduce->type.lanes(); + int Ki = 4 / element_width; + int Ko = K / Ki; + + // Now deduce LHS and RHS. First some helpers. + auto swap_sides = [&]() { + std::swap(lhs, rhs); + std::swap(lhs_mr, rhs_mr); + std::swap(lhs_load, rhs_load); + }; + + auto swizzled = [](const MultiRamp &mr) { + // Count the number of non-broadcast dimensions + int count = 0; + for (int i = 0; i < mr.dimensions(); i++) { + count += !is_const_zero(mr.strides[i]); } - } - - const auto *lhs_load = matches[0].as(); - const auto *rhs_broadcast = matches[1].as(); - - const Cast *rhs_cast = nullptr; - - if (lhs_load && !rhs_broadcast) { - // now working on a larger k dimension - // with a K dimension of 4 (or 2) with bf16 all the elements in the right-hand matrix are - // laid out in a way that multiplying with a column can be done in a single dot product. - // Therefore the indexing can be reused with a broadcast, - // with higher K dimensions this can no longer be done and the broadcast won't exist. - // ┌──┐ - // │1 │ - // │2 │ - // │3 │ ┌────────┐ - // │4 │ │1234 │ - // │5 │ │5678 │ - // │6 │ └────────┘ - // │7 │ - // │8 │ - // └──┘ - rhs_cast = matches[1].as(); - } else { - rhs_cast = rhs_broadcast->value.as(); - } - - if (!lhs_load || !rhs_cast) { - return {}; - } - - if (rhs_cast) { - bool is_i8_u8 = rhs_cast->value.type().element_of() == Int(8) || rhs_cast->value.type().element_of() == UInt(8); - bool is_bf16 = rhs_cast->value.type().element_of() == BFloat(16); - - if ((op_type == AMXOpType::Int8 && !is_i8_u8) || (op_type == AMXOpType::Bfloat16 && !is_bf16)) { - user_error << "Expected rhs type of " << (op_type == AMXOpType::Int8 ? "i8/u8" : "bf16") - << ", got " << rhs_cast->value.type() << " instead.\nIn Expression: " << Expr(rhs_cast); + return count > 2; + }; + + auto has_trailing_zero = [](const MultiRamp &mr) { + return !mr.lanes.empty() && is_const_zero(mr.strides.back()); + }; + + // The RHS is the one that's swizzled. The LHS should be stored densely in + // K. If both sides are stored densely either Ko is one (there's no outer + // dimension in the swizzle) or J is one (there's no dimension that would go + // between Ki and Ko). The RHS is the one that doesn't depend on I, so it + // should have a trailing zero stride. If neither side is swizzled and + // neither side has a trailing zero stride then it doesn't matter which side + // is which. + if (swizzled(lhs_mr) || (!swizzled(rhs_mr) && has_trailing_zero(lhs_mr))) { + swap_sides(); + } + + auto unique_lanes = [](const MultiRamp &mr) { + int u = 1; + for (int i = 0; i < mr.dimensions(); i++) { + if (!is_const_zero(mr.strides[i])) { + u *= mr.lanes[i]; + } + } + return u; + }; + + // Now deduce I, J. The output has I * J lanes. The LHS has I * K unique + // addresses loaded, and the RHS has J * K unique addresses. + int IJ = reduce->type.lanes(); + int IK = unique_lanes(lhs_mr); + int I = IK / K; + int J = IJ / I; + + // Coerce both MRs into the canonical [Ki, Ko, J, I] shape (innermost + // first). When Ko == 1, the second slot is just an extent-1 dim and + // strides_for_shape will return a 0 stride there. The expected strides + // we'll then validate are: + // lhs: [1, Ki, 0, ?] (with `?` the LHS row stride) + // rhs: [1, ?, Ki, 0] (with `?` the RHS row stride between Ko chunks) + std::vector + shape{Ki, Ko, J, I}; + std::vector lhs_strides, rhs_strides; + if (!lhs_mr.strides_for_shape(shape, &lhs_strides) || + !rhs_mr.strides_for_shape(shape, &rhs_strides)) { + return fail("a matrix multiply operand has an unsupported access pattern"); + } + + if (!is_const_one(lhs_strides[0]) || + !is_const_one(rhs_strides[0]) || + (Ko > 1 && !is_const(lhs_strides[1], Ki)) || + (J > 1 && !is_const(rhs_strides[2], Ki)) || + !is_const_zero(lhs_strides[2]) || + !is_const_zero(rhs_strides[3])) { + return fail("the storage layout for a matrix multiply operand is unsupported by AMX"); + } + + // Both sides of the multiply must be things that fit in AMX registers. We + // could manually split up too-large matrices here into a collection of + // matrix multiply ops, but for now we just assert. + { + Type t = op->value.type(); + bool result_ok = t.bytes() * I * J <= 1024; + bool lhs_ok = lhs.type().bytes() * I * K <= 1024; + bool rhs_ok = rhs.type().bytes() * K * J <= 1024; + if (!result_ok || !lhs_ok || !rhs_ok) { + return fail("one more more matrices are too large to fit in AMX registers (more than 1024 bytes)"); + } + if (I > 16) { + return fail("the result matrix has more than 16 rows"); + } + if (Ko > 16) { + return fail("the RHS matrix has more than 16 rows"); } - } else { - return {}; - } - - const auto *rhs_load = rhs_cast->value.as(); - if (!rhs_load) { - return {}; - } - - const auto lhs_tile = get_3d_tile_index(lhs_load->index); - - if (!lhs_tile.result) { - return {}; - } - - const int tile_x = lhs_tile.extent[0]; - const int tile_y = lhs_tile.extent[1]; - const int tile_r = lhs_tile.extent[2]; - const int factor = reduce->value.type().lanes() / reduce->type.lanes(); - - Expr rhs_base; - Expr rhs_stride; - - auto opt_base_stride = get_rhs_tile_index(rhs_load->index, amx_op_type_size(op_type), tile_x, tile_y, tile_r); - - if (!opt_base_stride.result) { - return {}; } - rhs_base = opt_base_stride.base; - rhs_stride = opt_base_stride.stride; - - if (op->index.type().lanes() != tile_x * tile_y || - factor != tile_r) { - return {}; - } + Expr rhs_stride_bytes = Ko > 1 ? rhs_strides[1] * element_width : make_zero(rhs_mr.base.type()); + Expr lhs_stride_bytes = lhs_strides[3] * element_width; - // {rows, colbytes, var, index} + // Build the AMX intrinsics. auto lhs_var = Variable::make(Handle(), lhs_load->name); - const auto &lhs_load_type = lhs_load->type; - int element_width = lhs_load_type.bytes(); - auto lhs_type = lhs_load_type.with_lanes(1024 / element_width); - auto lhs = Call::make(lhs_type, "tile_load", {tile_x, tile_r * element_width, lhs_var, lhs_tile.base * element_width, lhs_tile.stride[0] * element_width}, Call::Intrinsic); + auto lhs_type = lhs_load->type.with_lanes(1024 / element_width); + auto lhs_call = Call::make(lhs_type, "tile_load", + {I, K * element_width, lhs_var, + lhs_mr.base * element_width, lhs_stride_bytes}, + Call::Intrinsic); auto rhs_var = Variable::make(Handle(), rhs_load->name); - const auto &rhs_load_type = rhs_load->type; - auto rhs_type = rhs_load_type.with_lanes(1024 / element_width); - - auto rhs = Call::make(rhs_type, "tile_load", {tile_r / (4 / element_width), tile_y * 4, rhs_var, rhs_base * element_width, rhs_stride}, Call::Intrinsic); - auto res_type = amx_op_type_result_type(op_type); - - // {rows, colbytes, acc, out, lhs, rhs} - auto out = Load::make(res_type, new_name, Ramp::make(0, 1, 256), {}, {}, const_true(256), {}); - - // 4 bytes for i32, f32 - auto colbytes = tile_y * 4; - auto matmul = Call::make(res_type, "tile_matmul", {tile_x, colbytes, tile_r, out, lhs, rhs}, Call::Intrinsic); - auto store = Store::make(new_name, matmul, Ramp::make(0, 1, 256), Parameter(), const_true(256), ModulusRemainder()); - return {true, std::move(store), tile_x, tile_y, tile_r}; + auto rhs_type = rhs_load->type.with_lanes(1024 / element_width); + auto col_bytes = J * 4; // 4 bytes per innermost K slice + auto rhs_call = Call::make(rhs_type, "tile_load", + {Ko, col_bytes, rhs_var, + rhs_mr.base * element_width, rhs_stride_bytes}, + Call::Intrinsic); + + Type res_type = op->value.type().with_lanes(256); + Expr subtile_idx = Ramp::make(0, 1, 256); + auto out_load = Load::make(res_type, new_name, subtile_idx, {}, {}, const_true(256), {}); + + auto matmul = Call::make(res_type, "tile_matmul", + {I, col_bytes, K, out_load, lhs_call, rhs_call}, + Call::Intrinsic); + auto store = Store::make(new_name, matmul, std::move(subtile_idx), Parameter(), const_true(256), ModulusRemainder()); + for (auto &[name, value] : reverse_view(peeled_lets)) { + store = LetStmt::make(name, std::move(value), store); + } + return {true, std::move(store), I, J, K}; } -Stmt convert_to_zero(const Store *op, int tile_x, int tile_y, const string &new_name) { - if (const auto *ramp = op->index.as()) { - if (const auto *bcast = op->value.as()) { - if (is_const_one(ramp->stride) && - is_const_zero(bcast->value) && - (bcast->lanes == tile_x * tile_y)) { - auto rows = Cast::make(Int(16), tile_x); - auto bytes = op->value.type().bytes(); - auto colbytes = Cast::make(Int(16), tile_y * bytes); - const auto &store_type = op->value.type(); - // will be f32 or i32 - auto tile_zero_type = store_type.with_lanes(1024 / store_type.bytes()); - auto val = Call::make(tile_zero_type, "tile_zero", {rows, colbytes}, Call::Intrinsic); - auto store = Store::make(new_name, std::move(val), Ramp::make(0, 1, 256), Parameter(), const_true(256), ModulusRemainder()); - return store; - } - } - } - return {}; +Stmt convert_to_zero(const Store *op, const string &new_name, int I, int J) { + auto rows = Cast::make(Int(16), I); + auto bytes = op->value.type().bytes(); + auto colbytes = Cast::make(Int(16), J * bytes); + const auto &store_type = op->value.type(); + auto tile_zero_type = store_type.with_lanes(1024 / store_type.bytes()); + auto val = Call::make(tile_zero_type, "tile_zero", {rows, colbytes}, Call::Intrinsic); + Expr subtile_idx = Ramp::make(0, 1, 256); + return Store::make(new_name, std::move(val), std::move(subtile_idx), Parameter(), const_true(256), ModulusRemainder()); } -Stmt convert_to_tile_store(const Store *op, const string &amx_name, int tile_x, int tile_y) { - auto tile = get_2d_tile_index(op->index); - if (tile.result && tile.extent[0] == tile_x && tile.extent[1] == tile_y) { - auto out = Variable::make(Handle(), op->name); - auto tile_type = op->value.type().with_lanes(256); - auto tile_val = Load::make(tile_type, amx_name, Ramp::make(0, 1, 256), {}, {}, const_true(256), {}); - auto bytes = op->value.type().bytes(); - internal_assert(bytes == 4) << "AMX store only supported for int32 and float32 output, not for " << op->value.type() << "\n"; - // {tile_x, tile_y, var, base, stride} - auto store = Call::make(Int(32), "tile_store", {tile_x, tile_y * bytes, std::move(out), tile.base * bytes, tile.stride[0] * bytes, std::move(tile_val)}, Call::Intrinsic); - return Evaluate::make(std::move(store)); - } - return {}; +Stmt convert_to_tile_store(const Store *op, const string &amx_name, int I, int J) { + auto fail = [&](const char *reason) { + user_error << "Store of AMX register to memory not supported. " + << reason << ".\n" + << Stmt(op); + return Stmt{}; + }; + + if (!is_const_one(op->predicate)) { + return fail("The store has a predicate"); + } + MultiRamp mr; + if (!is_multiramp(op->index, Scope::empty_scope(), &mr)) { + return fail("The store index is not affine"); + } + if (mr.total_lanes() != I * J) { + return fail("There are too many lanes for the deduced matrix shape"); + } + + // Coerce the index into the canonical 2D shape: stride-1 inner of + // lanes J, row-stride outer of lanes I. Either dim may be + // extent 1 — strides_for_shape returns a zero stride for those slots. + std::vector mr_strides; + if (!mr.strides_for_shape({J, I}, &mr_strides)) { + return fail("The store index is incompatible with the deduced matrix shape"); + } + if (J > 1 && !is_const_one(mr_strides[0])) { + return fail("The innermost stride of the store index is not one"); + } + Expr x_stride = mr_strides[1]; + + auto out_var = Variable::make(Handle(), op->name); + auto tile_type = op->value.type().with_lanes(256); + Expr subtile_idx = Ramp::make(0, 1, 256); + auto tile_val = Load::make(tile_type, amx_name, std::move(subtile_idx), {}, {}, const_true(256), {}); + auto bytes = op->value.type().bytes(); + // This should have been caught earlier, so internal assert + internal_assert(bytes == 4) + << "AMX store only supported for int32 and float32 output, not for " + << op->value.type() << "\n"; + auto store = Call::make(Int(32), "tile_store", + {I, J * bytes, std::move(out_var), + mr.base * bytes, x_stride * bytes, std::move(tile_val)}, + Call::Intrinsic); + return Evaluate::make(std::move(store)); } class ExtractTileOperations : public IRMutator { @@ -563,12 +417,89 @@ class ExtractTileOperations : public IRMutator { string tile_name; string amx_name; - vector pending_stores; + int pass = 0; bool in_allocate = false; - int found_tile_x = -1; - int found_tile_y = -1; - int found_tile_r = -1; - AMXOpType op_type; + int found_I = -1; + int found_J = -1; + int found_K = -1; + + // An AMXTile allocation may represent multiple AMX accumulator + // registers as 2D sub-tiles. This map tracks those. + std::vector amx_subtiles; + + // Returns a unique subtile index for a load or store index, or -1 if it + // overlaps with an existing subtile, or is otherwise poorly behaved. + int get_subtile(const Expr &index) { + MultiRamp mr; + if (!is_multiramp(index, Scope::empty_scope(), &mr)) { + user_error << "Access to AMX tile not affine: " << index << "\n"; + } + if (!can_prove(mr.alias_free())) { + // What are you doing? + user_error << "Access to AMX tile may have duplicated lanes: " << index << "\n"; + } + if (amx_subtiles.empty()) { + amx_subtiles.push_back(std::move(mr)); + return 0; + } + + // All strides and lanes must match across all subtiles, or we give up. + const MultiRamp &first = amx_subtiles[0]; + if (mr.dimensions() != first.dimensions()) { + user_error + << "Access to AMX tile does not have the same shape as other accesses to the same memory."; + return -1; + } + for (int i = 0; i < first.dimensions(); i++) { + if (!can_prove(mr.strides[i] == first.strides[i]) || + mr.lanes[i] != first.lanes[i]) { + user_error + << "Access to AMX tile has different size and strides to other " + << "accesses to the same memory. All accesses must have the same " + << "subtile size and strides: " << index; + } + } + + // Now check for disjointedness + // Add a synthetic dimension, the purpose of which will become clear. + mr.strides.emplace_back(); + mr.lanes.push_back(2); + for (int i = 0; i < (int)amx_subtiles.size(); i++) { + auto &other = amx_subtiles[i]; + // One of two things must be true: + // 1) All of the lanes of mr equal the corresponding lane of + // other. We've already checked the strides and lanes, so it's just + // a matter of checking the base. + if (can_prove(mr.base == other.base)) { + return i; + } + + // 2) None of the lanes or mr equal any of the lanes of other. To do + // this we'll construct a combined mr that can be either 'mr' or + // 'other', and ask if it's alias-free. This is what the synthetic + // dimension was for. + mr.strides.back() = mr.base - other.base; + if (!can_prove(mr.alias_free())) { + user_error + << "Failed to prove access to AMX does not partially overlap " + << "another distinct access: " << index; + return -1; + } + } + + // Didn't already exist and didn't alias with anything. + mr.strides.pop_back(); + mr.lanes.pop_back(); + amx_subtiles.push_back(std::move(mr)); + return (int)amx_subtiles.size() - 1; + } + + // Returns an index expression for a given load or store index. user_asserts if impossible + std::string get_subtile_name(const Expr &index) { + int idx = get_subtile(index); + internal_assert(idx >= 0); // errors handled already + return amx_name + std::to_string(idx); + } Stmt visit(const Allocate *op) override { if (op->memory_type == MemoryType::AMXTile) { @@ -577,30 +508,29 @@ class ExtractTileOperations : public IRMutator { (op->type.is_float() && op->type.bits() == 32)) << "scheduled tile operations must yield 32-bit integers or 32-bit floats"; - if (op->type.is_int() && op->type.bits() == 32) { - op_type = AMXOpType::Int8; - } else { - op_type = AMXOpType::Bfloat16; - } - - user_assert(!in_allocate) << "Already in AMX allocation: " << amx_name; - ScopedValue old_amx_name(amx_name, op->name + ".amx"); + // We only support one live AMX allocation at a time for now + user_assert(!in_allocate) + << "Already in AMX allocation at allocation for " << op->name + << ". We do not currently support multiple nested AMX matrix multiplies."; + ScopedValue old_amx_name(amx_name, op->name + ".amx."); ScopedValue old_tile_name(tile_name, op->name); ScopedValue old_in_alloc(in_allocate, true); Stmt body = op->body; - pending_stores.clear(); + pass = 0; + body = mutate(body); + user_assert(found_I >= 0 && found_J >= 0 && found_K >= 0) + << op->name << " is stored in AMXTile memory, but no matrix multiply " + << "operation was found that stores to it, so the shape of the tile " + << "was unable to be determined.\n"; + pass = 1; body = mutate(body); - if (found_tile_x < 0 || found_tile_y < 0 || found_tile_r < 0) { - return op; - } - if (!pending_stores.empty()) { - // Really only need to go over the pending stores - body = mutate(body); - } - auto alloc_type = amx_op_type_result_type(op_type); - return Allocate::make(amx_name, alloc_type, MemoryType::AMXTile, {1}, const_true(), body); + for (int i = 0; i < (int)amx_subtiles.size(); i++) { + body = Allocate::make(amx_name + std::to_string(i), op->type.element_of(), + MemoryType::AMXTile, {256}, const_true(), body); + } + return body; } return IRMutator::visit(op); } @@ -609,63 +539,81 @@ class ExtractTileOperations : public IRMutator { if (op->name != tile_name) { return op; } - return Free::make(amx_name); + Stmt s; + for (int i = 0; i < (int)amx_subtiles.size(); i++) { + Stmt f = Free::make(amx_name + std::to_string(i)); + if (s.defined()) { + s = Block::make(std::move(s), std::move(f)); + } else { + s = std::move(f); + } + } + return s; } Stmt visit(const ProducerConsumer *op) override { if (op->name != tile_name) { return IRMutator::visit(op); } - auto body = mutate(op->body); return ProducerConsumer::make(amx_name, op->is_producer, std::move(body)); } Expr visit(const Load *op) override { - // Any tile load will be matched elsewhere, so a load here means that - // the AMX tile is used outside of a tile instruction. user_assert(op->name != tile_name) << "AMX tile allocation used outside a tile instruction"; return IRMutator::visit(op); } Stmt visit(const Store *op) override { + // There are three operations on a tile register: + // 1) Zero-initialization + // 2) Matrix multiply + // 3) Stores to memory + + // For the matrix multiply we can deduce the tile shape. The stores to + // memory and zero-intialization may be flat loads and stores, but to + // emit the code we need to know the shape. We do two passes - in the + // first we just recognize the matrix multiplies, and in the second we + // recognize the initializations and stores. + + // All three convert ops either succeed, or do their own user_error internally. + if (op->name != tile_name) { const auto *load = op->value.as(); if (!load || load->name != tile_name) { return op; } - auto store = convert_to_tile_store(op, amx_name, found_tile_x, found_tile_y); - user_assert(store.defined()) << "Store to AMX tile allocation of a non-tile value"; - return store; + if (pass == 1) { + return convert_to_tile_store(op, get_subtile_name(load->index), found_I, found_J); + } else { + return op; + } } - auto matmul = convert_to_matmul(op, amx_name, op_type); - if (matmul.result) { - user_assert( - (found_tile_x < 0 || matmul.tile_x == found_tile_x) && - (found_tile_y < 0 || matmul.tile_y == found_tile_y) && - (found_tile_r < 0 || matmul.tile_r == found_tile_r)) - << "Found different tile sizes for AMX tile allocation"; - found_tile_x = matmul.tile_x; - found_tile_y = matmul.tile_y; - found_tile_r = matmul.tile_r; + std::string subtile_name = get_subtile_name(op->index); - return matmul.stmt; + if (is_const_zero(op->value)) { + if (pass == 1) { + return convert_to_zero(op, subtile_name, found_I, found_J); + } else { + return op; + } } - if (found_tile_x < 0 || found_tile_y < 0) { - pending_stores.emplace_back(op); + if (pass == 0) { + auto matmul = convert_to_matmul(op, subtile_name); + user_assert((found_I < 0 || matmul.I == found_I) && + (found_J < 0 || matmul.J == found_J) && + (found_K < 0 || matmul.K == found_K)) + << "Found inconsistent tile sizes for AMX tile allocation across multiple " + << "matrix multiplies that store to it."; + found_I = matmul.I; + found_J = matmul.J; + found_K = matmul.K; + return matmul.stmt; + } else { return op; } - - auto zero = convert_to_zero(op, found_tile_x, found_tile_y, amx_name); - if (zero.defined()) { - return zero; - } - - // Otherwise there is some other operation using the allocation, so we cannot use the AMX instructions - user_error << "Found non-tile operations for AMX tile allocation"; - return op; } }; diff --git a/src/IROperator.cpp b/src/IROperator.cpp index c729539daa29..ceb22c4fccf0 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -441,7 +441,7 @@ Expr lossless_cast(Type t, return lossless_cast(t, c->value, scope, cache); } } else if (const Broadcast *b = e.as()) { - Expr v = lossless_cast(t.element_of(), b->value, scope, cache); + Expr v = lossless_cast(t.with_lanes(b->value.type().lanes()), b->value, scope, cache); if (v.defined()) { return Broadcast::make(v, b->lanes); } diff --git a/src/MultiRamp.cpp b/src/MultiRamp.cpp index 6ecbb4d3fcd0..3bd52e14d99e 100644 --- a/src/MultiRamp.cpp +++ b/src/MultiRamp.cpp @@ -36,6 +36,19 @@ void collapse_adjacent_dims(MultiRamp *m) { } // namespace +MultiRamp::MultiRamp(Expr base, std::vector strides, std::vector lanes) + : base(std::move(base)), strides(std::move(strides)), lanes(std::move(lanes)) { + internal_assert(this->strides.size() == this->lanes.size()); + for (size_t i = this->lanes.size(); i-- > 0;) { + internal_assert(this->lanes[i] >= 1); + if (this->lanes[i] == 1) { + this->strides.erase(this->strides.begin() + i); + this->lanes.erase(this->lanes.begin() + i); + } + } + collapse_adjacent_dims(this); +} + // Multiramps with compatible lanes form a vector space. Here is scalar multiplication. void MultiRamp::mul(const Expr &e) { internal_assert(e.type().is_scalar()); @@ -113,6 +126,62 @@ bool MultiRamp::add(const MultiRamp &other) { } } +bool MultiRamp::strides_for_shape(const std::vector &target_lanes, + std::vector *out_strides) const { + // We rely on the canonical-form invariant that adjacent dims of *this + // with aligned strides have already been collapsed (see + // collapse_adjacent_dims). That makes the mapping easy: each MR dim + // entirely consumes some consecutive target dims whose lane-counts + // multiply up to that MR dim's lane count. If a target dim's lane + // count doesn't divide what's left of the current MR dim, no mapping + // exists — merging multiple MR dims to fill one target dim would + // require their strides to align, but if they did, they'd already be + // a single dim. + int total = 1; + for (int n : target_lanes) { + internal_assert(n >= 1) << "target_lanes entries must be >= 1\n"; + total *= n; + } + if (total != total_lanes()) { + return false; + } + + out_strides->clear(); + out_strides->reserve(target_lanes.size()); + Type t = base.type(); + + size_t target_idx = 0; + auto consume_extent_one_targets = [&]() { + while (target_idx < target_lanes.size() && target_lanes[target_idx] == 1) { + out_strides->push_back(make_zero(t)); + target_idx++; + } + }; + + for (size_t mr_idx = 0; mr_idx < lanes.size(); mr_idx++) { + int needed = lanes[mr_idx]; + int accumulated = 1; + while (needed > 1) { + consume_extent_one_targets(); + if (target_idx >= target_lanes.size()) { + return false; + } + int n = target_lanes[target_idx]; + if (needed % n != 0) { + return false; + } + out_strides->push_back(simplify(strides[mr_idx] * accumulated)); + accumulated *= n; + needed /= n; + target_idx++; + } + } + consume_extent_one_targets(); + // Total-lane check up front guarantees we've consumed every target dim. + internal_assert(target_idx == target_lanes.size()); + return true; +} + namespace { // Divide (or mod) a MultiRamp by a positive integer k. Returns a new diff --git a/src/MultiRamp.h b/src/MultiRamp.h index 35daef494b8e..583d257d93f4 100644 --- a/src/MultiRamp.h +++ b/src/MultiRamp.h @@ -43,6 +43,30 @@ struct MultiRamp { std::vector strides; std::vector lanes; + MultiRamp() = default; + + /** Construct from explicit dim lists, applying the standard normalization + * — extent-1 dims are dropped (the invariants forbid them) and adjacent + * dims whose strides line up (`strides[i+1] == strides[i] · lanes[i]`) + * are merged. Strides and lanes must be the same length, and every lane + * count must be >= 1. */ + MultiRamp(Expr base, std::vector strides, std::vector lanes); + + /** Try to express this MultiRamp's lane sequence using a different + * dim shape. The target shape is given as a vector of lane counts + * (innermost first). On success, *out_strides is filled with one Expr + * per target dim — these are the strides such that a MultiRamp with + * `base = this->base`, `lanes = target_lanes`, `strides = *out_strides` + * enumerates the same lane values as `*this`. Slots in target_lanes of + * value 1 receive a zero stride. + * + * Returns false if target_lanes has a different total lane count than + * this MultiRamp, or if the dim factorizations can't be aligned (the + * gcd walk gets stuck on coprime factors, or merging adjacent dims + * requires their strides to align in a way that doesn't hold). */ + bool strides_for_shape(const std::vector &target_lanes, + std::vector *out_strides) const; + /** Multiply by a scalar. Always a multiramp. */ void mul(const Expr &e); diff --git a/src/StageStridedLoads.cpp b/src/StageStridedLoads.cpp index 896a33b5193e..7496fa31f850 100644 --- a/src/StageStridedLoads.cpp +++ b/src/StageStridedLoads.cpp @@ -110,7 +110,10 @@ class FindStridedLoads : public IRVisitor { if (const Allocate *const *a_ptr = allocation_scope.find(op->name)) { a = *a_ptr; } - found_loads[Key{op->name, base, stride, r->lanes, op->type, a, s}][offset].push_back(op); + // Don't mess with loads from tile memory + if (!a || a->memory_type != MemoryType::AMXTile) { + found_loads[Key{op->name, base, stride, r->lanes, op->type, a, s}][offset].push_back(op); + } } } } @@ -152,6 +155,17 @@ class FindStridedLoads : public IRVisitor { IRVisitor::visit(op); } + void visit(const Store *op) override { + // Don't mess with the loads inside a matrix multiply op. Those are + // natively supported as 2D loads and must remain naked loads. + if (auto *alloc = allocation_scope.find(op->name); + alloc && (*alloc)->memory_type == MemoryType::AMXTile) { + return; + } else { + IRVisitor::visit(op); + } + } + using IRVisitor::visit; }; diff --git a/src/Type.h b/src/Type.h index 8a023429e5e6..174c53de4961 100644 --- a/src/Type.h +++ b/src/Type.h @@ -374,6 +374,10 @@ struct Type { /** Return Type with the same type code and number of lanes, but with at least twice as many bits. */ Type widen() const { + if (is_bfloat()) { + // Widening a bfloat16 should produce a float32. + return with_code(Float).with_bits(32); + } if (bits() == 1) { // Widening a 1-bit type should produce an 8-bit type. return with_bits(8); diff --git a/test/correctness/tiled_matmul.cpp b/test/correctness/tiled_matmul.cpp index c25c3f706653..ea44064902db 100644 --- a/test/correctness/tiled_matmul.cpp +++ b/test/correctness/tiled_matmul.cpp @@ -1,4 +1,5 @@ #include "Halide.h" +#include "halide_test_dirs.h" #include using namespace Halide; @@ -45,7 +46,7 @@ void fill_buffer_b(Buffer &buf, int col, int acc) { } bool equal_eps(float lhs, float rhs, float eps) { - return std::abs(lhs - rhs) < eps; + return std::abs(lhs - rhs) / (std::max(std::abs(lhs), std::abs(rhs)) + 1e-10f) < eps; } struct make_uint_t { @@ -88,7 +89,7 @@ void print_mat_rhs(const Buffer &buf, int rows, int cols) { } template -bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { +bool matmul(int col, int row, int acc, int tile_x, int tile_y, int tile_r, bool use_intrinsic) { Buffer A_buf(acc, row); Buffer B_buf(4, col, acc / 4); @@ -96,12 +97,25 @@ bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { RDom r(0, acc); Func mm("matmul"); - mm(x, y) = cast(0); - mm(x, y) += cast(A_buf(r, y)) * cast(B_buf(r % 4, x, r / 4)); - Var rxi("rxi"), ryi("ryi"); + mm(x, y) = 0; + + if (use_intrinsic) { + mm(x, y) += widening_mul(A_buf(r, y), B_buf(r % 4, x, r / 4)); + } else { + mm(x, y) += cast(A_buf(r, y)) * cast(B_buf(r % 4, x, r / 4)); + } + + Var rxi("rxi"), ryi("ryi"), xi("xi"), yi("yi"); RVar rri("rri"), rro("rro"); + // An outer layer of tiling is necessary to reuse repeated subtile + // loads. But if you go too big you'll run out of tile registers and + // compilation will fail (The LLVM AMX register allocator will spill, but it + // seems to be fussy about it). Doing this also tests the case of more than + // one matrix multiply operation applied to a single AMXTile allocation. + int outer_tile_x = col > tile_x ? 2 : 1, outer_tile_y = row > tile_y ? 2 : 1; + mm.compute_at(mm.in(), x) .store_in(MemoryType::AMXTile) .update() @@ -111,20 +125,28 @@ bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { .atomic() .vectorize(rri) .vectorize(rxi) - .vectorize(ryi); + .vectorize(ryi) + .tile(x, y, xi, yi, outer_tile_x, outer_tile_y) + .reorder(rri, rxi, ryi, xi, yi, rro, x, y) + .unroll(xi) + .unroll(yi); Var ixi("ixi"), iyi("iyi"); mm.compute_at(mm.in(), x) .tile(x, y, ixi, iyi, tile_x, tile_y) .vectorize(ixi) - .vectorize(iyi); + .vectorize(iyi) + .unroll(x) + .unroll(y); // schedule the consumer Var mmxi("mmxi"), mmyi("mmyi"); mm.in() - .tile(x, y, mmxi, mmyi, tile_x, tile_y) - .vectorize(mmxi) - .vectorize(mmyi); + .tile(x, y, mmxi, mmyi, tile_x * outer_tile_x, tile_y * outer_tile_y) + .vectorize(mmxi, tile_x) + .vectorize(mmyi, tile_y) + .unroll(mmxi) + .unroll(mmyi); Func result = mm.in(); @@ -133,7 +155,15 @@ bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { Buffer out(col, row); - result.realize(out); + Target target = get_jit_target_from_environment(); + if (target.has_feature(Target::AVX512_SapphireRapids)) { + result.realize(out); + } else { + // Just compile it to see if anything crashes + result.compile_to_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.s", + {A_buf, B_buf}, Target{"x86-64-linux-avx512_sapphirerapids-no_asserts-no_runtime-no_bounds_query"}); + return true; + } // uncomment to check the matrices // std::cout << "Matrix A\n"; @@ -159,11 +189,10 @@ bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { } } - std::cout << "Success!\n"; return true; } -bool matmul_bf16(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { +bool matmul_bf16(int col, int row, int acc, int tile_x, int tile_y, int tile_r, bool use_intrinsics) { Var x("x"), y("y"); Buffer A(acc, row); Buffer B(2, col, acc / 2); @@ -171,8 +200,12 @@ bool matmul_bf16(int row, int col, int acc, int tile_x, int tile_y, int tile_r) RDom r(0, acc, "acc"); Func mm("matmul"); - mm(x, y) = cast(0); - mm(x, y) += cast(cast(A(r.x, y))) * cast(B(r.x % 2, x, r.x / 2)); + mm(x, y) = 0.f; + if (use_intrinsics) { + mm(x, y) += widening_mul(A(r.x, y), B(r.x % 2, x, r.x / 2)); + } else { + mm(x, y) += cast(A(r.x, y)) * cast(B(r.x % 2, x, r.x / 2)); + } Var rxi("rxi"), ryi("ryi"); RVar rri("rri"), rro("rro"); @@ -212,7 +245,14 @@ bool matmul_bf16(int row, int col, int acc, int tile_x, int tile_y, int tile_r) // result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul_bf16.ll", {A, B}, target); // result.compile_to_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.s", {A, B}, target); - result.realize(out); + Target target = get_jit_target_from_environment(); + if (target.has_feature(Target::AVX512_SapphireRapids)) { + result.realize(out); + } else { + // Just compile it to see if anything crashes + result.compile_to_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.s", {A, B}, Target{"x86-64-linux-avx512_sapphirerapids"}); + return true; + } // uncomment to check the matrices // std::cout << "Matrix A\n"; @@ -229,7 +269,7 @@ bool matmul_bf16(int row, int col, int acc, int tile_x, int tile_y, int tile_r) for (int k = 0; k < acc; ++k) { val += static_cast(A(k, j)) * static_cast(B(k % 2, i, k / 2)); } - if (!equal_eps(val, out(i, j), 0.03f)) { + if (!equal_eps(val, out(i, j), 0.01f)) { std::cerr << "Invalid result at " << i << ", " << j << "\n" << out(i, j) << " != " << val << "\n" << "Matrix dims: " << row << "x" << col << "x" << acc << "\nTile dims: " << tile_x << "x" << tile_y << "x" << tile_r << "\n"; @@ -238,7 +278,6 @@ bool matmul_bf16(int row, int col, int acc, int tile_x, int tile_y, int tile_r) } } - std::cout << "Success!\n"; return true; } @@ -247,17 +286,62 @@ auto matmul_us = &matmul; auto matmul_su = &matmul; auto matmul_uu = &matmul; -bool run_tests(bool (*fn)(int, int, int, int, int, int), int element_width) { - return fn(2, 2, 16, 2, 2, 8 / element_width) && fn(4, 4, 8, 4, 4, 8 / element_width) && fn(32, 32, 32, 8, 8, 8 / element_width) && fn(32, 32, 32, 8, 8, 4 / element_width); +bool run_tests(bool (*fn)(int, int, int, int, int, int, bool), int element_width) { + struct Cfg { + int col, row, acc, tx, ty, tr; + bool intrin; + }; + Cfg cfgs[] = { + {2, 2, 16, 2, 2, 8 / element_width, true}, + {4, 4, 8, 4, 4, 8 / element_width, false}, + {32, 32, 32, 8, 8, 8 / element_width, true}, + {32, 32, 32, 8, 8, 4 / element_width, false}, + + // Asymmetric tiles + {32, 16, 32, 8, 4, 8 / element_width, true}, + {16, 32, 32, 4, 8, 8 / element_width, false}, + {32, 32, 32, 8, 4, 4 / element_width, true}, + {32, 32, 32, 4, 8, 4 / element_width, false}, + + // Degenerate along I or J axis (a matrix-vector multiply) + {8, 8, 64, 2, 1, 64 / element_width, false}, + + {8, 8, 64, 1, 2, 64 / element_width, false}, + + // Degenerate along both (a dot product) + {1, 1, 64, 1, 1, 64 / element_width, false}, + + // A matmul scheduled as individual dot products + {8, 8, 64, 1, 1, 64 / element_width, false}, + + // The size of the intermediate vector (and the number of multiplies + // done) is IJK. AMX requires: + // I <= 16, + // K <= 64 / element_width + // IJ <= 256 + // JK <= 1024 / element_width + // IK <= 1024 / element_width + // Under these constraints, IJK is maximized when I = J = 16, K = 64 / element_width. + // So this is the setting that does the most work for us. + {32, 32, 64, 16, 16, 64 / element_width, false}, + + // Larger-than-native tiles (unsupported for now, may destructure into + // multiple ops in future) + // {64, 64, 64, 32, 16, 8 / element_width, true}, + }; + for (const auto &c : cfgs) { + std::cerr << "Testing col=" << c.col << " row=" << c.row << " acc=" << c.acc + << " tx=" << c.tx << " ty=" << c.ty << " tr=" << c.tr << "\n"; + if (!fn(c.col, c.row, c.acc, c.tx, c.ty, c.tr, c.intrin)) { + std::cerr << "Failed at col=" << c.col << " row=" << c.row << " acc=" << c.acc + << " tx=" << c.tx << " ty=" << c.ty << " tr=" << c.tr << "\n"; + return false; + } + } + return true; } int main(int argc, char **argv) { - Target t = get_jit_target_from_environment(); - if (!t.has_feature(Target::AVX512_SapphireRapids)) { - printf("[SKIP] No AMX target enabled\n"); - return 0; - } - printf("Running AMX matmul (signed/signed)\n"); if (!run_tests(matmul_ss, 1)) { return 1; @@ -283,5 +367,7 @@ int main(int argc, char **argv) { return 1; } + printf("Success!\n"); + return 0; } diff --git a/test/correctness/tiled_matmul_errors.cpp b/test/correctness/tiled_matmul_errors.cpp new file mode 100644 index 000000000000..303c927fff13 --- /dev/null +++ b/test/correctness/tiled_matmul_errors.cpp @@ -0,0 +1,322 @@ +// Exercises the user-facing error paths in ExtractTileOperations.cpp. Each +// scenario below is the most natural-looking pattern that triggers a particular +// error. This doubles as a TODO list of patterns we'd ideally support but +// currently reject. +// +// The test verifies that each scenario produces a Halide::CompileError +// (a user error) rather than crashing or hitting an internal assert. + +#include "Halide.h" +#include + +using namespace Halide; + +namespace { + +const Target amx_target("x86-64-linux-avx512_sapphirerapids"); + +// Run `body` and assert it produces a Halide user error. +template +bool expect_user_error(const char *name, F body) { + try { + body(); + } catch (const CompileError &e) { + printf("[%s] OK: %s\n", name, e.what()); + return true; + } catch (...) { + printf("[%s] FAIL: expected a CompileError but got a different exception\n", name); + return false; + } + printf("[%s] FAIL: expected a user error but none was raised\n", name); + return false; +} + +// Apply a stock AMX matmul schedule to `mm` (with reduction var `r`) and +// the given tile sizes. +void schedule_matmul(Func mm, RVar r, int tile_x, int tile_y, int tile_r) { + Var x("x"), y("y"), rxi("rxi"), ryi("ryi"); + RVar rri("rri"), rro("rro"); + mm.compute_at(mm.in(), x) + .store_in(MemoryType::AMXTile) + .update() + .tile(x, y, rxi, ryi, tile_x, tile_y, TailStrategy::GuardWithIf) + .split(r, rro, rri, tile_r) + .reorder(rri, rxi, ryi, rro, x, y) + .atomic() + .vectorize(rri) + .vectorize(rxi) + .vectorize(ryi); + + Var ixi("ixi"), iyi("iyi"); + mm.compute_at(mm.in(), x) + .tile(x, y, ixi, iyi, tile_x, tile_y) + .vectorize(ixi) + .vectorize(iyi); + + Var mmxi("mmxi"), mmyi("mmyi"); + mm.in() + .tile(x, y, mmxi, mmyi, tile_x, tile_y) + .vectorize(mmxi) + .vectorize(mmyi); +} + +// A tile too large for an AMX register (rows > 16). Triggers the explicit +// row-count check in convert_to_matmul. +void scenario_too_large() { + Buffer A(64, 64); + Buffer B(4, 64, 16); + Var x("x"), y("y"); + RDom r(0, 64, "r"); + + Func mm("matmul_large"); + mm(x, y) = cast(0); + mm(x, y) += cast(A(r, y)) * cast(B(r % 4, x, r / 4)); + schedule_matmul(mm, r.x, /*tile_x=*/32, /*tile_y=*/16, /*tile_r=*/8); + mm.in().compile_jit(amx_target); +} + +// AMXTile allocated for a non-i32/f32 result. AMX always accumulates into +// 32-bit registers, so we reject. Triggers the user_assert in +// visit(Allocate). +void scenario_bad_result_type() { + Buffer A(64, 64); + Buffer B(4, 64, 16); + Var x("x"), y("y"); + RDom r(0, 64, "r"); + + Func mm("matmul_i16"); + mm(x, y) = cast(0); + mm(x, y) += cast(A(r, y)) * cast(B(r % 4, x, r / 4)); + schedule_matmul(mm, r.x, 8, 8, 8); + mm.in().compile_jit(amx_target); +} + +// The most natural form of an int8 matmul: row-major LHS and RHS, no VNNI +// packing. AMX requires the RHS to be pre-packed as (4, cols, rows/4); a +// row-major (col, row) RHS isn't expressible as an AMX tile load, so we +// reject. Triggers the layout check in convert_to_matmul. +void scenario_naive_rhs() { + Buffer A(64, 64); + Buffer B(64, 64); + Var x("x"), y("y"); + RDom r(0, 64, "r"); + + Func mm("matmul_naive"); + mm(x, y) = cast(0); + mm(x, y) += cast(A(r, y)) * cast(B(x, r)); + schedule_matmul(mm, r.x, 8, 8, 8); + mm.in().compile_jit(amx_target); +} + +// A gather-style matmul with an indirect row index — natural for sparse / +// pruned matmul, indirect attention, embedding lookups. The LHS load index +// goes through a table lookup, so the multiramp lift fails. Triggers the +// "loads indices are not affine" path in convert_to_matmul. +void scenario_indirect() { + Buffer A(64, 64); + Buffer B(4, 64, 16); + Buffer row_indices(64); + Var x("x"), y("y"); + RDom r(0, 64, "r"); + + Func mm("matmul_indirect"); + mm(x, y) = cast(0); + mm(x, y) += + cast(A(r, clamp(row_indices(y), 0, 10))) * + cast(B(r % 4, x, r / 4)); + schedule_matmul(mm, r.x, 8, 8, 8); + mm.in().compile_jit(amx_target); +} + +// A 1D convolution of a 2D signal with per-row kernels, aggressively +// vectorized. Structurally a sum-of-widening-multiplies with a contiguous +// inner K, but the LHS depends on x, k, and y simultaneously (no broadcast +// dim) and so doesn't match the AMX matmul shape. Triggers the +// access-pattern / layout check. +void scenario_conv1d() { + Buffer input(128, 128); + Buffer kernels(4, 64, 8); + Var x("x"), y("y"); + const int K = 32; + RDom r(0, K, "r"); + + Func conv("conv"); + conv(x, y) = cast(0); + conv(x, y) += + cast(input(x + r, y)) * + cast(kernels(r % 4, x, r / 4)); + schedule_matmul(conv, r.x, 8, 8, 8); + conv.in().compile_jit(amx_target); +} + +// A non-matmul value scheduled into an AMXTile allocation by mistake. +// Triggers the "no matrix multiply was found" assertion. +void scenario_no_matmul() { + Var x("x"), y("y"), xo("xo"), yo("yo"), xi("xi"), yi("yi"); + + Func f("f"); + f(x, y) = 0; + f.compute_at(f.in(), xo) + .store_in(MemoryType::AMXTile) + .vectorize(x, 8) + .vectorize(y, 8); + + f.in().tile(x, y, xo, yo, xi, yi, 8, 8).vectorize(xi).vectorize(yi); + f.in().compile_jit(amx_target); +} + +// A user wants i16 × i16 → i32, expecting AMX to widen-multiply 16-bit +// inputs the way it does 8-bit. AMX's tdpb*sd instructions only support +// 8-bit input lanes (and bf16 for floats); we reject anything else. The +// reduction inside is still a widening multiply, but with the wrong inner +// element width, so the type check fires. +void scenario_widening_16bit() { + Buffer A(64, 64); + Buffer B(4, 64, 16); + Var x("x"), y("y"); + RDom r(0, 64, "r"); + + Func mm("matmul_i16_input"); + mm(x, y) = cast(0); + mm(x, y) += cast(A(r, y)) * cast(B(r % 4, x, r / 4)); + schedule_matmul(mm, r.x, 8, 8, 8); + mm.in().compile_jit(amx_target); +} + +// A user gives the same Func two update definitions that each store into +// the same AMXTile allocation but with different tile sizes (e.g. a fast +// path for the bulk of K and a smaller fallback). The matcher requires +// every matmul touching a given allocation to agree on tile dimensions. +void scenario_inconsistent_tiles() { + Buffer A(64, 64), C(64, 64); + Buffer B(4, 64, 16), D(4, 64, 16); + Var x("x"), y("y"); + RDom r1(0, 64, "r1"), r2(0, 64, "r2"); + + Func mm("matmul_two_updates"); + mm(x, y) = cast(0); + mm(x, y) += cast(A(r1, y)) * cast(B(r1 % 4, x, r1 / 4)); + mm(x, y) += cast(C(r2, y)) * cast(D(r2 % 4, x, r2 / 4)); + + Var rxi("rxi"), ryi("ryi"); + RVar rri("rri"), rro("rro"); + + mm.compute_at(mm.in(), x).store_in(MemoryType::AMXTile); + mm.update(0) + .tile(x, y, rxi, ryi, 8, 8, TailStrategy::GuardWithIf) + .split(r1.x, rro, rri, 8) + .reorder(rri, rxi, ryi, rro, x, y) + .atomic() + .vectorize(rri) + .vectorize(rxi) + .vectorize(ryi); + mm.update(1) + .tile(x, y, rxi, ryi, 4, 4, TailStrategy::GuardWithIf) + .split(r2.x, rro, rri, 8) + .reorder(rri, rxi, ryi, rro, x, y) + .atomic() + .vectorize(rri) + .vectorize(rxi) + .vectorize(ryi); + + Var ixi("ixi"), iyi("iyi"); + mm.compute_at(mm.in(), x) + .tile(x, y, ixi, iyi, 8, 8) + .vectorize(ixi) + .vectorize(iyi); + + Var mmxi("mmxi"), mmyi("mmyi"); + mm.in() + .tile(x, y, mmxi, mmyi, 8, 8) + .vectorize(mmxi) + .vectorize(mmyi); + + mm.in().compile_jit(amx_target); +} + +// A reduction inside an AMXTile that's a sum-of-something-else, not a +// vector_reduce_add of a widening multiply. Here we accumulate a +// non-multiplied value, which produces a Store whose RHS is not a +// vector-reduce-of-multiply. +void scenario_not_a_matmul_pattern() { + Buffer A(64, 64); + Var x("x"), y("y"); + RDom r(0, 64, "r"); + + Func mm("not_matmul"); + mm(x, y) = cast(0); + // Sum without a multiply. + mm(x, y) += A(r, y); + + Var rxi("rxi"), ryi("ryi"); + RVar rri("rri"), rro("rro"); + mm.compute_at(mm.in(), x) + .store_in(MemoryType::AMXTile) + .update() + .tile(x, y, rxi, ryi, 8, 8, TailStrategy::GuardWithIf) + .split(r.x, rro, rri, 8) + .reorder(rri, rxi, ryi, rro, x, y) + .atomic() + .vectorize(rri) + .vectorize(rxi) + .vectorize(ryi); + + Var ixi("ixi"), iyi("iyi"); + mm.compute_at(mm.in(), x) + .tile(x, y, ixi, iyi, 8, 8) + .vectorize(ixi) + .vectorize(iyi); + + Var mmxi("mmxi"), mmyi("mmyi"); + mm.in() + .tile(x, y, mmxi, mmyi, 8, 8) + .vectorize(mmxi) + .vectorize(mmyi); + + mm.in().compile_jit(amx_target); +} + +// Multiplication of a matrix by a value. In theory we could materialize the +// value into memory (e.g. as a scaled identity matrix), but we don't, and the +// matcher rejects. +void scenario_matmul_by_constant() { + Buffer A(64, 64); + Var x("x"), y("y"); + RDom r(0, 64, "r"); + + Func mm("matmul_by_constant"); + mm(x, y) = cast(0); + mm(x, y) += cast(A(r, y)) * select(x == r, cast(3), cast(0)); + schedule_matmul(mm, r.x, 8, 8, 8); + mm.in().compile_jit(amx_target); +} + +} // namespace + +int main(int argc, char **argv) { + if (!exceptions_enabled()) { + printf("[SKIP] Halide was compiled without exceptions; this test " + "needs them to catch each scenario's user error.\n"); + return 0; + } + + int failures = 0; + + failures += !expect_user_error("too_large", scenario_too_large); + failures += !expect_user_error("bad_result_type", scenario_bad_result_type); + failures += !expect_user_error("naive_rhs", scenario_naive_rhs); + failures += !expect_user_error("indirect", scenario_indirect); + failures += !expect_user_error("conv1d", scenario_conv1d); + failures += !expect_user_error("no_matmul", scenario_no_matmul); + failures += !expect_user_error("widening_16bit", scenario_widening_16bit); + failures += !expect_user_error("inconsistent_tiles", scenario_inconsistent_tiles); + failures += !expect_user_error("not_a_matmul_pattern", scenario_not_a_matmul_pattern); + failures += !expect_user_error("matmul_by_constant", scenario_matmul_by_constant); + + if (failures != 0) { + printf("%d scenario(s) failed to produce a user-facing CompileError\n", failures); + return 1; + } + printf("Success!\n"); + return 0; +} diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp index 03bd243ef554..916aa354eb57 100644 --- a/test/performance/tiled_matmul.cpp +++ b/test/performance/tiled_matmul.cpp @@ -84,6 +84,13 @@ bool matmul(Halide::Target target) { // should be (4, cols, rows / 4) for int8, or (2, cols, rows / 2) for bf16. // This means that the rows must always be divisible by 4 (or 2 for bf16). ImageParam B(rhs(8), 3, "rhs"); + // Constrain B's innermost dim to exactly 4 contiguous elements (the + // VNNI K-pack), and the next dim's stride to 4. AMX's tile_load expects + // each output column to be K bytes packed contiguously; without these + // constraints the strides are symbolic and the AMX matcher conservatively + // rejects. + B.dim(0).set_stride(1).set_extent(4); + B.dim(1).set_stride(4); RDom r(0, acc); @@ -171,6 +178,9 @@ bool matmul_bf16(Halide::Target target) { Var x("x"), y("y"); ImageParam A(BFloat(16), 2, "lhs"); ImageParam B(BFloat(16), 3, "rhs"); + // Same VNNI-pack constraint as the int8 case, but with K=2 for bf16. + B.dim(0).set_stride(1).set_extent(2); + B.dim(1).set_stride(2); RDom r(0, acc, "acc");