Skip to content

feat(moe): layout-API MXFP4 (a4w4/a8w4) MoE gemm#753

Open
coderfeli wants to merge 75 commits into
mainfrom
mxfp4-moe-gemm
Open

feat(moe): layout-API MXFP4 (a4w4/a8w4) MoE gemm#753
coderfeli wants to merge 75 commits into
mainfrom
mxfp4-moe-gemm

Conversation

@coderfeli

@coderfeli coderfeli commented Jun 26, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds a layout-API MXFP4 MoE gemm (a4w4 + a8w4)

🤖 Generated with Claude Code

Add a layout-API MXFP4 MoE up/gate + down-proj gemm that consumes the
standard (opus) sort contract from moe_sorting_kernel directly -- gemm1
gathers A rows via sorted_token_ids & 0xFFFFFF; gemm2 scatters per sorted
row via global atomic add weighted by sorted_weights. No fused-sort extras
(m_indices / reverse_sorted) are needed.

  kernels/mxfp4_moe_layout.py    - layout-API building blocks (fx.copy B/B-scale,
                                   fx.gemm scaled-MFMA atoms)
  kernels/mxfp4_moe_common.py    - shared raw helpers / constants / K-derived size
                                   formulas / atomic bf16 epilogue
  kernels/mxfp4_moe_gemm1.py     - BM32 up/gate gemm (a4w4 + a8w4, interleave +
                                   separated, nt/cached, out fp4/fp8)
  kernels/mxfp4_moe_gemm2.py     - BM32 atomic down-proj (a4w4 + a8w4 fp8 input)
  kernels/mxfp4_moe_gemm_2stage.py - public API + host launchers

Wire a4w4/a8w4 of tests/kernels/test_moe_gemm.py::test_moe_gemm_2stage to the
new pipeline (opus sort -> gemm1 -> gemm2 atomic) vs an independent dequant-MoE
reference; a8w4 added to the in_dtype matrix.

Validated on gfx950: chain cosine a4w4=0.988, a8w4=1.000 (interleave + separated);
test_moe_gemm_2stage fp4/a8w4 over FP4-S/M/L pass.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@coderfeli coderfeli changed the title feat(moe): layout-API MXFP4 (a4w4/a8w4) MoE gemm (opus-sort only) feat(moe): layout-API MXFP4 (a4w4/a8w4) MoE gemm Jun 26, 2026
…oe_dispatcher

Collapse the 5 partly-duplicated MXFP4 MoE layout-API modules into 3:

  * utils.py         -- basics (shape constants, K-size formulas, pointer/LDS
                        helpers, e8m0/SwiGLU quant math); no mma/load
  * moegemm.py       -- gemm1 + gemm2 device bodies in one file, plus the shared
                        B/B-scale fx.copy + fx.gemm layout primitives, the A-LDS
                        loaders, and the atomic bf16 epilogue
  * moe_dispatcher.py-- compile_gemm1/2_a4w4_port + launch dispatch
                        (mxfp4_moe_gemm1/2, caches, gemm1_grid)

Removes the helper duplication (gemm1 previously re-defined _raw/_lds_ptr3/
_silu_mul_batch/... already in common) and makes the stack self-contained (no
imports from the v1 modules). Pure code-motion -- no behavioral change.

Deletes mxfp4_moe_{common,layout,gemm1,gemm2,gemm_2stage}.py and rewires the
test import (kernels.mxfp4_moe_gemm_2stage -> kernels.moe_dispatcher).

Validated on gfx950: test_moe_gemm_2stage layout-API path (gemm1->gemm2,
opus-sort), a4w4 + a8w4, S/M/L -- numeric gate (mean_row_cos > 0.85) passes.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
- Drop the redundant BM32 constants: N_LOAD_WAVES (only fed ROWS_PER_WAVE),
  ROWS_PER_WAVE -> BM//4, BN_INT -> inline BN//4, M_REPS -> kMChunks (both BM//16).
  Keep BM, kAStages, kSubBlocks, kMChunks.
- Condense the module docstrings and the verbose inline comments across all three
  files (keep the load-bearing ones).

Numerically identical substitutions; GPU re-run (a4w4 + a8w4, gemm1->gemm2) still
passes the numeric gate.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
coderfeli and others added 23 commits June 26, 2026 08:24
…lpers

Globals: collapse the gemm1/gemm2 value-duplicate pairs into one name each --
NE_DEFAULT/NE -> NE, K_DEFAULT/N_OUT -> H_DEFAULT (model_dim), INTER_DEFAULT/K ->
INTER_DEFAULT (inter_dim).

Funcs: the per-body issue_a_ds_read and the per-J MMA cluster were duplicated
between gemm1 and gemm2; hoist them to module-level _issue_a_ds_read_dt /
_mma_one_j and call from both. Drop the now-unused _gemm_mma closure.

Trace-time-only change (identical emitted IR); GPU re-run (a4w4 + a8w4,
gemm1->gemm2, S+M) still passes the numeric gate.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…smem)

Replace the raw flydsl.utils.smem_allocator SmemAllocator/SmemPtr usage with the
layout-API shared-memory idiom used by the main gemm kernels (preshuffle_gemm_v2,
flash_attn_gfx950):

  @fx.struct
  class SharedStorage:
      buf: fx.Array[Int8, lds_bytes, 16]
  lds = fx.SharedAllocator().allocate(SharedStorage).peek()
  lds_base_i32 = fx.Int32(fx.ptrtoint(lds.buf.ptr))

The single i8 buffer mirrors the prior union LDS (s_aq | s_asc carved from the
front; the f32 acc aliases the whole region); the bodies now take one i32 base and
derive s_aq/s_asc/acc sub-bases by offset. All raw buffer_load_lds / ds-read
addressing is unchanged. Drops the manual allocator.finalize() in the launch jit
(SharedAllocator finalizes automatically) and the SmemPtr _view_cache resets.

utils: _lds_base_ptr3(view) -> _lds_base3(base_i32); drop memref_dialect.

GPU re-run (a4w4 + a8w4, gemm1->gemm2, S+M) passes the numeric gate.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Migrate the activation gather (a), gemm1 activation-scale (ascale), and the
output/ascaleout stores of the layout-API MXFP4 MoE GEMM from raw rocdl/llvm
intrinsics to fx.copy copy-atoms, matching the existing b/bscale paths so all
global memory traffic goes through one layout-API abstraction.

Direct-to-LDS loads and global stores port 1:1 via BufferCopy / BufferCopyLDS
atoms, including the data-dependent A-gather with bounds-checked OOB-zero for
padded rows. The gemm2 ascale load (a scalar global->register feed consumed
inline by the MFMA) is kept raw: routing it through a register fragment
regresses gemm2 ~4-5%, with no perf-neutral fx.copy equivalent.

A single _flat_buf_view helper backs all four buffer views (fold/no-fold,
bounds/max-size via args). Correctness unchanged (fp4 cos~0.988, a8w4
cos~0.9996); per-shape perf at parity on the stable t=4096 shape (interleaved
A/B). Adds an MXFP4_BENCH-gated run_perftest harness; the layout-API MXFP4 path
had no built-in benchmark.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…pers

Hoist the flat buffer-tensor view and the BufferCopyLDS128b DMA atom out of
moegemm.py into fp8_gemm_utils.py as the single source of truth, removing the
duplicated copies the fx.copy migration had introduced:

- flat_buffer_view(arg, base_elems, elem_ty, *, align, elem_bytes, fold,
  num_records_bytes): the flat (1,1) buffer view from a raw i64 address
  (fold/no-fold, max-size/num-records bounds). All 5 moegemm call sites use it.
- lds_dma_atom_128(): the shared BufferCopyLDS128b atom, now also used by
  G2SLoader (was inlined identically in both).

StoreC's typed-buffer view path and G2SLoader's hardcoded-stride load loop are
left as-is: they start from typed kernel args / fixed LDS strides and do not fit
the MoE raw-address, swizzled, data-dependent-gather case.

Behavior byte-identical: cos unchanged (fp4 ~0.988, a8w4 ~0.9996), per-shape
perf at parity (interleaved A/B). fp8 gemm rowscale test still passes (G2SLoader
touched).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…format

Remove the raw-path leftovers ruff flagged after the migration: the unused
`_lds_ptr3` import and the now-dead `aq_rsrc` / `ascale_rsrc` (+ its `ascale_num`)
buffer resources and `out_row` (the fx.copy views carry this addressing). Apply
black/ruff (line-length 120). No behavior change: cos unchanged (fp4 ~0.988,
a8w4 ~0.9996).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…onflict)

main already ignores .humanize/ (plus benchmark CSV + .rocprofv3 patterns); the
migration's incidental .humanize* line conflicted with that. Net .gitignore diff
vs base is now empty.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Replace the raw llvm.load / llvm.StoreOp at the 7 scalar i32/f32 sites
(expert-id and sorted-token-id global reads; the gemm1/gemm2 f32 accumulator
LDS store + read-back) with FlyDSL's idiomatic typed-pointer indexing via two
small helpers (_global_typed_ptr / _lds_typed_ptr). Element-indexing also drops
the manual *4 byte arithmetic.

Vector LDS ds-reads, the invariant=True metadata loads (no invariant flag on
ptr_load), and the swizzled ds-read paths stay on llvm.* by design. Identical
IR: cos unchanged (fp4 ~0.988, a8w4 ~0.9996), perf parity (interleaved A/B).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…calar reads

Replace the raw llvm.load LDS reads (vec<2xi64>/vec<4xi32> A ds-reads, the
scalar A-scale ds-read, and the vec<2xf32> epilogue read-back) with a reusable
_lds_vec_load helper built on fx.ptr_load(result_type=...). Same load
instruction, same byte offsets, swizzle math untouched. Drops the now-unused
_gep3 / _lds_base3 byte-GEP helpers.

Only llvm.* left are the invariant=True metadata loads (ptr_load has no
invariant flag) and the atomic fadd accumulate (no typed-ptr equivalent).
Identical IR: cos unchanged (fp4 ~0.988, a8w4 ~0.9996), perf parity
(interleaved A/B, re-confirmed at 200 iters).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The mi355 CI job drives test_moe_gemm.py as a script with --gemm2_mode both.
For the layout-API MXFP4 path (fp4/a8w4) the reduce-mode combo calls
pytest.skip(), which outside a pytest session raises an uncaught Skipped and
crashes the runner (exit 1) despite all correctness checks passing. Catch
pytest.skip.Exception in run_one and report it as a printed skip.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Collapse the module + helper docstrings, the multi-line inline comments, and the
3-line `# ===` section banners to one/two lines across moegemm.py + utils.py.
Hard-won rationale kept tersely (B-load waterfall warning, OOB-zero, e8m0 path,
AddressSpace.Shared=2 gotcha). Comment-only; cos unchanged (fp4 ~0.988 / a8w4 0.9996).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Token-level rename (NAME tokens only; strings/comments untouched) of 86
leading-underscore identifiers across moegemm.py / utils.py / moe_dispatcher.py /
fp8_gemm_utils.py: module funcs, MoE-only helpers, and locals. Two locals that
would shadow module funcs renamed distinctly (_b_copy_atom->b_catom,
_scale_mma_atoms->mma_atoms). Kept flydsl._mlir module path. Dropped dead
_lds_ptr3/_PTR3 and the unused gemm1 asc_per_mb (surfaced by F841 post-rename).
Pure rename: cos unchanged (fp4 0.988 / a8w4 0.9997), ruff/black clean.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Replace the last raw arith ops with fx spellings; all operands are wave-uniform
non-negative values (<2^31), so signed/unsigned and zero/sign-extend are
bit-identical here:
  ExtUIOp(i64)        -> fx.Int64(x)        (col/scale base extend)
  TruncIOp(i16)       -> fx.Int16(x)        (e8m0 pair pack)
  index_cast(i32)     -> fx.Int32(thread_id)
  index_cast(index)   -> fx.Index(x)        (buffer num_records)
  cmpi(ult)+select    -> (bexp < 254).select(bexp, 254)
  divui               -> // (udiv helper)
Drops the now-unused `arith` import from both files. cos unchanged
(fp4 0.9881 / a8w4 0.9996); interleaved A/B parity held (within noise).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…m load

- Comments/docstrings collapsed to a single line across moegemm/utils/
  moe_dispatcher/fp8_gemm_utils (incl the long flat_buffer_view + module
  docstrings and the `# ===` banners); key rationale kept tersely.
- run_compiled: drop moe_dispatcher's duplicate; reuse the shared
  tensor_shim._run_compiled (aliased), folding in the ir.Context-leak cleanup
  and fixing its exe.cf -> exe._cf cache bug. Call sites pass *args.
- cumsum0: raw `llvm.load(T.i32, global_ptr1(...))` -> typed
  `global_typed_ptr(arg_cumsum, T.i32)[0]`; drop now-dead global_ptr1 +
  unused llvm/ir imports.
cos unchanged (fp4 0.9881 / a8w4 0.9997); ruff/black clean.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…pecific)

kernels/utils.py was 100% MoE-specific (created by the MoE consolidation, only
imported by the MoE files). A generic name for MoE config + kernel geometry +
quant helpers is misleading, so fold it all into moegemm.py where it's used:
- shape constants (NE/TOPK/H/INTER defaults, BN/BK/KH_TILE/kStages/...),
  K-derived size formulas, raw/pointer/LDS helpers, e8m0/SwiGLU math.
- inline the two identical *_c_k1_for wrappers (0 external users).
- moe_dispatcher now imports these from .moegemm; decouple fp8_gemm_utils by
  taking `raw` from flydsl.expr.arith (_to_raw) instead of the MoE module.
- delete kernels/utils.py.
Pure move (no logic change): cos unchanged (fp4 0.9882 / a8w4 0.9996); ruff/black clean.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…rs, dedup raw/udiv

Rename kernels/moegemm.py to kernels/mxmoe_gemm_v2.py and update the
moe_dispatcher import + fp8_gemm_utils comment.

Remove the K-/N-derived size helper defs (k_half_for, k_tiles_total_for,
kunroll_for, kbs_stride_n0_dw_for, kas_per_chunk_dw_for, num_n_blocks_for,
kbs_per_expert_dw_for, kmchunks) and inline their compile-time arithmetic at
the gemm2 + epilogue call sites, mirroring gemm1's existing inline style.

Drop the local raw()/udiv() helpers: use fx._to_raw (imported as _raw) and
plain // (signed floordiv, matching the old udiv).

Correctness unchanged: tests/kernels/test_moe_gemm.py fp4+a8w4 subset
69 passed / 30 skipped (gfx950), identical to baseline.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…rs, dedup raw/udiv (#765)

Rename kernels/moegemm.py to kernels/mxmoe_gemm_v2.py and update the
moe_dispatcher import + fp8_gemm_utils comment.

Remove the K-/N-derived size helper defs (k_half_for, k_tiles_total_for,
kunroll_for, kbs_stride_n0_dw_for, kas_per_chunk_dw_for, num_n_blocks_for,
kbs_per_expert_dw_for, kmchunks) and inline their compile-time arithmetic at
the gemm2 + epilogue call sites, mirroring gemm1's existing inline style.

Drop the local raw()/udiv() helpers: use fx._to_raw (imported as _raw) and
plain // (signed floordiv, matching the old udiv).

Correctness unchanged: tests/kernels/test_moe_gemm.py fp4+a8w4 subset
69 passed / 30 skipped (gfx950), identical to baseline.

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
topk only feeds gemm1_grid host-side block sizing and never enters the
compiled kernel (not in the kernel name, not in the body). Removing it from
the get_g1 cache key and compile signature collapses all topk values for a
given shape onto a single compiled kernel.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
NE is host-only: it caps the active-expert count in gemm1_grid but never
enters either compiled body (it was an unused kwarg on gemm1_body_v2 /
gemm2_body_v2 and an ne{NE} segment in both kernel names). Remove it from the
kernel names, the get_g1/get_g2 cache keys, and the compile signatures so all
expert counts for a given shape share one compiled kernel. Public
mxfp4_moe_gemm1/2 keep NE for the host grid calc.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
inter_dim is the gemm2 down-proj contraction. It was baked into the kernel
name (i{INTER}) and the get_g2 cache key, forcing one compiled kernel per
inter_dim. Convert it to a runtime fx.Int32 (i32_inter):

- The contraction K-loop becomes an scf.for over a runtime trip count
  (K_TILES = inter_dim/256), carrying the C accumulators (fp4 fragments /
  fp8 accm) as loop-carried state.
- B and B-scale are now streamed one K-tile per iteration into fresh
  fragments instead of preloading all K_TILES tiles into registers; the
  bq_view/bscale_view strides are K-independent constants so only the
  wave-uniform base offsets become runtime. The view shape K-axis and the
  triple-buffered A LDS are bounded by a compile-time cap INTER_MAX.
- A->LDS is streamed inside the loop for K_TILES > kStages.

gemm2 kernel name now carries imax{INTER_MAX} (not i{INTER}); get_g2 keys on
INTER_MAX, so all inter_dim values <= cap share one compiled kernel. The
public mxfp4_moe_gemm2 API is unchanged.

Perf (gemm2, median-of-5, gfx950): fp4 t8192/7168,512 909->901us;
fp4 t16384/5120,1536 1042->994us; a8w4 t8192/7168,512 922->905us;
a8w4 t16384/5120,1536 1855->1070us (streaming B removes the
register-resident-B spill at larger K). Cold correctness unchanged
(69 passed).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
inter_dim is the gemm1 up/gate-proj N-output dim. It has no contraction
K-loop dependency (the K-loop is hidden_dim), so converting it to a runtime
fx.Int32 (i32_inter) only turns N-block sizing, B/B-scale base offsets, and
the epilogue output strides (N_OUT, NUM_N_BLOCKS, kBS_per_expert_dw,
OUT_AS_PER_CHUNK_DW, K_G2_BYTES) into runtime i32 ops. gemm1 LDS and the
B-view K-axis depend only on hidden_dim, so no compile-time cap is needed.

The gemm1 kernel name drops i{INTER}; get_g1 drops inter_dim from its cache
key. Together with the gemm2 conversion, inter_dim is now fully runtime:
all inter_dim values share ONE gemm1 and ONE gemm2 compiled kernel. The
public mxfp4_moe_gemm1 API is unchanged.

Perf (median-of-5, gfx950) at parity vs base 9f359ee: gemm1 fp4
t8192/7168,512 777->736us, fp4 t16384/5120,1536 ~2527us (base ~2527 stable;
the earlier 2314 was a transient high-clock sample), a8w4 ~870/~2643us.
Cold correctness unchanged (69 passed).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
coderfeli and others added 30 commits July 1, 2026 07:58
…ipe_config

Extend the per-(shape,token) dispatcher to emit (BM, epilog, bm_stage1, persist):
- bm_stage1=128 for high-expert small-inter families (DSV3/Kimi/DSV4) at large M
  (>=4096 tok, allow_bm128-gated); stage2 tile stays <=64 (BM128 is stage1-only,
  regresses stage2/GPT-OSS per F1).
- persist ON for those same high-expert large-M families (F2 +5-17%); OFF for
  GPT-OSS (byte-identical one-shot grid).
Harness uses BM_S1 for gemm1 + A-scale shuffle, BM for gemm2, and a shared
SBM=lcm(BM_S1,BM). Default path (BM_S1==BM==32) stays byte-identical.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The fp8-A (a8w4) gemm2 persist path is a known-broken F2 combo: the multi-iteration
grid-stride corrupts the fp8-A accumulator/LDS state and yields cos=0 at large M
(reproduces on rlcr/moe-persist-sbm alone; fp4 persist is correct to 32768 tok).
- Dispatcher: drop DeepSeekV4 (a8w4) from _PERSIST_MIN_TOK so select_pipe_config
  never emits persist for the fp8-A path (DSV3/Kimi fp4 keep persist).
- compile_gemm2_a4w4_port: raise a clear AssertionError on a8w4+persist so the
  manual knob can't silently ship garbage. Re-enable once fp8-A persist is fixed.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Land the consolidated MXFP4 MoE pipe (swiglu/reduce/valid_mask/graph,
cached-B, BM{16,32,64,128}, persist fenced for fp8-A, SBM, per-token CSV
dispatch, tok32768 crash fixes, a8w4-BM64 atomic fix, arith.minimumf) onto
the PR head, which had upstreamed the moegemm->mxmoe_gemm_v2 rename + helper
dedup via #765.

Conflict resolution (3 files):
- fp8_gemm_utils.py: kept PR-head/main (identical) version verbatim; #765
  moved flat_buffer_view + lds_dma_atom_128 out of this module.
- mxmoe_gemm_v2.py / moe_dispatcher.py: took our superset (already deduped:
  _raw, //, inline K//BK), which contains both the #765 rename/dedup and all
  our features. Verified main touched none of these 3 files (mxmoe/dispatcher
  are PR-only; fp8_gemm_utils identical in main vs PR head), so no main
  content is dropped.
- Relocated flat_buffer_view + lds_dma_atom_128 into mxmoe_gemm_v2.py
  (matching #765's placement) and dropped the fp8_gemm_utils import; added
  `arith` to mxmoe imports for flat_buffer_view.

arith.py + test_moe_gemm.py applied cleanly.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…1 byte-identical)

Repartition the 256-thread (4-wave) block as num_n_waves x k_wave: each K-wave
group processes K/k_wave of the model_dim contraction into its own A-LDS region
and cshuffle slab, then the k_wave partials are summed in LDS before the shared
silu+quant epilogue. Threaded through compile_gemm1_a4w4_port/get_g1/
mxfp4_moe_gemm1 as a compile-time cache-key dim with a _kw{N} name suffix.

Guards (aiter-parity): k_wave in {1,2,4}, fp4-only, interleave-only, K/BK %
k_wave == 0, 4*eff_tile_n <= tile_k. Default k_wave=1 is compile-gated to emit
byte-identical IR (final-ISA md5 verified equal for BM16 and BM32).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
… real LDS-fits guard

- s_asc_base must sit past ALL k_wave A-staging regions (was overlapping region 1
  at kw>1, causing NaN); kw=1 unchanged (byte-identical re-verified equal md5).
- Replace aiter's per-group 4*tile_n<=tile_k scratch guard (inapplicable to this
  full-[BM,BN]-slab reduction) with an actual gfx950 160KB LDS-fits check.

k_wave is correct (cos>0.99 at DSV3 fp4 m=2,4,8 for BM{16,32} kw{1,2,4}) and
byte-identical at the default kw=1, but a PERF NO-GO on the BN=256 v2 pipe (see
.humanize/kernel-agent/kwave.md): DSV3-fp4 small-M is block-count bound so
intra-block K-slice adds no occupancy and is monotonically slower.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
… coverage

Parameterize the gemm1 fused gate|up N-tile BN in {64, 256} (compile-time cache-key
dim, _bn{N} name suffix). BN=64 gives N_OUT//64 = 4x more N-blocks than BN=256, the
block-count coverage lever for tiny-M DSV3 fp4 (where pure k_wave at BN=256 was a
NO-GO because small-M is block-count bound). Pairs with k_wave: BN=64 requires
num_n_waves<=2 (each N-wave needs an even NJ>=2 gate+up pair), i.e. k_wave in {2,4};
the tiny-M fix is BN=64 + k_wave=4 (nnw=1).

Parameterized sites (BN=256 compile-gated to stay byte-identical, AC-3 verified):
- NUM_N_BLOCKS = N_OUT//BN; grid num_n_blocks = N_OUT//BN (dispatcher + kernel).
- requant gate/up read: split at BN//2; col-group == n_lane (gate cols n_lane*8+ee).
- aqout / ascaleout stores predicated on n_lane < BN//16 (BN=64: wave_grp==0) so the
  shrunk tile never writes a neighbouring n_block.
- ascaleout physical layout rewritten in terms of the ABSOLUTE 32-INTER-col scale
  group g = n_block_idx*(BN//64)+wave_grp: ku=g>>3, ikxdl=(g>>2)&1, lane=(g&3);
  reduces exactly to the old literals at BN=256, and at BN=64 (g==n_block_idx) yields
  the IDENTICAL physical output layout gemm2 consumes (scale-group boundary == BN=64
  n_block boundary). aqout byte layout n_block*(BN//4)+n_lane*4 is BN-independent.
- lds_bytes_for(BN=...) sizes the [BM,BN] cshuffle slab.

Guards: BN in {64,256}; BN!=256 is fp4-only + interleave-only + NJ even>=2.
Threaded through compile_gemm1_a4w4_port/get_g1/mxfp4_moe_gemm1 (+ MXFP4_BN test env).

AC-3: BN=256 default final-ISA md5 EQUAL to baseline (BM16 a3318f65..., BM32 d472bdb7...).
AC-1: cold cos>0.99 at DSV3 fp4 m=2,4,8 for BM{16,32} x kw{2,4} at BN=64.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…expert fp4)

select_pipe_config now returns (BM, epilog, bm_stage1, persist, bn, k_wave):
the high-expert small-inter fp4 families (DSV3 E257, Kimi E384) are block-count
bound at tiny M, so m<=2 selects BN=64 (N_OUT//64=8 N-blocks, 4x coverage) +
k_wave=4 (nnw=1) for the ~1.5x gemm1 win; BN=256+kw1 (byte-identical default)
otherwise (crossover at m~=3-4). gemm2 is BN-independent (unchanged). Manual
MXFP4_KW / MXFP4_BN env overrides preserved under MXFP4_DISPATCH=1.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…SBM resolution

Cleanup (behavior-preserving, AC-3 byte-identical-default verified fp4+a8w4):
- Remove inline_quant (always False; only existed to be rejected) from the gemm1
  compile/get_g1/mxfp4_moe_gemm1 chain + cache key + test call.
- Remove D_INTER_REAL (always None; only existed to be rejected) from the gemm2
  compile/get_g2/mxfp4_moe_gemm2 chain + cache key.
- Extract _norm_sbm(SBM, BM) for the 5 duplicated SBM None->BM normalizations.
SBM (SBM!=BM) is KEPT: it is required by the BM128-stage1 dispatch path (SBM must
be lcm(bm_stage1, BM)); not dead.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…default byte-identical

Add a compile-gated has_pad variant to the v2 mx MoE gemm1/gemm2 that saves B-weight
bandwidth on padded contraction dims via hardware OOB-load-skip (aiter-style), NOT the
prior whole-tile compute-skip (a no-op for GPT-OSS pad=192 < BK=256).

Mechanism: size the per-16N-tile B-weight buffer resource (bq_view) to the REAL K extent
(K - kpad) so the fully-pad 128-K weight half loads of the tail tile buffer-load
OOB -> 0 (no HBM transaction). The K axis is K-major and the half stride (256 i32)
exceeds every within-half sub-offset (255), so the num_records cut lands exactly on a half
boundary: zeros the fully-pad half, keeps the partial-pad half (host zero-fill). B-scale is
NOT shrunk (256-K granular + host 256-align; sub-256 pad saves 0 and risks NaN).

- gemm1 K=D_HIDDEN (kpad=model_dim_pad); gemm2 K=inter_dim (kpad=inter_dim_pad), both runtime.
- has_pad=False (default): no i32_kpad kernarg, pad math const-folds away. Verified
  gemm1+gemm2 GPT-OSS-3072 atomic final-ISA md5 UNCHANGED vs a5bb72a (AC-3).
- has_pad=True: distinct compile variant (_pad name tag) with the runtime i32_kpad kernarg;
  shared kernel body extracted to @flyc.jit _gemm{1,2}_kernel_body so the AST rewriter
  recurses into its scf dispatch.
- run wrappers: mxfp4_moe_gemm1(model_dim_pad=), mxfp4_moe_gemm2(inter_dim_pad=).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Thread real (unpadded) dims through test_moe_gemm_2stage + run_mxfp4_moe_2stage and the
CLI (--real_dim model,inter, --garbage_pad). The has_pad path host-zero-pads the fp32
weight/A/intermediate pad regions and passes model_dim_pad/inter_dim_pad so the kernel
sizes the B-weight buffer resource to the real K (OOB-skips the fully-pad weight halves);
the reference is over the full (zero-in-pad) tensors == a real-dim GEMM.

--garbage_pad writes 1e30 into ONLY the fully-pad weight-contraction halves the OOB skip
drops (partial-pad remainder in a kept half stays host-zero). A correct output then proves
the OOB skip never fetches the padded weights into the accumulation.

Verified (cold, gpt-oss real2880/pad192, t=256 e=128 k=4):
  fp4  cos 0.9892 (garbage 0.9893), a8w4 cos 0.9996 (garbage 0.9996)
matching the naive padded-3072 baseline (fp4 0.9894).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Add the gemm1 fused gate|up N-output pad-skip to the has_pad variant: skip loading
w1 weight rows for the padded inter-N columns (both gate and up halves) via a per-16N-tile
buffer-OOB skip. Additive to the existing K-pad OOB skip.

Mechanism: each 16-N weight tile is an independent buffer resource (its N column is the
resource BASE offset, not an in-view axis), so the K-axis num_records trick does NOT compose
on N. Instead, per tile, compute its LOGICAL inter col base (interleave gate|up: gate cols
[0,BN//2) and up cols [BN//2,BN) of each 256-block both map to inter [n_block*(BN//2),+BN//2))
and, when that base >= INTER_real (= i32_inter - i32_npad), size the tile buffer to 0 records
so every weight load OOB -> 0 (no HBM fetch). Since INTER_real is 16-aligned, every 16-N tile
is fully real or fully pad -> no partial-tile epilogue pairing problem: a skipped tile writes 0
into exactly its own gate-or-up output column (verified: per-j bq_frags -> per-J c_frags ->
cshuffle lds_col are index-preserving, no cross-contamination), and B-scale (32-N granular,
shared, unshrunk) is harmless because 0*scale=0.

CORRECTNESS-SAFE with NO epilogue change: the pad-N output feeds gemm2's pad-K input, which
gemm2 already OOB-skips (K-skip on this branch), so zeroed pad-N output is never read.

- gemm1 threads a new runtime i32_npad kernarg (= inter_dim_pad) alongside i32_kpad, both only
  in the has_pad variant. has_pad now enables on model_dim_pad>0 OR inter_dim_pad>0.
- run wrapper mxfp4_moe_gemm1(inter_dim_pad=).
- has_pad=False (default): no i32_npad kernarg, N-skip math emitted only under const_expr(has_pad);
  the default gemm1  expression keeps its exact original add-associativity. Verified default
  gemm1 AND gemm2 GPT-OSS-3072 final-ISA md5 UNCHANGED vs base b5dfe2b (AC-3):
    gemm1 fb8141b6b6f3aee1c91f12a384392e29, gemm2 a2a7f6defe7a7f988625bc1e2424a269.
- test: --garbage_pad now poisons w1's pad-N weight ROWS (N-skip drops them); thread inter_dim_pad
  through run_mxfp4_moe_2stage to gemm1.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…e|up)

The initial N-skip used col_in_block<BN//2 to split gate vs up, but interleave mode
selects gate/up by j PARITY (in_b=J%2, mfma_cluster) with n0 index J//2 -- gate and up
are interleaved at 16-col granularity, NOT a contiguous [0,128)|[128,256) split. The
wrong predicate skipped real tiles / kept pad tiles, regressing cos (fp4 0.9893->0.9835,
a8w4 0.9996->0.9941). Corrected to match the cshuffle:
  logical_inter = n_block*(BN//2) + wave_n*gate_span + (j//2)*16, gate_span=(BN//2)//nnw.
A gate tile (j even) and its sibling up tile (j+1) share this base -> skipped consistently.

Restores correctness (cold, gpt-oss real2880/pad192, t256 e128 k4):
  fp4  structural cos 0.9893, garbage-N cos 0.9893 (== structural, N-skip proven)
  a8w4 structural cos 0.9996, garbage-N cos 0.9996 (== structural)
matching the K-skip-only baseline. Default (no-pad) gemm1 final-ISA md5 UNCHANGED
(fb8141b6b6f3aee1c91f12a384392e29 == base; AC-3).

test: fix the garbage-N-pad harness. Poison ONLY w1's pad-N weight ROWS (1e30) and hold
w2's pad-K contraction cols at 0; the reference now sums over the REAL inter extent (rI =
INTER - inter_dim_pad) so the huge pad rows never enter the (finite) reference act. A
full-INTER reference overflowed act on the 1e30 rows -> ref NaN (the kernel out was always
finite/correct). rI==INTER when inter_dim_pad==0 (byte-identical full reference).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Formatting only, inside const_expr(has_pad); default gemm1 ISA md5 unchanged
(fb8141b6b6f3aee1c91f12a384392e29). black --line-length 120 clean; ruff clean.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
At M with <=1 stage1 m-block per active expert (small/mid M, e.g. GPT-OSS
M<=1024), each expert's w1 is single-use: caching it in L2 only allocates lines
that are never re-hit (measured L2 hit ~3% either way) while paying the
non-streaming cost. Stream (non-temporal) instead.

gemm1_use_nt(experts, topk, tokens, bm_stage1) returns the policy; the harness
uses it under MXFP4_DISPATCH=1. Byte-identical: only flips the B-load coherence
flag, HBM read bytes unchanged (605 vs 617 MB EA0 proxy, flyprof PMC). The
shipped non-dispatch default stays use_nt=False (cached).

GPT-OSS 3072/3072 E128 k4 gfx950 gemm1, median-of-3+ (run_perftest device time):
  M=128  a4w4 208->181us (-13%, cos 0.9909), a8w4 209.5->182.8us (cos 0.9996)
  M=1024 217->192us (-12%)   [<=1 block/expert: NT]
  M=2048 236->249us, M=4096 309->385us  [reuse: cached kept]
Crossover at m-blocks/expert==1, matching the existing "nt only helps when there
is no reuse" comment.

This closes the bulk of the apparent gemm1 gap vs aiter cktile a16w4 stage1:
apples-to-apples on aiter's own rocprofv3 kernel trace (stage1 median 181.4us,
min 173.6us -- the prior 164.5us wall baseline is not reproducible), our NT
gemm1 (a8w4 ~184us, dtype-matched to aiter's F8xMXF4) is at parity ~1.01x, both
at ~75-76% of gfx950 peak HBM.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
gemm2's w2 down-proj load was still CACHED at no-reuse M (the gemm1 NT fix
only touched gemm1's B). At M with <=1 stage2 m-block per active expert
(small/mid M, e.g. GPT-OSS M<=1024) each expert's w2 is single-use, so caching
it in L2 only allocates lines that are never re-hit while paying the
non-streaming cost. Stream (non-temporal) instead.

gemm2_use_nt(experts, topk, tokens, bm_stage2) mirrors gemm1_use_nt exactly
(reuse metric = m-blocks per active expert, nt when <=1); the harness uses it
under MXFP4_DISPATCH=1, and MXFP4_G2_NT forces the policy for measurement.
Byte-identical: only flips the w2-load coherence flag, HBM read bytes and
correctness unchanged; the shipped non-dispatch default stays use_nt=False
(cached, no _nt kernel tag).

GPT-OSS 3072/3072 E128 k4 M=128 gfx950, identical-work parity harness,
median-of-5 (run_perftest device time):
  a4w4 gemm2 114.6 -> 104.2us (-9.1%, cos 0.9910)
  a8w4 gemm2 121.0 -> 112.8us (-6.8%, cos 0.9996)
Closes ~half the gemm2 gap vs aiter cktile a16w4 stage2 (92.3us): a4w4
1.24x -> 1.13x, a8w4 1.30x -> 1.22x. Post-NT residual is a pure weight-BW-
efficiency gap at identical 551 MB weight traffic, dominated by occupancy
(ours 32KB LDS -> 5 blk/CU vs aiter 16KB -> 10 blk/CU); see
.humanize/kernel-agent/gptoss-gemm2-rootcause.md.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Mirror gemm1's N-skip into gemm2_body_v2: gemm2's N dimension is the
model_dim output, and its w2 N-tiles map 1:1 to model_dim output columns
(make_bq_view `col`). Add i32_npad, compute N_real = N_OUT - npad under
const_expr(has_pad), and gate each 16-N weight tile's num_records to 0 when
col >= N_real via (col < N_real).select(bq_num_records, 0). Fully-pad
model_dim output tiles then load OOB -> 0 (no HBM fetch).

PERF-ONLY: the pad model_dim output columns are unused (sliced off by the
real_model_dim reference), so correctness holds even without the skip; this
only drops the wasted w2 fetch.

Thread i32_npad through the gemm2 compile/dispatch path (kernel body,
has_pad kernel/launch signatures, mxfp4_moe_gemm2 model_dim_pad + pad_args)
mirroring gemm1's kpad/npad threading. has_pad now enables on
inter_dim_pad>0 OR model_dim_pad>0.

Default path (has_pad=False) is byte-identical: all new logic under
const_expr(has_pad); verified final ISA/LLVM-IR/MLIR md5 unchanged.

test_moe_gemm: under garbage_pad, poison w2[:, real_H:, :] = 1e30 instead
of structural 0, and slice the reference/output to :real_H so a correct cos
PROVES the N-skip never fetched the garbage w2 output rows.

Validation (gfx950, cold, cache disabled):
- garbage-pad cos: fp4 0.9893, a8w4 0.9996 (thr 0.85/0.95)
- default no-pad ISA/IR/MLIR md5 byte-identical
- GPT-OSS same-input harness gemm2: 120.0 -> 116.3 us (-3.1%)

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Port gemm1's kStages=2 B-prefetch double-buffer into gemm2's runtime scf.for
K-loop, opt-in via MXFP4_G2_KSTAGES=2 (default 1 = byte-identical).

gemm2 was 1-deep: B weight + B-scale loaded SYNCHRONOUSLY per K-tile on the
compute path (stream_b_tile). gemm2 is HBM-weight-bound, so prefetching the next
K-tile's B one step ahead hides the weight-load latency.

Implementation: since the runtime scf.for cannot index a python stage list by the
runtime kt, a rotating single-buffer carries the prefetched B-weight (i32x4) +
B-scale (i32x1) fragment VALUES through the loop state (alongside the existing
C/accm carry). Each iteration consumes the carried "current" B and prefetches the
next tile's B; the yield rotates prefetch->current. A prologue loads tile 0.
sched_barrier/s_setprio fence the MFMA chain from the B vmem loads, mirroring
gemm1_body_v2's main-loop fencing.

Default (g2_kstages=1) is proven byte-identical: all 20 IR stages + final binary
md5 match base for both gemm2 (default) and gemm1; the 2-stage variant carries a
distinct _g2ks2 kernel tag and cache key.

Measured (GPT-OSS M=128, NT gemm2, median-of-5, same_input_parity):
  a4w4 gemm2: 105.4us -> 98.7us (-6.4%); 1.14x -> 1.07x vs aiter 92.3us; cos 0.9910
  a8w4 gemm2: 109.8us -> 100.7us (-8.3%); 1.19x -> 1.09x vs aiter; cos 0.9996
VGPR a4w4 80->126, a8w4 98->146; NO spills; LDS 32768 unchanged (still LDS-bound,
so this stacks with the epilog-LDS/occupancy lever). gemm1 ISA + stage1 unchanged.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
ATT (rocprofv3 advanced thread trace) of the isolated GPT-OSS gemm2 at M=128
shows ours-BEST (BM32+reduce+G) is weight-VMEM-bound (83% VMEM stall) like aiter
(79.5%), but pays an extra ~5pp in barrier idle (11.3% vs 6.6%): the single
per-K-iteration gpu.barrier() guarding the A-LDS ring absorbs per-wave VMEM-load
latency imbalance. BM16-atomic (169us, VMEM-load back-pressure 45%) confirms the
pipeline-underfeed root cause for our wide 16x16x128 scaled-fp4 MFMA.

Fix: with the g2_kstages==2 B pipeline, the next-tile B weight+scale prefetch is a
GMEM->register stream with no LDS dependency, so it need not be gated by the LDS
barrier. Hoisting it ABOVE the barrier (opt-in g2_bhoist, env MXFP4_G2_BHOIST=1)
puts the long-latency weight loads in flight before the block syncs, overlapping
the barrier wait. ATT after: VMEM-wait 52.6%->34.9%; total stall 10.46M->10.17M.

a4w4 gemm2 97.8->95.6 us (-2.2%), 5.67->5.80 TB/s, ratio vs aiter 1.055x->1.028x.
a8w4 neutral (heavier A path, not VMEM-wait-bound). cos unchanged (a4w4 0.9910,
a8w4 0.9996). Default (flag off) byte-identical: g2ks2 ISA md5 identical
(18b5e4d7...); gemm1 untouched. Analysis in gemm2-att-analysis.md.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…vs aiter)

ISA-level diff of the 4 gemm2 memory-access classes (weight-B, B-scale, A-scale,
A, output) vs aiter cktile a16w4 MoeFlatmmKernel. Weight-B (93% of traffic) is
already instruction-for-instruction identical (buffer_load_dwordx4/128b/nt, same
per-unit-N count + coalescing, B prefetched 1 K-tile ahead above the barrier via
g2ks2+bhoist). The one under-issuing class: the A-scale was loaded SYNCHRONOUSLY
at the loop head and gated the whole MFMA cluster on its full VMEM latency with
s_waitcnt vmcnt(0) (also draining the fresh B prefetch), whereas aiter prefetches
its scale a K-block ahead into a register and never stalls the MFMA on it.

Fix (deepdiff lever C): opt-in g2_ascale_pf prefetches the A-scale one K-tile
ahead through the existing g2ks2 scf.for carry (rotating single-buffer, same as
the B carry). ISA: the loop-head buffer_load_dword + vmcnt(0) is gone; the MFMA
cluster starts on the already-resident prefetched scale.

- Opt-in g2_ascale_pf param on gemm2_body_v2 (default False = byte-identical),
  env MXFP4_G2_ASCALE_PF=1, gemm2 dispatcher cache key + _apf name tag. No-op
  unless g2_kstages==2. gemm1 untouched.
- a4w4 97.8 -> 97.1 us (-0.7%), a8w4 102.8 -> 101.9 us (-0.9%) (median-of-5,
  isolated, cold, same GPU3 session; aiter 95.24 same session -> ratio
  1.027x -> 1.020x a4w4). Tighter variance both dtypes.
- cos unchanged: a4w4 0.9910 (>0.85), a8w4 0.9996 (>0.95).
- Byte-identical default (AC-3): apf-off ISA md5 4fb7bad1... == pre-change
  (empty diff); g2_kstages=1 shipped default untouched.

Report: .humanize/kernel-agent/gemm2-mem-pattern-align.md

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…el balance, replaces xcd)

Port aiter's ACTUAL, XCD-count-INDEPENDENT gemm2 block->tile mapping — the
GemmSpatiallyLocalTilePartitioner grouped 2D rasterization — as our gemm2
block->(m_block_idx, n_block_idx) map, and DROP the non-portable explicit-8-XCD
swizzle (g2_xcd) as the channel-balance mechanism.

KEY FINDING: aiter's shipped gemm2 instantiates the partitioner with GroupNum=1,
M01=1 (moe_cktile2stages_common.cuh:58-59), which mathematically DEGENERATES to
the exact naive m-major linear map we already use (verified by hand + numeric
replay of GetOutputTileIndex, gemm_tile_partitioner.hpp:274-360). So the
partitioner grouping is NOT what balances aiter's channels for this instance; a
verbatim port would be a no-op. The productive port is the GENERAL
GetOutputTileIndex parameterized by (GroupNum,M01), with the grouping ENABLED.

- _spart_output_tile_index(): faithful DSL port of GetOutputTileIndex's
  else-branch (M0=total_m_blocks runtime; N0=num_n_blocks/GroupNum/M01
  compile-time). Bijection over [0, M0*N0) (no dropped/dup tiles), no hard-coded
  XCD count.
- g2_spart opt-in param (env MXFP4_G2_SPART), encoded GroupNum*100+M01 (e.g. 402);
  default 0 = off = byte-identical naive linear grid. Threaded through the gemm2
  dispatcher cache key + _spart{G}x{M01} kernel-name tag. One-shot grid path only.
  gemm1 and mxmoe_gemm_v2.py untouched.

MEASURED (rocprofv3 TCC_EA0_RDREQ, 128 channels; a4w4; GPU3 cold same session):
GroupNum=4,M01=2 (MXFP4_G2_SPART=402) is the winner — per-channel CV 9.68% ->
0.32% (BEATS prior xcd4's 0.64%, approaches aiter's 0.18%), max/min 1.28x ->
1.01x, total HBM reads 4.86M -> 4.50M (aiter parity 4.44M, was 9.5% over). The
portable spatial partitioner MATCHES-OR-BEATS the xcd swizzle on channel balance.

Perf (gemm2 isolated, median-of-5, cold, same GPU3 session, ~12% hotter than the
memalign doc): a4w4 109.4 -> 108.3 us (-1.0%), a8w4 112.0 -> 109.9 us (-1.9%);
aiter same-session 95.3 us.

Correctness (cold, real 2880 dims, spart402): a4w4 cos=0.9910 (>0.85), a8w4
cos=0.9996 (>0.95) — identical to baseline (bijective permutation, same math).
Byte-identical default: gemm2 default kernel ISA md5 8263440a... identical with
spart code present vs stashed (no _spart tag, no extra IR when off). Python style
gate passes.

Doc: .humanize/kernel-agent/gemm2-spatial-partitioner.md.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…n up

Wire the PR #753 gemm2 perf stack ON by default (previously opt-in env flags),
and trim the now-redundant exploration commentary / opt-in scaffolding.

Defaults flipped (env vars remain optional overrides, explicit args still win):
  MXFP4_G2_KSTAGES   1 -> 2   (2-stage B weight/scale prefetch)
  MXFP4_G2_BHOIST    0 -> 1   (hoist B prefetch above the LDS barrier)
  MXFP4_G2_ASCALE_PF 0 -> 1   (A-scale prefetch one K-tile ahead)
  MXFP4_G2_SPART     0 -> 402 (GroupNum=4,M01=2 spatial tile partitioner; HBM
                               channel balance, replaces the dropped g2_xcd)

gemm2-only: gemm1 default ISA is byte-identical (md5 b4a1f0d5...428cff376ce039bb894ad50d
matches 7d40f77). DSV3 large-M uses the persist grid (spart bypassed) and the
compute-bound large-M shapes are neutral, so no shape-gating is needed.

Measured (GPT-OSS M=128 gemm2, cold, same session): 115.6 -> 103 us (-11%) vs
opts-off. Correctness held (fp4>0.85, a8w4>0.95 thresholds unchanged).

Cleanup: collapse the four per-knob resolution blocks + duplicated `import os`,
condense the knob docstrings/tag comments and the K-loop inline commentary in
gemm2_body_v2. No kernel math changed.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Aggressive readability cleanup of the PR #753 gemm code, strictly behavior-
preserving. Collapse every multi-line/multi-paragraph comment and docstring
essay to a single terse line (or delete pure restatements), drop leftover
opt-in rationale now that the gemm2 perf stack (G/BHOIST/APF/spart402) is
default-on, and remove dead code (self-assignment `aStages = aStages`, a
redundant local `import os`, duplicated stream_b_tile/issue_b_load_into body).

No kernel math, tile logic, dispatch decisions, numerics, or env overrides
changed. Verified: the default gemm1 (cached + nt) and gemm2 (atomic + reduce)
final ISA are byte-identical (md5) to 5856379; fp4/a8w4 stage2-standalone and
fp4/a8w4 e2e 2stage correctness pass cold, cos thresholds unchanged.

full-line comments: mxmoe_gemm_v2 244->99, moe_dispatcher 166->39
LOC: mxmoe_gemm_v2 1583->1417, moe_dispatcher 1421->1205

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…e keys)

Fold the runtime-HIDDEN (model_dim) port onto the cleaned PR #753 head so ALL
MoE dims are runtime, keeping the clean-head hygiene (<=1-line comments, no dead
code, merged stream_b_tile/issue_b_load_into).

gemm1 (contraction K): new i32_hidden operand; K/kc/K_TILES/KT_PER_KW/K_BYTES/
KH4 runtime SSA capped by HIDDEN_MAX. The compile-time range_constexpr K-loop is
now a runtime-trip scf.for carrying C/accm AND the kStages=2 one-stage-ahead
B/B-scale prefetch as loop-carried fragments (CUR/NXT, last-iter prefetch clamped
to the last valid K-tile). BHOIST fences + k_wave preserved.

gemm2 (N-output): new i32_hidden operand; N_OUT_rt/num_n_blocks/kbs_per_expert/
N_real/bq-col/epilog-stride runtime. The runtime-inter scf.for + G2 perf stack
(kstages2/BHOIST/APF/spart402) untouched, default-on.

dispatcher: D_HIDDEN cache-key dim -> HIDDEN_MAX cap in get_g1/get_g2; kernels
thread i32_hidden; kernel names h{dim} -> hmax{HIDDEN_MAX}; K%BK / K%k_wave checks
host-side; H_DEFAULT global removed. One compiled binary serves any model_dim
<= HIDDEN_MAX.

Re-cleaned the port's additions to the clean-head standard (every comment <=1
line, tight LOCs); dropped the port's dead aStages self-assign and re-used the
clean stream_b_tile/issue_b_load_into consolidation.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…mpile-time schedule

The runtime-K gemm1 main loop (single rolled scf.for, 1 K-tile/trip) lost the
compile-time straight-line, software-pipelined schedule: per-iter runtime %kAStages
(magic ÷3) A-LDS slot recompute, no cross-tile load/MMA overlap (top-of-loop
s_waitcnt vmcnt(0) drain), and a hot-path last-iter clamp branch. Cost is exposed at
latency-bound small M (DSV3 M=8 1.38x vs compile-time; M=4096 already parity). NOT
spills: runtime .vgpr_count 95 vs compile-time 148, 0 spills both.

Fix: unroll the runtime K-loop by UNROLL=LCM(kAStages,2)=6. A multiple of kAStages so
every A-LDS slot is a compile-time constant across the group (no ÷3, no runtime slot
multiply); even so the B double-buffer parity returns to CUR at each group boundary
(carry stays CUR-only). Outer trip KT/UNROLL stays runtime -> model_dim runtime,
single binary. fp4 tail: statically-unrolled + per-tile runtime-guarded (in-place
c_frags survive scf.if) so the remainder is mod-free too; fp8 tail keeps the rolled
scf.for (accm carried SSA). ISA: ÷3 magic 2->0, no spills.

