diff --git a/kernels/fp8_gemm_utils.py b/kernels/fp8_gemm_utils.py index 4b3462320..d4b7efa71 100644 --- a/kernels/fp8_gemm_utils.py +++ b/kernels/fp8_gemm_utils.py @@ -207,6 +207,12 @@ def wait_barrier(count): class Mfma16x16x128: def __init__(self, n_tiles_a, n_tiles_b): self.atom = fx.make_mma_atom(fx.rocdl.cdna4.MFMA_Scale(16, 16, 128, fx.Float8E4M3FN)) + # The CDNA4 scaled-MFMA atom defaults its E8M0 block scales to 0 + # (== 2**-127 ~= 0), which annihilates every product. This kernel does + # the row/col dequant in the epilogue (StoreC), so the MMA itself wants + # the identity scale: each of the 4 per-32-block E8M0 bytes set to 127 + # (2**0 == 1.0), i.e. 0x7F7F7F7F for both A and B. + self.atom = self.atom.set_value({"scale_a": fx.Int32(0x7F7F7F7F), "scale_b": fx.Int32(0x7F7F7F7F)}) self.accum_type = Vec.make_type(4, fx.Float32) self.zero_value = Vec.filled(4, 0.0, fx.Float32) self.n_tiles_a = n_tiles_a