diff --git a/kernels/mxfp4_preshuffle.py b/kernels/mxfp4_preshuffle.py index d5944d538..8f567316c 100644 --- a/kernels/mxfp4_preshuffle.py +++ b/kernels/mxfp4_preshuffle.py @@ -1,12 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2025 FlyDSL Project Contributors -"""MXFP4 (E2M1) preshuffle GEMM, per-32 E8M0 scales consumed inside a scaled -16x16x128 MFMA. Data layout matches ``tests/kernels/utils/fp4_utils`` (CK weight -preshuffle ``shuffle_weight_w4(.,16)`` + ``shuffle_scale_w4``). +"""MXFP4 (E2M1) and MXFP6 (E2M3) preshuffle GEMM, per-32 E8M0 scales consumed +inside a scaled 16x16x128 MFMA. Data layout matches +``tests/kernels/utils/fp4_utils`` (CK weight preshuffle ``shuffle_weight_w4(.,16)`` ++ ``shuffle_scale_w4``). The MMA runs via ``fx.gemm`` over rank-1 register fragments: the per-32 E8M0 word -rides ``scale_a=/scale_b=`` and the ``(opsel_a, opsel_b)`` atom selects the packed byte. +rides ``scale_a=/scale_b=`` and the ``(opsel_a, opsel_b)`` atom selects the packed +byte. """ from typing import Optional @@ -16,7 +18,7 @@ from flydsl._mlir import ir from flydsl._mlir.dialects import fly from flydsl.expr import arith, buffer_ops, const_expr, gpu, range_constexpr, rocdl -from flydsl.expr.typing import BFloat16, Float4E2M1FN, Float16, Float32, Int8, Int32, T +from flydsl.expr.typing import BFloat16, Float4E2M1FN, Float6E2M3FN, Float16, Float32, Int8, Int32, T from flydsl.expr.typing import Vector as Vec @@ -26,10 +28,17 @@ def _raw(v): return v -def _scale_mma_atoms(): - """16 (opsel_a, opsel_b) scaled-MFMA atoms (opsel is a type param).""" +def _scale_mma_atoms(a_dtype: str = "fp4"): + """16 (opsel_a, opsel_b) scaled-MFMA atoms (opsel is a type param). + + a_dtype='fp4': fp4×fp4 (Float4E2M1FN for both A and B). + a_dtype='fp6': fp6×fp4 (Float6E2M3FN for A, Float4E2M1FN for B). + """ + elem_a = Float6E2M3FN if a_dtype == "fp6" else Float4E2M1FN return { - (osa, osb): fx.make_mma_atom(fx.rocdl.cdna4.MFMA_Scale(16, 16, 128, Float4E2M1FN, opsel_a=osa, opsel_b=osb)) + (osa, osb): fx.make_mma_atom( + fx.rocdl.cdna4.MFMA_Scale(16, 16, 128, elem_a, Float4E2M1FN, opsel_a=osa, opsel_b=osb) + ) for osa in range(4) for osb in range(4) } @@ -47,32 +56,46 @@ def _bq_view(arg_bq_addr, row_elems, KH4, k_tiles, k_halves): return fx.rocdl.make_buffer_tensor(view, max_size=False) -def compile_mxfp4_gemm( +def _compile_mxfp_blockscale_gemm( *, N: int, K: int, - tile_m: int, - tile_n: int, - tile_k: int, + BM: int, + BN: int, + BK: int, + a_dtype: str, out_dtype: str = "bf16", waves_per_eu: Optional[int] = None, enable_scheduler: Optional[bool] = None, + use_async_copy: bool = True, dsrd_preload: int = -1, dvmem_preload: int = -1, - use_async_copy: bool = False, ): - """Compile MXFP4 (A4W4) preshuffle GEMM -> fn(C, A, B, scale_a, scale_b, bias, M, N, stream). + """Shared implementation for MXFP4 (a_dtype='fp4') and MXFP6 (a_dtype='fp6') preshuffle GEMM. - A: MXFP4 (E2M1), 2 codes/byte. B: CK-preshuffled MXFP4. scale_a/scale_b are e8m0; - C is (M, N) out_dtype; bias unused (parity). + Returns fn(C, A, B, scale_a, scale_b, bias, M, N, stream). """ - BM, BN, BK = tile_m, tile_n, tile_k if BK not in (128, 256) or K % BK != 0: raise ValueError(f"tile_k must be 128 or 256 dividing K; got tile_k={BK}, K={K}") if K % 256 != 0: raise ValueError(f"K must be a multiple of 256 (e8m0 scale chunk); got K={K}") out_elem = BFloat16 if out_dtype == "bf16" else Float16 + # A dtype-specific row sizes + if a_dtype == "fp6": + # FP8-padded fp6: 1 byte per code + a_row_bytes = K + A_ROW_B = BK + else: + # fp4: 2 codes/byte + a_row_bytes = K // 2 # A bytes per full M-row + A_ROW_B = BK // 2 # A bytes per row in a K-tile + + # Cooperative LDS A tile (row-major [m][col]) shared by the 4 N-waves -> no 4x + # redundant A gmem reads. fp4/fp6 = 2/1 codes/byte. + A_LDS_B = BM * A_ROW_B # bytes per LDS A buffer + A_ROW_I32 = A_ROW_B // 4 + K_HALF = K // 2 KH4 = K_HALF // 4 K_TILES = K // BK @@ -85,11 +108,6 @@ def compile_mxfp4_gemm( _scale_chunk_dw = (K // 32 // 4 // 2) * 64 # e8m0 strides (dwords), per shuffle_scale_w4 _scale_k0_dw = 64 - # Cooperative LDS A tile (row-major [m][col]) shared by the 4 N-waves -> no 4x - # redundant A gmem reads. fp4 = 2 codes/byte. - a_row_bytes = K // 2 # A bytes per full M-row - A_ROW_B = BK // 2 # A bytes per row in a K-tile - A_LDS_B = BM * A_ROW_B # bytes per LDS A buffer n_coop = A_LDS_B // 256 // 16 # 16B cooperative loads per thread n_pairs = max(1, num_acc_n // 2) @@ -97,7 +115,11 @@ def compile_mxfp4_gemm( # Scheduler counts (sched_group_barrier interleave), per loop iter. sched_mfma_total = k_halves * m_chunks * num_acc_n - sched_num_ds_load = m_chunks * k_halves # A LDS reads/thread (read_a) + # fp6: two 128-bit LDS reads per (mi, kh); fp4: one + if a_dtype == "fp6": + sched_num_ds_load = m_chunks * k_halves * 2 # A LDS reads/thread (read_a) + else: + sched_num_ds_load = m_chunks * k_halves # A LDS reads/thread (read_a) sched_num_gmem = n_coop + num_acc_n * k_halves + m_pairs + n_pairs # A coop + B + scales sched_num_a_dswr = 0 if use_async_copy else n_coop # A LDS writes/thread (none for DMA) @@ -127,7 +149,7 @@ def kernel_gemm( i32_m: fx.Int32, i32_n: fx.Int32, ): - scale_atoms = _scale_mma_atoms() + scale_atoms = _scale_mma_atoms(a_dtype) tid = fx.thread_idx.x bid_x, bid_y, _ = fx.block_idx @@ -151,12 +173,16 @@ def kernel_gemm( a_flat_div = fx.logical_divide(a_flat, fx.make_layout(1, 1)) lds = fx.SharedAllocator().allocate(SharedA).peek() # A-LDS modeled as i32 (16B = 4 i32): fx.copy is dtype-agnostic, only the MMA - # cares about sub-byte semantics. Store + fp4 read go through fx.copy. + # cares about sub-byte semantics. Store + fp4/fp6 read go through fx.copy. sA0_i32 = fx.recast_iter(Int32, lds.a0.ptr) lds_db = fx.Int32(fx.ptrtoint(lds.a1.ptr)) - fx.Int32(fx.ptrtoint(lds.a0.ptr)) # ping/pong byte stride lds_db_i32 = lds_db // fx.Int32(4) lds_copy = fx.make_copy_atom(fx.UniversalCopy128b(), Int32) - A_ROW_I32 = A_ROW_B // 4 + + if const_expr(use_async_copy): + dma_atom = fx.make_copy_atom(fx.rocdl.BufferCopyLDS128b(), 128) + _i8s = fx.PointerType.get(Int8.ir_type, fx.AddressSpace.Shared, 512) + sA0_i8 = fx.recast_iter(_i8s, lds.a0.ptr) def _iter_of(parity): # parity in {0,1} (runtime) -> i32 LDS iterator return fx.add_offset(sA0_i32, parity * lds_db_i32) @@ -177,11 +203,6 @@ def coop_load_a(kt, base_iter): # Async A: direct gmem->LDS DMA (buffer_load_lds), same row-major LDS layout as # coop_load_a. Issued after the B/scale loads so it overlaps the MFMAs. - if const_expr(use_async_copy): - dma_atom = fx.make_copy_atom(fx.rocdl.BufferCopyLDS128b(), 128) - _i8s = fx.PointerType.get(Int8.ir_type, fx.AddressSpace.Shared, 512) - sA0_i8 = fx.recast_iter(_i8s, lds.a0.ptr) - def dma_a_to_lds(kt, parity): base_off = rocdl.readfirstlane(T.i32, parity * lds_db + wave * fx.Int32(64 * 16)) lds_ptr = fx.add_offset(sA0_i8, base_off) @@ -204,17 +225,48 @@ def _read16(base_iter, off_i32): return t def read_a(parity): - # Each lane's K=128 A operand = 16 B (k-group strides 16 B, 128-K half 64 B). base_iter = _iter_of(parity) av = [] - for mi in range_constexpr(m_chunks): - for kh in range_constexpr(k_halves): - off = ( - (fx.Int32(mi * 16) + lane_mod_16) * fx.Int32(A_ROW_I32) - + fx.Int32(kh * 16) - + lane_div_16 * fx.Int32(4) - ) - av.append(_read16(base_iter, off)) + if const_expr(a_dtype == "fp6"): + # Read 8 DWORDs per (mi, kh), store first 6 into i32[6] (discard zero-pad). + for mi in range_constexpr(m_chunks): + for kh in range_constexpr(k_halves): + off = ( + (fx.Int32(mi * 16) + lane_mod_16) * fx.Int32(A_ROW_I32) + + fx.Int32(kh * 32) + + lane_div_16 * fx.Int32(8) + ) + t_lo = fx.make_rmem_tensor(4, Int32) + t_hi = fx.make_rmem_tensor(4, Int32) + fx.copy(lds_copy, _lds_view(base_iter, off), t_lo) + fx.copy(lds_copy, _lds_view(base_iter, off + fx.Int32(4)), t_hi) + v_lo = Vec(fx.memref_load_vec(t_lo)) + v_hi = Vec(fx.memref_load_vec(t_hi)) + t = fx.make_rmem_tensor(6, Int32) + t.store( + Vec.from_elements( + [ + _raw(v_lo[0]), + _raw(v_lo[1]), + _raw(v_lo[2]), + _raw(v_lo[3]), + _raw(v_hi[0]), + _raw(v_hi[1]), + ] + ) + ) + av.append(t) + else: + # fp4: one 128-bit LDS read per (mi, kh) -> i32[4] + # Each lane's K=128 A operand = 16 B (k-group strides 16 B, 128-K half 64 B). + for mi in range_constexpr(m_chunks): + for kh in range_constexpr(k_halves): + off = ( + (fx.Int32(mi * 16) + lane_mod_16) * fx.Int32(A_ROW_I32) + + fx.Int32(kh * 16) + + lane_div_16 * fx.Int32(4) + ) + av.append(_read16(base_iter, off)) return av n_col_base = by_n + wave * fx.Int32(BN // 4) @@ -301,7 +353,7 @@ def compute(accs, av, bv, sa_v, sb_v, scale_shift=None): # kh OUTERMOST: consecutive MFMAs write distinct accumulators (dense issue), # spacing the per-acc accumulation dependency across the (mi,ni) grid. Each # scaled MFMA = fx.gemm over the rank-1 i32[4] A/B fragments (one MmaAtomCall); - # the atom bitcasts to fp4 and the e8m0 word rides scale_a=/scale_b=. + # the atom bitcasts to fp4/fp6 and the e8m0 word rides scale_a=/scale_b=. c_frags = [fx.make_rmem_tensor(4, Float32) for _ in range_constexpr(n_acc)] for idx in range_constexpr(n_acc): c_frags[idx].store(Vec(accs[idx])) @@ -473,3 +525,242 @@ def launch_gemm( ).launch(grid=(gx, gy, 1), block=(256, 1, 1), stream=stream) return launch_gemm + + +def compile_mxfp4_gemm( + *, + N: int, + K: int, + tile_m: int, + tile_n: int, + tile_k: int, + out_dtype: str = "bf16", + waves_per_eu: Optional[int] = None, + enable_scheduler: Optional[bool] = None, + use_async_copy: bool = True, + dsrd_preload: int = -1, + dvmem_preload: int = -1, +): + """Compile MXFP4 (A4W4) preshuffle GEMM -> fn(C, A, B, scale_a, scale_b, bias, M, N, stream). + + A: MXFP4 (E2M1), 2 codes/byte. B: CK-preshuffled MXFP4. scale_a/scale_b are e8m0; + C is (M, N) out_dtype; bias unused (parity). + """ + return _compile_mxfp_blockscale_gemm( + N=N, + K=K, + BM=tile_m, + BN=tile_n, + BK=tile_k, + a_dtype="fp4", + out_dtype=out_dtype, + waves_per_eu=waves_per_eu, + enable_scheduler=enable_scheduler, + use_async_copy=use_async_copy, + dsrd_preload=dsrd_preload, + dvmem_preload=dvmem_preload, + ) + + +# --------------------------------------------------------------------------- +# compile_mxfp6_gemm — MXFP6 (E2M3) A × MXFP4 (E2M1) B preshuffle GEMM +# --------------------------------------------------------------------------- + + +# Per-shape tile/knob overrides for compile_mxfp6_gemm. Starting from the +# A6W4_TUNED_CONFIGS shape set, adapted for compile_mxfp6_gemm constraints +# (tile_k ∈ {128, 256}; no lds_stage / k_batch yet). +# Re-tune when the kernel changes significantly. +MXFP6_TUNED_CONFIGS: dict[tuple[int, int, int], dict] = { + (32, 7168, 4608): {"tile_m": 32, "tile_n": 128, "tile_k": 256, "waves_per_eu": None}, + (32, 9216, 7168): {"tile_m": 32, "tile_n": 128, "tile_k": 256, "waves_per_eu": None}, + (32, 5120, 5120): {"tile_m": 32, "tile_n": 128, "tile_k": 256, "waves_per_eu": 1}, + (32, 12288, 4096): {"tile_m": 32, "tile_n": 128, "tile_k": 256, "waves_per_eu": 1}, + (32, 8192, 7168): {"tile_m": 32, "tile_n": 128, "tile_k": 256, "waves_per_eu": None}, + (32, 8192, 8192): {"tile_m": 32, "tile_n": 128, "tile_k": 256, "waves_per_eu": None}, + (32, 10240, 8192): {"tile_m": 32, "tile_n": 128, "tile_k": 256, "waves_per_eu": 1}, + (32, 14336, 8192): {"tile_m": 32, "tile_n": 128, "tile_k": 256, "waves_per_eu": 1}, + (32, 16384, 8192): {"tile_m": 32, "tile_n": 128, "tile_k": 256, "waves_per_eu": None}, + (32, 8192, 28672): {"tile_m": 32, "tile_n": 128, "tile_k": 256, "waves_per_eu": 1}, + (64, 7168, 4608): {"tile_m": 32, "tile_n": 128, "tile_k": 256, "waves_per_eu": None}, + (64, 9216, 7168): {"tile_m": 32, "tile_n": 128, "tile_k": 256, "waves_per_eu": 1}, + (64, 5120, 5120): {"tile_m": 32, "tile_n": 128, "tile_k": 256, "waves_per_eu": 1}, + (64, 12288, 4096): {"tile_m": 32, "tile_n": 128, "tile_k": 256, "waves_per_eu": 1}, + (64, 8192, 7168): {"tile_m": 32, "tile_n": 128, "tile_k": 256, "waves_per_eu": 1}, + (64, 8192, 8192): {"tile_m": 32, "tile_n": 128, "tile_k": 256, "waves_per_eu": None}, + (64, 10240, 8192): {"tile_m": 32, "tile_n": 256, "tile_k": 256, "waves_per_eu": 1}, + (64, 14336, 8192): {"tile_m": 32, "tile_n": 256, "tile_k": 256, "waves_per_eu": 1}, + (64, 16384, 8192): {"tile_m": 32, "tile_n": 128, "tile_k": 256, "waves_per_eu": 1}, + (64, 8192, 28672): {"tile_m": 32, "tile_n": 128, "tile_k": 256, "waves_per_eu": 1}, + (128, 7168, 4608): {"tile_m": 32, "tile_n": 128, "tile_k": 256, "waves_per_eu": None}, + (128, 9216, 7168): {"tile_m": 64, "tile_n": 128, "tile_k": 256, "waves_per_eu": 1}, + (128, 5120, 5120): {"tile_m": 32, "tile_n": 128, "tile_k": 256, "waves_per_eu": None}, + (128, 12288, 4096): {"tile_m": 64, "tile_n": 128, "tile_k": 256, "waves_per_eu": None}, + (128, 8192, 7168): {"tile_m": 64, "tile_n": 128, "tile_k": 256, "waves_per_eu": 1}, + (128, 8192, 8192): {"tile_m": 32, "tile_n": 128, "tile_k": 256, "waves_per_eu": None}, + (128, 10240, 8192): {"tile_m": 64, "tile_n": 256, "tile_k": 256, "waves_per_eu": None}, + (128, 14336, 8192): {"tile_m": 64, "tile_n": 256, "tile_k": 256, "waves_per_eu": 1}, + (128, 16384, 8192): {"tile_m": 64, "tile_n": 128, "tile_k": 256, "waves_per_eu": None}, + (128, 8192, 28672): {"tile_m": 32, "tile_n": 128, "tile_k": 256, "waves_per_eu": 1}, + (256, 7168, 4608): {"tile_m": 64, "tile_n": 128, "tile_k": 256, "waves_per_eu": 1}, + (256, 9216, 7168): {"tile_m": 128, "tile_n": 128, "tile_k": 256, "waves_per_eu": 1}, + (256, 5120, 5120): {"tile_m": 64, "tile_n": 128, "tile_k": 256, "waves_per_eu": None}, + (256, 12288, 4096): {"tile_m": 128, "tile_n": 128, "tile_k": 256, "waves_per_eu": 1}, + (256, 8192, 7168): {"tile_m": 64, "tile_n": 128, "tile_k": 256, "waves_per_eu": None}, + (256, 8192, 8192): {"tile_m": 64, "tile_n": 128, "tile_k": 256, "waves_per_eu": 1}, + (256, 10240, 8192): {"tile_m": 128, "tile_n": 128, "tile_k": 256, "waves_per_eu": 1}, + (256, 14336, 8192): {"tile_m": 128, "tile_n": 128, "tile_k": 256, "waves_per_eu": None}, + (256, 16384, 8192): {"tile_m": 128, "tile_n": 128, "tile_k": 256, "waves_per_eu": 1}, + (256, 8192, 28672): {"tile_m": 64, "tile_n": 128, "tile_k": 256, "waves_per_eu": None}, + (512, 7168, 4608): {"tile_m": 128, "tile_n": 128, "tile_k": 128, "waves_per_eu": 1}, + (512, 9216, 7168): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": None}, + (512, 5120, 5120): {"tile_m": 128, "tile_n": 128, "tile_k": 128, "waves_per_eu": 1}, + (512, 12288, 4096): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": 1}, + (512, 8192, 7168): {"tile_m": 128, "tile_n": 128, "tile_k": 128, "waves_per_eu": 1}, + (512, 8192, 8192): {"tile_m": 128, "tile_n": 128, "tile_k": 128, "waves_per_eu": 1}, + (512, 10240, 8192): {"tile_m": 128, "tile_n": 128, "tile_k": 128, "waves_per_eu": 1}, + (512, 14336, 8192): {"tile_m": 128, "tile_n": 128, "tile_k": 128, "waves_per_eu": None}, + (512, 16384, 8192): {"tile_m": 128, "tile_n": 128, "tile_k": 128, "waves_per_eu": 1}, + (512, 8192, 28672): {"tile_m": 128, "tile_n": 128, "tile_k": 256, "waves_per_eu": None}, + (1024, 7168, 4608): {"tile_m": 64, "tile_n": 256, "tile_k": 256, "waves_per_eu": 1}, + (1024, 9216, 7168): {"tile_m": 64, "tile_n": 256, "tile_k": 128, "waves_per_eu": None}, + (1024, 5120, 5120): {"tile_m": 64, "tile_n": 128, "tile_k": 256, "waves_per_eu": 1}, + (1024, 12288, 4096): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": 1}, + (1024, 8192, 7168): {"tile_m": 64, "tile_n": 256, "tile_k": 256, "waves_per_eu": 1}, + (1024, 8192, 8192): {"tile_m": 64, "tile_n": 256, "tile_k": 256, "waves_per_eu": 1}, + (1024, 10240, 8192): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": 1}, + (1024, 14336, 8192): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": 1}, + (1024, 16384, 8192): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": None}, + (1024, 8192, 28672): {"tile_m": 128, "tile_n": 256, "tile_k": 256, "waves_per_eu": 1}, + (2048, 7168, 4608): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": None}, + (2048, 9216, 7168): {"tile_m": 64, "tile_n": 256, "tile_k": 256, "waves_per_eu": None}, + (2048, 5120, 5120): {"tile_m": 64, "tile_n": 256, "tile_k": 256, "waves_per_eu": None}, + (2048, 12288, 4096): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": 1}, + (2048, 8192, 7168): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": 1}, + (2048, 8192, 8192): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": 1}, + (2048, 10240, 8192): {"tile_m": 64, "tile_n": 256, "tile_k": 128, "waves_per_eu": None}, + (2048, 14336, 8192): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": None}, + (2048, 16384, 8192): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": 1}, + (2048, 8192, 28672): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": None}, + (4096, 7168, 4608): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": None}, + (4096, 9216, 7168): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": None}, + (4096, 5120, 5120): {"tile_m": 64, "tile_n": 256, "tile_k": 128, "waves_per_eu": None}, + (4096, 12288, 4096): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": None}, + (4096, 8192, 7168): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": 1}, + (4096, 8192, 8192): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": None}, + (4096, 10240, 8192): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": None}, + (4096, 14336, 8192): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": None}, + (4096, 16384, 8192): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": None}, + (4096, 8192, 28672): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": None}, + (8192, 7168, 4608): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": 1}, + (8192, 9216, 7168): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": None}, + (8192, 5120, 5120): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": None}, + (8192, 12288, 4096): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": 1}, + (8192, 8192, 7168): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": 1}, + (8192, 8192, 8192): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": 1}, + (8192, 10240, 8192): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": None}, + (8192, 14336, 8192): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": None}, + (8192, 16384, 8192): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": None}, + (8192, 8192, 28672): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": None}, + (32, 6144, 4096): {"tile_m": 32, "tile_n": 128, "tile_k": 256, "waves_per_eu": 1}, + (32, 4096, 4096): {"tile_m": 32, "tile_n": 128, "tile_k": 256, "waves_per_eu": None}, + (64, 6144, 4096): {"tile_m": 32, "tile_n": 128, "tile_k": 256, "waves_per_eu": None}, + (64, 4096, 4096): {"tile_m": 32, "tile_n": 128, "tile_k": 256, "waves_per_eu": None}, + (128, 6144, 4096): {"tile_m": 32, "tile_n": 256, "tile_k": 256, "waves_per_eu": None}, + (256, 6144, 4096): {"tile_m": 64, "tile_n": 256, "tile_k": 256, "waves_per_eu": None}, + (256, 4096, 4096): {"tile_m": 64, "tile_n": 256, "tile_k": 256, "waves_per_eu": 1}, + (512, 6144, 4096): {"tile_m": 128, "tile_n": 128, "tile_k": 256, "waves_per_eu": None}, + (1024, 6144, 4096): {"tile_m": 64, "tile_n": 128, "tile_k": 128, "waves_per_eu": 1}, + (2048, 6144, 4096): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": 1}, + (4096, 6144, 4096): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": None}, + (128, 8192, 5120): {"tile_m": 32, "tile_n": 128, "tile_k": 256, "waves_per_eu": None}, + (256, 8192, 5120): {"tile_m": 64, "tile_n": 128, "tile_k": 256, "waves_per_eu": None}, + (512, 8192, 5120): {"tile_m": 128, "tile_n": 128, "tile_k": 256, "waves_per_eu": 1}, + (1024, 8192, 5120): {"tile_m": 64, "tile_n": 256, "tile_k": 128, "waves_per_eu": 1}, + (2048, 8192, 5120): {"tile_m": 128, "tile_n": 256, "tile_k": 128, "waves_per_eu": 1}, +} + + +def _pick_mxfp6_tiles(M: int, N: int, K: int) -> tuple[int, int, int]: + """Heuristic fallback tile selection for compile_mxfp6_gemm. + + Prefer MXFP6_TUNED_CONFIGS for known shapes; use this as the fallback. + """ + if M >= 512 and M % 128 == 0: + tile_m = 128 + elif M >= 256 and M % 64 == 0: + tile_m = 64 + else: + tile_m = 32 + tile_n = 256 if N % 256 == 0 and (M >= 512 or N >= 14336) else 128 + tile_k = 128 if M >= 512 and K < 28672 else 256 + return tile_m, tile_n, tile_k + + +def _pick_mxfp6_config(M: int, N: int, K: int) -> dict: + """Return the best known compile_mxfp6_gemm knobs for (M, N, K). + + On gfx950 consults MXFP6_TUNED_CONFIGS first (entries are gfx950-specific); + on other architectures falls straight through to _pick_mxfp6_tiles. + Returns a dict with keys: tile_m, tile_n, tile_k, waves_per_eu. + """ + from flydsl.runtime.device import get_rocm_arch + + tile_m, tile_n, tile_k = _pick_mxfp6_tiles(M, N, K) + cfg = {"tile_m": tile_m, "tile_n": tile_n, "tile_k": tile_k, "waves_per_eu": None} + if str(get_rocm_arch()).startswith("gfx950"): + cfg.update(MXFP6_TUNED_CONFIGS.get((M, N, K), {})) + return cfg + + +def compile_mxfp6_gemm( + *, + N: int, + K: int, + M_hint: int, + tile_m: Optional[int] = None, + tile_n: Optional[int] = None, + tile_k: Optional[int] = None, + out_dtype: str = "bf16", + waves_per_eu: Optional[int] = None, + enable_scheduler: Optional[bool] = None, + use_async_copy: bool = True, + dsrd_preload: int = -1, + dvmem_preload: int = -1, +): + """Compile MXFP6×MXFP4 (A6W4) preshuffle GEMM. + + Same signature as compile_mxfp4_gemm: + fn(C, A, B, scale_a, scale_b, bias, M, N, stream) + + A: MXFP6 E2M3, tight-packed fp6 (pack_fp6_e2m3 layout, 24 B per K=32 + chunk) + 8 B zero pad = 32 B per chunk. scale_a/scale_b: E8M0 per-32. + B: CK-preshuffled MXFP4 E2M1. bias unused (parity with compile_mxfp4_gemm). + + M_hint is used for tile selection when tile_m/tile_n/tile_k are not given. + Tile defaults come from MXFP6_TUNED_CONFIGS (falling back to _pick_mxfp6_tiles). + Only supported on gfx950 (CDNA4, has mfma.scale.f32.16x16x128.f8f6f4). + MXFP6_TUNED_CONFIGS entries are gfx950-specific; other architectures fall + back to the _pick_mxfp6_tiles heuristic. + """ + if tile_m is None or tile_n is None or tile_k is None: + cfg = _pick_mxfp6_config(M_hint, N, K) + tile_m = tile_m if tile_m is not None else cfg["tile_m"] + tile_n = tile_n if tile_n is not None else cfg["tile_n"] + tile_k = tile_k if tile_k is not None else cfg["tile_k"] + if waves_per_eu is None: + waves_per_eu = cfg["waves_per_eu"] + return _compile_mxfp_blockscale_gemm( + N=N, + K=K, + BM=tile_m, + BN=tile_n, + BK=tile_k, + a_dtype="fp6", + out_dtype=out_dtype, + waves_per_eu=waves_per_eu, + enable_scheduler=enable_scheduler, + use_async_copy=use_async_copy, + dsrd_preload=dsrd_preload, + dvmem_preload=dvmem_preload, + ) diff --git a/tests/kernels/test_preshuffle_gemm.py b/tests/kernels/test_preshuffle_gemm.py index 531900177..86d564c78 100644 --- a/tests/kernels/test_preshuffle_gemm.py +++ b/tests/kernels/test_preshuffle_gemm.py @@ -30,7 +30,7 @@ sys.path.insert(0, _PYFLYDSL_SRC) from flydsl.runtime.device import get_rocm_arch # noqa: E402 -from kernels.mxfp4_preshuffle import compile_mxfp4_gemm # noqa: E402 +from kernels.mxfp4_preshuffle import compile_mxfp4_gemm, compile_mxfp6_gemm # noqa: E402 from kernels.preshuffle_gemm import compile_preshuffle_gemm # noqa: E402 from tests.kernels.utils import fp4_utils # noqa: E402 from tests.test_common import run_perftest, verify_output # noqa: E402 @@ -325,7 +325,6 @@ def test_mfma_w4_flyc_preshuffle( bench_iters: int = DEFAULT_BENCH_ITERS, bench_warmup: int = DEFAULT_BENCH_WARMUP, waves_per_eu: int = 0, - use_async_copy: bool = False, ): """FP4 (MXFP4) preshuffle GEMM (layout-API v2) — gfx950 only.""" if get_rocm_arch() != "gfx950": @@ -347,9 +346,8 @@ def test_mfma_w4_flyc_preshuffle( tile_k=tile_k, out_dtype=out_dtype, waves_per_eu=_wpe, - use_async_copy=bool(use_async_copy), ) - print(f"✓ Compiled (async_copy={use_async_copy}, waves_per_eu={_wpe})") + print(f"✓ Compiled (waves_per_eu={_wpe})") device = torch.device("cuda") M_align_32 = (M + 31) // 32 * 32 @@ -437,6 +435,131 @@ def launch_kernel(c, a, b, sa, sb): print(f"[flyc] Throughput: {us:.1f} us, {tflops:.2f} TFLOPS, BW: {tbps:.3f} TB/s") +# ── W4A6: MXFP6 (E2M3) A × MXFP4 (E2M1) B ───────────────────────────────── + + +@pytest.mark.parametrize("out_dtype", ["bf16", "fp16"]) +@pytest.mark.parametrize( + "M, N, K, tile_m, tile_n, tile_k", + [ + (64, 8192, 8192, 64, 128, 128), + (32, 8192, 8192, 32, 128, 256), + pytest.param(128, 8192, 8192, 64, 128, 256, marks=pytest.mark.large_shape), + pytest.param(1024, 8192, 8192, 64, 256, 256, marks=pytest.mark.large_shape), + pytest.param(256, 4096, 14336, 128, 256, 256, marks=pytest.mark.large_shape), + ], +) +@pytest.mark.l2_device +@pytest.mark.rocm_lower +def test_mfma_a6w4_preshuffle( + out_dtype, + M, + N, + K, + tile_m, + tile_n, + tile_k, + *, + bench_iters: int = DEFAULT_BENCH_ITERS, + bench_warmup: int = DEFAULT_BENCH_WARMUP, + waves_per_eu: int = 0, +): + """W4A6: MXFP6 (E2M3) A × MXFP4 (E2M1) B preshuffle GEMM — gfx950 only.""" + if get_rocm_arch() != "gfx950": + pytest.skip(f"FP6/FP4 GEMM requires gfx950, got {get_rocm_arch()}") + + print("=" * 80) + print(f"MFMA W4A6 (MXFP6 A × MXFP4 B) GEMM Test (Tile: {tile_m}x{tile_n}x{tile_k})") + print("=" * 80) + + _wpe = int(waves_per_eu) if waves_per_eu else 0 + _wpe = None if _wpe <= 0 else _wpe + launch_fn = compile_mxfp6_gemm( + N=N, + K=K, + M_hint=M, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + out_dtype=out_dtype, + waves_per_eu=_wpe, + ) + print(f"✓ Compiled (waves_per_eu={_wpe})") + + device = torch.device("cuda") + M_align_32 = (M + 31) // 32 * 32 + N_align_32 = (N + 31) // 32 * 32 + + a_fp32 = torch.randn(M, K, device=device, dtype=torch.float32) + b_fp32 = torch.randn(N, K, device=device, dtype=torch.float32) + a_fp32_padded = torch.zeros(M_align_32, K, device=device, dtype=torch.float32) + b_fp32_padded = torch.zeros(N_align_32, K, device=device, dtype=torch.float32) + a_fp32_padded[:M] = a_fp32 + b_fp32_padded[:N] = b_fp32 + + # A: MXFP6 E2M3, FP8-padded (1 byte/code). + a_pad, scale_a_orig, a_unpacked = fp4_utils.per_1x32_f6_quant(a_fp32_padded) + a_codes = a_pad[:M] + scale_a = fp4_utils.shuffle_scale_w4(scale_a_orig, 1, False) + + # B: MXFP4 E2M1, identical to test_mfma_w4_flyc_preshuffle. + b_q, scale_b, _ = fp4_utils.per_1x32_f4_quant(b_fp32_padded) + b_q = b_q[:N] + b_shuffled = fp4_utils.shuffle_weight_w4(b_q, 16, False, False) + scale_b_shuffled = fp4_utils.shuffle_scale_w4(scale_b, 1, False) + + # Reference: dequant(A) @ dequant(B).T in fp32. + a_deq = fp4_utils.fp6_e2m3_to_f32(a_unpacked) * fp4_utils.e8m0_to_f32(scale_a_orig[:M].repeat_interleave(32, dim=1)) + b_deq = fp4_utils.mxfp4_to_f32(b_q) * fp4_utils.e8m0_to_f32(scale_b[:N].repeat_interleave(32, dim=1)) + c_ref = torch.mm(a_deq, b_deq.T).to(torch.float32) + + torch_out_dtype = torch.bfloat16 if out_dtype == "bf16" else torch.float16 + c_out = torch.zeros((M, N), dtype=torch_out_dtype, device=device) + _dummy_bias = torch.empty(0, dtype=torch.bfloat16, device=device) + + def _to_bytes(t): + return t if t.dtype in (torch.uint8, torch.int8) else t.view(torch.uint8) + + def _a6w4_args(c, a, b, sa, sb): + return ( + c.contiguous().view(-1), + _to_bytes(a).contiguous().view(-1), + _to_bytes(b).contiguous().view(-1), + _to_bytes(sa).contiguous().view(-1), + _to_bytes(sb).contiguous().view(-1), + _dummy_bias, + M, + N, + torch.cuda.current_stream(), + ) + + compiled_fn = flyc.compile(launch_fn, *_a6w4_args(c_out, a_codes, b_shuffled, scale_a, scale_b_shuffled)) + + def launch_kernel(c, a, b, sa, sb): + compiled_fn(*_a6w4_args(c, a, b, sa, sb)) + + bench_iters = max(2, int(bench_iters)) + _, us = run_perftest( + launch_kernel, + c_out, + a_codes, + b_shuffled, + scale_a, + scale_b_shuffled, + num_iters=bench_iters, + num_warmup=int(bench_warmup), + ) + torch.cuda.synchronize() + + assert verify_output(c_out.to(torch.float32), c_ref, rtol=0.1, atol=0.1) + + # A: 1 byte/code (FP8-padded); B: 0.5 byte/code (MXFP4). + bytes_moved = M * K + (N * K) // 2 + M * N * 2 + (M + N) * (K // 32) + tflops = (2 * M * N * K) / (us / 1e6) / 1e12 + tbps = bytes_moved / 1e12 / (us / 1e6) + print(f"[flyc] W4A6 Throughput: {us:.1f} us, {tflops:.2f} TFLOPS, BW: {tbps:.3f} TB/s") + + if __name__ == "__main__": import argparse @@ -611,7 +734,7 @@ def _args(c, a, b, sa, sb): # ── Verify ── max_diff = (ref - graph_result).abs().max().item() assert graph_result.abs().max().item() > 0, ( - f"CUDAGraph replay produced all zeros — kernel was NOT captured! " f"ref max={ref.abs().max().item():.4f}" + f"CUDAGraph replay produced all zeros — kernel was NOT captured! ref max={ref.abs().max().item():.4f}" ) assert torch.allclose(ref, graph_result, atol=1e-2), ( f"CUDAGraph result mismatch: max_diff={max_diff:.6f}, " @@ -699,7 +822,7 @@ def _args(c, a, b, sa, sb, bs): # error is bounded by ~K * eps_bf16 ~ 8192 * 2^-7 ~= 64 ULP. We use # rtol=0.05 (5%) and atol=2.0 (covers small-magnitude outputs). assert not torch.isnan(c_out).any(), ( - f"Epilogue {epilogue}: kernel produced NaN(s) " f"(count={int(torch.isnan(c_out).sum().item())})" + f"Epilogue {epilogue}: kernel produced NaN(s) (count={int(torch.isnan(c_out).sum().item())})" ) assert not torch.isinf(c_out).any(), f"Epilogue {epilogue}: kernel produced Inf(s)" atol = 2.0