DSV3 M=8 gemm1 1.38x -> 1.15x vs compile-time; M=4096 1.00x, GPT-OSS M=128 1.02x
(parity). Residual ~1.15-1.28x only at extreme small M = irreducible scf.for
loop-carried C+B phi at group boundaries (cannot fully unroll a runtime trip). cos:
fp4 >=0.99, a8w4 0.9996, GPT-OSS pad 0.989; atomic+reduce+graph + k_wave 1/2/4 all
pass; gemm2 untouched.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…dentical ISA)

Light clean of the fixed-factor-unroll runtime-K gemm1 commit: collapse the three
multi-line block comments to one line each, matching the PR clean standard. No logic
change (comment-only diff), ISA unchanged.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…e-K loop

Split v_e = (voffset + kt_abs*KH_TILE_A)//4 into voffset//4 + kt_abs*(KH_TILE_A//4).
With runtime model_dim, K_BYTES makes voffset//4 divisibility unprovable, so the
compiler emitted a per-tile signed-division dance recomputed 6x per UNROLL=6 group.
Both terms are provably multiples of 4, so the split is exact: voffset//4 is now
loop-invariant (hoisted) and KH_TILE_A//4=32 is a compile-time constant, leaving only
a cheap kt_abs*32 per tile.

ISA (DSV3 fp4, cold): bulk-body v_subbrev 13->0, addr-arith ops 55->21, body 376->318
lines; vgpr 132->134, 0 spills. No loop-carried phi copies exist in any variant
(refutes the prior 'C+B phi floor'). gemm1 small-M recovers: fp4 M=1 1.19x->1.07x,
fp4 M=8 1.12x->1.07x, a8w4 M=1 1.21x->1.12x vs compile-time; M>=128 stays at parity.
fp4 cos>=0.99, a8w4 cos>=0.9996, graph replay OK.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…vide

The fp8-A (a8w4) gemm1 remainder tail is a rolled scf.for (accm is raw f32x4 SSA
that cannot escape an scf.if, unlike the fp4 fragment-backed C). It re-derived the
A-LDS read/write slot as (full_tiles+kt_iv)%kAStages every iteration, which the
compiler lowered to the s_mul_hi_i32 0x55555556 signed-div-by-3 dance (2x per tail
tile) that the compile-time fully-unrolled kernel never emits.

UNROLL is a multiple of kAStages, so full_tiles%kAStages==0 and the first tail slot
is 0 (read) / kStages%kAStages (write). Carry both slots as scf.for loop values that
advance +1 with a cheap compare-select wrap, so the tail selects the slot instead of
re-computing the modulo. This is the tail analogue of the fp4 bulk address-divide
hoist (6158cee).

ISA (DSV3 a8w4, cold, gemm1): tail 0x55555556 slot-mod 2->0; .vgpr_count 168 (no
change), 0 spills; bulk loop and prologue/epilogue byte-unchanged; gemm2 ISA md5
identical to 16e3c0a (fix is confined to gemm1_body_v2). a8w4 cos=0.9996 (thr 0.95),
fp4 cos=0.9894 (thr 0.85). test_moe_gemm -k "(fp4 or a8w4)": 381 passed / 0 failed
(eager); the 6 graph-capture large_shape failures are pre-existing on 16e3c0a.

Note: this removes a scalar divide but does not move a8w4 small-M gemm1 latency
(M=1 1.12x, M=8 1.21x vs compile-time, unchanged) -- diagnosis showed the a8w4
bulk loop already has zero un-hoisted vector divides (the fp4 hoist covered the
shared A-gather), and the residual is the runtime-K rolled-loop back-edge amplified
by fp8's heavier per-tile body (more/larger ds_reads + mfma_scale), not address
arithmetic. fp4 unchanged (M=1 1.07x).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Collapse multi-line comment/docstring blocks added since the last cleanup
(runtime-K UNROLL=6 machinery, A-gather hoists, gemm2 perf-knob notes) to
one line each, and drop the redundant `aStagesC = aStages` alias in
gemm2_body_v2. Strictly behavior-preserving: default gemm1 and gemm2 final
ISA md5 unchanged (byte-identical to 0a0951e).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant