Skip to content

Raise on unsupported head_dim in SM100 dense FMHA dispatch instead of silently returning uninitialized output#185

Open
toffee-desuwa wants to merge 1 commit into
deepseek-ai:mainfrom
toffee-desuwa:fix-sm100-dense-fmha-unsupported-headdim
Open

Raise on unsupported head_dim in SM100 dense FMHA dispatch instead of silently returning uninitialized output#185
toffee-desuwa wants to merge 1 commit into
deepseek-ai:mainfrom
toffee-desuwa:fix-sm100-dense-fmha-unsupported-headdim

Conversation

@toffee-desuwa

Copy link
Copy Markdown

Problem

The two host-side dispatch entry points for the SM100 (Blackwell) dense prefill FMHA kernels, FMHACutlassSM100FwdRun and FMHACutlassSM100BwdRun (both in csrc/sm100/prefill/dense/), select a kernel instantiation based on the runtime (head_dim_qk, head_dim_vo) pair. Only two combinations are instantiated: (192, 128) and (128, 128). When the caller passes any other head-dim pair, the final else branch of the selector merely printed a line to stdout ("No kernel instantiated for head_dim_qk=...") and then fell through, returning normally without launching any kernel.

Because both functions operate on output tensors that the caller has already allocated (o/lse on the forward path; dq/dk/dv on the backward path), returning early without dispatching leaves those buffers in whatever state they were allocated in — typically uninitialized device memory. The caller receives no error, no exception, and no non-zero status: control simply returns and the downstream code proceeds to consume garbage as if it were a valid result. A message on stdout is easy to miss and is not a programmatic signal, so the failure is silent from the caller's point of view.

Root cause

The head-dim else branch was written as a diagnostic print rather than as a hard failure. This is inconsistent with the surrounding validation logic in these very same functions: the dtype else branch, which handles the case where the input/output scalar types are not the supported BFloat16 pair, already fails hard via FLASH_MLA_ASSERT(false) (fmha_cutlass_fwd_sm100.cu:81 and fmha_cutlass_bwd_sm100.cu:81). So an unsupported dtype aborts, but an unsupported head_dim quietly no-ops — two unsupported-configuration paths in the same function with opposite behavior. The head_dim path is the odd one out, and it is the more dangerous of the two because it returns successfully with bad data instead of failing.

Fix

Replace the silent std::cout body of each head_dim else branch with TORCH_CHECK(false, ...), carrying the offending dimensions in the message:

TORCH_CHECK(false, "No kernel instantiated for head_dim_qk=", head_dim_qk,
            " head_dim_vo=", head_dim_vo);

This makes the unsupported-head_dim path fail loudly and consistently with the existing dtype check, so neither unsupported configuration can return uninitialized output anymore.

TORCH_CHECK was chosen over FLASH_MLA_ASSERT(false) deliberately. FLASH_MLA_ASSERT (defined in csrc/sm100/prefill/dense/common/helper.h) calls std::abort(), which tears down the whole process and cannot be caught from Python. TORCH_CHECK(false, ...) instead raises a c10::Error, which surfaces as a normal, catchable Python exception and embeds the dimension values in the message — strictly more recoverable and more informative for a caller. It is also already the established validation idiom in the sibling launch headers these files include (fmha_cutlass_fwd_sm100.cuh and fmha_cutlass_bwd_sm100.cuh use TORCH_CHECK for their stride/contiguity preconditions), so this change introduces no new dependency and matches the local style. TORCH_CHECK is already transitively available in both translation units (each .cu includes common/utils.hpp, which includes <torch/extension.h>), so no new #include is required.

The change is intentionally minimal: only the body of the two else branches is touched; the selector conditions, the supported kernel paths, and everything else are untouched.

This is thematically aligned with the intent of open issue #161, which is about better input validation and surfacing errors rather than letting unsupported inputs slip through silent code paths. This PR does not claim to close that issue; it addresses one concrete instance of the silent-path problem in the SM100 dense FMHA dispatch.

Scope of verification

This is SM100 / Blackwell code. I do not have access to Blackwell hardware or an sm_100 build toolchain, so this change was verified by code inspection only; it was not compiled or run on Blackwell hardware.

What inspection establishes:

  • The bug mechanism is read directly from the source: both functions take caller-allocated output tensors, and the original head_dim else branch returns without launching a kernel, leaving those tensors unwritten. I did not observe garbage output at runtime — the uninitialized-output consequence is inferred from the control flow (early return past the kernel launch), not from a runtime repro.
  • The dtype else branch in both functions already hard-fails via FLASH_MLA_ASSERT(false), so making the head_dim else fail is a consistency fix, not a behavior change for any currently supported configuration. Supported head-dim pairs (192,128) and (128,128) still take exactly the same dispatch path as before.
  • head_dim_qk and head_dim_vo are in scope at the edit site in both functions (they are ints computed near the top of each function and captured by reference into the dispatch lambda), so the TORCH_CHECK message arguments are valid.
  • TORCH_CHECK resolves in both translation units via the existing common/utils.hpp<torch/extension.h> include chain; the same macro is already used in the included .cuh headers.

I was not able to perform a compile or runtime test because of the hardware/toolchain constraint above; a maintainer with Blackwell CI can confirm compilation trivially.


AI-assistance disclosure: this change was AI-assisted and human-reviewed before submission. The bug and the fix were identified and confirmed by reading the source; the verification scope is code inspection only, with no compilation or execution on Blackwell hardware, as stated above. No test results or benchmarks are claimed.

FMHACutlassSM100FwdRun and FMHACutlassSM100BwdRun end their head_dim
selector with an else branch that only printed to stdout and fell
through without launching a kernel or raising. Because the output
tensors (o/lse for fwd; dq/dk/dv for bwd) are pre-allocated by the
caller in flash_mla_interface.py, an unsupported head_dim returned
uninitialized memory as a silent wrong result instead of an error.

This is inconsistent with the dtype branch in the same two functions,
which already hard-fails via FLASH_MLA_ASSERT(false). Replace the
silent std::cout with TORCH_CHECK(false, ...) carrying the offending
head_dim_qk/head_dim_vo, so an unsupported config raises a recoverable
c10::Error instead of returning garbage. This matches the
TORCH_CHECK(false, ...) idiom already used for unsupported-config
errors across csrc/api/ (e.g. common.h). Host-side only; the two
supported (192,128)/(128,128) branches are unchanged.

Co-Authored-By: Claude <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant