Raise on unsupported head_dim in SM100 dense FMHA dispatch instead of silently returning uninitialized output#185
Open
toffee-desuwa wants to merge 1 commit into
Conversation
FMHACutlassSM100FwdRun and FMHACutlassSM100BwdRun end their head_dim selector with an else branch that only printed to stdout and fell through without launching a kernel or raising. Because the output tensors (o/lse for fwd; dq/dk/dv for bwd) are pre-allocated by the caller in flash_mla_interface.py, an unsupported head_dim returned uninitialized memory as a silent wrong result instead of an error. This is inconsistent with the dtype branch in the same two functions, which already hard-fails via FLASH_MLA_ASSERT(false). Replace the silent std::cout with TORCH_CHECK(false, ...) carrying the offending head_dim_qk/head_dim_vo, so an unsupported config raises a recoverable c10::Error instead of returning garbage. This matches the TORCH_CHECK(false, ...) idiom already used for unsupported-config errors across csrc/api/ (e.g. common.h). Host-side only; the two supported (192,128)/(128,128) branches are unchanged. Co-Authored-By: Claude <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
The two host-side dispatch entry points for the SM100 (Blackwell) dense prefill FMHA kernels,
FMHACutlassSM100FwdRunandFMHACutlassSM100BwdRun(both incsrc/sm100/prefill/dense/), select a kernel instantiation based on the runtime(head_dim_qk, head_dim_vo)pair. Only two combinations are instantiated:(192, 128)and(128, 128). When the caller passes any other head-dim pair, the finalelsebranch of the selector merely printed a line to stdout ("No kernel instantiated for head_dim_qk=...") and then fell through, returning normally without launching any kernel.Because both functions operate on output tensors that the caller has already allocated (
o/lseon the forward path;dq/dk/dvon the backward path), returning early without dispatching leaves those buffers in whatever state they were allocated in — typically uninitialized device memory. The caller receives no error, no exception, and no non-zero status: control simply returns and the downstream code proceeds to consume garbage as if it were a valid result. A message on stdout is easy to miss and is not a programmatic signal, so the failure is silent from the caller's point of view.Root cause
The head-dim
elsebranch was written as a diagnostic print rather than as a hard failure. This is inconsistent with the surrounding validation logic in these very same functions: the dtypeelsebranch, which handles the case where the input/output scalar types are not the supported BFloat16 pair, already fails hard viaFLASH_MLA_ASSERT(false)(fmha_cutlass_fwd_sm100.cu:81andfmha_cutlass_bwd_sm100.cu:81). So an unsupported dtype aborts, but an unsupported head_dim quietly no-ops — two unsupported-configuration paths in the same function with opposite behavior. The head_dim path is the odd one out, and it is the more dangerous of the two because it returns successfully with bad data instead of failing.Fix
Replace the silent
std::coutbody of each head_dimelsebranch withTORCH_CHECK(false, ...), carrying the offending dimensions in the message:This makes the unsupported-head_dim path fail loudly and consistently with the existing dtype check, so neither unsupported configuration can return uninitialized output anymore.
TORCH_CHECKwas chosen overFLASH_MLA_ASSERT(false)deliberately.FLASH_MLA_ASSERT(defined incsrc/sm100/prefill/dense/common/helper.h) callsstd::abort(), which tears down the whole process and cannot be caught from Python.TORCH_CHECK(false, ...)instead raises ac10::Error, which surfaces as a normal, catchable Python exception and embeds the dimension values in the message — strictly more recoverable and more informative for a caller. It is also already the established validation idiom in the sibling launch headers these files include (fmha_cutlass_fwd_sm100.cuhandfmha_cutlass_bwd_sm100.cuhuseTORCH_CHECKfor their stride/contiguity preconditions), so this change introduces no new dependency and matches the local style.TORCH_CHECKis already transitively available in both translation units (each.cuincludescommon/utils.hpp, which includes<torch/extension.h>), so no new#includeis required.The change is intentionally minimal: only the body of the two
elsebranches is touched; the selector conditions, the supported kernel paths, and everything else are untouched.This is thematically aligned with the intent of open issue #161, which is about better input validation and surfacing errors rather than letting unsupported inputs slip through silent code paths. This PR does not claim to close that issue; it addresses one concrete instance of the silent-path problem in the SM100 dense FMHA dispatch.
Scope of verification
This is SM100 / Blackwell code. I do not have access to Blackwell hardware or an
sm_100build toolchain, so this change was verified by code inspection only; it was not compiled or run on Blackwell hardware.What inspection establishes:
elsebranch returns without launching a kernel, leaving those tensors unwritten. I did not observe garbage output at runtime — the uninitialized-output consequence is inferred from the control flow (early return past the kernel launch), not from a runtime repro.elsebranch in both functions already hard-fails viaFLASH_MLA_ASSERT(false), so making the head_dimelsefail is a consistency fix, not a behavior change for any currently supported configuration. Supported head-dim pairs(192,128)and(128,128)still take exactly the same dispatch path as before.head_dim_qkandhead_dim_voare in scope at the edit site in both functions (they areints computed near the top of each function and captured by reference into the dispatch lambda), so theTORCH_CHECKmessage arguments are valid.TORCH_CHECKresolves in both translation units via the existingcommon/utils.hpp→<torch/extension.h>include chain; the same macro is already used in the included.cuhheaders.I was not able to perform a compile or runtime test because of the hardware/toolchain constraint above; a maintainer with Blackwell CI can confirm compilation trivially.
AI-assistance disclosure: this change was AI-assisted and human-reviewed before submission. The bug and the fix were identified and confirmed by reading the source; the verification scope is code inspection only, with no compilation or execution on Blackwell hardware, as stated above. No test results or benchmarks are claimed.