Skip to content

DSA: add q causal offsets and Rubin SM100F support#316

Open
jiayus-nvidia wants to merge 4 commits into
NVIDIA:developfrom
jiayus-nvidia:jiayus/fix-dsa-tail-drop-ratio
Open

DSA: add q causal offsets and Rubin SM100F support#316
jiayus-nvidia wants to merge 4 commits into
NVIDIA:developfrom
jiayus-nvidia:jiayus/fix-dsa-tail-drop-ratio

Conversation

@jiayus-nvidia

@jiayus-nvidia jiayus-nvidia commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

Summary

  • Add optional q_causal_offsets plumbing for DSA ratio-causal masking across indexer forward, dense score recompute, and dense indexer backward.
  • Replace bottom-right-derived causal formulas with offset-aware compressed-KV causal limits.
  • Include q_causal_offsets presence, not values, in compile cache keys.
  • Cherry-pick Rubin SM100F CuTe DSL support.

Summary by CodeRabbit

  • New Features
    • Added support for per-batch q_causal_offsets across indexer forward, dense indexer backward, and dense score recomputation, enabling ratio-causal masking and offset-aware causal boundaries.
  • Bug Fixes
    • Removed an overly restrictive dense backward sequence-length constraint in the affected path(s).
  • Improvements
    • Extended compiler support for Rubin and standardized kernel compilation options.
  • Documentation
    • Updated DSA API docs and examples for q_causal_offsets.
  • Tests
    • Enhanced reference and wrapper tests to validate offset-aware masking behavior.

@coderabbitai

coderabbitai Bot commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 15e139db-7a66-4bcd-8c82-1cd0331d3c8c

📥 Commits

Reviewing files that changed from the base of the PR and between 42fd79e and d4186bb.

📒 Files selected for processing (2)
  • docs/fe-oss-apis/dsa.md
  • python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm100.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm100.py

📝 Walkthrough

Walkthrough

Adds optional per-batch q_causal_offsets (int32 tensor, shape (batch,)) to all dense DSA operations — indexer forward, indexer backward, and score recompute — on both SM90 and SM100 backends. Replaces the hardcoded bottom-right seqlen-derived causal formula with an explicit offset-based bound throughout CuTe device kernels. Also adds Rubin GPU arch mapping, centralizes compile_options(), and switches some kernel scratch tensors from make_fragment to make_rmem_tensor.

Changes

q_causal_offsets support across DSA forward, backward, and score recompute

Layer / File(s) Summary
Validation helper, compiler utilities, and kernel scratch fixes
python/cudnn/deepseek_sparse_attention/utils/runtime.py, python/cudnn/deepseek_sparse_attention/utils/compiler.py, python/cudnn/deepseek_sparse_attention/sparse_attention_backward/_interface_sm100.py, python/cudnn/deepseek_sparse_attention/indexer_top_k/indexer_top_k_decode_varlen.py, python/cudnn/deepseek_sparse_attention/indexer_top_k/compactify.py, python/cudnn/deepseek_sparse_attention/indexer_top_k/indexer_top_k_varlen_util.py, python/cudnn/deepseek_sparse_attention/utils/sm90/mma.py
Adds validate_q_causal_offsets enforcing dtype/shape/device/contiguity; extends _ARCH_MAP for Rubin sm_100f; replaces hardcoded --enable-tvm-ffi strings with compile_options(); switches kernel scratch allocations from make_fragment to make_rmem_tensor; fixes cute.ThrMma type annotation.
Indexer forward: q_causal_offsets API wiring and SM90/SM100 kernel changes
python/cudnn/deepseek_sparse_attention/indexer_forward/api.py, python/cudnn/deepseek_sparse_attention/indexer_forward/_interface.py, python/cudnn/deepseek_sparse_attention/indexer_forward/_interface_sm90.py, python/cudnn/deepseek_sparse_attention/indexer_forward/indexer_fwd_sm90.py, python/cudnn/deepseek_sparse_attention/indexer_forward/indexer_fwd_sm100.py
indexer_forward_wrapper and both SM90/SM100 interfaces add q_causal_offsets with validation and compile-key separation; removes varlen seqlen_q > seqlen_k*ratio guard. SM90 kernel replaces q_global_start derivation with q_causal_offset in producer/consumer/epilogue and _compute_n_blocks. SM100 kernel threads q_causal_offset through all persistent warp loops and refactors _causal_num_n_blocks and _epilogue_warp.
Indexer backward: q_causal_offsets API wiring and SM90/SM100 kernel changes
python/cudnn/deepseek_sparse_attention/indexer_backward/api.py, python/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm90.py, python/cudnn/deepseek_sparse_attention/indexer_backward/indexer_backward_sm90.py, python/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm100.py
dense_indexer_backward_wrapper and DenseIndexerBackward add q_causal_offsets/has_q_causal_offsets; removes max_seqlen_q <= max_seqlen_k*ratio constraint. SM90 ScoreGradDenseSm90 and IndexerBackwardSm90 thread mQCausalOffsets into kernel/warpgroup/K-load paths. SM100 ScoreGradDense and DenseIndexerBackward2QGemmSm100 thread mQCausalOffsets; col_limit and max_kv_needed_raw use offset-based formula; factory compile cache keyed by has_q_causal_offsets.
Score recompute: q_causal_offsets API wiring and SM90/SM100 kernel changes
python/cudnn/deepseek_sparse_attention/score_recompute/api.py, python/cudnn/deepseek_sparse_attention/score_recompute/_interface_sm90.py, python/cudnn/deepseek_sparse_attention/score_recompute/_interface_sm100.py, python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm90.py, python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm100.py
DenseIndexerScoreRecompute, DenseAttnScoreRecompute, and wrapper functions add q_causal_offsets with compile-key separation. SM90 producer/consumer paths derive q_causal_offset per batch; _dense_compute_n_blocks and _postprocess_and_reduce replace q_global_start with offset formula. SM100 epilogues (_epilogue_indexer_dense, _epilogue_attention_dense) use q_causal_offset for col_limit.
Reference implementations and test harness updates
test/python/fe_api/dsa/dsa_reference.py, test/python/fe_api/dsa/test_DSA_dense_indexer_backward.py, test/python/fe_api/dsa/test_DSA_dense_score_recompute.py, test/python/fe_api/dsa/test_DSA_indexer_forward.py, docs/fe-oss-apis/dsa.md
Adds _ratio_causal_mask and _batched_ratio_causal_mask reference helpers; updates all reference functions to apply offset-aware validity masks; test files create q_causal_offsets tensors, wire them into DSA wrapper calls and reference checkers; adds test_DSA_ratio_causal_mask_offsets_reference; documentation updated to describe q_causal_offsets parameter semantics and usage across forward, backward, and recompute paths.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • NVIDIA/cudnn-frontend#297: Modifies the dense indexer backward SM100 score-grad (ScoreGradDense) kernel logic in dense_indexer_backward_sm100.py—both PRs affect causal masking and K-block bounds in the same stage.

Suggested labels

mod-frontend, cat-enhancements, orig-nv-eng, cat-infra, cat-doc

Suggested reviewers

  • saltyminty

Poem

🐇 A batch of offsets, fresh from the GPU,
No more "bottom-right"—we compute anew!
Each q_causal_offset shapes the causal bound,
Through SM90 and SM100, the kernels are found.
With validation, caching, and epilogue shifts so true,
The rabbit hops forward and backward—offset too! 🎉

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 44.79% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely summarizes the main changes: adding q causal offsets and Rubin SM100F support to DSA, which are the primary objectives of the pull request.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
python/cudnn/deepseek_sparse_attention/indexer_forward/api.py (1)

173-175: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Keep the APIBase path in sync with the new kernel argument.

Line 173 and Line 192 still pass only cu_seqlens_q, cu_seqlens_k, then stream. The SM100 kernel path now has q_causal_offsets between cu_seqlens_k and stream, so direct IndexerForward users can hit a compile/launch arity mismatch.

🐛 Proposed fix
             cutlass.Float32(self.sm_scale),
             None,
             None,
+            None,
             fake_stream,
             options=compile_options(),
         )
@@
                 cutlass.Float32(self.sm_scale),
                 None,
                 None,
+                None,
                 stream,
             )

Also applies to: 192-194

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@python/cudnn/deepseek_sparse_attention/indexer_forward/api.py` around lines
173 - 175, The APIBase path in the IndexerForward method is missing the
q_causal_offsets argument that was added to the SM100 kernel path. Update both
argument passing locations (around lines 173-175 and lines 192-194) to insert
q_causal_offsets between cu_seqlens_k and stream parameters to match the updated
kernel signature. This will ensure the argument order is consistent across all
kernel paths and prevent arity mismatch errors when calling IndexerForward.
python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm90.py (1)

1315-1321: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Clear skipped causal-tail columns when pruning K blocks.

_dense_compute_n_blocks now prunes blocks using q_causal_offset, but _postprocess_and_reduce only writes zeros inside processed blocks. When a caller provides a preallocated out, any columns beyond n_block_max * tile_n keep stale values because the SM90 wrapper only zero-initializes out when it allocates it itself. Consider zeroing supplied outputs before launch or explicitly clearing the skipped tail.

🐛 One possible host-side fix
 if out is None:
     out = torch.zeros((batch_size, seqlen_q, seqlen_k), dtype=torch.float32, device=q.device)
 else:
     out = out if out.is_contiguous() else out.contiguous()
+    out.zero_()

Also applies to: 1367-1374, 1486-1494

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm90.py`
around lines 1315 - 1321, The kernel leaves stale values in output columns
beyond the processed range when K blocks are pruned using q_causal_offset in
_dense_compute_n_blocks, but _postprocess_and_reduce only writes zeros inside
processed blocks. This affects preallocated output buffers since the SM90
wrapper only zero-initializes output when it allocates internally. Fix this by
either zero-initializing the entire supplied output buffer before kernel launch
(host-side), or explicitly clear the skipped tail columns in the kernel code
across all affected locations in _postprocess_and_reduce and related functions.
Apply the same fix pattern to the three mentioned ranges to ensure consistency
across all similar code paths.
python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm100.py (1)

999-1007: ⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Guard the indexer LSE reduction for empty causal windows.

With ratio > 1 and offset 0, early queries can get col_limit == 0, so local_max[qi] remains -inf and local_sum_exp[qi] remains 0. The later local_max - global_max becomes -inf - -inf, producing a NaN denominator.

🐛 Proposed fix
-            adjusted_sum = local_sum_exp[qi] * cute.math.exp2((local_max[qi] - global_max) * log2_e)
+            adjusted_sum = (
+                local_sum_exp[qi] * cute.math.exp2((local_max[qi] - global_max) * log2_e)
+                if local_sum_exp[qi] > Float32(0.0)
+                else Float32(0.0)
+            )

             global_sum_exp, reduce_phase = self._intra_inter_warp_reduce_sum(
                 sScoreAll_sum,
                 reduce_sync_mbar_ptr,
                 reduce_phase,
@@
-            lse_val = global_max + cute.math.log2(global_sum_exp) * inv_log2_e
+            lse_val = (
+                global_max + cute.math.log2(global_sum_exp) * inv_log2_e
+                if global_sum_exp > Float32(0.0)
+                else -Float32.inf
+            )

Also applies to: 1027-1037

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm100.py`
around lines 999 - 1007, When col_limit equals zero due to ratio greater than 1
with offset 0, no scores are accumulated for early queries, leaving
local_max[qi] at negative infinity and local_sum_exp[qi] at 0. Later in the LSE
reduction calculation, computing local_max[qi] minus global_max results in
negative infinity minus negative infinity which produces NaN. Add a guard
condition to check whether local_max[qi] is negative infinity before performing
the LSE reduction operations that involve subtraction with local_max[qi]. Apply
this same guard check to the similar code block also mentioned in the comment
(lines 1027-1037 range).
🧹 Nitpick comments (12)
python/cudnn/deepseek_sparse_attention/indexer_backward/api.py (1)

526-555: ⚡ Quick win

Document q_causal_offsets on the dense backward wrapper.

The new public argument materially changes the ratio-causal mask, but the wrapper docstring does not mention its shape, dtype/device requirements, or that cache selection depends only on whether it is provided. As per coding guidelines, python/cudnn/**: “Focus on documentation.”

Suggested docstring addition
     """Dense full-KV indexer backward. Returns ``{'d_index_q', 'd_weights', 'd_index_k'}``.
 
     ``attn_score`` and ``index_score`` are raw dense scores from
     ``dense_attn_score_recompute_wrapper`` and
     ``dense_indexer_score_recompute_wrapper`` respectively. They are consumed
     in-place by the score-gradient precompute stage.
+
+    Args:
+        q_causal_offsets: Optional CUDA int32 tensor of shape ``(batch,)``.
+            Each entry is the global uncompressed token index for local
+            ``q[0]`` in that batch. When omitted, the legacy bottom-right
+            ratio-causal alignment is used.
     """
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@python/cudnn/deepseek_sparse_attention/indexer_backward/api.py` around lines
526 - 555, The docstring for the dense_indexer_backward_wrapper function is
missing documentation for the q_causal_offsets parameter, which is a new public
argument that materially changes the ratio-causal mask behavior. Add
documentation to the function's docstring that specifies the shape and
dtype/device requirements for q_causal_offsets, and clearly document that cache
selection depends only on whether this parameter is provided or None. Ensure the
parameter documentation follows the existing docstring format and coding
guidelines for the python/cudnn module.

Source: Coding guidelines

python/cudnn/deepseek_sparse_attention/score_recompute/api.py (1)

613-628: ⚡ Quick win

Add wrapper docs for the new offset-aware causal mode.

Both dense score recompute wrappers now expose q_causal_offsets, but neither documents the (batch,) CUDA int32 contract or how it changes the ratio-causal bound. As per coding guidelines, python/cudnn/**: “Focus on documentation.”

Suggested docstring pattern
 def dense_indexer_score_recompute_wrapper(
     q: torch.Tensor,
     k: torch.Tensor,
     weights: torch.Tensor,
@@
     q_causal_offsets: Optional[torch.Tensor] = None,
     stream: Optional[cuda.CUstream] = None,
 ) -> TupleDict:
+    """Dense indexer score recompute over full KV.
+
+    Args:
+        q_causal_offsets: Optional CUDA int32 tensor of shape ``(batch,)``.
+            Each entry is the global uncompressed token index for local
+            ``q[0]`` in that batch; presence selects the offset-aware compiled
+            kernel variant.
+
+    Returns:
+        ``TupleDict(out=..., denom=...)``.
+    """

Apply the same wording to dense_attn_score_recompute_wrapper.

Also applies to: 824-839

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@python/cudnn/deepseek_sparse_attention/score_recompute/api.py` around lines
613 - 628, Add comprehensive docstring documentation to the
dense_indexer_score_recompute_wrapper function that clearly describes the
q_causal_offsets parameter, including that it is a (batch,) sized CUDA int32
tensor and how it affects the ratio-causal bound behavior. Apply the same
documentation pattern to the dense_attn_score_recompute_wrapper function as
well, ensuring both wrappers have consistent and complete documentation for the
offset-aware causal mode according to the cudnn coding guidelines that emphasize
documentation.

Source: Coding guidelines

python/cudnn/deepseek_sparse_attention/utils/runtime.py (1)

20-31: ⚡ Quick win

Document the shared q_causal_offsets contract.

This helper is now the central contract for the new public parameter; add a short docstring covering None, CUDA int32, shape (batch,), same-device requirement, and the contiguous copy behavior. As per coding guidelines, python/cudnn/**: “Focus on documentation.”

Suggested documentation
 def validate_q_causal_offsets(q_causal_offsets, batch: int, device: torch.device):
+    """Validate optional per-batch q-causal offsets.
+
+    Args:
+        q_causal_offsets: Optional CUDA int32 tensor with shape ``(batch,)``.
+            Each entry is the global uncompressed token index for local
+            query position 0 in that batch.
+        batch: Expected batch size.
+        device: Device expected to match the query tensor.
+
+    Returns:
+        ``None`` or a contiguous tensor suitable for CuTe kernel arguments.
+    """
     if q_causal_offsets is None:
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@python/cudnn/deepseek_sparse_attention/utils/runtime.py` around lines 20 -
31, The function validate_q_causal_offsets lacks documentation for this central
contract parameter. Add a docstring to the function that documents the following
contract requirements: that q_causal_offsets can be None and will be returned
as-is, that it must be a CUDA int32 tensor, that it must have shape (batch,),
that it must be on the same device as the input q, and that the function returns
either the original tensor if already contiguous or a contiguous copy. This
documentation should serve as the shared contract for this new public parameter.

Source: Coding guidelines

python/cudnn/deepseek_sparse_attention/score_recompute/_interface_sm90.py (1)

514-538: ⚡ Quick win

Update SM90 dense entry-point docstrings for q_causal_offsets.

The signatures now accept q_causal_offsets, but the public SM90 backend docs still omit the offset-aware causal semantics and expected CUDA int32 shape. As per coding guidelines, python/cudnn/**: “Focus on documentation.”

Also applies to: 577-598

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@python/cudnn/deepseek_sparse_attention/score_recompute/_interface_sm90.py`
around lines 514 - 538, The docstring for the dense_indexer_score_recompute
function is missing documentation for the q_causal_offsets parameter that is now
part of the function signature. Update the docstring to document the
q_causal_offsets parameter by adding a description explaining its purpose
(offset-aware causal semantics), expected type (CUDA int32), and shape
information. Apply the same docstring update to any related functions (such as
other SM90 backend entry points) that also accept the q_causal_offsets
parameter.

Source: Coding guidelines

test/python/fe_api/dsa/test_DSA_dense_indexer_backward.py (1)

90-90: ⚡ Quick win

Use non-uniform offsets to exercise per-batch indexing.

A constant 8 verifies the “offsets present” path, but it would not catch a kernel that always reads q_causal_offsets[0] for every batch. Vary the values by batch so the reference comparison checks the per-batch contract. As per coding guidelines, python/cudnn/**: “Focus on whether there are test cases in test/python/fe_api.”

Suggested test input
-    q_causal_offsets = torch.full((b_cfg,), 8, dtype=torch.int32, device="cuda")
+    q_causal_offsets = (ratio * (1 + torch.arange(b_cfg, dtype=torch.int32, device="cuda") % 4)).contiguous()
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@test/python/fe_api/dsa/test_DSA_dense_indexer_backward.py` at line 90, The
q_causal_offsets initialization using torch.full() creates uniform offset values
(all equal to 8), which only exercises the "offsets present" code path but fails
to validate per-batch indexing since a buggy kernel could always read from index
0 regardless of batch. Replace torch.full() with torch.arange() or another
method that creates varying offset values across different batch indices so that
the reference comparison properly validates the per-batch indexing contract and
catches kernels that incorrectly ignore per-batch offsets.

Source: Coding guidelines

python/cudnn/deepseek_sparse_attention/score_recompute/_interface_sm100.py (1)

1029-1066: ⚡ Quick win

Document the new dense SM100 q_causal_offsets argument.

These public backend entry points now accept q_causal_offsets, but the docstrings still only describe ratio/THD behavior. Add the offset tensor contract and note that it selects offset-aware ratio-causal masking. As per coding guidelines, python/cudnn/**: “Focus on documentation.”

Also applies to: 1093-1128

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@python/cudnn/deepseek_sparse_attention/score_recompute/_interface_sm100.py`
around lines 1029 - 1066, The docstring for the dense_indexer_score_recompute
function is missing documentation for the q_causal_offsets parameter, which is
now part of the function signature. Add q_causal_offsets to the Args section of
the docstring, describing it as an optional offset tensor and noting that it
enables offset-aware ratio-causal masking. Apply the same documentation update
to the other affected functions mentioned in the range 1093-1128 that similarly
accept the q_causal_offsets parameter.

Source: Coding guidelines

test/python/fe_api/dsa/test_DSA_dense_score_recompute.py (1)

117-117: ⚡ Quick win

Use non-uniform offsets in the wrapper test.

Line 117 gives every batch the same offset, so a kernel that always reads q_causal_offsets[0] can still pass. Vary the offsets per batch to exercise the new per-batch indexing path.

🧪 Proposed test update
-    q_causal_offsets = torch.full((cfg["b"],), 8, dtype=torch.int32, device=q.device)
+    q_causal_offsets = 8 + 4 * torch.arange(cfg["b"], dtype=torch.int32, device=q.device)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@test/python/fe_api/dsa/test_DSA_dense_score_recompute.py` at line 117, The
test currently creates uniform offsets for all batches using torch.full(), which
means a kernel that only reads the first offset can still pass. Replace the
q_causal_offsets initialization to use varying offset values across different
batch elements instead of filling all with the same value of 8. This will
properly exercise the per-batch indexing path by ensuring each batch element has
a different offset value.
test/python/fe_api/dsa/test_DSA_indexer_forward.py (1)

54-54: ⚡ Quick win

Vary offsets across batches.

Line 54 uses the same offset for every batch, which would not catch a kernel that ignores the batch index when loading q_causal_offsets. Use a per-batch sequence while keeping the first row safely valid.

🧪 Proposed test update
-    q_causal_offsets = torch.full((cfg["b"],), 4, dtype=torch.int32, device=q.device)
+    q_causal_offsets = int(ratio) + int(ratio) * torch.arange(cfg["b"], dtype=torch.int32, device=q.device)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@test/python/fe_api/dsa/test_DSA_indexer_forward.py` at line 54, The
q_causal_offsets tensor is created with torch.full using the same offset value
(4) for all batches, which would not catch a kernel bug that incorrectly ignores
the batch index when loading offsets. Replace the torch.full call with a tensor
that has varying offset values across batches (for example, using torch.arange
or torch.tensor with a sequence), while ensuring all offset values remain valid
and safe to use without causing out-of-bounds access.
python/cudnn/deepseek_sparse_attention/indexer_forward/api.py (1)

238-241: ⚡ Quick win

Document the q_causal_offsets tensor contract in the public wrapper.

The lower-level validator requires an int32 CUDA tensor shaped by batch; callers should not have to discover that from runtime errors.

📝 Proposed doc update
-    positions outside the valid KV range with -inf. ``q_causal_offsets`` may
-    specify the global uncompressed token index for each batch/THD segment's
-    local q[0].
+    positions outside the valid KV range with -inf. ``q_causal_offsets`` may
+    specify the global uncompressed token index for each batch/THD segment's
+    local q[0]. When provided, it must be an int32 CUDA tensor of shape
+    ``(B,)`` for BSHD input, or ``(cu_seqlens_q.numel() - 1,)`` for THD input;
+    non-contiguous inputs are copied by the lower-level validation path.

As per coding guidelines, python/cudnn/**: "Focus on documentation."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@python/cudnn/deepseek_sparse_attention/indexer_forward/api.py` around lines
238 - 241, The `q_causal_offsets` parameter in the public wrapper function is
not properly documented with its tensor contract requirements. Add clear
documentation to the docstring of the public wrapper function (in the api.py
file) that explicitly specifies that `q_causal_offsets` must be an int32 CUDA
tensor with shape determined by batch dimensions. This documentation should be
placed in the Parameters or Args section of the docstring so callers can
understand the requirements without encountering runtime errors from the
lower-level validator.

Source: Coding guidelines

python/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm90.py (1)

227-239: ⚡ Quick win

Document the has_q_causal_offsets runtime contract.

This factory now changes the compiled ABI and runtime assertions based on has_q_causal_offsets, but the docstring does not mention that callers must pass matching QCausalOffsets into the returned runner paths.

Suggested docstring update
 def dense_indexer_backward_sm90(
     batch,
     seqlen,
     seqlen_k,
@@
     has_q_causal_offsets=False,
 ):
-    """Factory for the dense indexer backward gradient kernel on SM90."""
+    """Factory for the dense indexer backward gradient kernel on SM90.
+
+    Args:
+        has_q_causal_offsets: When True, the returned runner, ``score_grad``,
+            and ``gemm_only`` paths must receive a matching ``QCausalOffsets``
+            tensor at runtime; when False, they must receive ``None``.
+    """

As per coding guidelines, python/cudnn/**: "Focus on documentation."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@python/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm90.py`
around lines 227 - 239, The docstring for the dense_indexer_backward_sm90
function does not document the runtime contract for the has_q_causal_offsets
parameter. Update the docstring to explicitly explain that when
has_q_causal_offsets is True, callers must pass matching QCausalOffsets into the
returned runner paths, and that this parameter affects the compiled ABI and
runtime assertions. Include this documentation requirement as part of the
function's parameter documentation.

Source: Coding guidelines

python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm90.py (1)

201-213: ⚡ Quick win

Add inline documentation for the offset-based causal formula.

The new mQCausalOffsets value is threaded through several producer/consumer paths, but the file does not document its shape, default behavior, or why the bound is (q_causal_offset + q + 1) // ratio. A short comment at the entrypoint and helper would make the cross-path contract easier to audit. As per coding guidelines, python/cudnn/**: “Focus on documentation.”

Also applies to: 530-533, 716-719, 984-987, 1266-1268, 1486-1494

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm90.py`
around lines 201 - 213, The mQCausalOffsets parameter in the __call__ method
lacks documentation about its shape, default behavior, and the reasoning behind
the bound formula (q_causal_offset + q + 1) // ratio. Add inline comments at the
__call__ entrypoint that clearly document what mQCausalOffsets represents, its
expected shape, how it behaves by default, and why the bound calculation uses
that specific formula. Additionally, add similar documentation comments at all
helper functions and producer/consumer paths that use mQCausalOffsets (as
indicated in the affected locations at lines 530-533, 716-719, 984-987,
1266-1268, and 1486-1494) to ensure the cross-path contract is clear and
consistent throughout the file.

Source: Coding guidelines

python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm100.py (1)

175-188: ⚡ Quick win

Document the offset-aware causal bound.

mQCausalOffsets is a new API-visible kernel argument, but its expected per-batch semantics are not described, and the epilogue comments still refer to the old “Bottom-right” mask. Please update these comments to describe col_limit = floor((q_causal_offset + q_token_idx + 1) / ratio) and the expected tensor shape/dtype. As per coding guidelines, python/cudnn/**: “Focus on documentation.”

Also applies to: 988-999, 1176-1184

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm100.py`
around lines 175 - 188, The `mQCausalOffsets` parameter in the `__call__` method
signature lacks documentation describing its per-batch semantics and expected
shape/dtype. Add comprehensive docstring comments to explain that this parameter
defines per-batch offsets for the causal bound calculation using the formula
col_limit = floor((q_causal_offset + q_token_idx + 1) / ratio), and document its
expected tensor shape and data type. Additionally, update the outdated epilogue
comments that still reference the old "Bottom-right" mask terminology to reflect
the new offset-aware causal masking behavior. Apply the same documentation
updates to the related code sections at lines 988-999 and 1176-1184 that also
reference the causal masking logic to maintain consistency throughout the file.

Source: Coding guidelines

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@python/cudnn/deepseek_sparse_attention/indexer_forward/indexer_fwd_sm90.py`:
- Around line 10-11: The kernel docstring contains misleading information about
the output data type. In the docstring for the kernel function in
indexer_fwd_sm90.py, the phrase "Reduced BF16 scores are written directly from
registers to global memory" incorrectly references BF16 when the actual output
dtype is Float32 (FP32) to the mOut variable. Update the docstring to replace
"Reduced BF16 scores" with "Reduced FP32 scores" or rephrase it as "Reduced
scores are written to the FP32 output" to accurately reflect the actual output
dtype being written to global memory.
- Around line 298-302: Add a new regression test case in the
test/python/fe_api/dsa test directory that validates the behavior when
q_causal_offsets is None (triggering the default value of 0 on line 298)
combined with mismatched Q and K sequence lengths. Set up test parameters with
seqlen_q different from seqlen_k while leaving q_causal_offsets unspecified, and
verify that the ratio-causal mask computation through the _compute_n_blocks
method produces correct results for this fallback scenario, distinguishing it
from the existing test case that only covers the explicit offset value of 4.

In `@test/python/fe_api/dsa/dsa_reference.py`:
- Around line 628-630: The current code masks invalid positions to negative
infinity before applying softmax, which causes NaN values when all positions in
a row are masked (all-false valid rows). Between the masked_fill call on scores
and the torch.softmax call on the masked scores, add logic to identify rows
where all values are negative infinity, temporarily replace those with a finite
sentinel value (like 0), apply softmax to get valid probability distributions,
and then zero out the invalid positions in the resulting predict tensor to
handle the edge case where early queries have no valid positions due to small
offsets in the _batched_ratio_causal_mask function.
- Around line 376-392: The code currently maps the None case for
q_causal_offsets to _ratio_causal_mask on line 377, which changes the legacy
bottom-right causal alignment semantics. To preserve the prior behavior when
offsets are omitted, replace the call to _ratio_causal_mask with a call to
_bottom_right_causal_mask in the condition where q_causal_offsets is None,
ensuring that the function maintains backward compatibility with the original
bottom-right causal mask behavior.

---

Outside diff comments:
In `@python/cudnn/deepseek_sparse_attention/indexer_forward/api.py`:
- Around line 173-175: The APIBase path in the IndexerForward method is missing
the q_causal_offsets argument that was added to the SM100 kernel path. Update
both argument passing locations (around lines 173-175 and lines 192-194) to
insert q_causal_offsets between cu_seqlens_k and stream parameters to match the
updated kernel signature. This will ensure the argument order is consistent
across all kernel paths and prevent arity mismatch errors when calling
IndexerForward.

In
`@python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm100.py`:
- Around line 999-1007: When col_limit equals zero due to ratio greater than 1
with offset 0, no scores are accumulated for early queries, leaving
local_max[qi] at negative infinity and local_sum_exp[qi] at 0. Later in the LSE
reduction calculation, computing local_max[qi] minus global_max results in
negative infinity minus negative infinity which produces NaN. Add a guard
condition to check whether local_max[qi] is negative infinity before performing
the LSE reduction operations that involve subtraction with local_max[qi]. Apply
this same guard check to the similar code block also mentioned in the comment
(lines 1027-1037 range).

In
`@python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm90.py`:
- Around line 1315-1321: The kernel leaves stale values in output columns beyond
the processed range when K blocks are pruned using q_causal_offset in
_dense_compute_n_blocks, but _postprocess_and_reduce only writes zeros inside
processed blocks. This affects preallocated output buffers since the SM90
wrapper only zero-initializes output when it allocates internally. Fix this by
either zero-initializing the entire supplied output buffer before kernel launch
(host-side), or explicitly clear the skipped tail columns in the kernel code
across all affected locations in _postprocess_and_reduce and related functions.
Apply the same fix pattern to the three mentioned ranges to ensure consistency
across all similar code paths.

---

Nitpick comments:
In `@python/cudnn/deepseek_sparse_attention/indexer_backward/api.py`:
- Around line 526-555: The docstring for the dense_indexer_backward_wrapper
function is missing documentation for the q_causal_offsets parameter, which is a
new public argument that materially changes the ratio-causal mask behavior. Add
documentation to the function's docstring that specifies the shape and
dtype/device requirements for q_causal_offsets, and clearly document that cache
selection depends only on whether this parameter is provided or None. Ensure the
parameter documentation follows the existing docstring format and coding
guidelines for the python/cudnn module.

In
`@python/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm90.py`:
- Around line 227-239: The docstring for the dense_indexer_backward_sm90
function does not document the runtime contract for the has_q_causal_offsets
parameter. Update the docstring to explicitly explain that when
has_q_causal_offsets is True, callers must pass matching QCausalOffsets into the
returned runner paths, and that this parameter affects the compiled ABI and
runtime assertions. Include this documentation requirement as part of the
function's parameter documentation.

In `@python/cudnn/deepseek_sparse_attention/indexer_forward/api.py`:
- Around line 238-241: The `q_causal_offsets` parameter in the public wrapper
function is not properly documented with its tensor contract requirements. Add
clear documentation to the docstring of the public wrapper function (in the
api.py file) that explicitly specifies that `q_causal_offsets` must be an int32
CUDA tensor with shape determined by batch dimensions. This documentation should
be placed in the Parameters or Args section of the docstring so callers can
understand the requirements without encountering runtime errors from the
lower-level validator.

In `@python/cudnn/deepseek_sparse_attention/score_recompute/_interface_sm100.py`:
- Around line 1029-1066: The docstring for the dense_indexer_score_recompute
function is missing documentation for the q_causal_offsets parameter, which is
now part of the function signature. Add q_causal_offsets to the Args section of
the docstring, describing it as an optional offset tensor and noting that it
enables offset-aware ratio-causal masking. Apply the same documentation update
to the other affected functions mentioned in the range 1093-1128 that similarly
accept the q_causal_offsets parameter.

In `@python/cudnn/deepseek_sparse_attention/score_recompute/_interface_sm90.py`:
- Around line 514-538: The docstring for the dense_indexer_score_recompute
function is missing documentation for the q_causal_offsets parameter that is now
part of the function signature. Update the docstring to document the
q_causal_offsets parameter by adding a description explaining its purpose
(offset-aware causal semantics), expected type (CUDA int32), and shape
information. Apply the same docstring update to any related functions (such as
other SM90 backend entry points) that also accept the q_causal_offsets
parameter.

In `@python/cudnn/deepseek_sparse_attention/score_recompute/api.py`:
- Around line 613-628: Add comprehensive docstring documentation to the
dense_indexer_score_recompute_wrapper function that clearly describes the
q_causal_offsets parameter, including that it is a (batch,) sized CUDA int32
tensor and how it affects the ratio-causal bound behavior. Apply the same
documentation pattern to the dense_attn_score_recompute_wrapper function as
well, ensuring both wrappers have consistent and complete documentation for the
offset-aware causal mode according to the cudnn coding guidelines that emphasize
documentation.

In
`@python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm100.py`:
- Around line 175-188: The `mQCausalOffsets` parameter in the `__call__` method
signature lacks documentation describing its per-batch semantics and expected
shape/dtype. Add comprehensive docstring comments to explain that this parameter
defines per-batch offsets for the causal bound calculation using the formula
col_limit = floor((q_causal_offset + q_token_idx + 1) / ratio), and document its
expected tensor shape and data type. Additionally, update the outdated epilogue
comments that still reference the old "Bottom-right" mask terminology to reflect
the new offset-aware causal masking behavior. Apply the same documentation
updates to the related code sections at lines 988-999 and 1176-1184 that also
reference the causal masking logic to maintain consistency throughout the file.

In
`@python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm90.py`:
- Around line 201-213: The mQCausalOffsets parameter in the __call__ method
lacks documentation about its shape, default behavior, and the reasoning behind
the bound formula (q_causal_offset + q + 1) // ratio. Add inline comments at the
__call__ entrypoint that clearly document what mQCausalOffsets represents, its
expected shape, how it behaves by default, and why the bound calculation uses
that specific formula. Additionally, add similar documentation comments at all
helper functions and producer/consumer paths that use mQCausalOffsets (as
indicated in the affected locations at lines 530-533, 716-719, 984-987,
1266-1268, and 1486-1494) to ensure the cross-path contract is clear and
consistent throughout the file.

In `@python/cudnn/deepseek_sparse_attention/utils/runtime.py`:
- Around line 20-31: The function validate_q_causal_offsets lacks documentation
for this central contract parameter. Add a docstring to the function that
documents the following contract requirements: that q_causal_offsets can be None
and will be returned as-is, that it must be a CUDA int32 tensor, that it must
have shape (batch,), that it must be on the same device as the input q, and that
the function returns either the original tensor if already contiguous or a
contiguous copy. This documentation should serve as the shared contract for this
new public parameter.

In `@test/python/fe_api/dsa/test_DSA_dense_indexer_backward.py`:
- Line 90: The q_causal_offsets initialization using torch.full() creates
uniform offset values (all equal to 8), which only exercises the "offsets
present" code path but fails to validate per-batch indexing since a buggy kernel
could always read from index 0 regardless of batch. Replace torch.full() with
torch.arange() or another method that creates varying offset values across
different batch indices so that the reference comparison properly validates the
per-batch indexing contract and catches kernels that incorrectly ignore
per-batch offsets.

In `@test/python/fe_api/dsa/test_DSA_dense_score_recompute.py`:
- Line 117: The test currently creates uniform offsets for all batches using
torch.full(), which means a kernel that only reads the first offset can still
pass. Replace the q_causal_offsets initialization to use varying offset values
across different batch elements instead of filling all with the same value of 8.
This will properly exercise the per-batch indexing path by ensuring each batch
element has a different offset value.

In `@test/python/fe_api/dsa/test_DSA_indexer_forward.py`:
- Line 54: The q_causal_offsets tensor is created with torch.full using the same
offset value (4) for all batches, which would not catch a kernel bug that
incorrectly ignores the batch index when loading offsets. Replace the torch.full
call with a tensor that has varying offset values across batches (for example,
using torch.arange or torch.tensor with a sequence), while ensuring all offset
values remain valid and safe to use without causing out-of-bounds access.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: a67a4e2a-c158-4cd1-9138-944efb43531c

📥 Commits

Reviewing files that changed from the base of the PR and between 603ffab and 42fd79e.

📒 Files selected for processing (25)
  • python/cudnn/deepseek_sparse_attention/indexer_backward/api.py
  • python/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm100.py
  • python/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm90.py
  • python/cudnn/deepseek_sparse_attention/indexer_backward/indexer_backward_sm90.py
  • python/cudnn/deepseek_sparse_attention/indexer_forward/_interface.py
  • python/cudnn/deepseek_sparse_attention/indexer_forward/_interface_sm90.py
  • python/cudnn/deepseek_sparse_attention/indexer_forward/api.py
  • python/cudnn/deepseek_sparse_attention/indexer_forward/indexer_fwd_sm100.py
  • python/cudnn/deepseek_sparse_attention/indexer_forward/indexer_fwd_sm90.py
  • python/cudnn/deepseek_sparse_attention/indexer_top_k/compactify.py
  • python/cudnn/deepseek_sparse_attention/indexer_top_k/indexer_top_k_decode_varlen.py
  • python/cudnn/deepseek_sparse_attention/indexer_top_k/indexer_top_k_varlen_util.py
  • python/cudnn/deepseek_sparse_attention/score_recompute/_interface_sm100.py
  • python/cudnn/deepseek_sparse_attention/score_recompute/_interface_sm90.py
  • python/cudnn/deepseek_sparse_attention/score_recompute/api.py
  • python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm100.py
  • python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm90.py
  • python/cudnn/deepseek_sparse_attention/sparse_attention_backward/_interface_sm100.py
  • python/cudnn/deepseek_sparse_attention/utils/compiler.py
  • python/cudnn/deepseek_sparse_attention/utils/runtime.py
  • python/cudnn/deepseek_sparse_attention/utils/sm90/mma.py
  • test/python/fe_api/dsa/dsa_reference.py
  • test/python/fe_api/dsa/test_DSA_dense_indexer_backward.py
  • test/python/fe_api/dsa/test_DSA_dense_score_recompute.py
  • test/python/fe_api/dsa/test_DSA_indexer_forward.py

Comment on lines +10 to +11
with ratio causal masking against compressed-KV positions. Reduced BF16 scores
are written directly from registers to global memory.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Fix the output dtype wording in the kernel docstring.

The epilogue writes Float32 scores to mOut, so “Reduced BF16 scores” is misleading here; use “Reduced FP32 scores” or “Reduced scores are written to the FP32 output.” As per coding guidelines, python/cudnn/**: focus on documentation.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@python/cudnn/deepseek_sparse_attention/indexer_forward/indexer_fwd_sm90.py`
around lines 10 - 11, The kernel docstring contains misleading information about
the output data type. In the docstring for the kernel function in
indexer_fwd_sm90.py, the phrase "Reduced BF16 scores are written directly from
registers to global memory" incorrectly references BF16 when the actual output
dtype is Float32 (FP32) to the mOut variable. Update the docstring to replace
"Reduced BF16 scores" with "Reduced FP32 scores" or rephrase it as "Reduced
scores are written to the FP32 output" to accurately reflect the actual output
dtype being written to global memory.

Source: Coding guidelines

Comment on lines +298 to +302
q_causal_offset = Int32(0) if const_expr(mQCausalOffsets is None) else mQCausalOffsets[batch_idx]
num_m_blocks = cute.ceil_div(seqlen.seqlen_q * self.qhead_per_kvhead, self.tile_m)
if block_x < num_m_blocks:
m_block = num_m_blocks - Int32(1) - block_x
n_block_max = self._compute_n_blocks(m_block, seqlen.seqlen_q, seqlen.seqlen_k)
n_block_max = self._compute_n_blocks(m_block, seqlen.seqlen_q, seqlen.seqlen_k, q_causal_offset)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Description: Inspect DSA FE API tests for no-offset q_causal_offsets coverage
# where seqlen_q and seqlen_k are not exactly related by ratio.

set -euo pipefail

rg -n -C3 'q_causal_offsets|q_causal_offset|ratio|seqlen_q|seqlen_k' test/python/fe_api/dsa || true

Repository: NVIDIA/cudnn-frontend

Length of output: 34073


🏁 Script executed:

# First, let me examine the specific file and lines mentioned
cat -n python/cudnn/deepseek_sparse_attention/indexer_forward/indexer_fwd_sm90.py | sed -n '290,310p'

Repository: NVIDIA/cudnn-frontend

Length of output: 1150


🏁 Script executed:

# Search for "bottom-right" references in the codebase to understand legacy behavior
rg -i "bottom.right|bottom_right" --type python

Repository: NVIDIA/cudnn-frontend

Length of output: 96


🏁 Script executed:

# Check for any tests that explicitly test q_causal_offsets=None with different seqlen_q/seqlen_k ratios
rg -n "q_causal_offsets.*None|None.*q_causal" test/python/fe_api/dsa --type python -A5 -B5

Repository: NVIDIA/cudnn-frontend

Length of output: 96


🏁 Script executed:

# Search for "bottom-right" references in the codebase to understand legacy behavior
rg -i "bottom.right|bottom_right"

Repository: NVIDIA/cudnn-frontend

Length of output: 32630


🏁 Script executed:

# Check for any tests that explicitly test q_causal_offsets=None with different seqlen_q/seqlen_k ratios
rg -n "q_causal_offsets.*None|None.*q_causal" test/python/fe_api/dsa -A5 -B5

Repository: NVIDIA/cudnn-frontend

Length of output: 8414


🏁 Script executed:

# Look for documentation or comments about the legacy behavior and bottom-right masking
rg -i "legacy|bottom.right|offset.*semantic" python/cudnn/ -A3 -B3

Repository: NVIDIA/cudnn-frontend

Length of output: 16143


🏁 Script executed:

# Search for _compute_n_blocks to understand how it uses the offset
rg -n "_compute_n_blocks" python/cudnn/deepseek_sparse_attention/indexer_forward/indexer_fwd_sm90.py -A10

Repository: NVIDIA/cudnn-frontend

Length of output: 2041


🏁 Script executed:

# Search for tests that explicitly pass q_causal_offsets=None
rg -n "q_causal_offsets\s*=\s*None" test/python/fe_api/dsa

Repository: NVIDIA/cudnn-frontend

Length of output: 47


🏁 Script executed:

# Check if there are any tests for the indexer forward with None offsets
rg -n "test.*indexer.*forward" test/python/fe_api/dsa -A20 | grep -E "(def test|q_causal_offsets)" | head -30

Repository: NVIDIA/cudnn-frontend

Length of output: 157


🏁 Script executed:

# Look at the complete test_DSA_indexer_forward function
cat -n test/python/fe_api/dsa/test_DSA_indexer_forward.py

Repository: NVIDIA/cudnn-frontend

Length of output: 2606


🏁 Script executed:

# Check if there's any test that passes q_causal_offsets as None (not just omitting it)
rg -B5 -A10 "def test_DSA" test/python/fe_api/dsa/test_DSA_indexer_forward.py

Repository: NVIDIA/cudnn-frontend

Length of output: 360


Add regression test for q_causal_offsets=None with mismatched Q/K sequence lengths.

Line 298 defaults missing q_causal_offsets to 0. The current test suite only covers the explicit offset case (value 4); add a test case under test/python/fe_api/dsa with q_causal_offsets=None and mismatched seqlen_q and seqlen_k to ensure the ratio-causal mask behavior is correct for this fallback scenario.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@python/cudnn/deepseek_sparse_attention/indexer_forward/indexer_fwd_sm90.py`
around lines 298 - 302, Add a new regression test case in the
test/python/fe_api/dsa test directory that validates the behavior when
q_causal_offsets is None (triggering the default value of 0 on line 298)
combined with mismatched Q and K sequence lengths. Set up test parameters with
seqlen_q different from seqlen_k while leaving q_causal_offsets unspecified, and
verify that the ratio-causal mask computation through the _compute_n_blocks
method produces correct results for this fallback scenario, distinguishing it
from the existing test case that only covers the explicit offset value of 4.

Source: Coding guidelines

Comment on lines +376 to +392
if q_causal_offsets is None:
return _ratio_causal_mask(s_q, s_k, ratio, device).unsqueeze(0).expand(batch, -1, -1)
offsets = q_causal_offsets.to(device=device, dtype=torch.int64)
q = torch.arange(s_q, device=device, dtype=torch.int64).view(1, s_q)
k_pos = torch.arange(s_k, device=device, dtype=torch.int64).view(1, 1, s_k)
col_limit = torch.div(offsets.view(batch, 1) + q + 1, ratio, rounding_mode="floor").clamp(0, s_k)
return k_pos < col_limit.unsqueeze(-1)


def _bottom_right_causal_mask(
s_q: int,
s_k: int,
ratio: int,
device: torch.device,
) -> torch.Tensor:
"""Compatibility alias for the offset-0 ratio-causal mask."""
return _ratio_causal_mask(s_q, s_k, ratio, device)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Preserve legacy bottom-right behavior when offsets are omitted.

Line 377 maps q_causal_offsets=None to offset 0, and Line 392 makes _bottom_right_causal_mask do the same. That changes the no-offset reference semantics instead of preserving the prior bottom-right causal alignment.

🐛 Proposed fix
     """Return ``(B, S_q, S_k)`` ratio-causal masks with per-batch offsets."""
     if q_causal_offsets is None:
-        return _ratio_causal_mask(s_q, s_k, ratio, device).unsqueeze(0).expand(batch, -1, -1)
+        return _bottom_right_causal_mask(s_q, s_k, ratio, device).unsqueeze(0).expand(batch, -1, -1)
@@
 def _bottom_right_causal_mask(
@@
 ) -> torch.Tensor:
-    """Compatibility alias for the offset-0 ratio-causal mask."""
-    return _ratio_causal_mask(s_q, s_k, ratio, device)
+    """Legacy bottom-right ratio-causal mask used when no q offsets are provided."""
+    q_causal_offset = max(s_k * ratio - s_q, 0)
+    return _ratio_causal_mask(s_q, s_k, ratio, device, q_causal_offset=q_causal_offset)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@test/python/fe_api/dsa/dsa_reference.py` around lines 376 - 392, The code
currently maps the None case for q_causal_offsets to _ratio_causal_mask on line
377, which changes the legacy bottom-right causal alignment semantics. To
preserve the prior behavior when offsets are omitted, replace the call to
_ratio_causal_mask with a call to _bottom_right_causal_mask in the condition
where q_causal_offsets is None, ensuring that the function maintains backward
compatibility with the original bottom-right causal mask behavior.

Comment on lines +628 to 630
valid = _batched_ratio_causal_mask(s_q, s_k, ratio, q_indexer.device, b, q_causal_offsets)
scores = scores.masked_fill(~valid, float("-inf"))
predict = torch.softmax(scores, dim=-1)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Guard zero-valid rows before softmax.

With small offsets, valid can be all false for early queries; after Line 629 those rows are all -inf, and torch.softmax returns NaNs. Mask those rows to a finite sentinel before softmax, then zero invalid positions afterward.

🐛 Proposed fix
     valid = _batched_ratio_causal_mask(s_q, s_k, ratio, q_indexer.device, b, q_causal_offsets)
     scores = scores.masked_fill(~valid, float("-inf"))
-    predict = torch.softmax(scores, dim=-1)
+    has_valid = valid.any(dim=-1, keepdim=True)
+    scores = scores.masked_fill(~has_valid, 0.0)
+    predict = torch.softmax(scores, dim=-1).masked_fill(~valid, 0.0)
     return (predict, scores) if return_scores else predict
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@test/python/fe_api/dsa/dsa_reference.py` around lines 628 - 630, The current
code masks invalid positions to negative infinity before applying softmax, which
causes NaN values when all positions in a row are masked (all-false valid rows).
Between the masked_fill call on scores and the torch.softmax call on the masked
scores, add logic to identify rows where all values are negative infinity,
temporarily replace those with a finite sentinel value (like 0), apply softmax
to get valid probability distributions, and then zero out the invalid positions
in the resulting predict tensor to handle the edge case where early queries have
no valid positions due to small offsets in the _batched_ratio_causal_mask
function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants