Skip to content

[Fix] Set identity block scales for CDNA4 MFMA_Scale in fp8 row-scale…#744

Open
amd-songpiao wants to merge 1 commit into
ROCm:mainfrom
amd-songpiao:fix/mfma-scale-identity-block-scales
Open

[Fix] Set identity block scales for CDNA4 MFMA_Scale in fp8 row-scale…#744
amd-songpiao wants to merge 1 commit into
ROCm:mainfrom
amd-songpiao:fix/mfma-scale-identity-block-scales

Conversation

@amd-songpiao

Copy link
Copy Markdown

GEMM

The 8-wave fp8 row-scale GEMM (kernels/fp8_gemm_8wave.py) drives the CDNA4 scaled MFMA through the layout API via kernels/fp8_gemm_utils.Mfma16x16x128 (fx.make_mma_atom(fx.rocdl.cdna4.MFMA_Scale(16,16,128,...)) + mma_atom_call_ssa). The MFMA_Scale atom defaults its E8M0 block scales to 0, which decodes to 2**-127 ~= 0 and annihilates every partial product, so the kernel emitted scrambled / near-zero output.

This kernel applies the row/column dequant in the epilogue (StoreC), so the MMA itself wants the identity scale: each of the four per-32-block E8M0 bytes set to 127 (2**0 == 1.0), i.e. 0x7F7F7F7F for both A and B. Set it explicitly on the atom via set_value.

Verified with tests/kernels/test_fp8_gemm_rowscale.py::test_fp8_gemm_8wave (verify_output rtol=0.1 atol=0.1) across multiple shapes, static and dynamic weight/scale paths.

Copilot AI review requested due to automatic review settings June 25, 2026 11:11

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes incorrect near-zero/scrambled output in the CDNA4 FP8 row-scale GEMM by explicitly setting the CDNA4 MFMA_Scale atom’s E8M0 block-scale state to the identity scale (rather than the default zero scale), aligning with the kernel’s design of applying row/column dequantization in the epilogue.

Changes:

  • Set scale_a and scale_b to 0x7F7F7F7F (E8M0 exponent=127 → 2**0) on the MFMA_Scale mma atom used by Mfma16x16x128.
  • Add an explanatory comment documenting why identity scales are required for these row/col-dequant-in-epilogue kernels.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

… GEMM

The 8-wave fp8 row-scale GEMM (kernels/fp8_gemm_8wave.py) drives the CDNA4
scaled MFMA through the layout API via kernels/fp8_gemm_utils.Mfma16x16x128
(fx.make_mma_atom(fx.rocdl.cdna4.MFMA_Scale(16,16,128,...)) +
mma_atom_call_ssa). The MFMA_Scale atom defaults its E8M0 block scales to 0,
which decodes to 2**-127 ~= 0 and annihilates every partial product, so the
kernel emitted scrambled / near-zero output.

This kernel applies the row/column dequant in the epilogue (StoreC), so the
MMA itself wants the identity scale: each of the four per-32-block E8M0 bytes
set to 127 (2**0 == 1.0), i.e. 0x7F7F7F7F for both A and B. Set it explicitly
on the atom via set_value.

Verified with tests/kernels/test_fp8_gemm_rowscale.py::test_fp8_gemm_8wave
(verify_output rtol=0.1 atol=0.1) across multiple shapes, static and dynamic
weight/scale paths.
@amd-songpiao amd-songpiao force-pushed the fix/mfma-scale-identity-block-scales branch from 9f8d90a to d617e9e Compare June 25, 2026 11:27
@coderfeli coderfeli requested a review from amd-cgilli June 26, 2026 10:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants