Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 104 additions & 30 deletions kernels/rmsnorm_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

RMSNorm(x) = x / sqrt(mean(x^2) + eps) * gamma

Two paths:
- Fast path (N % tile_cols == 0): buffer_load/store vectorised access.
- Generic path (arbitrary N): scalar copy_atom_call.
Three paths:
- Fast path (N % tile_cols == 0, f16/bf16): fully vectorised buffer_load/store.
- Vector-generic path (arbitrary N, f16/bf16): vectorised bulk + scalar tail.
- Scalar-generic path (f32 or very small N): scalar copy_atom_call.
"""

import math
Expand Down Expand Up @@ -39,6 +40,14 @@ class SharedStorage:
return SharedStorage


def _make_single_reduction_storage(red_slots: int):
@fx.struct
class SharedStorage:
s_red: fx.Array[fx.Float32, red_slots, 16]

return SharedStorage


def _load_scalar(copy_atom, elem_dtype, divided_tensor, index):
view = fx.slice(divided_tensor, (None, index))
r = fx.make_rmem_tensor(1, elem_dtype)
Expand Down Expand Up @@ -110,7 +119,13 @@ def build_rmsnorm_module(M: int, N: int, dtype_str: str):
RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE)
elem_bits = 32 if dtype_str == "f32" else 16

SharedStorage = _make_reduction_storage(RED_SLOTS)
# Vector-generic path: vectorised bulk tiles + scalar tail
full_vec_tiles = N // tile_cols
vec_covered = full_vec_tiles * tile_cols
scalar_tail = N - vec_covered
use_vec_generic = elem_bits <= 16 and (N % tile_cols != 0) and full_vec_tiles > 0

SharedStorage = _make_single_reduction_storage(RED_SLOTS)

@flyc.kernel
def rmsnorm_kernel(
Expand All @@ -129,7 +144,6 @@ def rmsnorm_kernel(

lds = fx.SharedAllocator().allocate(SharedStorage).peek()
s_red = lds.s_red.view(fx.make_layout(RED_SLOTS, 1))
s_red2 = lds.s_red2.view(fx.make_layout(RED_SLOTS, 1))

def wave_reduce_add(x):
w = x
Expand All @@ -140,48 +154,36 @@ def wave_reduce_add(x):
return w

def block_reduce_add(val):
dummy = fx.Float32(0.0)
r0, _ = block_reduce_add2(val, dummy)
return r0

def block_reduce_add2(val0, val1):
if const_expr(RED_SLOTS == 1):
return wave_reduce_add(val0), wave_reduce_add(val1)
return wave_reduce_add(val)

lane = tid % WARP_SIZE
wave = tid // WARP_SIZE

w0 = wave_reduce_add(val0)
w1 = wave_reduce_add(val1)
w = wave_reduce_add(val)

if lane == 0:
fx.memref_store(w0, s_red, wave)
fx.memref_store(w1, s_red2, wave)
fx.memref_store(w, s_red, wave)
gpu.barrier()

if wave == 0:
in_range = lane < RED_SLOTS
lane_safe = in_range.select(lane, 0)
v0 = fx.memref_load(s_red, lane_safe)
v1 = fx.memref_load(s_red2, lane_safe)
ww0 = in_range.select(v0, 0.0)
ww1 = in_range.select(v1, 0.0)
ww0 = wave_reduce_add(ww0)
ww1 = wave_reduce_add(ww1)
v = fx.memref_load(s_red, lane_safe)
ww = in_range.select(v, 0.0)
ww = wave_reduce_add(ww)

if lane == 0:
fx.memref_store(ww0, s_red, 0)
fx.memref_store(ww1, s_red2, 0)
fx.memref_store(ww, s_red, 0)
gpu.barrier()

return fx.memref_load(s_red, 0), fx.memref_load(s_red2, 0)
return fx.memref_load(s_red, 0)

# ==================================================================
# Fast path: N is a multiple of tile_cols
# Fast path: N is a multiple of tile_cols (fully aligned, f16/bf16)
# ==================================================================
if const_expr(N >= tile_cols and N % tile_cols == 0 and elem_bits <= 16):
num_tiles = N // tile_cols
# ── Layout API: buffer-backed tensors + tiled access ─────
Input_buf = fx.rocdl.make_buffer_tensor(Input)
Output_buf = fx.rocdl.make_buffer_tensor(Output)
Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma)
Expand All @@ -197,10 +199,9 @@ def block_reduce_add2(val0, val1):

c_zero_f = fx.Float32(0.0)
thread_sumsq = c_zero_f
thread_dummy = c_zero_f
in_local = []

# Pass 1: load + cache + sumsq
# Pass 1: load + cache input + sumsq
for tile_i in range_constexpr(num_tiles):
idx = tid + tile_i * BLOCK_THREADS
vec = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, in_div, idx)
Expand All @@ -211,7 +212,7 @@ def block_reduce_add2(val0, val1):
red2 = x2.reduce(ReductionOp.ADD, fastmath=fm_fast)
thread_sumsq = thread_sumsq + red2

_, sum_sq = block_reduce_add2(thread_dummy, thread_sumsq)
sum_sq = block_reduce_add(thread_sumsq)
mean_sq = sum_sq / n_float
ms_eps = mean_sq + eps_c
rrms = ms_eps.rsqrt(fastmath=fm_fast)
Expand All @@ -229,9 +230,82 @@ def block_reduce_add2(val0, val1):
out_idx = tid + tile_i * BLOCK_THREADS
_store_vec(copy_atom, VEC_WIDTH, elem_dtype, out_e, out_div, out_idx)

elif const_expr(use_vec_generic):
# ==============================================================
# Vector-generic path: vectorised bulk + scalar tail (f16/bf16)
# ==============================================================
Input_buf = fx.rocdl.make_buffer_tensor(Input)
Output_buf = fx.rocdl.make_buffer_tensor(Output)
Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma)

row_in = fx.slice(Input_buf, (bid, None))
row_out = fx.slice(Output_buf, (bid, None))

# Vectorised access for the bulk
in_div = fx.logical_divide(row_in, fx.make_layout(VEC_WIDTH, 1))
out_div = fx.logical_divide(row_out, fx.make_layout(VEC_WIDTH, 1))
gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(VEC_WIDTH, 1))
copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits)

# Scalar access for the tail
row_div_s = fx.logical_divide(row_in, fx.make_layout(1, 1))
gamma_div_s = fx.logical_divide(Gamma_buf, fx.make_layout(1, 1))
out_div_s = fx.logical_divide(row_out, fx.make_layout(1, 1))
copy_atom_s = fx.make_copy_atom(fx.rocdl.BufferCopy16b(), elem_bits)

c_zero_f = fx.Float32(0.0)
thread_sumsq = c_zero_f
in_local = []

# Pass 1a: vectorised bulk — sumsq + cache input
for tile_i in range_constexpr(full_vec_tiles):
idx = tid + tile_i * BLOCK_THREADS
vec = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, in_div, idx)
in_local.append(vec)
x = vec.to(fx.Float32)
x2 = x * x
red2 = x2.reduce(ReductionOp.ADD, fastmath=fm_fast)
thread_sumsq = thread_sumsq + red2

# Pass 1b: scalar tail — sumsq
for tail_off in range_constexpr(0, scalar_tail, BLOCK_THREADS):
idx = tid + vec_covered + tail_off
is_valid = idx < N
idx_safe = is_valid.select(idx, 0)
x_e = _load_scalar(copy_atom_s, elem_dtype, row_div_s, idx_safe)
x = x_e.to(fx.Float32)
x2 = x * x
thread_sumsq = thread_sumsq + is_valid.select(x2, c_zero_f)

sum_sq = block_reduce_add(thread_sumsq)
mean_sq = sum_sq / n_float
ms_eps = mean_sq + eps_c
rrms = ms_eps.rsqrt(fastmath=fm_fast)

# Pass 2a: vectorised bulk — normalize + store (reuse cached input)
for tile_i in range_constexpr(full_vec_tiles):
idx = tid + tile_i * BLOCK_THREADS
g = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, gamma_div, idx).to(fx.Float32)
x = in_local[tile_i].to(fx.Float32)
y = (x * rrms) * g
out_e = _to_elem_vec(dtype_str, elem_dtype, USE_HW_CVT_PK_BF16_F32, y)
_store_vec(copy_atom, VEC_WIDTH, elem_dtype, out_e, out_div, idx)

# Pass 2b: scalar tail — normalize + store
for tail_off in range_constexpr(0, scalar_tail, BLOCK_THREADS):
idx = tid + vec_covered + tail_off
if idx < N:
x_e = _load_scalar(copy_atom_s, elem_dtype, row_div_s, idx)
g_e = _load_scalar(copy_atom_s, elem_dtype, gamma_div_s, idx)
x = x_e.to(fx.Float32)
g = g_e.to(fx.Float32)
y = (x * rrms) * g
y_e = _to_elem_scalar(dtype_str, elem_dtype, y)
_store_scalar(copy_atom_s, elem_dtype, elem_dtype, out_div_s, idx, y_e)

else:
# ==============================================================
# Generic path: scalar 2-pass for arbitrary N
# Scalar-generic path: f32 or very small N
# ==============================================================
Input_buf = fx.rocdl.make_buffer_tensor(Input)
Output_buf = fx.rocdl.make_buffer_tensor(Output)
Expand Down
Loading