feat(moe): layout-API MXFP4 (a4w4/a8w4) MoE gemm#753
Open
coderfeli wants to merge 75 commits into
Open
Conversation
Add a layout-API MXFP4 MoE up/gate + down-proj gemm that consumes the
standard (opus) sort contract from moe_sorting_kernel directly -- gemm1
gathers A rows via sorted_token_ids & 0xFFFFFF; gemm2 scatters per sorted
row via global atomic add weighted by sorted_weights. No fused-sort extras
(m_indices / reverse_sorted) are needed.
kernels/mxfp4_moe_layout.py - layout-API building blocks (fx.copy B/B-scale,
fx.gemm scaled-MFMA atoms)
kernels/mxfp4_moe_common.py - shared raw helpers / constants / K-derived size
formulas / atomic bf16 epilogue
kernels/mxfp4_moe_gemm1.py - BM32 up/gate gemm (a4w4 + a8w4, interleave +
separated, nt/cached, out fp4/fp8)
kernels/mxfp4_moe_gemm2.py - BM32 atomic down-proj (a4w4 + a8w4 fp8 input)
kernels/mxfp4_moe_gemm_2stage.py - public API + host launchers
Wire a4w4/a8w4 of tests/kernels/test_moe_gemm.py::test_moe_gemm_2stage to the
new pipeline (opus sort -> gemm1 -> gemm2 atomic) vs an independent dequant-MoE
reference; a8w4 added to the in_dtype matrix.
Validated on gfx950: chain cosine a4w4=0.988, a8w4=1.000 (interleave + separated);
test_moe_gemm_2stage fp4/a8w4 over FP4-S/M/L pass.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…oe_dispatcher
Collapse the 5 partly-duplicated MXFP4 MoE layout-API modules into 3:
* utils.py -- basics (shape constants, K-size formulas, pointer/LDS
helpers, e8m0/SwiGLU quant math); no mma/load
* moegemm.py -- gemm1 + gemm2 device bodies in one file, plus the shared
B/B-scale fx.copy + fx.gemm layout primitives, the A-LDS
loaders, and the atomic bf16 epilogue
* moe_dispatcher.py-- compile_gemm1/2_a4w4_port + launch dispatch
(mxfp4_moe_gemm1/2, caches, gemm1_grid)
Removes the helper duplication (gemm1 previously re-defined _raw/_lds_ptr3/
_silu_mul_batch/... already in common) and makes the stack self-contained (no
imports from the v1 modules). Pure code-motion -- no behavioral change.
Deletes mxfp4_moe_{common,layout,gemm1,gemm2,gemm_2stage}.py and rewires the
test import (kernels.mxfp4_moe_gemm_2stage -> kernels.moe_dispatcher).
Validated on gfx950: test_moe_gemm_2stage layout-API path (gemm1->gemm2,
opus-sort), a4w4 + a8w4, S/M/L -- numeric gate (mean_row_cos > 0.85) passes.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
01a5bd8 to
6db7724
Compare
- Drop the redundant BM32 constants: N_LOAD_WAVES (only fed ROWS_PER_WAVE), ROWS_PER_WAVE -> BM//4, BN_INT -> inline BN//4, M_REPS -> kMChunks (both BM//16). Keep BM, kAStages, kSubBlocks, kMChunks. - Condense the module docstrings and the verbose inline comments across all three files (keep the load-bearing ones). Numerically identical substitutions; GPU re-run (a4w4 + a8w4, gemm1->gemm2) still passes the numeric gate. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
7eed205 to
5802e13
Compare
…lpers Globals: collapse the gemm1/gemm2 value-duplicate pairs into one name each -- NE_DEFAULT/NE -> NE, K_DEFAULT/N_OUT -> H_DEFAULT (model_dim), INTER_DEFAULT/K -> INTER_DEFAULT (inter_dim). Funcs: the per-body issue_a_ds_read and the per-J MMA cluster were duplicated between gemm1 and gemm2; hoist them to module-level _issue_a_ds_read_dt / _mma_one_j and call from both. Drop the now-unused _gemm_mma closure. Trace-time-only change (identical emitted IR); GPU re-run (a4w4 + a8w4, gemm1->gemm2, S+M) still passes the numeric gate. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…smem) Replace the raw flydsl.utils.smem_allocator SmemAllocator/SmemPtr usage with the layout-API shared-memory idiom used by the main gemm kernels (preshuffle_gemm_v2, flash_attn_gfx950): @fx.struct class SharedStorage: buf: fx.Array[Int8, lds_bytes, 16] lds = fx.SharedAllocator().allocate(SharedStorage).peek() lds_base_i32 = fx.Int32(fx.ptrtoint(lds.buf.ptr)) The single i8 buffer mirrors the prior union LDS (s_aq | s_asc carved from the front; the f32 acc aliases the whole region); the bodies now take one i32 base and derive s_aq/s_asc/acc sub-bases by offset. All raw buffer_load_lds / ds-read addressing is unchanged. Drops the manual allocator.finalize() in the launch jit (SharedAllocator finalizes automatically) and the SmemPtr _view_cache resets. utils: _lds_base_ptr3(view) -> _lds_base3(base_i32); drop memref_dialect. GPU re-run (a4w4 + a8w4, gemm1->gemm2, S+M) passes the numeric gate. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Migrate the activation gather (a), gemm1 activation-scale (ascale), and the output/ascaleout stores of the layout-API MXFP4 MoE GEMM from raw rocdl/llvm intrinsics to fx.copy copy-atoms, matching the existing b/bscale paths so all global memory traffic goes through one layout-API abstraction. Direct-to-LDS loads and global stores port 1:1 via BufferCopy / BufferCopyLDS atoms, including the data-dependent A-gather with bounds-checked OOB-zero for padded rows. The gemm2 ascale load (a scalar global->register feed consumed inline by the MFMA) is kept raw: routing it through a register fragment regresses gemm2 ~4-5%, with no perf-neutral fx.copy equivalent. A single _flat_buf_view helper backs all four buffer views (fold/no-fold, bounds/max-size via args). Correctness unchanged (fp4 cos~0.988, a8w4 cos~0.9996); per-shape perf at parity on the stable t=4096 shape (interleaved A/B). Adds an MXFP4_BENCH-gated run_perftest harness; the layout-API MXFP4 path had no built-in benchmark. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…pers Hoist the flat buffer-tensor view and the BufferCopyLDS128b DMA atom out of moegemm.py into fp8_gemm_utils.py as the single source of truth, removing the duplicated copies the fx.copy migration had introduced: - flat_buffer_view(arg, base_elems, elem_ty, *, align, elem_bytes, fold, num_records_bytes): the flat (1,1) buffer view from a raw i64 address (fold/no-fold, max-size/num-records bounds). All 5 moegemm call sites use it. - lds_dma_atom_128(): the shared BufferCopyLDS128b atom, now also used by G2SLoader (was inlined identically in both). StoreC's typed-buffer view path and G2SLoader's hardcoded-stride load loop are left as-is: they start from typed kernel args / fixed LDS strides and do not fit the MoE raw-address, swizzled, data-dependent-gather case. Behavior byte-identical: cos unchanged (fp4 ~0.988, a8w4 ~0.9996), per-shape perf at parity (interleaved A/B). fp8 gemm rowscale test still passes (G2SLoader touched). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…format Remove the raw-path leftovers ruff flagged after the migration: the unused `_lds_ptr3` import and the now-dead `aq_rsrc` / `ascale_rsrc` (+ its `ascale_num`) buffer resources and `out_row` (the fx.copy views carry this addressing). Apply black/ruff (line-length 120). No behavior change: cos unchanged (fp4 ~0.988, a8w4 ~0.9996). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…onflict) main already ignores .humanize/ (plus benchmark CSV + .rocprofv3 patterns); the migration's incidental .humanize* line conflicted with that. Net .gitignore diff vs base is now empty. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Replace the raw llvm.load / llvm.StoreOp at the 7 scalar i32/f32 sites (expert-id and sorted-token-id global reads; the gemm1/gemm2 f32 accumulator LDS store + read-back) with FlyDSL's idiomatic typed-pointer indexing via two small helpers (_global_typed_ptr / _lds_typed_ptr). Element-indexing also drops the manual *4 byte arithmetic. Vector LDS ds-reads, the invariant=True metadata loads (no invariant flag on ptr_load), and the swizzled ds-read paths stay on llvm.* by design. Identical IR: cos unchanged (fp4 ~0.988, a8w4 ~0.9996), perf parity (interleaved A/B). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…calar reads Replace the raw llvm.load LDS reads (vec<2xi64>/vec<4xi32> A ds-reads, the scalar A-scale ds-read, and the vec<2xf32> epilogue read-back) with a reusable _lds_vec_load helper built on fx.ptr_load(result_type=...). Same load instruction, same byte offsets, swizzle math untouched. Drops the now-unused _gep3 / _lds_base3 byte-GEP helpers. Only llvm.* left are the invariant=True metadata loads (ptr_load has no invariant flag) and the atomic fadd accumulate (no typed-ptr equivalent). Identical IR: cos unchanged (fp4 ~0.988, a8w4 ~0.9996), perf parity (interleaved A/B, re-confirmed at 200 iters). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The mi355 CI job drives test_moe_gemm.py as a script with --gemm2_mode both. For the layout-API MXFP4 path (fp4/a8w4) the reduce-mode combo calls pytest.skip(), which outside a pytest session raises an uncaught Skipped and crashes the runner (exit 1) despite all correctness checks passing. Catch pytest.skip.Exception in run_one and report it as a printed skip. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Collapse the module + helper docstrings, the multi-line inline comments, and the 3-line `# ===` section banners to one/two lines across moegemm.py + utils.py. Hard-won rationale kept tersely (B-load waterfall warning, OOB-zero, e8m0 path, AddressSpace.Shared=2 gotcha). Comment-only; cos unchanged (fp4 ~0.988 / a8w4 0.9996). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Token-level rename (NAME tokens only; strings/comments untouched) of 86 leading-underscore identifiers across moegemm.py / utils.py / moe_dispatcher.py / fp8_gemm_utils.py: module funcs, MoE-only helpers, and locals. Two locals that would shadow module funcs renamed distinctly (_b_copy_atom->b_catom, _scale_mma_atoms->mma_atoms). Kept flydsl._mlir module path. Dropped dead _lds_ptr3/_PTR3 and the unused gemm1 asc_per_mb (surfaced by F841 post-rename). Pure rename: cos unchanged (fp4 0.988 / a8w4 0.9997), ruff/black clean. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Replace the last raw arith ops with fx spellings; all operands are wave-uniform non-negative values (<2^31), so signed/unsigned and zero/sign-extend are bit-identical here: ExtUIOp(i64) -> fx.Int64(x) (col/scale base extend) TruncIOp(i16) -> fx.Int16(x) (e8m0 pair pack) index_cast(i32) -> fx.Int32(thread_id) index_cast(index) -> fx.Index(x) (buffer num_records) cmpi(ult)+select -> (bexp < 254).select(bexp, 254) divui -> // (udiv helper) Drops the now-unused `arith` import from both files. cos unchanged (fp4 0.9881 / a8w4 0.9996); interleaved A/B parity held (within noise). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…m load - Comments/docstrings collapsed to a single line across moegemm/utils/ moe_dispatcher/fp8_gemm_utils (incl the long flat_buffer_view + module docstrings and the `# ===` banners); key rationale kept tersely. - run_compiled: drop moe_dispatcher's duplicate; reuse the shared tensor_shim._run_compiled (aliased), folding in the ir.Context-leak cleanup and fixing its exe.cf -> exe._cf cache bug. Call sites pass *args. - cumsum0: raw `llvm.load(T.i32, global_ptr1(...))` -> typed `global_typed_ptr(arg_cumsum, T.i32)[0]`; drop now-dead global_ptr1 + unused llvm/ir imports. cos unchanged (fp4 0.9881 / a8w4 0.9997); ruff/black clean. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…pecific) kernels/utils.py was 100% MoE-specific (created by the MoE consolidation, only imported by the MoE files). A generic name for MoE config + kernel geometry + quant helpers is misleading, so fold it all into moegemm.py where it's used: - shape constants (NE/TOPK/H/INTER defaults, BN/BK/KH_TILE/kStages/...), K-derived size formulas, raw/pointer/LDS helpers, e8m0/SwiGLU math. - inline the two identical *_c_k1_for wrappers (0 external users). - moe_dispatcher now imports these from .moegemm; decouple fp8_gemm_utils by taking `raw` from flydsl.expr.arith (_to_raw) instead of the MoE module. - delete kernels/utils.py. Pure move (no logic change): cos unchanged (fp4 0.9882 / a8w4 0.9996); ruff/black clean. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…rs, dedup raw/udiv Rename kernels/moegemm.py to kernels/mxmoe_gemm_v2.py and update the moe_dispatcher import + fp8_gemm_utils comment. Remove the K-/N-derived size helper defs (k_half_for, k_tiles_total_for, kunroll_for, kbs_stride_n0_dw_for, kas_per_chunk_dw_for, num_n_blocks_for, kbs_per_expert_dw_for, kmchunks) and inline their compile-time arithmetic at the gemm2 + epilogue call sites, mirroring gemm1's existing inline style. Drop the local raw()/udiv() helpers: use fx._to_raw (imported as _raw) and plain // (signed floordiv, matching the old udiv). Correctness unchanged: tests/kernels/test_moe_gemm.py fp4+a8w4 subset 69 passed / 30 skipped (gfx950), identical to baseline. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…rs, dedup raw/udiv (#765) Rename kernels/moegemm.py to kernels/mxmoe_gemm_v2.py and update the moe_dispatcher import + fp8_gemm_utils comment. Remove the K-/N-derived size helper defs (k_half_for, k_tiles_total_for, kunroll_for, kbs_stride_n0_dw_for, kas_per_chunk_dw_for, num_n_blocks_for, kbs_per_expert_dw_for, kmchunks) and inline their compile-time arithmetic at the gemm2 + epilogue call sites, mirroring gemm1's existing inline style. Drop the local raw()/udiv() helpers: use fx._to_raw (imported as _raw) and plain // (signed floordiv, matching the old udiv). Correctness unchanged: tests/kernels/test_moe_gemm.py fp4+a8w4 subset 69 passed / 30 skipped (gfx950), identical to baseline. Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
topk only feeds gemm1_grid host-side block sizing and never enters the compiled kernel (not in the kernel name, not in the body). Removing it from the get_g1 cache key and compile signature collapses all topk values for a given shape onto a single compiled kernel. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
NE is host-only: it caps the active-expert count in gemm1_grid but never
enters either compiled body (it was an unused kwarg on gemm1_body_v2 /
gemm2_body_v2 and an ne{NE} segment in both kernel names). Remove it from the
kernel names, the get_g1/get_g2 cache keys, and the compile signatures so all
expert counts for a given shape share one compiled kernel. Public
mxfp4_moe_gemm1/2 keep NE for the host grid calc.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
inter_dim is the gemm2 down-proj contraction. It was baked into the kernel
name (i{INTER}) and the get_g2 cache key, forcing one compiled kernel per
inter_dim. Convert it to a runtime fx.Int32 (i32_inter):
- The contraction K-loop becomes an scf.for over a runtime trip count
(K_TILES = inter_dim/256), carrying the C accumulators (fp4 fragments /
fp8 accm) as loop-carried state.
- B and B-scale are now streamed one K-tile per iteration into fresh
fragments instead of preloading all K_TILES tiles into registers; the
bq_view/bscale_view strides are K-independent constants so only the
wave-uniform base offsets become runtime. The view shape K-axis and the
triple-buffered A LDS are bounded by a compile-time cap INTER_MAX.
- A->LDS is streamed inside the loop for K_TILES > kStages.
gemm2 kernel name now carries imax{INTER_MAX} (not i{INTER}); get_g2 keys on
INTER_MAX, so all inter_dim values <= cap share one compiled kernel. The
public mxfp4_moe_gemm2 API is unchanged.
Perf (gemm2, median-of-5, gfx950): fp4 t8192/7168,512 909->901us;
fp4 t16384/5120,1536 1042->994us; a8w4 t8192/7168,512 922->905us;
a8w4 t16384/5120,1536 1855->1070us (streaming B removes the
register-resident-B spill at larger K). Cold correctness unchanged
(69 passed).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
inter_dim is the gemm1 up/gate-proj N-output dim. It has no contraction
K-loop dependency (the K-loop is hidden_dim), so converting it to a runtime
fx.Int32 (i32_inter) only turns N-block sizing, B/B-scale base offsets, and
the epilogue output strides (N_OUT, NUM_N_BLOCKS, kBS_per_expert_dw,
OUT_AS_PER_CHUNK_DW, K_G2_BYTES) into runtime i32 ops. gemm1 LDS and the
B-view K-axis depend only on hidden_dim, so no compile-time cap is needed.
The gemm1 kernel name drops i{INTER}; get_g1 drops inter_dim from its cache
key. Together with the gemm2 conversion, inter_dim is now fully runtime:
all inter_dim values share ONE gemm1 and ONE gemm2 compiled kernel. The
public mxfp4_moe_gemm1 API is unchanged.
Perf (median-of-5, gfx950) at parity vs base 9f359ee: gemm1 fp4
t8192/7168,512 777->736us, fp4 t16384/5120,1536 ~2527us (base ~2527 stable;
the earlier 2314 was a transient high-clock sample), a8w4 ~870/~2643us.
Cold correctness unchanged (69 passed).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…ipe_config Extend the per-(shape,token) dispatcher to emit (BM, epilog, bm_stage1, persist): - bm_stage1=128 for high-expert small-inter families (DSV3/Kimi/DSV4) at large M (>=4096 tok, allow_bm128-gated); stage2 tile stays <=64 (BM128 is stage1-only, regresses stage2/GPT-OSS per F1). - persist ON for those same high-expert large-M families (F2 +5-17%); OFF for GPT-OSS (byte-identical one-shot grid). Harness uses BM_S1 for gemm1 + A-scale shuffle, BM for gemm2, and a shared SBM=lcm(BM_S1,BM). Default path (BM_S1==BM==32) stays byte-identical. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The fp8-A (a8w4) gemm2 persist path is a known-broken F2 combo: the multi-iteration grid-stride corrupts the fp8-A accumulator/LDS state and yields cos=0 at large M (reproduces on rlcr/moe-persist-sbm alone; fp4 persist is correct to 32768 tok). - Dispatcher: drop DeepSeekV4 (a8w4) from _PERSIST_MIN_TOK so select_pipe_config never emits persist for the fp8-A path (DSV3/Kimi fp4 keep persist). - compile_gemm2_a4w4_port: raise a clear AssertionError on a8w4+persist so the manual knob can't silently ship garbage. Re-enable once fp8-A persist is fixed. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Land the consolidated MXFP4 MoE pipe (swiglu/reduce/valid_mask/graph,
cached-B, BM{16,32,64,128}, persist fenced for fp8-A, SBM, per-token CSV
dispatch, tok32768 crash fixes, a8w4-BM64 atomic fix, arith.minimumf) onto
the PR head, which had upstreamed the moegemm->mxmoe_gemm_v2 rename + helper
dedup via #765.
Conflict resolution (3 files):
- fp8_gemm_utils.py: kept PR-head/main (identical) version verbatim; #765
moved flat_buffer_view + lds_dma_atom_128 out of this module.
- mxmoe_gemm_v2.py / moe_dispatcher.py: took our superset (already deduped:
_raw, //, inline K//BK), which contains both the #765 rename/dedup and all
our features. Verified main touched none of these 3 files (mxmoe/dispatcher
are PR-only; fp8_gemm_utils identical in main vs PR head), so no main
content is dropped.
- Relocated flat_buffer_view + lds_dma_atom_128 into mxmoe_gemm_v2.py
(matching #765's placement) and dropped the fp8_gemm_utils import; added
`arith` to mxmoe imports for flat_buffer_view.
arith.py + test_moe_gemm.py applied cleanly.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…1 byte-identical)
Repartition the 256-thread (4-wave) block as num_n_waves x k_wave: each K-wave
group processes K/k_wave of the model_dim contraction into its own A-LDS region
and cshuffle slab, then the k_wave partials are summed in LDS before the shared
silu+quant epilogue. Threaded through compile_gemm1_a4w4_port/get_g1/
mxfp4_moe_gemm1 as a compile-time cache-key dim with a _kw{N} name suffix.
Guards (aiter-parity): k_wave in {1,2,4}, fp4-only, interleave-only, K/BK %
k_wave == 0, 4*eff_tile_n <= tile_k. Default k_wave=1 is compile-gated to emit
byte-identical IR (final-ISA md5 verified equal for BM16 and BM32).
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
… real LDS-fits guard
- s_asc_base must sit past ALL k_wave A-staging regions (was overlapping region 1
at kw>1, causing NaN); kw=1 unchanged (byte-identical re-verified equal md5).
- Replace aiter's per-group 4*tile_n<=tile_k scratch guard (inapplicable to this
full-[BM,BN]-slab reduction) with an actual gfx950 160KB LDS-fits check.
k_wave is correct (cos>0.99 at DSV3 fp4 m=2,4,8 for BM{16,32} kw{1,2,4}) and
byte-identical at the default kw=1, but a PERF NO-GO on the BN=256 v2 pipe (see
.humanize/kernel-agent/kwave.md): DSV3-fp4 small-M is block-count bound so
intra-block K-slice adds no occupancy and is monotonically slower.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
… coverage
Parameterize the gemm1 fused gate|up N-tile BN in {64, 256} (compile-time cache-key
dim, _bn{N} name suffix). BN=64 gives N_OUT//64 = 4x more N-blocks than BN=256, the
block-count coverage lever for tiny-M DSV3 fp4 (where pure k_wave at BN=256 was a
NO-GO because small-M is block-count bound). Pairs with k_wave: BN=64 requires
num_n_waves<=2 (each N-wave needs an even NJ>=2 gate+up pair), i.e. k_wave in {2,4};
the tiny-M fix is BN=64 + k_wave=4 (nnw=1).
Parameterized sites (BN=256 compile-gated to stay byte-identical, AC-3 verified):
- NUM_N_BLOCKS = N_OUT//BN; grid num_n_blocks = N_OUT//BN (dispatcher + kernel).
- requant gate/up read: split at BN//2; col-group == n_lane (gate cols n_lane*8+ee).
- aqout / ascaleout stores predicated on n_lane < BN//16 (BN=64: wave_grp==0) so the
shrunk tile never writes a neighbouring n_block.
- ascaleout physical layout rewritten in terms of the ABSOLUTE 32-INTER-col scale
group g = n_block_idx*(BN//64)+wave_grp: ku=g>>3, ikxdl=(g>>2)&1, lane=(g&3);
reduces exactly to the old literals at BN=256, and at BN=64 (g==n_block_idx) yields
the IDENTICAL physical output layout gemm2 consumes (scale-group boundary == BN=64
n_block boundary). aqout byte layout n_block*(BN//4)+n_lane*4 is BN-independent.
- lds_bytes_for(BN=...) sizes the [BM,BN] cshuffle slab.
Guards: BN in {64,256}; BN!=256 is fp4-only + interleave-only + NJ even>=2.
Threaded through compile_gemm1_a4w4_port/get_g1/mxfp4_moe_gemm1 (+ MXFP4_BN test env).
AC-3: BN=256 default final-ISA md5 EQUAL to baseline (BM16 a3318f65..., BM32 d472bdb7...).
AC-1: cold cos>0.99 at DSV3 fp4 m=2,4,8 for BM{16,32} x kw{2,4} at BN=64.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…expert fp4) select_pipe_config now returns (BM, epilog, bm_stage1, persist, bn, k_wave): the high-expert small-inter fp4 families (DSV3 E257, Kimi E384) are block-count bound at tiny M, so m<=2 selects BN=64 (N_OUT//64=8 N-blocks, 4x coverage) + k_wave=4 (nnw=1) for the ~1.5x gemm1 win; BN=256+kw1 (byte-identical default) otherwise (crossover at m~=3-4). gemm2 is BN-independent (unchanged). Manual MXFP4_KW / MXFP4_BN env overrides preserved under MXFP4_DISPATCH=1. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…SBM resolution Cleanup (behavior-preserving, AC-3 byte-identical-default verified fp4+a8w4): - Remove inline_quant (always False; only existed to be rejected) from the gemm1 compile/get_g1/mxfp4_moe_gemm1 chain + cache key + test call. - Remove D_INTER_REAL (always None; only existed to be rejected) from the gemm2 compile/get_g2/mxfp4_moe_gemm2 chain + cache key. - Extract _norm_sbm(SBM, BM) for the 5 duplicated SBM None->BM normalizations. SBM (SBM!=BM) is KEPT: it is required by the BM128-stage1 dispatch path (SBM must be lcm(bm_stage1, BM)); not dead. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…default byte-identical Add a compile-gated has_pad variant to the v2 mx MoE gemm1/gemm2 that saves B-weight bandwidth on padded contraction dims via hardware OOB-load-skip (aiter-style), NOT the prior whole-tile compute-skip (a no-op for GPT-OSS pad=192 < BK=256). Mechanism: size the per-16N-tile B-weight buffer resource (bq_view) to the REAL K extent (K - kpad) so the fully-pad 128-K weight half loads of the tail tile buffer-load OOB -> 0 (no HBM transaction). The K axis is K-major and the half stride (256 i32) exceeds every within-half sub-offset (255), so the num_records cut lands exactly on a half boundary: zeros the fully-pad half, keeps the partial-pad half (host zero-fill). B-scale is NOT shrunk (256-K granular + host 256-align; sub-256 pad saves 0 and risks NaN). - gemm1 K=D_HIDDEN (kpad=model_dim_pad); gemm2 K=inter_dim (kpad=inter_dim_pad), both runtime. - has_pad=False (default): no i32_kpad kernarg, pad math const-folds away. Verified gemm1+gemm2 GPT-OSS-3072 atomic final-ISA md5 UNCHANGED vs a5bb72a (AC-3). - has_pad=True: distinct compile variant (_pad name tag) with the runtime i32_kpad kernarg; shared kernel body extracted to @flyc.jit _gemm{1,2}_kernel_body so the AST rewriter recurses into its scf dispatch. - run wrappers: mxfp4_moe_gemm1(model_dim_pad=), mxfp4_moe_gemm2(inter_dim_pad=). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Thread real (unpadded) dims through test_moe_gemm_2stage + run_mxfp4_moe_2stage and the CLI (--real_dim model,inter, --garbage_pad). The has_pad path host-zero-pads the fp32 weight/A/intermediate pad regions and passes model_dim_pad/inter_dim_pad so the kernel sizes the B-weight buffer resource to the real K (OOB-skips the fully-pad weight halves); the reference is over the full (zero-in-pad) tensors == a real-dim GEMM. --garbage_pad writes 1e30 into ONLY the fully-pad weight-contraction halves the OOB skip drops (partial-pad remainder in a kept half stays host-zero). A correct output then proves the OOB skip never fetches the padded weights into the accumulation. Verified (cold, gpt-oss real2880/pad192, t=256 e=128 k=4): fp4 cos 0.9892 (garbage 0.9893), a8w4 cos 0.9996 (garbage 0.9996) matching the naive padded-3072 baseline (fp4 0.9894). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Add the gemm1 fused gate|up N-output pad-skip to the has_pad variant: skip loading w1 weight rows for the padded inter-N columns (both gate and up halves) via a per-16N-tile buffer-OOB skip. Additive to the existing K-pad OOB skip. Mechanism: each 16-N weight tile is an independent buffer resource (its N column is the resource BASE offset, not an in-view axis), so the K-axis num_records trick does NOT compose on N. Instead, per tile, compute its LOGICAL inter col base (interleave gate|up: gate cols [0,BN//2) and up cols [BN//2,BN) of each 256-block both map to inter [n_block*(BN//2),+BN//2)) and, when that base >= INTER_real (= i32_inter - i32_npad), size the tile buffer to 0 records so every weight load OOB -> 0 (no HBM fetch). Since INTER_real is 16-aligned, every 16-N tile is fully real or fully pad -> no partial-tile epilogue pairing problem: a skipped tile writes 0 into exactly its own gate-or-up output column (verified: per-j bq_frags -> per-J c_frags -> cshuffle lds_col are index-preserving, no cross-contamination), and B-scale (32-N granular, shared, unshrunk) is harmless because 0*scale=0. CORRECTNESS-SAFE with NO epilogue change: the pad-N output feeds gemm2's pad-K input, which gemm2 already OOB-skips (K-skip on this branch), so zeroed pad-N output is never read. - gemm1 threads a new runtime i32_npad kernarg (= inter_dim_pad) alongside i32_kpad, both only in the has_pad variant. has_pad now enables on model_dim_pad>0 OR inter_dim_pad>0. - run wrapper mxfp4_moe_gemm1(inter_dim_pad=). - has_pad=False (default): no i32_npad kernarg, N-skip math emitted only under const_expr(has_pad); the default gemm1 expression keeps its exact original add-associativity. Verified default gemm1 AND gemm2 GPT-OSS-3072 final-ISA md5 UNCHANGED vs base b5dfe2b (AC-3): gemm1 fb8141b6b6f3aee1c91f12a384392e29, gemm2 a2a7f6defe7a7f988625bc1e2424a269. - test: --garbage_pad now poisons w1's pad-N weight ROWS (N-skip drops them); thread inter_dim_pad through run_mxfp4_moe_2stage to gemm1. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…e|up) The initial N-skip used col_in_block<BN//2 to split gate vs up, but interleave mode selects gate/up by j PARITY (in_b=J%2, mfma_cluster) with n0 index J//2 -- gate and up are interleaved at 16-col granularity, NOT a contiguous [0,128)|[128,256) split. The wrong predicate skipped real tiles / kept pad tiles, regressing cos (fp4 0.9893->0.9835, a8w4 0.9996->0.9941). Corrected to match the cshuffle: logical_inter = n_block*(BN//2) + wave_n*gate_span + (j//2)*16, gate_span=(BN//2)//nnw. A gate tile (j even) and its sibling up tile (j+1) share this base -> skipped consistently. Restores correctness (cold, gpt-oss real2880/pad192, t256 e128 k4): fp4 structural cos 0.9893, garbage-N cos 0.9893 (== structural, N-skip proven) a8w4 structural cos 0.9996, garbage-N cos 0.9996 (== structural) matching the K-skip-only baseline. Default (no-pad) gemm1 final-ISA md5 UNCHANGED (fb8141b6b6f3aee1c91f12a384392e29 == base; AC-3). test: fix the garbage-N-pad harness. Poison ONLY w1's pad-N weight ROWS (1e30) and hold w2's pad-K contraction cols at 0; the reference now sums over the REAL inter extent (rI = INTER - inter_dim_pad) so the huge pad rows never enter the (finite) reference act. A full-INTER reference overflowed act on the 1e30 rows -> ref NaN (the kernel out was always finite/correct). rI==INTER when inter_dim_pad==0 (byte-identical full reference). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Formatting only, inside const_expr(has_pad); default gemm1 ISA md5 unchanged (fb8141b6b6f3aee1c91f12a384392e29). black --line-length 120 clean; ruff clean. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
At M with <=1 stage1 m-block per active expert (small/mid M, e.g. GPT-OSS M<=1024), each expert's w1 is single-use: caching it in L2 only allocates lines that are never re-hit (measured L2 hit ~3% either way) while paying the non-streaming cost. Stream (non-temporal) instead. gemm1_use_nt(experts, topk, tokens, bm_stage1) returns the policy; the harness uses it under MXFP4_DISPATCH=1. Byte-identical: only flips the B-load coherence flag, HBM read bytes unchanged (605 vs 617 MB EA0 proxy, flyprof PMC). The shipped non-dispatch default stays use_nt=False (cached). GPT-OSS 3072/3072 E128 k4 gfx950 gemm1, median-of-3+ (run_perftest device time): M=128 a4w4 208->181us (-13%, cos 0.9909), a8w4 209.5->182.8us (cos 0.9996) M=1024 217->192us (-12%) [<=1 block/expert: NT] M=2048 236->249us, M=4096 309->385us [reuse: cached kept] Crossover at m-blocks/expert==1, matching the existing "nt only helps when there is no reuse" comment. This closes the bulk of the apparent gemm1 gap vs aiter cktile a16w4 stage1: apples-to-apples on aiter's own rocprofv3 kernel trace (stage1 median 181.4us, min 173.6us -- the prior 164.5us wall baseline is not reproducible), our NT gemm1 (a8w4 ~184us, dtype-matched to aiter's F8xMXF4) is at parity ~1.01x, both at ~75-76% of gfx950 peak HBM. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
gemm2's w2 down-proj load was still CACHED at no-reuse M (the gemm1 NT fix only touched gemm1's B). At M with <=1 stage2 m-block per active expert (small/mid M, e.g. GPT-OSS M<=1024) each expert's w2 is single-use, so caching it in L2 only allocates lines that are never re-hit while paying the non-streaming cost. Stream (non-temporal) instead. gemm2_use_nt(experts, topk, tokens, bm_stage2) mirrors gemm1_use_nt exactly (reuse metric = m-blocks per active expert, nt when <=1); the harness uses it under MXFP4_DISPATCH=1, and MXFP4_G2_NT forces the policy for measurement. Byte-identical: only flips the w2-load coherence flag, HBM read bytes and correctness unchanged; the shipped non-dispatch default stays use_nt=False (cached, no _nt kernel tag). GPT-OSS 3072/3072 E128 k4 M=128 gfx950, identical-work parity harness, median-of-5 (run_perftest device time): a4w4 gemm2 114.6 -> 104.2us (-9.1%, cos 0.9910) a8w4 gemm2 121.0 -> 112.8us (-6.8%, cos 0.9996) Closes ~half the gemm2 gap vs aiter cktile a16w4 stage2 (92.3us): a4w4 1.24x -> 1.13x, a8w4 1.30x -> 1.22x. Post-NT residual is a pure weight-BW- efficiency gap at identical 551 MB weight traffic, dominated by occupancy (ours 32KB LDS -> 5 blk/CU vs aiter 16KB -> 10 blk/CU); see .humanize/kernel-agent/gptoss-gemm2-rootcause.md. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Mirror gemm1's N-skip into gemm2_body_v2: gemm2's N dimension is the model_dim output, and its w2 N-tiles map 1:1 to model_dim output columns (make_bq_view `col`). Add i32_npad, compute N_real = N_OUT - npad under const_expr(has_pad), and gate each 16-N weight tile's num_records to 0 when col >= N_real via (col < N_real).select(bq_num_records, 0). Fully-pad model_dim output tiles then load OOB -> 0 (no HBM fetch). PERF-ONLY: the pad model_dim output columns are unused (sliced off by the real_model_dim reference), so correctness holds even without the skip; this only drops the wasted w2 fetch. Thread i32_npad through the gemm2 compile/dispatch path (kernel body, has_pad kernel/launch signatures, mxfp4_moe_gemm2 model_dim_pad + pad_args) mirroring gemm1's kpad/npad threading. has_pad now enables on inter_dim_pad>0 OR model_dim_pad>0. Default path (has_pad=False) is byte-identical: all new logic under const_expr(has_pad); verified final ISA/LLVM-IR/MLIR md5 unchanged. test_moe_gemm: under garbage_pad, poison w2[:, real_H:, :] = 1e30 instead of structural 0, and slice the reference/output to :real_H so a correct cos PROVES the N-skip never fetched the garbage w2 output rows. Validation (gfx950, cold, cache disabled): - garbage-pad cos: fp4 0.9893, a8w4 0.9996 (thr 0.85/0.95) - default no-pad ISA/IR/MLIR md5 byte-identical - GPT-OSS same-input harness gemm2: 120.0 -> 116.3 us (-3.1%) Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Port gemm1's kStages=2 B-prefetch double-buffer into gemm2's runtime scf.for K-loop, opt-in via MXFP4_G2_KSTAGES=2 (default 1 = byte-identical). gemm2 was 1-deep: B weight + B-scale loaded SYNCHRONOUSLY per K-tile on the compute path (stream_b_tile). gemm2 is HBM-weight-bound, so prefetching the next K-tile's B one step ahead hides the weight-load latency. Implementation: since the runtime scf.for cannot index a python stage list by the runtime kt, a rotating single-buffer carries the prefetched B-weight (i32x4) + B-scale (i32x1) fragment VALUES through the loop state (alongside the existing C/accm carry). Each iteration consumes the carried "current" B and prefetches the next tile's B; the yield rotates prefetch->current. A prologue loads tile 0. sched_barrier/s_setprio fence the MFMA chain from the B vmem loads, mirroring gemm1_body_v2's main-loop fencing. Default (g2_kstages=1) is proven byte-identical: all 20 IR stages + final binary md5 match base for both gemm2 (default) and gemm1; the 2-stage variant carries a distinct _g2ks2 kernel tag and cache key. Measured (GPT-OSS M=128, NT gemm2, median-of-5, same_input_parity): a4w4 gemm2: 105.4us -> 98.7us (-6.4%); 1.14x -> 1.07x vs aiter 92.3us; cos 0.9910 a8w4 gemm2: 109.8us -> 100.7us (-8.3%); 1.19x -> 1.09x vs aiter; cos 0.9996 VGPR a4w4 80->126, a8w4 98->146; NO spills; LDS 32768 unchanged (still LDS-bound, so this stacks with the epilog-LDS/occupancy lever). gemm1 ISA + stage1 unchanged. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
ATT (rocprofv3 advanced thread trace) of the isolated GPT-OSS gemm2 at M=128 shows ours-BEST (BM32+reduce+G) is weight-VMEM-bound (83% VMEM stall) like aiter (79.5%), but pays an extra ~5pp in barrier idle (11.3% vs 6.6%): the single per-K-iteration gpu.barrier() guarding the A-LDS ring absorbs per-wave VMEM-load latency imbalance. BM16-atomic (169us, VMEM-load back-pressure 45%) confirms the pipeline-underfeed root cause for our wide 16x16x128 scaled-fp4 MFMA. Fix: with the g2_kstages==2 B pipeline, the next-tile B weight+scale prefetch is a GMEM->register stream with no LDS dependency, so it need not be gated by the LDS barrier. Hoisting it ABOVE the barrier (opt-in g2_bhoist, env MXFP4_G2_BHOIST=1) puts the long-latency weight loads in flight before the block syncs, overlapping the barrier wait. ATT after: VMEM-wait 52.6%->34.9%; total stall 10.46M->10.17M. a4w4 gemm2 97.8->95.6 us (-2.2%), 5.67->5.80 TB/s, ratio vs aiter 1.055x->1.028x. a8w4 neutral (heavier A path, not VMEM-wait-bound). cos unchanged (a4w4 0.9910, a8w4 0.9996). Default (flag off) byte-identical: g2ks2 ISA md5 identical (18b5e4d7...); gemm1 untouched. Analysis in gemm2-att-analysis.md. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…vs aiter) ISA-level diff of the 4 gemm2 memory-access classes (weight-B, B-scale, A-scale, A, output) vs aiter cktile a16w4 MoeFlatmmKernel. Weight-B (93% of traffic) is already instruction-for-instruction identical (buffer_load_dwordx4/128b/nt, same per-unit-N count + coalescing, B prefetched 1 K-tile ahead above the barrier via g2ks2+bhoist). The one under-issuing class: the A-scale was loaded SYNCHRONOUSLY at the loop head and gated the whole MFMA cluster on its full VMEM latency with s_waitcnt vmcnt(0) (also draining the fresh B prefetch), whereas aiter prefetches its scale a K-block ahead into a register and never stalls the MFMA on it. Fix (deepdiff lever C): opt-in g2_ascale_pf prefetches the A-scale one K-tile ahead through the existing g2ks2 scf.for carry (rotating single-buffer, same as the B carry). ISA: the loop-head buffer_load_dword + vmcnt(0) is gone; the MFMA cluster starts on the already-resident prefetched scale. - Opt-in g2_ascale_pf param on gemm2_body_v2 (default False = byte-identical), env MXFP4_G2_ASCALE_PF=1, gemm2 dispatcher cache key + _apf name tag. No-op unless g2_kstages==2. gemm1 untouched. - a4w4 97.8 -> 97.1 us (-0.7%), a8w4 102.8 -> 101.9 us (-0.9%) (median-of-5, isolated, cold, same GPU3 session; aiter 95.24 same session -> ratio 1.027x -> 1.020x a4w4). Tighter variance both dtypes. - cos unchanged: a4w4 0.9910 (>0.85), a8w4 0.9996 (>0.95). - Byte-identical default (AC-3): apf-off ISA md5 4fb7bad1... == pre-change (empty diff); g2_kstages=1 shipped default untouched. Report: .humanize/kernel-agent/gemm2-mem-pattern-align.md Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…el balance, replaces xcd)
Port aiter's ACTUAL, XCD-count-INDEPENDENT gemm2 block->tile mapping — the
GemmSpatiallyLocalTilePartitioner grouped 2D rasterization — as our gemm2
block->(m_block_idx, n_block_idx) map, and DROP the non-portable explicit-8-XCD
swizzle (g2_xcd) as the channel-balance mechanism.
KEY FINDING: aiter's shipped gemm2 instantiates the partitioner with GroupNum=1,
M01=1 (moe_cktile2stages_common.cuh:58-59), which mathematically DEGENERATES to
the exact naive m-major linear map we already use (verified by hand + numeric
replay of GetOutputTileIndex, gemm_tile_partitioner.hpp:274-360). So the
partitioner grouping is NOT what balances aiter's channels for this instance; a
verbatim port would be a no-op. The productive port is the GENERAL
GetOutputTileIndex parameterized by (GroupNum,M01), with the grouping ENABLED.
- _spart_output_tile_index(): faithful DSL port of GetOutputTileIndex's
else-branch (M0=total_m_blocks runtime; N0=num_n_blocks/GroupNum/M01
compile-time). Bijection over [0, M0*N0) (no dropped/dup tiles), no hard-coded
XCD count.
- g2_spart opt-in param (env MXFP4_G2_SPART), encoded GroupNum*100+M01 (e.g. 402);
default 0 = off = byte-identical naive linear grid. Threaded through the gemm2
dispatcher cache key + _spart{G}x{M01} kernel-name tag. One-shot grid path only.
gemm1 and mxmoe_gemm_v2.py untouched.
MEASURED (rocprofv3 TCC_EA0_RDREQ, 128 channels; a4w4; GPU3 cold same session):
GroupNum=4,M01=2 (MXFP4_G2_SPART=402) is the winner — per-channel CV 9.68% ->
0.32% (BEATS prior xcd4's 0.64%, approaches aiter's 0.18%), max/min 1.28x ->
1.01x, total HBM reads 4.86M -> 4.50M (aiter parity 4.44M, was 9.5% over). The
portable spatial partitioner MATCHES-OR-BEATS the xcd swizzle on channel balance.
Perf (gemm2 isolated, median-of-5, cold, same GPU3 session, ~12% hotter than the
memalign doc): a4w4 109.4 -> 108.3 us (-1.0%), a8w4 112.0 -> 109.9 us (-1.9%);
aiter same-session 95.3 us.
Correctness (cold, real 2880 dims, spart402): a4w4 cos=0.9910 (>0.85), a8w4
cos=0.9996 (>0.95) — identical to baseline (bijective permutation, same math).
Byte-identical default: gemm2 default kernel ISA md5 8263440a... identical with
spart code present vs stashed (no _spart tag, no extra IR when off). Python style
gate passes.
Doc: .humanize/kernel-agent/gemm2-spatial-partitioner.md.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…n up Wire the PR #753 gemm2 perf stack ON by default (previously opt-in env flags), and trim the now-redundant exploration commentary / opt-in scaffolding. Defaults flipped (env vars remain optional overrides, explicit args still win): MXFP4_G2_KSTAGES 1 -> 2 (2-stage B weight/scale prefetch) MXFP4_G2_BHOIST 0 -> 1 (hoist B prefetch above the LDS barrier) MXFP4_G2_ASCALE_PF 0 -> 1 (A-scale prefetch one K-tile ahead) MXFP4_G2_SPART 0 -> 402 (GroupNum=4,M01=2 spatial tile partitioner; HBM channel balance, replaces the dropped g2_xcd) gemm2-only: gemm1 default ISA is byte-identical (md5 b4a1f0d5...428cff376ce039bb894ad50d matches 7d40f77). DSV3 large-M uses the persist grid (spart bypassed) and the compute-bound large-M shapes are neutral, so no shape-gating is needed. Measured (GPT-OSS M=128 gemm2, cold, same session): 115.6 -> 103 us (-11%) vs opts-off. Correctness held (fp4>0.85, a8w4>0.95 thresholds unchanged). Cleanup: collapse the four per-knob resolution blocks + duplicated `import os`, condense the knob docstrings/tag comments and the K-loop inline commentary in gemm2_body_v2. No kernel math changed. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Aggressive readability cleanup of the PR #753 gemm code, strictly behavior- preserving. Collapse every multi-line/multi-paragraph comment and docstring essay to a single terse line (or delete pure restatements), drop leftover opt-in rationale now that the gemm2 perf stack (G/BHOIST/APF/spart402) is default-on, and remove dead code (self-assignment `aStages = aStages`, a redundant local `import os`, duplicated stream_b_tile/issue_b_load_into body). No kernel math, tile logic, dispatch decisions, numerics, or env overrides changed. Verified: the default gemm1 (cached + nt) and gemm2 (atomic + reduce) final ISA are byte-identical (md5) to 5856379; fp4/a8w4 stage2-standalone and fp4/a8w4 e2e 2stage correctness pass cold, cos thresholds unchanged. full-line comments: mxmoe_gemm_v2 244->99, moe_dispatcher 166->39 LOC: mxmoe_gemm_v2 1583->1417, moe_dispatcher 1421->1205 Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…e keys) Fold the runtime-HIDDEN (model_dim) port onto the cleaned PR #753 head so ALL MoE dims are runtime, keeping the clean-head hygiene (<=1-line comments, no dead code, merged stream_b_tile/issue_b_load_into). gemm1 (contraction K): new i32_hidden operand; K/kc/K_TILES/KT_PER_KW/K_BYTES/ KH4 runtime SSA capped by HIDDEN_MAX. The compile-time range_constexpr K-loop is now a runtime-trip scf.for carrying C/accm AND the kStages=2 one-stage-ahead B/B-scale prefetch as loop-carried fragments (CUR/NXT, last-iter prefetch clamped to the last valid K-tile). BHOIST fences + k_wave preserved. gemm2 (N-output): new i32_hidden operand; N_OUT_rt/num_n_blocks/kbs_per_expert/ N_real/bq-col/epilog-stride runtime. The runtime-inter scf.for + G2 perf stack (kstages2/BHOIST/APF/spart402) untouched, default-on. dispatcher: D_HIDDEN cache-key dim -> HIDDEN_MAX cap in get_g1/get_g2; kernels thread i32_hidden; kernel names h{dim} -> hmax{HIDDEN_MAX}; K%BK / K%k_wave checks host-side; H_DEFAULT global removed. One compiled binary serves any model_dim <= HIDDEN_MAX. Re-cleaned the port's additions to the clean-head standard (every comment <=1 line, tight LOCs); dropped the port's dead aStages self-assign and re-used the clean stream_b_tile/issue_b_load_into consolidation. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…mpile-time schedule The runtime-K gemm1 main loop (single rolled scf.for, 1 K-tile/trip) lost the compile-time straight-line, software-pipelined schedule: per-iter runtime %kAStages (magic ÷3) A-LDS slot recompute, no cross-tile load/MMA overlap (top-of-loop s_waitcnt vmcnt(0) drain), and a hot-path last-iter clamp branch. Cost is exposed at latency-bound small M (DSV3 M=8 1.38x vs compile-time; M=4096 already parity). NOT spills: runtime .vgpr_count 95 vs compile-time 148, 0 spills both. Fix: unroll the runtime K-loop by UNROLL=LCM(kAStages,2)=6. A multiple of kAStages so every A-LDS slot is a compile-time constant across the group (no ÷3, no runtime slot multiply); even so the B double-buffer parity returns to CUR at each group boundary (carry stays CUR-only). Outer trip KT/UNROLL stays runtime -> model_dim runtime, single binary. fp4 tail: statically-unrolled + per-tile runtime-guarded (in-place c_frags survive scf.if) so the remainder is mod-free too; fp8 tail keeps the rolled scf.for (accm carried SSA). ISA: ÷3 magic 2->0, no spills. DSV3 M=8 gemm1 1.38x -> 1.15x vs compile-time; M=4096 1.00x, GPT-OSS M=128 1.02x (parity). Residual ~1.15-1.28x only at extreme small M = irreducible scf.for loop-carried C+B phi at group boundaries (cannot fully unroll a runtime trip). cos: fp4 >=0.99, a8w4 0.9996, GPT-OSS pad 0.989; atomic+reduce+graph + k_wave 1/2/4 all pass; gemm2 untouched. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…dentical ISA) Light clean of the fixed-factor-unroll runtime-K gemm1 commit: collapse the three multi-line block comments to one line each, matching the PR clean standard. No logic change (comment-only diff), ISA unchanged. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…e-K loop Split v_e = (voffset + kt_abs*KH_TILE_A)//4 into voffset//4 + kt_abs*(KH_TILE_A//4). With runtime model_dim, K_BYTES makes voffset//4 divisibility unprovable, so the compiler emitted a per-tile signed-division dance recomputed 6x per UNROLL=6 group. Both terms are provably multiples of 4, so the split is exact: voffset//4 is now loop-invariant (hoisted) and KH_TILE_A//4=32 is a compile-time constant, leaving only a cheap kt_abs*32 per tile. ISA (DSV3 fp4, cold): bulk-body v_subbrev 13->0, addr-arith ops 55->21, body 376->318 lines; vgpr 132->134, 0 spills. No loop-carried phi copies exist in any variant (refutes the prior 'C+B phi floor'). gemm1 small-M recovers: fp4 M=1 1.19x->1.07x, fp4 M=8 1.12x->1.07x, a8w4 M=1 1.21x->1.12x vs compile-time; M>=128 stays at parity. fp4 cos>=0.99, a8w4 cos>=0.9996, graph replay OK. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…vide The fp8-A (a8w4) gemm1 remainder tail is a rolled scf.for (accm is raw f32x4 SSA that cannot escape an scf.if, unlike the fp4 fragment-backed C). It re-derived the A-LDS read/write slot as (full_tiles+kt_iv)%kAStages every iteration, which the compiler lowered to the s_mul_hi_i32 0x55555556 signed-div-by-3 dance (2x per tail tile) that the compile-time fully-unrolled kernel never emits. UNROLL is a multiple of kAStages, so full_tiles%kAStages==0 and the first tail slot is 0 (read) / kStages%kAStages (write). Carry both slots as scf.for loop values that advance +1 with a cheap compare-select wrap, so the tail selects the slot instead of re-computing the modulo. This is the tail analogue of the fp4 bulk address-divide hoist (6158cee). ISA (DSV3 a8w4, cold, gemm1): tail 0x55555556 slot-mod 2->0; .vgpr_count 168 (no change), 0 spills; bulk loop and prologue/epilogue byte-unchanged; gemm2 ISA md5 identical to 16e3c0a (fix is confined to gemm1_body_v2). a8w4 cos=0.9996 (thr 0.95), fp4 cos=0.9894 (thr 0.85). test_moe_gemm -k "(fp4 or a8w4)": 381 passed / 0 failed (eager); the 6 graph-capture large_shape failures are pre-existing on 16e3c0a. Note: this removes a scalar divide but does not move a8w4 small-M gemm1 latency (M=1 1.12x, M=8 1.21x vs compile-time, unchanged) -- diagnosis showed the a8w4 bulk loop already has zero un-hoisted vector divides (the fp4 hoist covered the shared A-gather), and the residual is the runtime-K rolled-loop back-edge amplified by fp8's heavier per-tile body (more/larger ds_reads + mfma_scale), not address arithmetic. fp4 unchanged (M=1 1.07x). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Collapse multi-line comment/docstring blocks added since the last cleanup (runtime-K UNROLL=6 machinery, A-gather hoists, gemm2 perf-knob notes) to one line each, and drop the redundant `aStagesC = aStages` alias in gemm2_body_v2. Strictly behavior-preserving: default gemm1 and gemm2 final ISA md5 unchanged (byte-identical to 0a0951e). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds a layout-API MXFP4 MoE gemm (a4w4 + a8w4)
🤖 Generated with Claude Code