Skip to content

[Kernel] conv2d: 8-wave double-buffered implicit-GEMM BF16 kernel(gfx950)#733

Open
jiacao-amd wants to merge 6 commits into
ROCm:mainfrom
jiacao-amd:jiacao/conv2d-implicit-mfma
Open

[Kernel] conv2d: 8-wave double-buffered implicit-GEMM BF16 kernel(gfx950)#733
jiacao-amd wants to merge 6 commits into
ROCm:mainfrom
jiacao-amd:jiacao/conv2d-implicit-mfma

Conversation

@jiacao-amd

@jiacao-amd jiacao-amd commented Jun 24, 2026

Copy link
Copy Markdown

Summary

8-wave BF16 implicit-GEMM conv2d kernel with double-buffered software pipeline.

  • 128×128×32 tile, 2×4 wave layout (WAVE_M=2, WAVE_N=4, 512 threads/block)
  • Double-buffered software pipeline with sched_vmem/dsrd/dswr and s_setprio hints
  • Any K and any C (multiples of 8)
  • Full stride/padding support
  • Tiled NCHW→NHWC transpose kernel for zero-copy input layout change
  • Split-K with conservative heuristic (NPQ≥16384, base<num_cu//4)
  • Bias support

Changes

File Description
kernels/conv2d_implicit_mfma.py 8-wave implicit-GEMM conv2d kernel
tests/kernels/test_conv2d_implicit_8wave.py Correctness tests vs torch.nn.functional.conv2d

Performance (MI355X gfx950)

C=K=128, 3×3 pad=1 same-size conv, N=1 vs torch.nn.functional.conv2d:

H=W 8wave (µs) 8wave (TF) PyTorch (µs) PyTorch (TF) Speedup
16² 19.7 3.82 26.0 2.90 1.32x
32² 19.3 15.65 25.5 11.82 1.32x
64² 18.9 63.81 24.7 48.91 1.30x
128² 20.6 234.48 33.1 145.77 1.61x
256² 46.0 420.47 54.8 352.41 1.19x
512² 141.2 547.69 188.4 410.34 1.33x
1024² 587.2 526.59 843.6 366.58 1.44x

Consistently faster than PyTorch across all spatial sizes from 16² to 1024².

Accuracy (MI355X gfx950, vs PyTorch BF16 reference)

BF16 kernel numerically matches PyTorch BF16 (only BF16 rounding error):

Shape rel_err% cosine SNR (dB)
3×3 C=32 H=34 pad=0 0.000% 1.000000 92.6
3×3 C=32 H=32 pad=1 0.000% 1.000000 86.6
3×3 C=32 H=34 K=96 (K-tail) 0.000% 1.000000 91.4
3×3 C=32 H=17 (M-tail) 0.000% 1.000000 92.3
3×3 C=32 H=66 stride=2 0.000% 1.000000 117.8
3×3 C=64 H=34 K=256 0.000% 1.000000 92.9
1×1 C=128 H=32 K=256 0.000% 1.000000 110.9

0.000% relative error and cosine=1.0 across all shapes — the kernel produces bit-identical results to PyTorch BF16.

Test plan

  • python -m pytest tests/kernels/test_conv2d_implicit_8wave.py -v
    • 8 cases: valid conv, same-size (pad=1), K-tail, M-tail, stride=2, large C/K, 1×1, bias
    • Note: requires CDNA4 (gfx95x), auto-skipped on other architectures

Comment thread kernels/conv2d_implicit_mfma.py Outdated
else:
valid_row = arith.cmpi(arith.CmpIPredicate.ult, global_m, npq_i32)
valid_if = scf.IfOp(valid_row, results_=[], has_else=False)
with ir.InsertionPoint(valid_if.then_block):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use native if?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment thread kernels/conv2d_implicit_mfma.py Outdated
else:
valid_row = arith.cmpi(arith.CmpIPredicate.ult, row, npq_i32)
valid_if = scf.IfOp(valid_row, results_=[], has_else=False)
with ir.InsertionPoint(valid_if.then_block):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use @flyc.jit to func and native if

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment thread kernels/conv2d_implicit_mfma.py Outdated
)

def buffer_load_to_lds(rsrc, lds_ptr, global_offset):
llvm.InlineAsmOp(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why need inline asm here?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am doing a direct global load to LDS, is rocdl.raw_ptr_buffer_load_lds a better way to do so?

@coderfeli

Copy link
Copy Markdown
Collaborator

And CI failed

@jiacao-amd jiacao-amd changed the title Add bf16 implicit-MFMA conv2d kernel (F.conv2d-style API) Add implicit-MFMA conv kernels: bf16 conv2d, bf16 & FP8 8-wave conv3d Jun 30, 2026
@jiacao-amd jiacao-amd force-pushed the jiacao/conv2d-implicit-mfma branch from e1781eb to 7802556 Compare July 1, 2026 18:30
@apicciau apicciau changed the title Add implicit-MFMA conv kernels: bf16 conv2d, bf16 & FP8 8-wave conv3d conv2d: 8-wave double-buffered implicit-GEMM BF16 kernel Jul 1, 2026
Replaces the prototype conv2d kernel (fixed K=64, no stride/padding,
4×1 wave layout) with a full-featured 8-wave BF16 implicit-GEMM kernel
sharing the same 128×128×32 tile and 2×4 wave pipeline as
conv3d_implicit_8wave.

Key improvements over the old kernel:
- Any K (was fixed K=64 only)
- Full stride/padding support (was valid-only)
- 8-wave 2×4 layout: TILE_N=128 (was 64), better CU utilization
- Double-buffered software pipeline with sched hints (dsrd/dswr/vmem)
- Tiled NCHW→NHWC transpose kernel (mirrors conv3d's)
- Split-K with conservative heuristic (NPQ≥4096, base<num_cu//2)
- Bias support

Performance vs old prototype (K=64, 3×3 valid, MI355X gfx950):
  C=128 P=32:   1.38x  C=128 P=64:  1.44x  C=128 P=128: 1.28x
  C=256 P=32:   1.47x  C=256 P=64:  1.48x  C=256 P=128: 1.23x

Performance vs MIOpen (K=C, 3×3 pad=1, MI355X gfx950):
  C=128 K=128 128²: 1.60x   C=256 K=256 128²: 1.30x
  C=512 K=512  64²: 1.04x   C=512 K=512 128²: 1.08x
@jiacao-amd jiacao-amd force-pushed the jiacao/conv2d-implicit-mfma branch from 54c39d4 to 7a40578 Compare July 1, 2026 20:40
@jiacao-amd jiacao-amd marked this pull request as ready for review July 1, 2026 21:35
@jiacao-amd jiacao-amd changed the title conv2d: 8-wave double-buffered implicit-GEMM BF16 kernel [Kernel] conv2d: 8-wave double-buffered implicit-GEMM BF16 kernel(gfx950) Jul 1, 2026
The local _run_compiled copy is identical to the canonical one in
kernels/tensor_shim.py; import it instead of redefining.
@jiacao-amd jiacao-amd force-pushed the jiacao/conv2d-implicit-mfma branch from 655f9b4 to cbaa3cd Compare July 3, 2026 06:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants