Skip to content

feat(moe): layout-API MXFP4 (a4w4/a8w4) MoE gemm (opus-sort only)#752

Closed
coderfeli wants to merge 2 commits into
mainfrom
mxfp4-moe-gemm
Closed

feat(moe): layout-API MXFP4 (a4w4/a8w4) MoE gemm (opus-sort only)#752
coderfeli wants to merge 2 commits into
mainfrom
mxfp4-moe-gemm

Conversation

@coderfeli

@coderfeli coderfeli commented Jun 26, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds a layout-API MXFP4 MoE gemm (a4w4 + a8w4) that consumes the standard
opus sort contract from moe_sorting_kernel directly — no fused-sort extras
(m_indices / reverse_sorted) needed:

  • gemm1 (up/gate) gathers A rows via sorted_token_ids & 0xFFFFFF
  • gemm2 (down-proj) scatters per sorted row via global.atomic.fadd weighted
    by sorted_weights

Data movement (B / B-scale load, scaled-MFMA) runs through the FlyDSL layout API
(fx.copy + fx.gemm); the A-side LDS staging, e8m0 scale math, and atomic
epilogue are raw shared helpers. The nt B-load uses the cache_modifier copy-atom
hint already in main.

Files (kernels/)

  • mxfp4_moe_layout.py — layout-API building blocks
  • mxfp4_moe_common.py — shared raw helpers / constants / K-derived size formulas / atomic bf16 epilogue
  • mxfp4_moe_gemm1.py — BM32 up/gate (a4w4 + a8w4, interleave + separated, nt/cached, out fp4/fp8)
  • mxfp4_moe_gemm2.py — BM32 atomic down-proj (a4w4 + a8w4 fp8 input)
  • mxfp4_moe_gemm_2stage.py — public API + host launchers

Test

Wires the a4w4/a8w4 cases of tests/kernels/test_moe_gemm.py::test_moe_gemm_2stage
to the new pipeline (opus sort → gemm1 → gemm2 atomic) vs an independent
dequant-MoE reference; a8w4 added to the in_dtype matrix.

Validation (gfx950)

  • Chain cosine vs dequant reference: a4w4 = 0.988, a8w4 = 1.000 (interleave + separated)
  • test_moe_gemm_2stage fp4/a8w4 over FP4-S/M/L: pass

Notes

  • Scope is the BM32 atomic surface (KIMI-style decode/mid). This is intended to
    complement, not replace, mixed_moe_gemm_2stage: a4w4 gemm1 is ~10% faster,
    but the block-sparse atomic gemm2 pays a per-expert padding cost on
    high-expert/low-token decode, where the dense mixed_moe gemm2 still wins — so
    mixed_moe_gemm_2stage is left in place.

🤖 Generated with Claude Code

coderfeli and others added 2 commits June 26, 2026 04:40
Add an optional cacheModifier param (DefaultValuedParameter, default 0) to
the cdna3.buffer_copy atom type, forwarded as the rocdl raw buffer
load/store `cachepolicy`/aux operand (0=cached, 2=non-temporal). Lets the
layout-API fx.copy express non-temporal B-weight loads, which previously
required a raw buffer_load(cache_modifier=2) fallback.

- CopyAtom.td: new param + optional assembly group (`<128>` still parses;
  `<128, cache = 2>` when set).
- CDNA3/CopyAtom.cpp: thread getCacheModifier() into the load+store aux.
- FlyROCDLExtension.cpp: BufferCopy get() gains cache_modifier=0 kwarg.
- universal.py: BufferCopy/BufferCopy{8,16,32,64,128}b accept cache_modifier.

Backward compatible (default 0 == prior behavior).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Add a layout-API MXFP4 MoE up/gate + down-proj gemm that consumes the
standard (opus) sort contract from moe_sorting_kernel directly -- gemm1
gathers A rows via sorted_token_ids & 0xFFFFFF; gemm2 scatters per sorted
row via global atomic add weighted by sorted_weights. No fused-sort extras
(m_indices / reverse_sorted) are needed.

  kernels/mxfp4_moe_layout.py    - layout-API building blocks (fx.copy B/B-scale,
                                   fx.gemm scaled-MFMA atoms)
  kernels/mxfp4_moe_common.py    - shared raw helpers / constants / K-derived size
                                   formulas / atomic bf16 epilogue
  kernels/mxfp4_moe_gemm1.py     - BM32 up/gate gemm (a4w4 + a8w4, interleave +
                                   separated, nt/cached, out fp4/fp8)
  kernels/mxfp4_moe_gemm2.py     - BM32 atomic down-proj (a4w4 + a8w4 fp8 input)
  kernels/mxfp4_moe_gemm_2stage.py - public API + host launchers

Wire a4w4/a8w4 of tests/kernels/test_moe_gemm.py::test_moe_gemm_2stage to the
new pipeline (opus sort -> gemm1 -> gemm2 atomic) vs an independent dequant-MoE
reference; a8w4 added to the in_dtype matrix.

Validated on gfx950: chain cosine a4w4=0.988, a8w4=1.000 (interleave + separated);
test_moe_gemm_2stage fp4/a8w4 over FP4-S/M/L pass.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@coderfeli

Copy link
Copy Markdown
Collaborator Author

Superseded by #753 (rebased onto main after the cache_modifier C++ landed; GitHub blocks reopening a force-pushed PR). The moe-gemm change is unchanged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant