Skip to content

[Kernel] conv3d: 8-wave double-buffered FP8 implicit-GEMM kernel(gfx950)#794

Open
jiacao-amd wants to merge 8 commits into
ROCm:mainfrom
jiacao-amd:jiacao/conv3d-implicit-8wave-fp8
Open

[Kernel] conv3d: 8-wave double-buffered FP8 implicit-GEMM kernel(gfx950)#794
jiacao-amd wants to merge 8 commits into
ROCm:mainfrom
jiacao-amd:jiacao/conv3d-implicit-8wave-fp8

Conversation

@jiacao-amd

@jiacao-amd jiacao-amd commented Jul 1, 2026

Copy link
Copy Markdown

Summary

Adds two new conv3d kernels sharing the same 128×128×32 tile, 2×4 wave (WAVE_M=2, WAVE_N=4, 512 threads) layout, and double-buffered software-pipelined structure.

conv3d_implicit_8wave — BF16 (mfma_f32_16x16x32_bf16)

  • Implicit-GEMM with per-thread im2col address decode; no explicit im2col buffer
  • Two-step load: global → register → LDS, with OOB masked to zero in registers
  • Input layout conversion (NCDHW→NDHWC) done once via a tiled LDS transpose kernel
  • Weight repacked to KRSC order once and cached across calls
  • Double-buffered pipeline: global prefetch, LDS staging, and MFMA overlap across K tiles
  • Split-K, full stride/padding, and bias support

conv3d_implicit_8wave_fp8 — FP8 E4M3FN (mfma_f32_16x16x128_fp8, CDNA4 only)

  • Same public API as the BF16 kernel; BF16 inputs packed to FP8 once and cached
  • One-step load: global → LDS directly, bypassing the register file — reduces register pressure vs the BF16 path
  • 4× wider K tile (TILE_K=128 vs 32) from the 16×16×128 FP8 MFMA instruction
  • Requires gfx95x (CDNA4); C%128==0, CRS%128==0, NPQ%128==0

Changes

File Description
kernels/conv3d_implicit_8wave.py 8-wave implicit-GEMM conv3d
kernels/conv3d_implicit_8wave_fp8.py FP8 8-wave implicit-GEMM conv3d (CDNA4)
tests/kernels/test_conv3d_implicit_8wave.py Correctness vs torch.nn.functional.conv3d
tests/kernels/test_conv3d_implicit_8wave_fp8.py Correctness vs FP8-cast reference (gfx95x only)

Performance (MI355X gfx950)

3×3×3 kernel, stride=1, pad=1, C=K=128, D=6 (Do=6), N=1 vs PyTorch:

Shape NPQ BF16 (TF) FP8 (TF) PyTorch (TF) BF16/PT FP8/PT
H=40 9,600 184.5 258.8 138.1 1.34x 1.87x
H=56 18,816 364.0 586.8 270.4 1.35x 2.17x
H=72 31,104 499.6 845.5 367.3 1.36x 2.30x
H=104 64,896 606.7 1020.5 453.0 1.34x 2.25x
H=144 124,416 631.7 1108.6 447.3 1.41x 2.48x
H=168 169,344 604.0 1081.5 525.9 1.15x 2.06x

BF16 is consistently 1.15–1.41x faster than PyTorch. FP8 reaches 2.06–2.48x across the full NPQ range.

Accuracy (MI355X gfx950, vs PyTorch BF16 reference)

BF16 kernel — numerically matches PyTorch BF16 (BF16 rounding only):

Shape rel_err% cosine SNR (dB)
3×3×3 C128 H18 pad=0 0.000% 1.000000 103.4
3×3×3 C256 H18 pad=0 0.000% 1.000000 93.0
3×3×3 C128 H16 pad=1 0.000% 1.000000 89.5
3×3×3 C128 H40 pad=1 0.000% 1.000000 89.0
3×3×3 C128 H72 pad=1 0.000% 1.000000 88.3
3×3×3 C128 H144 pad=1 0.000% 1.000000 88.1

FP8 kernel — inherent E4M3FN quantization noise (direct cast, no per-tensor scale):

Shape rel_err% cosine SNR (dB)
3×3×3 C128 H18 pad=0 3.769% 0.999287 28.5
3×3×3 C256 H18 pad=0 3.803% 0.999273 28.4
3×3×3 C128 H16 pad=1 3.762% 0.999289 28.5
3×3×3 C128 H40 pad=1 3.771% 0.999287 28.5
3×3×3 C128 H72 pad=1 3.765% 0.999288 28.5
3×3×3 C128 H144 pad=1 3.769% 0.999287 28.5

The ~3.7% relative error and SNR ~28.5 dB are the inherent FP8 E4M3FN quantization floor for normally-distributed activations (3-bit mantissa), consistent across all shapes.

Test plan

  • python -m pytest tests/kernels/test_conv3d_implicit_8wave.py -v — any GPU
  • python -m pytest tests/kernels/test_conv3d_implicit_8wave_fp8.py -v — gfx95x only (auto-skipped otherwise)

🤖 Generated with Claude Code

Adds two conv3d kernels sharing the same 128×128×32 tile, 2×4 wave layout,
and double-buffered software-pipelined structure.

conv3d_implicit_8wave (BF16, mfma_f32_16x16x32_bf16):
  - Implicit GEMM with per-thread 3D im2col gather into LDS
  - Tiled NCDHW→NDHWC transpose kernel for zero-copy input layout
  - Cached weight permute (KCTRS→K,CRS)
  - Split-K with auto heuristic; stride/padding; bias

conv3d_implicit_8wave_fp8 (FP8, mfma_f32_16x16x128_fp8, CDNA4 only):
  - Drop-in companion to the BF16 kernel with the same public API
  - Packs BF16 inputs once to FP8 (E4M3FN) before the GEMM loop;
    weight is cached in FP8 across calls
  - 4× wider TILE_K (128 vs 32) → higher MFMA throughput per cycle
  - Requires gfx95x (CDNA4), C%128==0, CRS%128==0, NPQ%128==0

Performance vs MIOpen (tuned, 3×3×3 kernel, N=1 D=8 H=W=32, MI355X gfx950):
  BF16: C=K=128 181 TF (1.31x)   C=K=256 306 TF (0.82x)   C=K=512 644 TF (1.16x)
  FP8:  C=K=128 221 TF (1.18x)   C=K=256 596 TF (1.60x)   C=K=512 1122 TF (2.02x)

Performance vs MIOpen (default, Qwen-VL patch-embed 2×14×14, C=16 K=1152):
  448p-16f (NPQ=8192):   BF16 1.04x  FP8 1.59x
  448p-64f (NPQ=32768):  BF16 0.88x  FP8 1.26x
  720p-32f (NPQ=76544):  BF16 0.75x  FP8 1.12x
  1080p-32f (NPQ=165k):  BF16 0.75x  FP8 1.12x
@jiacao-amd jiacao-amd changed the title conv3d: 8-wave double-buffered BF16 + FP8 implicit-GEMM kernels conv3d: 8-wave double-buffered FP8 implicit-GEMM kernel Jul 1, 2026
H=W=18 pad=1 gives NPQ=972 which fails the alignment check.
Use H=W=16 pad=1 giving NPQ=768 (768%128==0).
@jiacao-amd jiacao-amd force-pushed the jiacao/conv3d-implicit-8wave-fp8 branch from dd08dfa to 448c50a Compare July 1, 2026 20:40
@jiacao-amd jiacao-amd changed the title conv3d: 8-wave double-buffered FP8 implicit-GEMM kernel [Kernel] conv3d: 8-wave double-buffered FP8 implicit-GEMM kernel(gfx950) Jul 1, 2026
Comment thread kernels/conv3d_implicit_8wave.py Outdated
LDS_B_SIZE = STAGES * TILE_N * TILE_K


def _run_compiled(exe, *args):

@coderfeli coderfeli Jul 2, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

already this in utils.

from flydsl.expr.typing import T

TILE_M = 128
TILE_N = 128

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only support all fixed tile config?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this kernel only supports one fixed tile config. I can add autotuning if needed or put it into a future PR.

Relax the aligned-shape constraints so both the BF16 and FP8 8-wave conv3d
kernels handle arbitrary N/C/D/H/W/K:

- Partial M/N/K tiles are masked instead of asserted: grid_m/grid_n/k_tiles
  round up; OOB activation loads are zeroed (k_abs < crs), OOB stores masked
  (col < k), and split-K atomics guarded with scf.if (hardware OOB suppression
  does not apply to atomics).
- Epilogue writes NCDHW directly (col*dhw+row for n==1, general offset for n>1)
  so the output is contiguous, avoiding a downstream contiguous() copy under
  nn.GroupNorm consumers. split-K keeps the NDHWC+permute path.
- FP8 _resolve_splitk caps tiles_per_split at 54 to avoid LLVM JIT OOM from
  range_constexpr unrolling large K-tile counts (C=384/512).
- Deduplicate _run_compiled by importing from kernels.tensor_shim.

Tests: add partial-tile and K=32 shapes; use a 5% FP8 threshold for
CRS%128!=0 cases (partial-K region is zeroed vs the FP8-cast reference).
@jiacao-amd jiacao-amd force-pushed the jiacao/conv3d-implicit-8wave-fp8 branch from 679173f to 04def00 Compare July 3, 2026 06:10
cc = c0 + cv
valid = arith.andi(ss < dhw, cc < c)
store_if = scf.IfOp(valid, results_=[], has_else=False)
with ir.InsertionPoint(store_if.then_block):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use if : but scf.Ifop?

if const_expr(use_splitk):
# Atomics ignore hardware OOB suppression; guard explicitly.
valid = arith.andi(col < fx.Index(k), row < fx.Index(npq))
atom_if = scf.IfOp(valid, results_=[], has_else=False)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if?

scalars = [lds_load_scalar((cv + j) * _TR_LDS_S + rs) for j in range_constexpr(TR_VEC)]
vv = Vec.from_elements(scalars, dtype=elem_ty)
valid = arith.andi(ss < s, cc < c)
store_if = scf.IfOp(valid, results_=[], has_else=False)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try not use scf.

cc = c0 + cv
scalars = [lds_load_scalar((cv + j) * _TR_LDS_S + rs) for j in range_constexpr(TR_VEC)]
vv = Vec.from_elements(scalars, dtype=elem_ty)
valid = arith.andi(ss < s, cc < c)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try not use arith but use &&

def conv3d_8wave_kernel(y: fx.Tensor, x: fx.Tensor, weight: fx.Tensor, bias: fx.Tensor):
x_rsrc = buffer_ops.create_buffer_resource(x, max_size=True)
w_rsrc = buffer_ops.create_buffer_resource(weight, max_size=True)
y_rsrc = buffer_ops.create_buffer_resource(y, max_size=True)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no oob needed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants