[Fix] Set identity block scales for CDNA4 MFMA_Scale in fp8 row-scale…#744
Open
amd-songpiao wants to merge 1 commit into
Open
[Fix] Set identity block scales for CDNA4 MFMA_Scale in fp8 row-scale…#744amd-songpiao wants to merge 1 commit into
amd-songpiao wants to merge 1 commit into
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
Fixes incorrect near-zero/scrambled output in the CDNA4 FP8 row-scale GEMM by explicitly setting the CDNA4 MFMA_Scale atom’s E8M0 block-scale state to the identity scale (rather than the default zero scale), aligning with the kernel’s design of applying row/column dequantization in the epilogue.
Changes:
- Set
scale_aandscale_bto0x7F7F7F7F(E8M0 exponent=127 → 2**0) on theMFMA_Scalemma atom used byMfma16x16x128. - Add an explanatory comment documenting why identity scales are required for these row/col-dequant-in-epilogue kernels.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
… GEMM The 8-wave fp8 row-scale GEMM (kernels/fp8_gemm_8wave.py) drives the CDNA4 scaled MFMA through the layout API via kernels/fp8_gemm_utils.Mfma16x16x128 (fx.make_mma_atom(fx.rocdl.cdna4.MFMA_Scale(16,16,128,...)) + mma_atom_call_ssa). The MFMA_Scale atom defaults its E8M0 block scales to 0, which decodes to 2**-127 ~= 0 and annihilates every partial product, so the kernel emitted scrambled / near-zero output. This kernel applies the row/column dequant in the epilogue (StoreC), so the MMA itself wants the identity scale: each of the four per-32-block E8M0 bytes set to 127 (2**0 == 1.0), i.e. 0x7F7F7F7F for both A and B. Set it explicitly on the atom via set_value. Verified with tests/kernels/test_fp8_gemm_rowscale.py::test_fp8_gemm_8wave (verify_output rtol=0.1 atol=0.1) across multiple shapes, static and dynamic weight/scale paths.
9f8d90a to
d617e9e
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
GEMM
The 8-wave fp8 row-scale GEMM (kernels/fp8_gemm_8wave.py) drives the CDNA4 scaled MFMA through the layout API via kernels/fp8_gemm_utils.Mfma16x16x128 (fx.make_mma_atom(fx.rocdl.cdna4.MFMA_Scale(16,16,128,...)) + mma_atom_call_ssa). The MFMA_Scale atom defaults its E8M0 block scales to 0, which decodes to 2**-127 ~= 0 and annihilates every partial product, so the kernel emitted scrambled / near-zero output.
This kernel applies the row/column dequant in the epilogue (StoreC), so the MMA itself wants the identity scale: each of the four per-32-block E8M0 bytes set to 127 (2**0 == 1.0), i.e. 0x7F7F7F7F for both A and B. Set it explicitly on the atom via set_value.
Verified with tests/kernels/test_fp8_gemm_rowscale.py::test_fp8_gemm_8wave (verify_output rtol=0.1 atol=0.1) across multiple shapes, static and dynamic weight/scale paths.