From a7c1f40605850f67d4572bf74643a25bc84fee1b Mon Sep 17 00:00:00 2001 From: kudomcho Date: Fri, 5 Jun 2026 19:07:56 +0000 Subject: [PATCH] perf(rmsnorm): vectorize generic path and simplify block reduction Replace the scalar-only generic path with a vector-generic path that uses vectorised buffer_load/store for the bulk of elements and falls back to scalar operations only for the tail (N % tile_cols remainder). This improves throughput on non-aligned hidden dimensions like N=2880 (GPT-2 XL) by ~21% at M=16384. Also replace the dual block_reduce_add2 with a direct single-value block_reduce_add, halving shared memory usage and removing one unnecessary reduction slot. Benchmark (MI300X, bf16, GPU profiling, 50 warmup + 500 iters): (4096, 2880) vec-gen: 14.50 -> 13.13 us (+9.4%) (16384, 2880) vec-gen: 57.06 -> 45.30 us (+20.6%) Fast-path shapes: neutral Co-Authored-By: Claude Opus 4 (1M context) --- kernels/rmsnorm_kernel.py | 134 +++++++++++++++++++++++++++++--------- 1 file changed, 104 insertions(+), 30 deletions(-) diff --git a/kernels/rmsnorm_kernel.py b/kernels/rmsnorm_kernel.py index ce4bd0a98..ea9d58e15 100644 --- a/kernels/rmsnorm_kernel.py +++ b/kernels/rmsnorm_kernel.py @@ -5,9 +5,10 @@ RMSNorm(x) = x / sqrt(mean(x^2) + eps) * gamma -Two paths: - - Fast path (N % tile_cols == 0): buffer_load/store vectorised access. - - Generic path (arbitrary N): scalar copy_atom_call. +Three paths: + - Fast path (N % tile_cols == 0, f16/bf16): fully vectorised buffer_load/store. + - Vector-generic path (arbitrary N, f16/bf16): vectorised bulk + scalar tail. + - Scalar-generic path (f32 or very small N): scalar copy_atom_call. """ import math @@ -39,6 +40,14 @@ class SharedStorage: return SharedStorage +def _make_single_reduction_storage(red_slots: int): + @fx.struct + class SharedStorage: + s_red: fx.Array[fx.Float32, red_slots, 16] + + return SharedStorage + + def _load_scalar(copy_atom, elem_dtype, divided_tensor, index): view = fx.slice(divided_tensor, (None, index)) r = fx.make_rmem_tensor(1, elem_dtype) @@ -110,7 +119,13 @@ def build_rmsnorm_module(M: int, N: int, dtype_str: str): RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE) elem_bits = 32 if dtype_str == "f32" else 16 - SharedStorage = _make_reduction_storage(RED_SLOTS) + # Vector-generic path: vectorised bulk tiles + scalar tail + full_vec_tiles = N // tile_cols + vec_covered = full_vec_tiles * tile_cols + scalar_tail = N - vec_covered + use_vec_generic = elem_bits <= 16 and (N % tile_cols != 0) and full_vec_tiles > 0 + + SharedStorage = _make_single_reduction_storage(RED_SLOTS) @flyc.kernel def rmsnorm_kernel( @@ -129,7 +144,6 @@ def rmsnorm_kernel( lds = fx.SharedAllocator().allocate(SharedStorage).peek() s_red = lds.s_red.view(fx.make_layout(RED_SLOTS, 1)) - s_red2 = lds.s_red2.view(fx.make_layout(RED_SLOTS, 1)) def wave_reduce_add(x): w = x @@ -140,48 +154,36 @@ def wave_reduce_add(x): return w def block_reduce_add(val): - dummy = fx.Float32(0.0) - r0, _ = block_reduce_add2(val, dummy) - return r0 - - def block_reduce_add2(val0, val1): if const_expr(RED_SLOTS == 1): - return wave_reduce_add(val0), wave_reduce_add(val1) + return wave_reduce_add(val) lane = tid % WARP_SIZE wave = tid // WARP_SIZE - w0 = wave_reduce_add(val0) - w1 = wave_reduce_add(val1) + w = wave_reduce_add(val) if lane == 0: - fx.memref_store(w0, s_red, wave) - fx.memref_store(w1, s_red2, wave) + fx.memref_store(w, s_red, wave) gpu.barrier() if wave == 0: in_range = lane < RED_SLOTS lane_safe = in_range.select(lane, 0) - v0 = fx.memref_load(s_red, lane_safe) - v1 = fx.memref_load(s_red2, lane_safe) - ww0 = in_range.select(v0, 0.0) - ww1 = in_range.select(v1, 0.0) - ww0 = wave_reduce_add(ww0) - ww1 = wave_reduce_add(ww1) + v = fx.memref_load(s_red, lane_safe) + ww = in_range.select(v, 0.0) + ww = wave_reduce_add(ww) if lane == 0: - fx.memref_store(ww0, s_red, 0) - fx.memref_store(ww1, s_red2, 0) + fx.memref_store(ww, s_red, 0) gpu.barrier() - return fx.memref_load(s_red, 0), fx.memref_load(s_red2, 0) + return fx.memref_load(s_red, 0) # ================================================================== - # Fast path: N is a multiple of tile_cols + # Fast path: N is a multiple of tile_cols (fully aligned, f16/bf16) # ================================================================== if const_expr(N >= tile_cols and N % tile_cols == 0 and elem_bits <= 16): num_tiles = N // tile_cols - # ── Layout API: buffer-backed tensors + tiled access ───── Input_buf = fx.rocdl.make_buffer_tensor(Input) Output_buf = fx.rocdl.make_buffer_tensor(Output) Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) @@ -197,10 +199,9 @@ def block_reduce_add2(val0, val1): c_zero_f = fx.Float32(0.0) thread_sumsq = c_zero_f - thread_dummy = c_zero_f in_local = [] - # Pass 1: load + cache + sumsq + # Pass 1: load + cache input + sumsq for tile_i in range_constexpr(num_tiles): idx = tid + tile_i * BLOCK_THREADS vec = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, in_div, idx) @@ -211,7 +212,7 @@ def block_reduce_add2(val0, val1): red2 = x2.reduce(ReductionOp.ADD, fastmath=fm_fast) thread_sumsq = thread_sumsq + red2 - _, sum_sq = block_reduce_add2(thread_dummy, thread_sumsq) + sum_sq = block_reduce_add(thread_sumsq) mean_sq = sum_sq / n_float ms_eps = mean_sq + eps_c rrms = ms_eps.rsqrt(fastmath=fm_fast) @@ -229,9 +230,82 @@ def block_reduce_add2(val0, val1): out_idx = tid + tile_i * BLOCK_THREADS _store_vec(copy_atom, VEC_WIDTH, elem_dtype, out_e, out_div, out_idx) + elif const_expr(use_vec_generic): + # ============================================================== + # Vector-generic path: vectorised bulk + scalar tail (f16/bf16) + # ============================================================== + Input_buf = fx.rocdl.make_buffer_tensor(Input) + Output_buf = fx.rocdl.make_buffer_tensor(Output) + Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) + + row_in = fx.slice(Input_buf, (bid, None)) + row_out = fx.slice(Output_buf, (bid, None)) + + # Vectorised access for the bulk + in_div = fx.logical_divide(row_in, fx.make_layout(VEC_WIDTH, 1)) + out_div = fx.logical_divide(row_out, fx.make_layout(VEC_WIDTH, 1)) + gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(VEC_WIDTH, 1)) + copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits) + + # Scalar access for the tail + row_div_s = fx.logical_divide(row_in, fx.make_layout(1, 1)) + gamma_div_s = fx.logical_divide(Gamma_buf, fx.make_layout(1, 1)) + out_div_s = fx.logical_divide(row_out, fx.make_layout(1, 1)) + copy_atom_s = fx.make_copy_atom(fx.rocdl.BufferCopy16b(), elem_bits) + + c_zero_f = fx.Float32(0.0) + thread_sumsq = c_zero_f + in_local = [] + + # Pass 1a: vectorised bulk — sumsq + cache input + for tile_i in range_constexpr(full_vec_tiles): + idx = tid + tile_i * BLOCK_THREADS + vec = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, in_div, idx) + in_local.append(vec) + x = vec.to(fx.Float32) + x2 = x * x + red2 = x2.reduce(ReductionOp.ADD, fastmath=fm_fast) + thread_sumsq = thread_sumsq + red2 + + # Pass 1b: scalar tail — sumsq + for tail_off in range_constexpr(0, scalar_tail, BLOCK_THREADS): + idx = tid + vec_covered + tail_off + is_valid = idx < N + idx_safe = is_valid.select(idx, 0) + x_e = _load_scalar(copy_atom_s, elem_dtype, row_div_s, idx_safe) + x = x_e.to(fx.Float32) + x2 = x * x + thread_sumsq = thread_sumsq + is_valid.select(x2, c_zero_f) + + sum_sq = block_reduce_add(thread_sumsq) + mean_sq = sum_sq / n_float + ms_eps = mean_sq + eps_c + rrms = ms_eps.rsqrt(fastmath=fm_fast) + + # Pass 2a: vectorised bulk — normalize + store (reuse cached input) + for tile_i in range_constexpr(full_vec_tiles): + idx = tid + tile_i * BLOCK_THREADS + g = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, gamma_div, idx).to(fx.Float32) + x = in_local[tile_i].to(fx.Float32) + y = (x * rrms) * g + out_e = _to_elem_vec(dtype_str, elem_dtype, USE_HW_CVT_PK_BF16_F32, y) + _store_vec(copy_atom, VEC_WIDTH, elem_dtype, out_e, out_div, idx) + + # Pass 2b: scalar tail — normalize + store + for tail_off in range_constexpr(0, scalar_tail, BLOCK_THREADS): + idx = tid + vec_covered + tail_off + if idx < N: + x_e = _load_scalar(copy_atom_s, elem_dtype, row_div_s, idx) + g_e = _load_scalar(copy_atom_s, elem_dtype, gamma_div_s, idx) + x = x_e.to(fx.Float32) + g = g_e.to(fx.Float32) + y = (x * rrms) * g + y_e = _to_elem_scalar(dtype_str, elem_dtype, y) + _store_scalar(copy_atom_s, elem_dtype, elem_dtype, out_div_s, idx, y_e) + else: # ============================================================== - # Generic path: scalar 2-pass for arbitrary N + # Scalar-generic path: f32 or very small N # ============================================================== Input_buf = fx.rocdl.make_buffer_tensor(Input) Output_buf = fx.rocdl.make_buffer_tensor(Output)