Skip to content

Fix AMX support using MultiRamp#9122

Draft
abadams wants to merge 78 commits intomainfrom
abadams/fix_amx
Draft

Fix AMX support using MultiRamp#9122
abadams wants to merge 78 commits intomainfrom
abadams/fix_amx

Conversation

@abadams
Copy link
Copy Markdown
Member

@abadams abadams commented May 5, 2026

Rewrites the AMX support to use MultiRamp. This, I believe, fixes the outstanding bugs in AMX support identified by #8350

Validated by running the AMX tests under SDE.

Future work is generalizing the AMX support to be willing to ingest larger vectors, and automatically slice it up into multiple tile-level operations. More TODO scenarios are in the test tiled_matmul_errors.cpp

abadams and others added 30 commits January 26, 2026 15:52
The previous comment reported a time that seemed to have regressed. It
was not 8.2ms on main - more like 11
Before:

Computing best tile sizes for each type
.................................................
bytes, tile width, tile height, bandwidth (GB/s):
1 8 8 20.9997
1 16 8 20.8329
1 8 16 18.5702
1 8 32 17.2463
1 8 64 14.312

2 8 16 19.2047
2 8 8 18.8368
2 16 8 17.0593
2 8 32 17.0591
2 4 8 15.7681

4 8 8 24.9364
4 4 16 22.9699
4 8 16 22.5743
4 4 32 22.255
4 4 8 20.4468

8 8 8 38.4094
8 16 4 28.4167
8 16 8 27.6184
8 8 4 27.6062
8 8 16 26.8693

After:

Computing best tile sizes for each type
.................................................
bytes, tile width, tile height, bandwidth (GB/s):
1 16 32 34.1921
1 16 16 31.8399
1 8 16 25.575
1 16 64 25.1665
1 32 16 25.0061

2 8 32 28.2635
2 8 16 27.7648
2 16 16 27.2126
2 16 32 23.9034
2 8 8 23.6345

4 8 16 34.5303
4 8 8 28.3653
4 16 8 26.8521
4 8 32 26.084
4 16 16 24.4519

8 8 8 33.7163
8 8 4 29.1339
8 4 16 26.418
8 16 4 25.4663
8 2 8 24.3949
Also better algorithm for innermost containing stmt
abadams and others added 12 commits April 22, 2026 14:42
Add hand-picked tests for MultiRamp API properties that weren't
previously covered: mul, operator==, alias_free_slice (unique lanes /
zero-stride peeling / degenerate scalar), rotate_stride_one_innermost
(rotation + transpose round-trip), and is_multiramp round-trips for a
handful of shapes.

Add test_random to transposed_vector_reduce.cpp: 1000 random
quasi-affine store/load index pairs over a 3-dim RDom, each compiled
scalarly and with .atomic().vectorize() across all three RVars and
compared. This test found all three bugs fixed in the preceding
commit.

Co-authored-by: Claude <noreply@anthropic.com>
Ran a weak subagent (Haiku) over the MultiRamp PR as an adversarial
comprehension test — asking it to explain the code in detail, then
fixing whatever it got wrong. The theory: if a weaker model
misreads something, the comment is probably unclear, not the model.

Fixes prompted by the review:

- Simplify_Exprs.cpp / Simplify_Stmts.cpp: stale "outermost" wording
  from before rotate_stride_one_outermost was renamed to
  rotate_stride_one_innermost. The comments contradicted the function
  name and Haiku echoed the contradiction.
- MultiRamp.h alias_free: state explicitly that the returned Expr is
  a sufficient (not necessary) condition for lane uniqueness.
- MultiRamp.h alias_free_slice: clarify that kept dims are a subset
  preserving order, not necessarily a prefix.
- VectorizeLoops.cpp: rename ContainingLoop -> UnrolledLoop and note
  that the peeled dims are fully unrolled into a flat Block, not a
  runtime loop nest (despite the old name).
- MultiRamp.h alias_free_slice: note that stride-zero and purely
  symbolic strides always peel (added by Andrew directly).

A second Haiku pass after these edits answered every question
correctly, including the ones it got wrong the first time.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replaces the four hand-rolled AST pattern matchers in
ExtractTileOperations with a uniform multiramp-based approach. The
previous matchers were brittle (Add operand ordering, tile_x/tile_y
swap, missing stride checks, etc. — see #8350) and had separate
branches for the broadcast / collapsed / general RHS cases.

The new flow lifts each load index to a MultiRamp, then coerces it
into the canonical AMX shape via MultiRamp::strides_for_shape (a new
one-sided gcd-walk that returns per-target-dim strides), and reads
off the extracted strides from known slot positions. The inner-K /
broadcast / row-stride bits all fall out of the same path.

Also adds:
- A normalizing MultiRamp constructor that drops extent-1 dims.
- MultiRamp::strides_for_shape to map an MR's lane sequence onto a
  caller-specified dim shape.
- Asymmetric tile-size cases to the correctness test (tile_x !=
  tile_y) — these would have caught the swap bug listed in #8350.
- Constraints on the perf-test ImageParams' VNNI inner stride so the
  contiguity check now folds away cleanly.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@abadams abadams marked this pull request as draft May 5, 2026 22:53
@alexreinking
Copy link
Copy Markdown
Member

alexreinking commented May 8, 2026

Ignore this PR until after MultiRamp has gone in

We might want to look into this: https://github.github.com/gh-stack/

abadams and others added 5 commits May 8, 2026 09:59
A correctness test that exercises ten of the user-facing error paths in
ExtractTileOperations.cpp. Each scenario is the most natural-looking
matmul pattern that triggers a particular reject, doubling as a TODO
list of cases we'd ideally support but currently don't:

- too_large            tile_x > 16
- bad_result_type      i8 * i8 -> i16 (AMX always accumulates i32/f32)
- naive_rhs            row-major RHS without VNNI packing
- indirect             gather-style A(r, row_indices(y)) * B(...)
- conv1d               1D conv of a 2D signal (LHS depends on x, k, y)
- no_matmul            store_in(AMXTile) on a non-matmul Func
- widening_16bit       i16 * i16 -> i32 (only 8-bit inputs supported)
- inconsistent_tiles   one Func with two updates at different tile sizes
- not_a_matmul_pattern row-sum into AMXTile (no multiply)
- scaled_matmul        A(r, y) * 3 (RHS hoisted out of the reduce)

The harness wraps each scenario in try/catch around Halide::CompileError.
Halide is sometimes built without exceptions; in that case the test
prints [SKIP] and exits 0 since the catch path can't fire.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Comment on lines +6 to +32
// The test verifies that each scenario produces a Halide::CompileError
// (a user error) rather than crashing or hitting an internal assert.

#include "Halide.h"
#include <stdio.h>

using namespace Halide;

namespace {

const Target amx_target("x86-64-linux-avx512_sapphirerapids");

// Run `body` and assert it produces a Halide user error.
template<typename F>
bool expect_user_error(const char *name, F body) {
try {
body();
} catch (const CompileError &e) {
printf("[%s] OK: %s\n", name, e.what());
return true;
} catch (...) {
printf("[%s] FAIL: expected a CompileError but got a different exception\n", name);
return false;
}
printf("[%s] FAIL: expected a user error but none was raised\n", name);
return false;
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this what the test/error directory is for?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test/error are tests we expect to crash, which makes it hard to test more than one thing per binary. In at least one other case where we test lots of failure modes in a single test we catch exceptions instead (see test/correctness/invalid_gpu_loop_nests.cpp)

Comment on lines +305 to +314
failures += !expect_user_error("too_large", scenario_too_large);
failures += !expect_user_error("bad_result_type", scenario_bad_result_type);
failures += !expect_user_error("naive_rhs", scenario_naive_rhs);
failures += !expect_user_error("indirect", scenario_indirect);
failures += !expect_user_error("conv1d", scenario_conv1d);
failures += !expect_user_error("no_matmul", scenario_no_matmul);
failures += !expect_user_error("widening_16bit", scenario_widening_16bit);
failures += !expect_user_error("inconsistent_tiles", scenario_inconsistent_tiles);
failures += !expect_user_error("not_a_matmul_pattern", scenario_not_a_matmul_pattern);
failures += !expect_user_error("matmul_by_constant", scenario_matmul_by_constant);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At a glance, I thought that the string was something that we expected to find in e.what(), but this isn't the case. That would be useful, though.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 8, 2026

Codecov Report

❌ Patch coverage is 6.05187% with 326 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (main@aade603). Learn more about missing BASE report.

Files with missing lines Patch % Lines
src/ExtractTileOperations.cpp 1.40% 278 Missing and 3 partials ⚠️
src/MultiRamp.cpp 7.31% 35 Missing and 3 partials ⚠️
src/Deinterleave.cpp 55.55% 2 Missing and 2 partials ⚠️
src/StageStridedLoads.cpp 62.50% 1 Missing and 2 partials ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main    #9122   +/-   ##
=======================================
  Coverage        ?   69.87%           
=======================================
  Files           ?      256           
  Lines           ?    77749           
  Branches        ?    18555           
=======================================
  Hits            ?    54326           
  Misses          ?    17953           
  Partials        ?     5470           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants