Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 40 additions & 11 deletions kernels/mixed_moe_gemm_2stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,16 @@ def compile_mixed_moe_gemm1(
allocator_ping = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem1")
_state = {}

if a_dtype not in ("fp8", "fp16", "int8", "fp4"):
raise ValueError(f"a_dtype must be one of ('fp8','fp16','int8','fp4'), got {a_dtype!r}")
if a_dtype not in ("fp8", "fp16", "int8", "fp4", "fp6"):
raise ValueError(f"a_dtype must be one of ('fp8','fp16','int8','fp4','fp6'), got {a_dtype!r}")
if b_dtype not in ("fp8", "fp16", "int8", "int4", "fp4"):
raise ValueError(f"b_dtype must be one of ('fp8','fp16','int8','int4','fp4'), got {b_dtype!r}")

is_f16_a = a_dtype == "fp16"
is_f16_b = b_dtype == "fp16"
is_f8_a = a_dtype == "fp8"
is_f4_a = a_dtype == "fp4"
is_f6_a = a_dtype == "fp6"
is_f4_b = b_dtype == "fp4"

sort_block_m = max(32, tile_m)
Expand All @@ -161,8 +162,11 @@ def compile_mixed_moe_gemm1(
a_elem_bytes = 2 if is_f16_a else 1
b_elem_bytes = 1
tile_k_bytes = int(tile_k) * int(a_elem_bytes)
# fp4: 2 elements per byte (packed); fp6: FP8-padded layout = 1 byte/element (like fp8).
a_elem_vec_pack = 2 if is_f4_a else 1
cbsz = 0 if is_f8_a else 4
# fp6: 32 B per K=32 chunk (24 B packed codes + 8 B zero pad); fp4/fp8: 16 B.
a_per_lane_kpack_bytes = 32 if is_f6_a else 16
cbsz = 0 if is_f8_a else (2 if is_f6_a else 4)
blgp = 4

if (tile_k_bytes % 64) != 0:
Expand Down Expand Up @@ -409,7 +413,7 @@ def x_lds_elem():
_pp_b_loads = [p["b_loads"] for p in _pipe_phases]
_pp_has_scale = [p["has_scale"] for p in _pipe_phases]

fp4_ratio = 2 if a_dtype == "fp4" else 1
fp4_ratio = 2 if is_f4_a else 1 # fp6 uses FP8-padded layout (1 byte/elem), same count as fp8
gui_ratio = 1 if gate_up_interleave else 2
_vmcnt_before_barrier = tile_m // 32 // fp4_ratio + tile_n // 32 * gui_ratio

Expand Down Expand Up @@ -736,7 +740,7 @@ def load_x_tile(base_k):
lane_div_16 = layout_get(coord_l16, 0)
lane_mod_16 = layout_get(coord_l16, 1)
row_a_lds = lane_mod_16
col_offset_base = lane_div_16 * arith.constant(16, index=True)
col_offset_base = lane_div_16 * arith.constant(a_per_lane_kpack_bytes, index=True)

num_acc_n = n_per_wave // 16
c_n_per_wave = arith.constant(n_per_wave, index=True)
Expand Down Expand Up @@ -1034,6 +1038,11 @@ def prefetch_full_a_from_lds(lds_buffer, ku_limit=k_unroll):
if const_expr(is_f8_a):
a2, a3 = lds_load_packs_k64(curr_row, col_base + 64, lds_buffer)
a_regs.append((a0, a1, a2, a3))
elif const_expr(is_f6_a):
# fp6: 32B FP8-padded slot; 3rd 16B chunk carries codes,
# 4th 8B is zero pad (ignored by cbsz=2 MFMA).
a2, _ = lds_load_packs_k64(curr_row, col_base + 16, lds_buffer)
a_regs.append((a0, a1, a2))
else:
a_regs.append((a0, a1))
return a_regs
Expand Down Expand Up @@ -1156,6 +1165,10 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3):
if const_expr(is_f8_a):
a0, a1, a2, a3 = a_tile_regs[_a_reg_idx]
a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3)
elif const_expr(is_f6_a):
# fp6: 3x16B loads; 4th slot is zero pad (cbsz=2 ignores it)
a0, a1, a2 = a_tile_regs[_a_reg_idx]
a128 = pack_i64x4_to_i32x8(a0, a1, a2, c0_i64)
else:
a0, a1 = a_tile_regs[_a_reg_idx]
a128 = pack_i64x4_to_i32x8(a0, a1, c0_i64, c0_i64)
Expand Down Expand Up @@ -1200,6 +1213,10 @@ def load_a_subtile(k_idx, mi_idx, lds_buffer):
if const_expr(is_f8_a):
a2, a3 = lds_load_packs_k64(curr_row, col_base + 64, lds_buffer)
return (a0, a1, a2, a3)
elif const_expr(is_f6_a):
# fp6: 3x16B loads; 4th slot is zero pad (cbsz=2 ignores it)
a2, _ = lds_load_packs_k64(curr_row, col_base + 16, lds_buffer)
return (a0, a1, a2)
else:
return (a0, a1)

Expand Down Expand Up @@ -1250,6 +1267,9 @@ def _pack(x0, x1, x2, x3):

if const_expr(is_f8_a):
a128 = _pack(a_reg[0], a_reg[1], a_reg[2], a_reg[3])
elif const_expr(is_f6_a):
# fp6: 3x16B loads; 4th slot is zero pad (cbsz=2 ignores it)
a128 = _pack(a_reg[0], a_reg[1], a_reg[2], c0_i64)
else:
a128 = _pack(a_reg[0], a_reg[1], c0_i64, c0_i64)

Expand Down Expand Up @@ -2479,8 +2499,8 @@ def compile_mixed_moe_gemm2(
allocator_ping = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem1")
_state = {}

if a_dtype not in ("fp8", "fp16", "int8", "fp4"):
raise ValueError(f"a_dtype must be one of ('fp8','fp16','int8','fp4'), got {a_dtype!r}")
if a_dtype not in ("fp8", "fp16", "int8", "fp4", "fp6"):
raise ValueError(f"a_dtype must be one of ('fp8','fp16','int8','fp4','fp6'), got {a_dtype!r}")
if b_dtype not in ("fp8", "fp16", "int8", "int4", "fp4"):
raise ValueError(f"b_dtype must be one of ('fp8','fp16','int8','int4','fp4'), got {b_dtype!r}")

Expand All @@ -2489,6 +2509,8 @@ def compile_mixed_moe_gemm2(

is_f8_a = a_dtype == "fp8"
is_f4_a = a_dtype == "fp4"
is_f6_a = a_dtype == "fp6"
is_f4_or_f6_a = is_f4_a or is_f6_a
is_f4_b = b_dtype == "fp4"

_scale_pack_m = 2 # physical mn_pack in preshuffle microscale layout
Expand All @@ -2505,8 +2527,11 @@ def compile_mixed_moe_gemm2(
b_elem_bytes = 1
tile_k_bytes = int(tile_k) * int(a_elem_bytes)

# fp4: 2 elements per byte (packed); fp6: FP8-padded layout = 1 byte/element (like fp8).
a_elem_vec_pack = 2 if is_f4_a else 1
cbsz = 0 if is_f8_a else 4
# fp6: 32 B per K=32 chunk (24 B packed codes + 8 B zero pad); fp4/fp8: 16 B.
a_per_lane_kpack_bytes = 32 if is_f6_a else 16
cbsz = 0 if is_f8_a else (2 if is_f6_a else 4)
blgp = 4

# ---- Static B preshuffle strides (compile-time) ----
Expand Down Expand Up @@ -2548,7 +2573,7 @@ def compile_mixed_moe_gemm2(
mfma_i32_k32 = getattr(rocdl, "mfma_i32_16x16x32i8", None) or getattr(rocdl, "mfma_i32_16x16x32_i8", None)
if mfma_i32_k32 is None:
raise AttributeError(
"INT8 K32 MFMA op not found: expected `rocdl.mfma_i32_16x16x32i8` " "(or `rocdl.mfma_i32_16x16x32_i8`)."
"INT8 K32 MFMA op not found: expected `rocdl.mfma_i32_16x16x32i8` (or `rocdl.mfma_i32_16x16x32_i8`)."
)

def _x_elem_type():
Expand Down Expand Up @@ -2819,7 +2844,7 @@ def check_c_k_valid_gate(base_k):
sx_rsrc = 1
sw_rsrc = 1
if const_expr(not is_f16_a):
if const_expr(is_f4_a or is_f8_a):
if const_expr(is_f4_or_f6_a or is_f8_a):
# A2 microscale: e8m0 in sorted layout [sorted_size, K/32].
# Caller must pre-scatter a2_scale via moe_mxfp4_sort.
kblk = _div_pow2(k_in, 32)
Expand Down Expand Up @@ -3083,7 +3108,7 @@ def load_x_tile(base_k):

row_a_lds = lane_mod_16

col_offset_base = lane_div_16 * arith.constant(16, index=True)
col_offset_base = lane_div_16 * arith.constant(a_per_lane_kpack_bytes, index=True)

# Dynamic N tiling within block.
num_waves = 4
Expand Down Expand Up @@ -3448,6 +3473,10 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3):
col_base1 = col_base + 64
a2, a3 = lds_load_packs_k64(curr_row_a_lds, col_base1, lds_buffer)
a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3)
elif const_expr(is_f6_a):
# fp6: 3x16B loads; 4th slot is zero pad (cbsz=2 ignores it)
a2, _ = lds_load_packs_k64(curr_row_a_lds, col_base0 + 16, lds_buffer)
a128 = pack_i64x4_to_i32x8(a0, a1, a2, c0_i64)
else:
a128 = pack_i64x4_to_i32x8(a0, a1, c0_i64, c0_i64)

Expand Down
Loading
Loading