feat(moe): layout-API MXFP4 (a4w4/a8w4) MoE gemm (opus-sort only)#752
Closed
coderfeli wants to merge 2 commits into
Closed
feat(moe): layout-API MXFP4 (a4w4/a8w4) MoE gemm (opus-sort only)#752coderfeli wants to merge 2 commits into
coderfeli wants to merge 2 commits into
Conversation
Add an optional cacheModifier param (DefaultValuedParameter, default 0) to
the cdna3.buffer_copy atom type, forwarded as the rocdl raw buffer
load/store `cachepolicy`/aux operand (0=cached, 2=non-temporal). Lets the
layout-API fx.copy express non-temporal B-weight loads, which previously
required a raw buffer_load(cache_modifier=2) fallback.
- CopyAtom.td: new param + optional assembly group (`<128>` still parses;
`<128, cache = 2>` when set).
- CDNA3/CopyAtom.cpp: thread getCacheModifier() into the load+store aux.
- FlyROCDLExtension.cpp: BufferCopy get() gains cache_modifier=0 kwarg.
- universal.py: BufferCopy/BufferCopy{8,16,32,64,128}b accept cache_modifier.
Backward compatible (default 0 == prior behavior).
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Add a layout-API MXFP4 MoE up/gate + down-proj gemm that consumes the
standard (opus) sort contract from moe_sorting_kernel directly -- gemm1
gathers A rows via sorted_token_ids & 0xFFFFFF; gemm2 scatters per sorted
row via global atomic add weighted by sorted_weights. No fused-sort extras
(m_indices / reverse_sorted) are needed.
kernels/mxfp4_moe_layout.py - layout-API building blocks (fx.copy B/B-scale,
fx.gemm scaled-MFMA atoms)
kernels/mxfp4_moe_common.py - shared raw helpers / constants / K-derived size
formulas / atomic bf16 epilogue
kernels/mxfp4_moe_gemm1.py - BM32 up/gate gemm (a4w4 + a8w4, interleave +
separated, nt/cached, out fp4/fp8)
kernels/mxfp4_moe_gemm2.py - BM32 atomic down-proj (a4w4 + a8w4 fp8 input)
kernels/mxfp4_moe_gemm_2stage.py - public API + host launchers
Wire a4w4/a8w4 of tests/kernels/test_moe_gemm.py::test_moe_gemm_2stage to the
new pipeline (opus sort -> gemm1 -> gemm2 atomic) vs an independent dequant-MoE
reference; a8w4 added to the in_dtype matrix.
Validated on gfx950: chain cosine a4w4=0.988, a8w4=1.000 (interleave + separated);
test_moe_gemm_2stage fp4/a8w4 over FP4-S/M/L pass.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
4117bc9 to
9d18d29
Compare
Collaborator
Author
|
Superseded by #753 (rebased onto main after the cache_modifier C++ landed; GitHub blocks reopening a force-pushed PR). The moe-gemm change is unchanged. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds a layout-API MXFP4 MoE gemm (a4w4 + a8w4) that consumes the standard
opus sort contract from
moe_sorting_kerneldirectly — no fused-sort extras(
m_indices/reverse_sorted) needed:sorted_token_ids & 0xFFFFFFglobal.atomic.faddweightedby
sorted_weightsData movement (B / B-scale load, scaled-MFMA) runs through the FlyDSL layout API
(
fx.copy+fx.gemm); the A-side LDS staging, e8m0 scale math, and atomicepilogue are raw shared helpers. The nt B-load uses the
cache_modifiercopy-atomhint already in
main.Files (
kernels/)mxfp4_moe_layout.py— layout-API building blocksmxfp4_moe_common.py— shared raw helpers / constants / K-derived size formulas / atomic bf16 epiloguemxfp4_moe_gemm1.py— BM32 up/gate (a4w4 + a8w4, interleave + separated, nt/cached, out fp4/fp8)mxfp4_moe_gemm2.py— BM32 atomic down-proj (a4w4 + a8w4 fp8 input)mxfp4_moe_gemm_2stage.py— public API + host launchersTest
Wires the a4w4/a8w4 cases of
tests/kernels/test_moe_gemm.py::test_moe_gemm_2stageto the new pipeline (opus sort → gemm1 → gemm2 atomic) vs an independent
dequant-MoE reference;
a8w4added to thein_dtypematrix.Validation (gfx950)
test_moe_gemm_2stagefp4/a8w4 over FP4-S/M/L: passNotes
complement, not replace,
mixed_moe_gemm_2stage: a4w4 gemm1 is ~10% faster,but the block-sparse atomic gemm2 pays a per-expert padding cost on
high-expert/low-token decode, where the dense
mixed_moegemm2 still wins — somixed_moe_gemm_2stageis left in place.🤖 Generated with Claude Code