[Kernel] conv2d: 8-wave double-buffered implicit-GEMM BF16 kernel(gfx950)#733
Open
jiacao-amd wants to merge 6 commits into
Open
[Kernel] conv2d: 8-wave double-buffered implicit-GEMM BF16 kernel(gfx950)#733jiacao-amd wants to merge 6 commits into
jiacao-amd wants to merge 6 commits into
Conversation
coderfeli
reviewed
Jun 24, 2026
| else: | ||
| valid_row = arith.cmpi(arith.CmpIPredicate.ult, global_m, npq_i32) | ||
| valid_if = scf.IfOp(valid_row, results_=[], has_else=False) | ||
| with ir.InsertionPoint(valid_if.then_block): |
coderfeli
reviewed
Jun 24, 2026
| else: | ||
| valid_row = arith.cmpi(arith.CmpIPredicate.ult, row, npq_i32) | ||
| valid_if = scf.IfOp(valid_row, results_=[], has_else=False) | ||
| with ir.InsertionPoint(valid_if.then_block): |
Collaborator
There was a problem hiding this comment.
use @flyc.jit to func and native if
coderfeli
reviewed
Jun 24, 2026
| ) | ||
|
|
||
| def buffer_load_to_lds(rsrc, lds_ptr, global_offset): | ||
| llvm.InlineAsmOp( |
Collaborator
There was a problem hiding this comment.
why need inline asm here?
Author
There was a problem hiding this comment.
I am doing a direct global load to LDS, is rocdl.raw_ptr_buffer_load_lds a better way to do so?
Collaborator
|
And CI failed |
e1781eb to
7802556
Compare
Replaces the prototype conv2d kernel (fixed K=64, no stride/padding, 4×1 wave layout) with a full-featured 8-wave BF16 implicit-GEMM kernel sharing the same 128×128×32 tile and 2×4 wave pipeline as conv3d_implicit_8wave. Key improvements over the old kernel: - Any K (was fixed K=64 only) - Full stride/padding support (was valid-only) - 8-wave 2×4 layout: TILE_N=128 (was 64), better CU utilization - Double-buffered software pipeline with sched hints (dsrd/dswr/vmem) - Tiled NCHW→NHWC transpose kernel (mirrors conv3d's) - Split-K with conservative heuristic (NPQ≥4096, base<num_cu//2) - Bias support Performance vs old prototype (K=64, 3×3 valid, MI355X gfx950): C=128 P=32: 1.38x C=128 P=64: 1.44x C=128 P=128: 1.28x C=256 P=32: 1.47x C=256 P=64: 1.48x C=256 P=128: 1.23x Performance vs MIOpen (K=C, 3×3 pad=1, MI355X gfx950): C=128 K=128 128²: 1.60x C=256 K=256 128²: 1.30x C=512 K=512 64²: 1.04x C=512 K=512 128²: 1.08x
54c39d4 to
7a40578
Compare
The local _run_compiled copy is identical to the canonical one in kernels/tensor_shim.py; import it instead of redefining.
655f9b4 to
cbaa3cd
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
8-wave BF16 implicit-GEMM conv2d kernel with double-buffered software pipeline.
sched_vmem/dsrd/dswrands_setpriohintsChanges
kernels/conv2d_implicit_mfma.pytests/kernels/test_conv2d_implicit_8wave.pytorch.nn.functional.conv2dPerformance (MI355X gfx950)
C=K=128, 3×3 pad=1 same-size conv, N=1 vs
torch.nn.functional.conv2d:Consistently faster than PyTorch across all spatial sizes from 16² to 1024².
Accuracy (MI355X gfx950, vs PyTorch BF16 reference)
BF16 kernel numerically matches PyTorch BF16 (only BF16 rounding error):
0.000% relative error and cosine=1.0 across all shapes — the kernel produces bit-identical results to PyTorch BF16.
Test plan
python -m pytest tests/kernels/test_conv2d_implicit_8wave.py -v