DSA: add q causal offsets and Rubin SM100F support#316
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: 📒 Files selected for processing (2)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughAdds optional per-batch Changesq_causal_offsets support across DSA forward, backward, and score recompute
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
python/cudnn/deepseek_sparse_attention/indexer_forward/api.py (1)
173-175:⚠️ Potential issue | 🟠 Major | ⚡ Quick winKeep the APIBase path in sync with the new kernel argument.
Line 173 and Line 192 still pass only
cu_seqlens_q,cu_seqlens_k, thenstream. The SM100 kernel path now hasq_causal_offsetsbetweencu_seqlens_kandstream, so directIndexerForwardusers can hit a compile/launch arity mismatch.🐛 Proposed fix
cutlass.Float32(self.sm_scale), None, None, + None, fake_stream, options=compile_options(), ) @@ cutlass.Float32(self.sm_scale), None, None, + None, stream, )Also applies to: 192-194
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@python/cudnn/deepseek_sparse_attention/indexer_forward/api.py` around lines 173 - 175, The APIBase path in the IndexerForward method is missing the q_causal_offsets argument that was added to the SM100 kernel path. Update both argument passing locations (around lines 173-175 and lines 192-194) to insert q_causal_offsets between cu_seqlens_k and stream parameters to match the updated kernel signature. This will ensure the argument order is consistent across all kernel paths and prevent arity mismatch errors when calling IndexerForward.python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm90.py (1)
1315-1321:⚠️ Potential issue | 🟠 Major | ⚡ Quick winClear skipped causal-tail columns when pruning K blocks.
_dense_compute_n_blocksnow prunes blocks usingq_causal_offset, but_postprocess_and_reduceonly writes zeros inside processed blocks. When a caller provides a preallocatedout, any columns beyondn_block_max * tile_nkeep stale values because the SM90 wrapper only zero-initializesoutwhen it allocates it itself. Consider zeroing supplied outputs before launch or explicitly clearing the skipped tail.🐛 One possible host-side fix
if out is None: out = torch.zeros((batch_size, seqlen_q, seqlen_k), dtype=torch.float32, device=q.device) else: out = out if out.is_contiguous() else out.contiguous() + out.zero_()Also applies to: 1367-1374, 1486-1494
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm90.py` around lines 1315 - 1321, The kernel leaves stale values in output columns beyond the processed range when K blocks are pruned using q_causal_offset in _dense_compute_n_blocks, but _postprocess_and_reduce only writes zeros inside processed blocks. This affects preallocated output buffers since the SM90 wrapper only zero-initializes output when it allocates internally. Fix this by either zero-initializing the entire supplied output buffer before kernel launch (host-side), or explicitly clear the skipped tail columns in the kernel code across all affected locations in _postprocess_and_reduce and related functions. Apply the same fix pattern to the three mentioned ranges to ensure consistency across all similar code paths.python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm100.py (1)
999-1007:⚠️ Potential issue | 🔴 Critical | ⚡ Quick winGuard the indexer LSE reduction for empty causal windows.
With
ratio > 1and offset0, early queries can getcol_limit == 0, solocal_max[qi]remains-infandlocal_sum_exp[qi]remains0. The laterlocal_max - global_maxbecomes-inf - -inf, producing a NaN denominator.🐛 Proposed fix
- adjusted_sum = local_sum_exp[qi] * cute.math.exp2((local_max[qi] - global_max) * log2_e) + adjusted_sum = ( + local_sum_exp[qi] * cute.math.exp2((local_max[qi] - global_max) * log2_e) + if local_sum_exp[qi] > Float32(0.0) + else Float32(0.0) + ) global_sum_exp, reduce_phase = self._intra_inter_warp_reduce_sum( sScoreAll_sum, reduce_sync_mbar_ptr, reduce_phase, @@ - lse_val = global_max + cute.math.log2(global_sum_exp) * inv_log2_e + lse_val = ( + global_max + cute.math.log2(global_sum_exp) * inv_log2_e + if global_sum_exp > Float32(0.0) + else -Float32.inf + )Also applies to: 1027-1037
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm100.py` around lines 999 - 1007, When col_limit equals zero due to ratio greater than 1 with offset 0, no scores are accumulated for early queries, leaving local_max[qi] at negative infinity and local_sum_exp[qi] at 0. Later in the LSE reduction calculation, computing local_max[qi] minus global_max results in negative infinity minus negative infinity which produces NaN. Add a guard condition to check whether local_max[qi] is negative infinity before performing the LSE reduction operations that involve subtraction with local_max[qi]. Apply this same guard check to the similar code block also mentioned in the comment (lines 1027-1037 range).
🧹 Nitpick comments (12)
python/cudnn/deepseek_sparse_attention/indexer_backward/api.py (1)
526-555: ⚡ Quick winDocument
q_causal_offsetson the dense backward wrapper.The new public argument materially changes the ratio-causal mask, but the wrapper docstring does not mention its shape, dtype/device requirements, or that cache selection depends only on whether it is provided. As per coding guidelines,
python/cudnn/**: “Focus on documentation.”Suggested docstring addition
"""Dense full-KV indexer backward. Returns ``{'d_index_q', 'd_weights', 'd_index_k'}``. ``attn_score`` and ``index_score`` are raw dense scores from ``dense_attn_score_recompute_wrapper`` and ``dense_indexer_score_recompute_wrapper`` respectively. They are consumed in-place by the score-gradient precompute stage. + + Args: + q_causal_offsets: Optional CUDA int32 tensor of shape ``(batch,)``. + Each entry is the global uncompressed token index for local + ``q[0]`` in that batch. When omitted, the legacy bottom-right + ratio-causal alignment is used. """🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@python/cudnn/deepseek_sparse_attention/indexer_backward/api.py` around lines 526 - 555, The docstring for the dense_indexer_backward_wrapper function is missing documentation for the q_causal_offsets parameter, which is a new public argument that materially changes the ratio-causal mask behavior. Add documentation to the function's docstring that specifies the shape and dtype/device requirements for q_causal_offsets, and clearly document that cache selection depends only on whether this parameter is provided or None. Ensure the parameter documentation follows the existing docstring format and coding guidelines for the python/cudnn module.Source: Coding guidelines
python/cudnn/deepseek_sparse_attention/score_recompute/api.py (1)
613-628: ⚡ Quick winAdd wrapper docs for the new offset-aware causal mode.
Both dense score recompute wrappers now expose
q_causal_offsets, but neither documents the(batch,)CUDAint32contract or how it changes the ratio-causal bound. As per coding guidelines,python/cudnn/**: “Focus on documentation.”Suggested docstring pattern
def dense_indexer_score_recompute_wrapper( q: torch.Tensor, k: torch.Tensor, weights: torch.Tensor, @@ q_causal_offsets: Optional[torch.Tensor] = None, stream: Optional[cuda.CUstream] = None, ) -> TupleDict: + """Dense indexer score recompute over full KV. + + Args: + q_causal_offsets: Optional CUDA int32 tensor of shape ``(batch,)``. + Each entry is the global uncompressed token index for local + ``q[0]`` in that batch; presence selects the offset-aware compiled + kernel variant. + + Returns: + ``TupleDict(out=..., denom=...)``. + """Apply the same wording to
dense_attn_score_recompute_wrapper.Also applies to: 824-839
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@python/cudnn/deepseek_sparse_attention/score_recompute/api.py` around lines 613 - 628, Add comprehensive docstring documentation to the dense_indexer_score_recompute_wrapper function that clearly describes the q_causal_offsets parameter, including that it is a (batch,) sized CUDA int32 tensor and how it affects the ratio-causal bound behavior. Apply the same documentation pattern to the dense_attn_score_recompute_wrapper function as well, ensuring both wrappers have consistent and complete documentation for the offset-aware causal mode according to the cudnn coding guidelines that emphasize documentation.Source: Coding guidelines
python/cudnn/deepseek_sparse_attention/utils/runtime.py (1)
20-31: ⚡ Quick winDocument the shared
q_causal_offsetscontract.This helper is now the central contract for the new public parameter; add a short docstring covering
None, CUDAint32, shape(batch,), same-device requirement, and the contiguous copy behavior. As per coding guidelines,python/cudnn/**: “Focus on documentation.”Suggested documentation
def validate_q_causal_offsets(q_causal_offsets, batch: int, device: torch.device): + """Validate optional per-batch q-causal offsets. + + Args: + q_causal_offsets: Optional CUDA int32 tensor with shape ``(batch,)``. + Each entry is the global uncompressed token index for local + query position 0 in that batch. + batch: Expected batch size. + device: Device expected to match the query tensor. + + Returns: + ``None`` or a contiguous tensor suitable for CuTe kernel arguments. + """ if q_causal_offsets is None:🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@python/cudnn/deepseek_sparse_attention/utils/runtime.py` around lines 20 - 31, The function validate_q_causal_offsets lacks documentation for this central contract parameter. Add a docstring to the function that documents the following contract requirements: that q_causal_offsets can be None and will be returned as-is, that it must be a CUDA int32 tensor, that it must have shape (batch,), that it must be on the same device as the input q, and that the function returns either the original tensor if already contiguous or a contiguous copy. This documentation should serve as the shared contract for this new public parameter.Source: Coding guidelines
python/cudnn/deepseek_sparse_attention/score_recompute/_interface_sm90.py (1)
514-538: ⚡ Quick winUpdate SM90 dense entry-point docstrings for
q_causal_offsets.The signatures now accept
q_causal_offsets, but the public SM90 backend docs still omit the offset-aware causal semantics and expected CUDAint32shape. As per coding guidelines,python/cudnn/**: “Focus on documentation.”Also applies to: 577-598
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@python/cudnn/deepseek_sparse_attention/score_recompute/_interface_sm90.py` around lines 514 - 538, The docstring for the dense_indexer_score_recompute function is missing documentation for the q_causal_offsets parameter that is now part of the function signature. Update the docstring to document the q_causal_offsets parameter by adding a description explaining its purpose (offset-aware causal semantics), expected type (CUDA int32), and shape information. Apply the same docstring update to any related functions (such as other SM90 backend entry points) that also accept the q_causal_offsets parameter.Source: Coding guidelines
test/python/fe_api/dsa/test_DSA_dense_indexer_backward.py (1)
90-90: ⚡ Quick winUse non-uniform offsets to exercise per-batch indexing.
A constant
8verifies the “offsets present” path, but it would not catch a kernel that always readsq_causal_offsets[0]for every batch. Vary the values by batch so the reference comparison checks the per-batch contract. As per coding guidelines,python/cudnn/**: “Focus on whether there are test cases in test/python/fe_api.”Suggested test input
- q_causal_offsets = torch.full((b_cfg,), 8, dtype=torch.int32, device="cuda") + q_causal_offsets = (ratio * (1 + torch.arange(b_cfg, dtype=torch.int32, device="cuda") % 4)).contiguous()🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@test/python/fe_api/dsa/test_DSA_dense_indexer_backward.py` at line 90, The q_causal_offsets initialization using torch.full() creates uniform offset values (all equal to 8), which only exercises the "offsets present" code path but fails to validate per-batch indexing since a buggy kernel could always read from index 0 regardless of batch. Replace torch.full() with torch.arange() or another method that creates varying offset values across different batch indices so that the reference comparison properly validates the per-batch indexing contract and catches kernels that incorrectly ignore per-batch offsets.Source: Coding guidelines
python/cudnn/deepseek_sparse_attention/score_recompute/_interface_sm100.py (1)
1029-1066: ⚡ Quick winDocument the new dense SM100
q_causal_offsetsargument.These public backend entry points now accept
q_causal_offsets, but the docstrings still only describeratio/THD behavior. Add the offset tensor contract and note that it selects offset-aware ratio-causal masking. As per coding guidelines,python/cudnn/**: “Focus on documentation.”Also applies to: 1093-1128
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@python/cudnn/deepseek_sparse_attention/score_recompute/_interface_sm100.py` around lines 1029 - 1066, The docstring for the dense_indexer_score_recompute function is missing documentation for the q_causal_offsets parameter, which is now part of the function signature. Add q_causal_offsets to the Args section of the docstring, describing it as an optional offset tensor and noting that it enables offset-aware ratio-causal masking. Apply the same documentation update to the other affected functions mentioned in the range 1093-1128 that similarly accept the q_causal_offsets parameter.Source: Coding guidelines
test/python/fe_api/dsa/test_DSA_dense_score_recompute.py (1)
117-117: ⚡ Quick winUse non-uniform offsets in the wrapper test.
Line 117 gives every batch the same offset, so a kernel that always reads
q_causal_offsets[0]can still pass. Vary the offsets per batch to exercise the new per-batch indexing path.🧪 Proposed test update
- q_causal_offsets = torch.full((cfg["b"],), 8, dtype=torch.int32, device=q.device) + q_causal_offsets = 8 + 4 * torch.arange(cfg["b"], dtype=torch.int32, device=q.device)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@test/python/fe_api/dsa/test_DSA_dense_score_recompute.py` at line 117, The test currently creates uniform offsets for all batches using torch.full(), which means a kernel that only reads the first offset can still pass. Replace the q_causal_offsets initialization to use varying offset values across different batch elements instead of filling all with the same value of 8. This will properly exercise the per-batch indexing path by ensuring each batch element has a different offset value.test/python/fe_api/dsa/test_DSA_indexer_forward.py (1)
54-54: ⚡ Quick winVary offsets across batches.
Line 54 uses the same offset for every batch, which would not catch a kernel that ignores the batch index when loading
q_causal_offsets. Use a per-batch sequence while keeping the first row safely valid.🧪 Proposed test update
- q_causal_offsets = torch.full((cfg["b"],), 4, dtype=torch.int32, device=q.device) + q_causal_offsets = int(ratio) + int(ratio) * torch.arange(cfg["b"], dtype=torch.int32, device=q.device)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@test/python/fe_api/dsa/test_DSA_indexer_forward.py` at line 54, The q_causal_offsets tensor is created with torch.full using the same offset value (4) for all batches, which would not catch a kernel bug that incorrectly ignores the batch index when loading offsets. Replace the torch.full call with a tensor that has varying offset values across batches (for example, using torch.arange or torch.tensor with a sequence), while ensuring all offset values remain valid and safe to use without causing out-of-bounds access.python/cudnn/deepseek_sparse_attention/indexer_forward/api.py (1)
238-241: ⚡ Quick winDocument the
q_causal_offsetstensor contract in the public wrapper.The lower-level validator requires an int32 CUDA tensor shaped by batch; callers should not have to discover that from runtime errors.
📝 Proposed doc update
- positions outside the valid KV range with -inf. ``q_causal_offsets`` may - specify the global uncompressed token index for each batch/THD segment's - local q[0]. + positions outside the valid KV range with -inf. ``q_causal_offsets`` may + specify the global uncompressed token index for each batch/THD segment's + local q[0]. When provided, it must be an int32 CUDA tensor of shape + ``(B,)`` for BSHD input, or ``(cu_seqlens_q.numel() - 1,)`` for THD input; + non-contiguous inputs are copied by the lower-level validation path.As per coding guidelines,
python/cudnn/**: "Focus on documentation."🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@python/cudnn/deepseek_sparse_attention/indexer_forward/api.py` around lines 238 - 241, The `q_causal_offsets` parameter in the public wrapper function is not properly documented with its tensor contract requirements. Add clear documentation to the docstring of the public wrapper function (in the api.py file) that explicitly specifies that `q_causal_offsets` must be an int32 CUDA tensor with shape determined by batch dimensions. This documentation should be placed in the Parameters or Args section of the docstring so callers can understand the requirements without encountering runtime errors from the lower-level validator.Source: Coding guidelines
python/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm90.py (1)
227-239: ⚡ Quick winDocument the
has_q_causal_offsetsruntime contract.This factory now changes the compiled ABI and runtime assertions based on
has_q_causal_offsets, but the docstring does not mention that callers must pass matchingQCausalOffsetsinto the returned runner paths.Suggested docstring update
def dense_indexer_backward_sm90( batch, seqlen, seqlen_k, @@ has_q_causal_offsets=False, ): - """Factory for the dense indexer backward gradient kernel on SM90.""" + """Factory for the dense indexer backward gradient kernel on SM90. + + Args: + has_q_causal_offsets: When True, the returned runner, ``score_grad``, + and ``gemm_only`` paths must receive a matching ``QCausalOffsets`` + tensor at runtime; when False, they must receive ``None``. + """As per coding guidelines,
python/cudnn/**: "Focus on documentation."🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@python/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm90.py` around lines 227 - 239, The docstring for the dense_indexer_backward_sm90 function does not document the runtime contract for the has_q_causal_offsets parameter. Update the docstring to explicitly explain that when has_q_causal_offsets is True, callers must pass matching QCausalOffsets into the returned runner paths, and that this parameter affects the compiled ABI and runtime assertions. Include this documentation requirement as part of the function's parameter documentation.Source: Coding guidelines
python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm90.py (1)
201-213: ⚡ Quick winAdd inline documentation for the offset-based causal formula.
The new
mQCausalOffsetsvalue is threaded through several producer/consumer paths, but the file does not document its shape, default behavior, or why the bound is(q_causal_offset + q + 1) // ratio. A short comment at the entrypoint and helper would make the cross-path contract easier to audit. As per coding guidelines,python/cudnn/**: “Focus on documentation.”Also applies to: 530-533, 716-719, 984-987, 1266-1268, 1486-1494
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm90.py` around lines 201 - 213, The mQCausalOffsets parameter in the __call__ method lacks documentation about its shape, default behavior, and the reasoning behind the bound formula (q_causal_offset + q + 1) // ratio. Add inline comments at the __call__ entrypoint that clearly document what mQCausalOffsets represents, its expected shape, how it behaves by default, and why the bound calculation uses that specific formula. Additionally, add similar documentation comments at all helper functions and producer/consumer paths that use mQCausalOffsets (as indicated in the affected locations at lines 530-533, 716-719, 984-987, 1266-1268, and 1486-1494) to ensure the cross-path contract is clear and consistent throughout the file.Source: Coding guidelines
python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm100.py (1)
175-188: ⚡ Quick winDocument the offset-aware causal bound.
mQCausalOffsetsis a new API-visible kernel argument, but its expected per-batch semantics are not described, and the epilogue comments still refer to the old “Bottom-right” mask. Please update these comments to describecol_limit = floor((q_causal_offset + q_token_idx + 1) / ratio)and the expected tensor shape/dtype. As per coding guidelines,python/cudnn/**: “Focus on documentation.”Also applies to: 988-999, 1176-1184
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm100.py` around lines 175 - 188, The `mQCausalOffsets` parameter in the `__call__` method signature lacks documentation describing its per-batch semantics and expected shape/dtype. Add comprehensive docstring comments to explain that this parameter defines per-batch offsets for the causal bound calculation using the formula col_limit = floor((q_causal_offset + q_token_idx + 1) / ratio), and document its expected tensor shape and data type. Additionally, update the outdated epilogue comments that still reference the old "Bottom-right" mask terminology to reflect the new offset-aware causal masking behavior. Apply the same documentation updates to the related code sections at lines 988-999 and 1176-1184 that also reference the causal masking logic to maintain consistency throughout the file.Source: Coding guidelines
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@python/cudnn/deepseek_sparse_attention/indexer_forward/indexer_fwd_sm90.py`:
- Around line 10-11: The kernel docstring contains misleading information about
the output data type. In the docstring for the kernel function in
indexer_fwd_sm90.py, the phrase "Reduced BF16 scores are written directly from
registers to global memory" incorrectly references BF16 when the actual output
dtype is Float32 (FP32) to the mOut variable. Update the docstring to replace
"Reduced BF16 scores" with "Reduced FP32 scores" or rephrase it as "Reduced
scores are written to the FP32 output" to accurately reflect the actual output
dtype being written to global memory.
- Around line 298-302: Add a new regression test case in the
test/python/fe_api/dsa test directory that validates the behavior when
q_causal_offsets is None (triggering the default value of 0 on line 298)
combined with mismatched Q and K sequence lengths. Set up test parameters with
seqlen_q different from seqlen_k while leaving q_causal_offsets unspecified, and
verify that the ratio-causal mask computation through the _compute_n_blocks
method produces correct results for this fallback scenario, distinguishing it
from the existing test case that only covers the explicit offset value of 4.
In `@test/python/fe_api/dsa/dsa_reference.py`:
- Around line 628-630: The current code masks invalid positions to negative
infinity before applying softmax, which causes NaN values when all positions in
a row are masked (all-false valid rows). Between the masked_fill call on scores
and the torch.softmax call on the masked scores, add logic to identify rows
where all values are negative infinity, temporarily replace those with a finite
sentinel value (like 0), apply softmax to get valid probability distributions,
and then zero out the invalid positions in the resulting predict tensor to
handle the edge case where early queries have no valid positions due to small
offsets in the _batched_ratio_causal_mask function.
- Around line 376-392: The code currently maps the None case for
q_causal_offsets to _ratio_causal_mask on line 377, which changes the legacy
bottom-right causal alignment semantics. To preserve the prior behavior when
offsets are omitted, replace the call to _ratio_causal_mask with a call to
_bottom_right_causal_mask in the condition where q_causal_offsets is None,
ensuring that the function maintains backward compatibility with the original
bottom-right causal mask behavior.
---
Outside diff comments:
In `@python/cudnn/deepseek_sparse_attention/indexer_forward/api.py`:
- Around line 173-175: The APIBase path in the IndexerForward method is missing
the q_causal_offsets argument that was added to the SM100 kernel path. Update
both argument passing locations (around lines 173-175 and lines 192-194) to
insert q_causal_offsets between cu_seqlens_k and stream parameters to match the
updated kernel signature. This will ensure the argument order is consistent
across all kernel paths and prevent arity mismatch errors when calling
IndexerForward.
In
`@python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm100.py`:
- Around line 999-1007: When col_limit equals zero due to ratio greater than 1
with offset 0, no scores are accumulated for early queries, leaving
local_max[qi] at negative infinity and local_sum_exp[qi] at 0. Later in the LSE
reduction calculation, computing local_max[qi] minus global_max results in
negative infinity minus negative infinity which produces NaN. Add a guard
condition to check whether local_max[qi] is negative infinity before performing
the LSE reduction operations that involve subtraction with local_max[qi]. Apply
this same guard check to the similar code block also mentioned in the comment
(lines 1027-1037 range).
In
`@python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm90.py`:
- Around line 1315-1321: The kernel leaves stale values in output columns beyond
the processed range when K blocks are pruned using q_causal_offset in
_dense_compute_n_blocks, but _postprocess_and_reduce only writes zeros inside
processed blocks. This affects preallocated output buffers since the SM90
wrapper only zero-initializes output when it allocates internally. Fix this by
either zero-initializing the entire supplied output buffer before kernel launch
(host-side), or explicitly clear the skipped tail columns in the kernel code
across all affected locations in _postprocess_and_reduce and related functions.
Apply the same fix pattern to the three mentioned ranges to ensure consistency
across all similar code paths.
---
Nitpick comments:
In `@python/cudnn/deepseek_sparse_attention/indexer_backward/api.py`:
- Around line 526-555: The docstring for the dense_indexer_backward_wrapper
function is missing documentation for the q_causal_offsets parameter, which is a
new public argument that materially changes the ratio-causal mask behavior. Add
documentation to the function's docstring that specifies the shape and
dtype/device requirements for q_causal_offsets, and clearly document that cache
selection depends only on whether this parameter is provided or None. Ensure the
parameter documentation follows the existing docstring format and coding
guidelines for the python/cudnn module.
In
`@python/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm90.py`:
- Around line 227-239: The docstring for the dense_indexer_backward_sm90
function does not document the runtime contract for the has_q_causal_offsets
parameter. Update the docstring to explicitly explain that when
has_q_causal_offsets is True, callers must pass matching QCausalOffsets into the
returned runner paths, and that this parameter affects the compiled ABI and
runtime assertions. Include this documentation requirement as part of the
function's parameter documentation.
In `@python/cudnn/deepseek_sparse_attention/indexer_forward/api.py`:
- Around line 238-241: The `q_causal_offsets` parameter in the public wrapper
function is not properly documented with its tensor contract requirements. Add
clear documentation to the docstring of the public wrapper function (in the
api.py file) that explicitly specifies that `q_causal_offsets` must be an int32
CUDA tensor with shape determined by batch dimensions. This documentation should
be placed in the Parameters or Args section of the docstring so callers can
understand the requirements without encountering runtime errors from the
lower-level validator.
In `@python/cudnn/deepseek_sparse_attention/score_recompute/_interface_sm100.py`:
- Around line 1029-1066: The docstring for the dense_indexer_score_recompute
function is missing documentation for the q_causal_offsets parameter, which is
now part of the function signature. Add q_causal_offsets to the Args section of
the docstring, describing it as an optional offset tensor and noting that it
enables offset-aware ratio-causal masking. Apply the same documentation update
to the other affected functions mentioned in the range 1093-1128 that similarly
accept the q_causal_offsets parameter.
In `@python/cudnn/deepseek_sparse_attention/score_recompute/_interface_sm90.py`:
- Around line 514-538: The docstring for the dense_indexer_score_recompute
function is missing documentation for the q_causal_offsets parameter that is now
part of the function signature. Update the docstring to document the
q_causal_offsets parameter by adding a description explaining its purpose
(offset-aware causal semantics), expected type (CUDA int32), and shape
information. Apply the same docstring update to any related functions (such as
other SM90 backend entry points) that also accept the q_causal_offsets
parameter.
In `@python/cudnn/deepseek_sparse_attention/score_recompute/api.py`:
- Around line 613-628: Add comprehensive docstring documentation to the
dense_indexer_score_recompute_wrapper function that clearly describes the
q_causal_offsets parameter, including that it is a (batch,) sized CUDA int32
tensor and how it affects the ratio-causal bound behavior. Apply the same
documentation pattern to the dense_attn_score_recompute_wrapper function as
well, ensuring both wrappers have consistent and complete documentation for the
offset-aware causal mode according to the cudnn coding guidelines that emphasize
documentation.
In
`@python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm100.py`:
- Around line 175-188: The `mQCausalOffsets` parameter in the `__call__` method
signature lacks documentation describing its per-batch semantics and expected
shape/dtype. Add comprehensive docstring comments to explain that this parameter
defines per-batch offsets for the causal bound calculation using the formula
col_limit = floor((q_causal_offset + q_token_idx + 1) / ratio), and document its
expected tensor shape and data type. Additionally, update the outdated epilogue
comments that still reference the old "Bottom-right" mask terminology to reflect
the new offset-aware causal masking behavior. Apply the same documentation
updates to the related code sections at lines 988-999 and 1176-1184 that also
reference the causal masking logic to maintain consistency throughout the file.
In
`@python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm90.py`:
- Around line 201-213: The mQCausalOffsets parameter in the __call__ method
lacks documentation about its shape, default behavior, and the reasoning behind
the bound formula (q_causal_offset + q + 1) // ratio. Add inline comments at the
__call__ entrypoint that clearly document what mQCausalOffsets represents, its
expected shape, how it behaves by default, and why the bound calculation uses
that specific formula. Additionally, add similar documentation comments at all
helper functions and producer/consumer paths that use mQCausalOffsets (as
indicated in the affected locations at lines 530-533, 716-719, 984-987,
1266-1268, and 1486-1494) to ensure the cross-path contract is clear and
consistent throughout the file.
In `@python/cudnn/deepseek_sparse_attention/utils/runtime.py`:
- Around line 20-31: The function validate_q_causal_offsets lacks documentation
for this central contract parameter. Add a docstring to the function that
documents the following contract requirements: that q_causal_offsets can be None
and will be returned as-is, that it must be a CUDA int32 tensor, that it must
have shape (batch,), that it must be on the same device as the input q, and that
the function returns either the original tensor if already contiguous or a
contiguous copy. This documentation should serve as the shared contract for this
new public parameter.
In `@test/python/fe_api/dsa/test_DSA_dense_indexer_backward.py`:
- Line 90: The q_causal_offsets initialization using torch.full() creates
uniform offset values (all equal to 8), which only exercises the "offsets
present" code path but fails to validate per-batch indexing since a buggy kernel
could always read from index 0 regardless of batch. Replace torch.full() with
torch.arange() or another method that creates varying offset values across
different batch indices so that the reference comparison properly validates the
per-batch indexing contract and catches kernels that incorrectly ignore
per-batch offsets.
In `@test/python/fe_api/dsa/test_DSA_dense_score_recompute.py`:
- Line 117: The test currently creates uniform offsets for all batches using
torch.full(), which means a kernel that only reads the first offset can still
pass. Replace the q_causal_offsets initialization to use varying offset values
across different batch elements instead of filling all with the same value of 8.
This will properly exercise the per-batch indexing path by ensuring each batch
element has a different offset value.
In `@test/python/fe_api/dsa/test_DSA_indexer_forward.py`:
- Line 54: The q_causal_offsets tensor is created with torch.full using the same
offset value (4) for all batches, which would not catch a kernel bug that
incorrectly ignores the batch index when loading offsets. Replace the torch.full
call with a tensor that has varying offset values across batches (for example,
using torch.arange or torch.tensor with a sequence), while ensuring all offset
values remain valid and safe to use without causing out-of-bounds access.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: a67a4e2a-c158-4cd1-9138-944efb43531c
📒 Files selected for processing (25)
python/cudnn/deepseek_sparse_attention/indexer_backward/api.pypython/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm100.pypython/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm90.pypython/cudnn/deepseek_sparse_attention/indexer_backward/indexer_backward_sm90.pypython/cudnn/deepseek_sparse_attention/indexer_forward/_interface.pypython/cudnn/deepseek_sparse_attention/indexer_forward/_interface_sm90.pypython/cudnn/deepseek_sparse_attention/indexer_forward/api.pypython/cudnn/deepseek_sparse_attention/indexer_forward/indexer_fwd_sm100.pypython/cudnn/deepseek_sparse_attention/indexer_forward/indexer_fwd_sm90.pypython/cudnn/deepseek_sparse_attention/indexer_top_k/compactify.pypython/cudnn/deepseek_sparse_attention/indexer_top_k/indexer_top_k_decode_varlen.pypython/cudnn/deepseek_sparse_attention/indexer_top_k/indexer_top_k_varlen_util.pypython/cudnn/deepseek_sparse_attention/score_recompute/_interface_sm100.pypython/cudnn/deepseek_sparse_attention/score_recompute/_interface_sm90.pypython/cudnn/deepseek_sparse_attention/score_recompute/api.pypython/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm100.pypython/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm90.pypython/cudnn/deepseek_sparse_attention/sparse_attention_backward/_interface_sm100.pypython/cudnn/deepseek_sparse_attention/utils/compiler.pypython/cudnn/deepseek_sparse_attention/utils/runtime.pypython/cudnn/deepseek_sparse_attention/utils/sm90/mma.pytest/python/fe_api/dsa/dsa_reference.pytest/python/fe_api/dsa/test_DSA_dense_indexer_backward.pytest/python/fe_api/dsa/test_DSA_dense_score_recompute.pytest/python/fe_api/dsa/test_DSA_indexer_forward.py
| with ratio causal masking against compressed-KV positions. Reduced BF16 scores | ||
| are written directly from registers to global memory. |
There was a problem hiding this comment.
Fix the output dtype wording in the kernel docstring.
The epilogue writes Float32 scores to mOut, so “Reduced BF16 scores” is misleading here; use “Reduced FP32 scores” or “Reduced scores are written to the FP32 output.” As per coding guidelines, python/cudnn/**: focus on documentation.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@python/cudnn/deepseek_sparse_attention/indexer_forward/indexer_fwd_sm90.py`
around lines 10 - 11, The kernel docstring contains misleading information about
the output data type. In the docstring for the kernel function in
indexer_fwd_sm90.py, the phrase "Reduced BF16 scores are written directly from
registers to global memory" incorrectly references BF16 when the actual output
dtype is Float32 (FP32) to the mOut variable. Update the docstring to replace
"Reduced BF16 scores" with "Reduced FP32 scores" or rephrase it as "Reduced
scores are written to the FP32 output" to accurately reflect the actual output
dtype being written to global memory.
Source: Coding guidelines
| q_causal_offset = Int32(0) if const_expr(mQCausalOffsets is None) else mQCausalOffsets[batch_idx] | ||
| num_m_blocks = cute.ceil_div(seqlen.seqlen_q * self.qhead_per_kvhead, self.tile_m) | ||
| if block_x < num_m_blocks: | ||
| m_block = num_m_blocks - Int32(1) - block_x | ||
| n_block_max = self._compute_n_blocks(m_block, seqlen.seqlen_q, seqlen.seqlen_k) | ||
| n_block_max = self._compute_n_blocks(m_block, seqlen.seqlen_q, seqlen.seqlen_k, q_causal_offset) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Description: Inspect DSA FE API tests for no-offset q_causal_offsets coverage
# where seqlen_q and seqlen_k are not exactly related by ratio.
set -euo pipefail
rg -n -C3 'q_causal_offsets|q_causal_offset|ratio|seqlen_q|seqlen_k' test/python/fe_api/dsa || trueRepository: NVIDIA/cudnn-frontend
Length of output: 34073
🏁 Script executed:
# First, let me examine the specific file and lines mentioned
cat -n python/cudnn/deepseek_sparse_attention/indexer_forward/indexer_fwd_sm90.py | sed -n '290,310p'Repository: NVIDIA/cudnn-frontend
Length of output: 1150
🏁 Script executed:
# Search for "bottom-right" references in the codebase to understand legacy behavior
rg -i "bottom.right|bottom_right" --type pythonRepository: NVIDIA/cudnn-frontend
Length of output: 96
🏁 Script executed:
# Check for any tests that explicitly test q_causal_offsets=None with different seqlen_q/seqlen_k ratios
rg -n "q_causal_offsets.*None|None.*q_causal" test/python/fe_api/dsa --type python -A5 -B5Repository: NVIDIA/cudnn-frontend
Length of output: 96
🏁 Script executed:
# Search for "bottom-right" references in the codebase to understand legacy behavior
rg -i "bottom.right|bottom_right"Repository: NVIDIA/cudnn-frontend
Length of output: 32630
🏁 Script executed:
# Check for any tests that explicitly test q_causal_offsets=None with different seqlen_q/seqlen_k ratios
rg -n "q_causal_offsets.*None|None.*q_causal" test/python/fe_api/dsa -A5 -B5Repository: NVIDIA/cudnn-frontend
Length of output: 8414
🏁 Script executed:
# Look for documentation or comments about the legacy behavior and bottom-right masking
rg -i "legacy|bottom.right|offset.*semantic" python/cudnn/ -A3 -B3Repository: NVIDIA/cudnn-frontend
Length of output: 16143
🏁 Script executed:
# Search for _compute_n_blocks to understand how it uses the offset
rg -n "_compute_n_blocks" python/cudnn/deepseek_sparse_attention/indexer_forward/indexer_fwd_sm90.py -A10Repository: NVIDIA/cudnn-frontend
Length of output: 2041
🏁 Script executed:
# Search for tests that explicitly pass q_causal_offsets=None
rg -n "q_causal_offsets\s*=\s*None" test/python/fe_api/dsaRepository: NVIDIA/cudnn-frontend
Length of output: 47
🏁 Script executed:
# Check if there are any tests for the indexer forward with None offsets
rg -n "test.*indexer.*forward" test/python/fe_api/dsa -A20 | grep -E "(def test|q_causal_offsets)" | head -30Repository: NVIDIA/cudnn-frontend
Length of output: 157
🏁 Script executed:
# Look at the complete test_DSA_indexer_forward function
cat -n test/python/fe_api/dsa/test_DSA_indexer_forward.pyRepository: NVIDIA/cudnn-frontend
Length of output: 2606
🏁 Script executed:
# Check if there's any test that passes q_causal_offsets as None (not just omitting it)
rg -B5 -A10 "def test_DSA" test/python/fe_api/dsa/test_DSA_indexer_forward.pyRepository: NVIDIA/cudnn-frontend
Length of output: 360
Add regression test for q_causal_offsets=None with mismatched Q/K sequence lengths.
Line 298 defaults missing q_causal_offsets to 0. The current test suite only covers the explicit offset case (value 4); add a test case under test/python/fe_api/dsa with q_causal_offsets=None and mismatched seqlen_q and seqlen_k to ensure the ratio-causal mask behavior is correct for this fallback scenario.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@python/cudnn/deepseek_sparse_attention/indexer_forward/indexer_fwd_sm90.py`
around lines 298 - 302, Add a new regression test case in the
test/python/fe_api/dsa test directory that validates the behavior when
q_causal_offsets is None (triggering the default value of 0 on line 298)
combined with mismatched Q and K sequence lengths. Set up test parameters with
seqlen_q different from seqlen_k while leaving q_causal_offsets unspecified, and
verify that the ratio-causal mask computation through the _compute_n_blocks
method produces correct results for this fallback scenario, distinguishing it
from the existing test case that only covers the explicit offset value of 4.
Source: Coding guidelines
| if q_causal_offsets is None: | ||
| return _ratio_causal_mask(s_q, s_k, ratio, device).unsqueeze(0).expand(batch, -1, -1) | ||
| offsets = q_causal_offsets.to(device=device, dtype=torch.int64) | ||
| q = torch.arange(s_q, device=device, dtype=torch.int64).view(1, s_q) | ||
| k_pos = torch.arange(s_k, device=device, dtype=torch.int64).view(1, 1, s_k) | ||
| col_limit = torch.div(offsets.view(batch, 1) + q + 1, ratio, rounding_mode="floor").clamp(0, s_k) | ||
| return k_pos < col_limit.unsqueeze(-1) | ||
|
|
||
|
|
||
| def _bottom_right_causal_mask( | ||
| s_q: int, | ||
| s_k: int, | ||
| ratio: int, | ||
| device: torch.device, | ||
| ) -> torch.Tensor: | ||
| """Compatibility alias for the offset-0 ratio-causal mask.""" | ||
| return _ratio_causal_mask(s_q, s_k, ratio, device) |
There was a problem hiding this comment.
Preserve legacy bottom-right behavior when offsets are omitted.
Line 377 maps q_causal_offsets=None to offset 0, and Line 392 makes _bottom_right_causal_mask do the same. That changes the no-offset reference semantics instead of preserving the prior bottom-right causal alignment.
🐛 Proposed fix
"""Return ``(B, S_q, S_k)`` ratio-causal masks with per-batch offsets."""
if q_causal_offsets is None:
- return _ratio_causal_mask(s_q, s_k, ratio, device).unsqueeze(0).expand(batch, -1, -1)
+ return _bottom_right_causal_mask(s_q, s_k, ratio, device).unsqueeze(0).expand(batch, -1, -1)
@@
def _bottom_right_causal_mask(
@@
) -> torch.Tensor:
- """Compatibility alias for the offset-0 ratio-causal mask."""
- return _ratio_causal_mask(s_q, s_k, ratio, device)
+ """Legacy bottom-right ratio-causal mask used when no q offsets are provided."""
+ q_causal_offset = max(s_k * ratio - s_q, 0)
+ return _ratio_causal_mask(s_q, s_k, ratio, device, q_causal_offset=q_causal_offset)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@test/python/fe_api/dsa/dsa_reference.py` around lines 376 - 392, The code
currently maps the None case for q_causal_offsets to _ratio_causal_mask on line
377, which changes the legacy bottom-right causal alignment semantics. To
preserve the prior behavior when offsets are omitted, replace the call to
_ratio_causal_mask with a call to _bottom_right_causal_mask in the condition
where q_causal_offsets is None, ensuring that the function maintains backward
compatibility with the original bottom-right causal mask behavior.
| valid = _batched_ratio_causal_mask(s_q, s_k, ratio, q_indexer.device, b, q_causal_offsets) | ||
| scores = scores.masked_fill(~valid, float("-inf")) | ||
| predict = torch.softmax(scores, dim=-1) |
There was a problem hiding this comment.
Guard zero-valid rows before softmax.
With small offsets, valid can be all false for early queries; after Line 629 those rows are all -inf, and torch.softmax returns NaNs. Mask those rows to a finite sentinel before softmax, then zero invalid positions afterward.
🐛 Proposed fix
valid = _batched_ratio_causal_mask(s_q, s_k, ratio, q_indexer.device, b, q_causal_offsets)
scores = scores.masked_fill(~valid, float("-inf"))
- predict = torch.softmax(scores, dim=-1)
+ has_valid = valid.any(dim=-1, keepdim=True)
+ scores = scores.masked_fill(~has_valid, 0.0)
+ predict = torch.softmax(scores, dim=-1).masked_fill(~valid, 0.0)
return (predict, scores) if return_scores else predict🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@test/python/fe_api/dsa/dsa_reference.py` around lines 628 - 630, The current
code masks invalid positions to negative infinity before applying softmax, which
causes NaN values when all positions in a row are masked (all-false valid rows).
Between the masked_fill call on scores and the torch.softmax call on the masked
scores, add logic to identify rows where all values are negative infinity,
temporarily replace those with a finite sentinel value (like 0), apply softmax
to get valid probability distributions, and then zero out the invalid positions
in the resulting predict tensor to handle the edge case where early queries have
no valid positions due to small offsets in the _batched_ratio_causal_mask
function.
Summary
q_causal_offsetsplumbing for DSA ratio-causal masking across indexer forward, dense score recompute, and dense indexer backward.q_causal_offsetspresence, not values, in compile cache keys.Summary by CodeRabbit
q_causal_offsetsacross indexer forward, dense indexer backward, and dense score recomputation, enabling ratio-causal masking and offset-aware causal boundaries.q_causal_offsets.