[Kernel] conv3d: 8-wave double-buffered FP8 implicit-GEMM kernel(gfx950)#794
Open
jiacao-amd wants to merge 8 commits into
Open
[Kernel] conv3d: 8-wave double-buffered FP8 implicit-GEMM kernel(gfx950)#794jiacao-amd wants to merge 8 commits into
jiacao-amd wants to merge 8 commits into
Conversation
Adds two conv3d kernels sharing the same 128×128×32 tile, 2×4 wave layout,
and double-buffered software-pipelined structure.
conv3d_implicit_8wave (BF16, mfma_f32_16x16x32_bf16):
- Implicit GEMM with per-thread 3D im2col gather into LDS
- Tiled NCDHW→NDHWC transpose kernel for zero-copy input layout
- Cached weight permute (KCTRS→K,CRS)
- Split-K with auto heuristic; stride/padding; bias
conv3d_implicit_8wave_fp8 (FP8, mfma_f32_16x16x128_fp8, CDNA4 only):
- Drop-in companion to the BF16 kernel with the same public API
- Packs BF16 inputs once to FP8 (E4M3FN) before the GEMM loop;
weight is cached in FP8 across calls
- 4× wider TILE_K (128 vs 32) → higher MFMA throughput per cycle
- Requires gfx95x (CDNA4), C%128==0, CRS%128==0, NPQ%128==0
Performance vs MIOpen (tuned, 3×3×3 kernel, N=1 D=8 H=W=32, MI355X gfx950):
BF16: C=K=128 181 TF (1.31x) C=K=256 306 TF (0.82x) C=K=512 644 TF (1.16x)
FP8: C=K=128 221 TF (1.18x) C=K=256 596 TF (1.60x) C=K=512 1122 TF (2.02x)
Performance vs MIOpen (default, Qwen-VL patch-embed 2×14×14, C=16 K=1152):
448p-16f (NPQ=8192): BF16 1.04x FP8 1.59x
448p-64f (NPQ=32768): BF16 0.88x FP8 1.26x
720p-32f (NPQ=76544): BF16 0.75x FP8 1.12x
1080p-32f (NPQ=165k): BF16 0.75x FP8 1.12x
H=W=18 pad=1 gives NPQ=972 which fails the alignment check. Use H=W=16 pad=1 giving NPQ=768 (768%128==0).
dd08dfa to
448c50a
Compare
coderfeli
reviewed
Jul 2, 2026
| LDS_B_SIZE = STAGES * TILE_N * TILE_K | ||
|
|
||
|
|
||
| def _run_compiled(exe, *args): |
Collaborator
There was a problem hiding this comment.
already this in utils.
coderfeli
reviewed
Jul 2, 2026
| from flydsl.expr.typing import T | ||
|
|
||
| TILE_M = 128 | ||
| TILE_N = 128 |
Collaborator
There was a problem hiding this comment.
only support all fixed tile config?
Author
There was a problem hiding this comment.
Yes, this kernel only supports one fixed tile config. I can add autotuning if needed or put it into a future PR.
Relax the aligned-shape constraints so both the BF16 and FP8 8-wave conv3d kernels handle arbitrary N/C/D/H/W/K: - Partial M/N/K tiles are masked instead of asserted: grid_m/grid_n/k_tiles round up; OOB activation loads are zeroed (k_abs < crs), OOB stores masked (col < k), and split-K atomics guarded with scf.if (hardware OOB suppression does not apply to atomics). - Epilogue writes NCDHW directly (col*dhw+row for n==1, general offset for n>1) so the output is contiguous, avoiding a downstream contiguous() copy under nn.GroupNorm consumers. split-K keeps the NDHWC+permute path. - FP8 _resolve_splitk caps tiles_per_split at 54 to avoid LLVM JIT OOM from range_constexpr unrolling large K-tile counts (C=384/512). - Deduplicate _run_compiled by importing from kernels.tensor_shim. Tests: add partial-tile and K=32 shapes; use a 5% FP8 threshold for CRS%128!=0 cases (partial-K region is zeroed vs the FP8-cast reference).
679173f to
04def00
Compare
coderfeli
reviewed
Jul 3, 2026
| cc = c0 + cv | ||
| valid = arith.andi(ss < dhw, cc < c) | ||
| store_if = scf.IfOp(valid, results_=[], has_else=False) | ||
| with ir.InsertionPoint(store_if.then_block): |
Collaborator
There was a problem hiding this comment.
why not use if : but scf.Ifop?
coderfeli
reviewed
Jul 3, 2026
| if const_expr(use_splitk): | ||
| # Atomics ignore hardware OOB suppression; guard explicitly. | ||
| valid = arith.andi(col < fx.Index(k), row < fx.Index(npq)) | ||
| atom_if = scf.IfOp(valid, results_=[], has_else=False) |
coderfeli
reviewed
Jul 3, 2026
| scalars = [lds_load_scalar((cv + j) * _TR_LDS_S + rs) for j in range_constexpr(TR_VEC)] | ||
| vv = Vec.from_elements(scalars, dtype=elem_ty) | ||
| valid = arith.andi(ss < s, cc < c) | ||
| store_if = scf.IfOp(valid, results_=[], has_else=False) |
coderfeli
reviewed
Jul 3, 2026
| cc = c0 + cv | ||
| scalars = [lds_load_scalar((cv + j) * _TR_LDS_S + rs) for j in range_constexpr(TR_VEC)] | ||
| vv = Vec.from_elements(scalars, dtype=elem_ty) | ||
| valid = arith.andi(ss < s, cc < c) |
Collaborator
There was a problem hiding this comment.
try not use arith but use &&
coderfeli
reviewed
Jul 3, 2026
| def conv3d_8wave_kernel(y: fx.Tensor, x: fx.Tensor, weight: fx.Tensor, bias: fx.Tensor): | ||
| x_rsrc = buffer_ops.create_buffer_resource(x, max_size=True) | ||
| w_rsrc = buffer_ops.create_buffer_resource(weight, max_size=True) | ||
| y_rsrc = buffer_ops.create_buffer_resource(y, max_size=True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds two new conv3d kernels sharing the same 128×128×32 tile, 2×4 wave (WAVE_M=2, WAVE_N=4, 512 threads) layout, and double-buffered software-pipelined structure.
conv3d_implicit_8wave— BF16 (mfma_f32_16x16x32_bf16)conv3d_implicit_8wave_fp8— FP8 E4M3FN (mfma_f32_16x16x128_fp8, CDNA4 only)Changes
kernels/conv3d_implicit_8wave.pykernels/conv3d_implicit_8wave_fp8.pytests/kernels/test_conv3d_implicit_8wave.pytorch.nn.functional.conv3dtests/kernels/test_conv3d_implicit_8wave_fp8.pyPerformance (MI355X gfx950)
3×3×3 kernel, stride=1, pad=1, C=K=128, D=6 (Do=6), N=1 vs PyTorch:
BF16 is consistently 1.15–1.41x faster than PyTorch. FP8 reaches 2.06–2.48x across the full NPQ range.
Accuracy (MI355X gfx950, vs PyTorch BF16 reference)
BF16 kernel — numerically matches PyTorch BF16 (BF16 rounding only):
FP8 kernel — inherent E4M3FN quantization noise (direct cast, no per-tensor scale):
The ~3.7% relative error and SNR ~28.5 dB are the inherent FP8 E4M3FN quantization floor for normally-distributed activations (3-bit mantissa), consistent across all shapes.
Test plan
python -m pytest tests/kernels/test_conv3d_implicit_8wave.py -v— any GPUpython -m pytest tests/kernels/test_conv3d_implicit_8wave_fp8.py -v— gfx95x only (auto-skipped otherwise)🤖 Generated with Claude Code