Skip to content

OneDNN BRGeMM Micro-Kernel Integration for BF16 MatMul#903

Open
bbhattar wants to merge 5 commits intogoogle:devfrom
Intel-tensorflow:feature/onednn-brgemm
Open

OneDNN BRGeMM Micro-Kernel Integration for BF16 MatMul#903
bbhattar wants to merge 5 commits intogoogle:devfrom
Intel-tensorflow:feature/onednn-brgemm

Conversation

@bbhattar
Copy link
Copy Markdown

This PR integrates OneDNN BRGeMM (Batch-Reduced General Matrix Multiply) micro-kernels as an alternative compute path for BF16 MatMul on Intel Xeon platforms with AMX or AVX-512 BF16 support.

What

When enabled via the GEMMA_ONEDNN_BRGEMM compile-time flag, BF16×BF16 MatMul operations are dispatched to JIT-compiled BRGeMM kernels instead of the Highway SIMD path. This targets Gemma model workloads (FFW projections, attention) on Intel Xeon Scalable (SPR/EMR) processors. At this point support has been added to both CMake and Bazel build systems.

How to Enable

# CMake
cmake -DGEMMA_ONEDNN_BRGEMM=ON ..

# Bazel
bazel build --define gemma_onednn_brgemm=1 ...

Runtime Fallback

When GEMMA_ONEDNN_BRGEMM is enabled at compile time, the BRGeMM path activates for BF16×BF16 operations whose dimensions meet AMX tile constraints (M, N, K ≥ 32 and K % 32 == 0). All other cases — non-BF16 types, smaller or non-aligned dimensions, mixed precision — fall through to the standard Highway SIMD MatMul path automatically.

Changes

File Description
ops/brgemm.h Types, caches, thread-local buffers, UseOneDnnBrgemm(), autotuning candidates
ops/brgemm-inl.h DoMatMul_BRGeMM(): kernel JIT/caching, B-packing with hugepages, tiled parallel execution
ops/matmul-inl.h BRGeMM dispatch block in MatMul() guarded by #if GEMMA_ONEDNN_BRGEMM
ops/matmul.h #include "ops/brgemm.h", brgemm_autotune field in MMPerKey
ops/bench_matmul.cc Check brgemm_autotune.Best() to avoid infinite loop when BRGeMM handles dispatch
CMakeLists.txt GEMMA_ONEDNN_BRGEMM option, FetchContent for OneDNN v3.11, conditional target linking
BUILD.bazel config_setting for gemma_onednn_brgemm, conditional OneDNN dep and defines for x86_64
MODULE.bazel OneDNN v3.11 http_archive dependency
bazel/onednn.BUILD Bazel build rules for OneDNN
util/zones.h kBRGeMM caller enum for thread pool dispatch
util/zones.cc CallerName mapping for kBRGeMM

Testing

  • matmul_test passes with and without GEMMA_ONEDNN_BRGEMM (all original test shapes, types, and correctness checks preserved)
  • bench_matmul runs successfully with BRGeMM enabled
  • No changes to existing tests; zero impact when OneDNN is not enabled or on non-x86 platforms

@google-cla
Copy link
Copy Markdown

google-cla Bot commented Apr 28, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@bbhattar bbhattar force-pushed the feature/onednn-brgemm branch from 629b569 to e072d70 Compare April 28, 2026 22:19
Copy link
Copy Markdown
Member

@jan-wassenberg jan-wassenberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice work :) Just some fairly minor suggestions:

Comment thread ops/brgemm.h

struct BRGeMMConfig {
int64_t M_blk;
int64_t N_blk;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could set these to 32 directly, as member initializers? Possibly also make them const to make clear that they do not change.
Also, prefer size_t for all size-like things to prevent sign-conversion warnings.

Comment thread ops/brgemm.h
// Tunable: M_blk in {32,64}, batch_size in {16,32,64,128,256}.
inline std::vector<BRGeMMConfig> BRGeMMCandidates(size_t M, size_t K,
size_t N) {
std::vector<BRGeMMConfig> out;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's .reserve with some estimate, also to document how many there will be?

Comment thread ops/brgemm.h
static constexpr int64_t kMBlkValues[] = {32, 64};
static constexpr int64_t kBatchValues[] = {16, 32, 64, 128, 256};

const int64_t k_chunks = static_cast<int64_t>(K) / kKBlk;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this round up? We have hwy::DivCeil.

Comment thread ops/brgemm.h
}
madvise(ptr_, size_, MADV_HUGEPAGE);
for (size_t off = 0; off < size_; off += kHugePageSize) {
static_cast<volatile uint8_t*>(ptr_)[off] = 0;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly safer/more portable: consider ptr_[off] = 0; hwy::PreventElision(ptr_[off]).

Comment thread ops/brgemm.h
// Kernel cache key: identifies a JIT-compiled kernel set.
struct BRGeMMKernelKey {
size_t M, K, N;
int64_t M_blk, N_blk, K_blk, batch_size;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can these also be size_t? And below.

Comment thread ops/brgemm-inl.h
ke.M_blk =
static_cast<int64_t>(std::min(static_cast<size_t>(cfg.M_blk), M));

ke.M_tail = M % ke.M_blk;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want precomputed hwy::Divisor here to avoid actual division?

Comment thread ops/brgemm-inl.h
const int64_t ldb_for[2] = {ke.N_blk, ke.N_tail ? ke.N_tail : ke.N_blk};
const int64_t ldc_for[2] = {ke.N_blk, ke.N_tail ? ke.N_tail : ke.N_blk};

// Create brgemm kernels for each (M-tile, N-tile) variant.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these are "do we have an M and N tail" variants, could the comment be rephrased to make that more clear?

Comment thread ops/brgemm-inl.h
auto& kern_cache = GetBRGeMMKernelCache();
auto kern_it = kern_cache.find(kern_key);

if (kern_it == kern_cache.end()) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block is quite big. Might help readability and codegen to put it into a HWY_NOINLINE helper function?

Comment thread ops/brgemm-inl.h
if (!MakeBrgemm(ke.brg_first_all[mi][ni], ms, ns, ke.K_blk,
ke.K_super_size, ke.lda, ldb_for[ni], ldc_for[ni],
a_dt, b_dt, c_dt, false)) {
return;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we HWY_WARN on failure? Or even HWY_ABORT? If failure can happen, should we fall back to the prior matmul?

Comment thread ops/brgemm-inl.h
const auto va = hn::Load(df, add_row + n);
const auto result = hn::MulAdd(v, vscale, va);
if constexpr (hwy::IsSame<TC, float>()) {
hn::Store(result, df, reinterpret_cast<float*>(C_row) + n);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to use HWY_RCAST_ALIGNED to tell the compiler is this element-aligned. (also below)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants