Skip to content

fix: IMA on indexer_topk_wrapper#312

Open
Hyaloid wants to merge 1 commit into
NVIDIA:developfrom
Hyaloid:fix-topk-ima
Open

fix: IMA on indexer_topk_wrapper#312
Hyaloid wants to merge 1 commit into
NVIDIA:developfrom
Hyaloid:fix-topk-ima

Conversation

@Hyaloid

@Hyaloid Hyaloid commented Jun 16, 2026

Copy link
Copy Markdown

cute_dsl lowers the per-block slice buffer = extra_buffer[bidx, ...] into a 32-bit-by-32-bit -> 32-bit multiply followed by a sign-extending widening accumulate. This pr chunk the launch in cute_dsl_topk_wrapper along the row dimension so that within each chunk bidx < chunk_rows and (chunk_rows − 1) * buffer_numbers * num_cols < 2**31.

When run python repro_indexer_topk_ima.py, an IMA error does not exist.

# repro_indexer_topk_ima.py
import torch
from cudnn import DSA

S, S_kv, top_k, next_n = 4096, 262209, 2048, 1
device = torch.device("cuda")

g = torch.Generator(device=device).manual_seed(0)
logits   = torch.randn((S, S_kv), dtype=torch.float32, device=device, generator=g)
seq_lens = torch.full((S // next_n,), S_kv, dtype=torch.int32, device=device)

result = DSA.indexer_top_k_wrapper(
    logits, seq_lens,
    top_k=top_k, next_n=next_n, return_val=False,
)
torch.cuda.synchronize()  # raises cudaErrorIllegalAddress
print(result["indices"].sum().item())

No IMA occurs. And the output looks good.
When use compute-sanitizer to double check, the output show 0 errors.

Summary by CodeRabbit

  • Bug Fixes
    • Enhanced variable-length sequence handling in sparse attention decoding operations with improved computation efficiency and resource allocation.

@coderabbitai

coderabbitai Bot commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

Adds a math import and replaces the single-launch decode-varlen execution path in cute_dsl_topk_wrapper with a chunked loop. The loop computes per-row element sizing, derives an align_rows alignment from input byte width, selects a chunk_rows value within int32 limits, slices tensors into row chunks, allocates per-chunk chunk_extra buffers, and invokes the compiled kernel once per chunk.

Changes

Chunked Decode-Varlen IMA Workaround

Layer / File(s) Summary
Alignment math and chunked kernel dispatch
python/cudnn/deepseek_sparse_attention/indexer_top_k/indexer_top_k_decode_varlen.py
Adds math import; replaces single-launch decode-varlen path with a chunked loop that computes align_rows from input byte width, constrains chunk_rows to lcm(next_n, align_rows) within int32 index limits, slices input_values/output_*/seq_lens per chunk, allocates a per-chunk chunk_extra int32 buffer, and invokes the compiled kernel per chunk with g_global_counter_torch=None; falls back to single-launch when chunk_rows >= num_rows.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Poem

🐇 Hop hop, the rows grow tall,
Too many for int32 to call!
I slice them into chunks just right,
Aligned to lcm with all my might.
Each kernel launch a tiny leap—
No IMA gremlins in my sleep! 🌙

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title 'fix: IMA on indexer_topk_wrapper' directly references the specific technical issue (IMA - Integer Multiply-Accumulate operations) being fixed and identifies the component (indexer_topk_wrapper) affected, aligning with the PR's primary objective.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
python/cudnn/deepseek_sparse_attention/indexer_top_k/indexer_top_k_decode_varlen.py (1)

711-734: Add FE test coverage for the chunked dispatch path (lines 711–734) or extract the planner for isolated unit testing.

The chunked dispatch condition (0 < chunk_rows < num_rows) triggers when num_cols is large relative to num_rows due to the int32 buffer constraint. The current FE test parametrization (b=1, s_q=1024, s_kv=1024, next_n=1) does not trigger this path because max_chunk_rows (2,097,152) far exceeds num_rows (1,024).

Chunking would activate with larger num_cols (e.g., ≥ 262,144), which is realistic for KV cache scenarios. Add a parametrized FE test case that forces chunking, or extract the chunk planning logic (lines 693–710) into a helper function and unit test the boundary alignment and slicing math independently. Without this coverage, regressions in batch/row/seq_lens slicing could pass current tests.

Chunking trigger analysis

With num_cols=262144, num_rows=1024, buffer_numbers=1, next_n=1, align_rows=16:

  • max_chunk_rows = 8192
  • chunk_rows = 8192 (not < 1024, so no chunking)

But with num_cols=1000000:

  • max_chunk_rows = 2148
  • chunk_rows = 2144 (> 1024, but extraction of actual triggering configs from real KV cache sizes recommended)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@python/cudnn/deepseek_sparse_attention/indexer_top_k/indexer_top_k_decode_varlen.py`
around lines 711 - 734, The chunked dispatch path (triggered when `0 <
chunk_rows < num_rows`) lacks test coverage because the current FE test
parametrization does not generate large enough `num_cols` values to activate
chunking. Either add a parametrized FE test case with `num_cols` ≥ 262,144 that
exercises the chunked dispatch loop with `for row_lo in range(0, num_rows,
chunk_rows)` and its associated slicing of `chunk_input`, `chunk_seqlens`,
`chunk_oi`, `chunk_ov`, and `chunk_extra`, or extract the chunk planning logic
that calculates `max_chunk_rows` and `chunk_rows` into a separate helper
function and write isolated unit tests to verify the boundary alignment and
slicing math for `batch_lo`, `batch_hi`, `row_lo`, and `row_hi` calculations.

Source: Coding guidelines

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In
`@python/cudnn/deepseek_sparse_attention/indexer_top_k/indexer_top_k_decode_varlen.py`:
- Around line 700-711: The code sets chunk_rows = num_rows when max_chunk_rows <
row_step, but this fallback can violate the int32-safe invariant (chunk_rows -
1) * buffer_numbers * num_cols < 2**31 if num_rows is too large. Additionally,
the align_rows calculation only considers element size but should factor in row
stride to avoid unnecessarily rejecting valid safe chunks. When max_chunk_rows <
row_step, enforce the int32 invariant by clamping chunk_rows to a safe value if
needed, and update the alignment computation to derive align_rows from both the
element width and the row stride constraints so that valid chunk sizes meeting
the safety invariant are not rejected. Include clear documentation explaining
the int32 safety constraint and why alignment from row stride matters.

---

Nitpick comments:
In
`@python/cudnn/deepseek_sparse_attention/indexer_top_k/indexer_top_k_decode_varlen.py`:
- Around line 711-734: The chunked dispatch path (triggered when `0 < chunk_rows
< num_rows`) lacks test coverage because the current FE test parametrization
does not generate large enough `num_cols` values to activate chunking. Either
add a parametrized FE test case with `num_cols` ≥ 262,144 that exercises the
chunked dispatch loop with `for row_lo in range(0, num_rows, chunk_rows)` and
its associated slicing of `chunk_input`, `chunk_seqlens`, `chunk_oi`,
`chunk_ov`, and `chunk_extra`, or extract the chunk planning logic that
calculates `max_chunk_rows` and `chunk_rows` into a separate helper function and
write isolated unit tests to verify the boundary alignment and slicing math for
`batch_lo`, `batch_hi`, `row_lo`, and `row_hi` calculations.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 61393c85-a95b-4a87-9bc5-1d6e125f3096

📥 Commits

Reviewing files that changed from the base of the PR and between 603ffab and c254063.

📒 Files selected for processing (1)
  • python/cudnn/deepseek_sparse_attention/indexer_top_k/indexer_top_k_decode_varlen.py

Comment on lines +700 to +711
# Alignment
input_elem_bytes = dtype.width // 8
align_rows = 32 // input_elem_bytes if input_elem_bytes > 0 else 1
if align_rows < 1:
align_rows = 1
# chunk_rows must be a common multiple of next_n and align_rows.
row_step = (next_n * align_rows) // math.gcd(next_n, align_rows)
if max_chunk_rows < row_step:
chunk_rows = num_rows
else:
chunk_rows = (max_chunk_rows // row_step) * row_step
if 0 < chunk_rows < num_rows:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Enforce the int32-safe invariant before falling back to a full launch.

Line 708 sets chunk_rows = num_rows when max_chunk_rows < row_step; if num_rows > max_chunk_rows, this re-enters the unsafe single-launch path and can violate (chunk_rows - 1) * buffer_numbers * num_cols < 2**31. Also compute alignment from the row stride, not just element size, so valid safe chunks are not rejected unnecessarily.

🐛 Proposed fix
-    # Alignment
+    # Preserve the 32-byte assumed alignment of input row slices:
+    # row_lo * row_stride_bytes must be 32-byte aligned.
     input_elem_bytes = dtype.width // 8
-    align_rows = 32 // input_elem_bytes if input_elem_bytes > 0 else 1
-    if align_rows < 1:
-        align_rows = 1
-    # chunk_rows must be a common multiple of next_n and align_rows.
+    row_stride_bytes = num_cols * input_elem_bytes
+    align_rows = 32 // math.gcd(32, row_stride_bytes) if row_stride_bytes > 0 else 1
+    # chunk_rows must be a common multiple of next_n (whole seq_lens groups)
+    # and align_rows (input slice alignment), while preserving the int32 IMA bound.
     row_step = (next_n * align_rows) // math.gcd(next_n, align_rows)
     if max_chunk_rows < row_step:
+        if num_rows > max_chunk_rows:
+            raise NotImplementedError(
+                "Cannot choose an int32-safe row chunk aligned to next_n "
+                "and the input row stride "
+                f"(max_chunk_rows={max_chunk_rows}, row_step={row_step}, num_rows={num_rows})."
+            )
         chunk_rows = num_rows
     else:
         chunk_rows = (max_chunk_rows // row_step) * row_step

As per coding guidelines, for python/cudnn/**: “Focus on documentation.”

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@python/cudnn/deepseek_sparse_attention/indexer_top_k/indexer_top_k_decode_varlen.py`
around lines 700 - 711, The code sets chunk_rows = num_rows when max_chunk_rows
< row_step, but this fallback can violate the int32-safe invariant (chunk_rows -
1) * buffer_numbers * num_cols < 2**31 if num_rows is too large. Additionally,
the align_rows calculation only considers element size but should factor in row
stride to avoid unnecessarily rejecting valid safe chunks. When max_chunk_rows <
row_step, enforce the int32 invariant by clamping chunk_rows to a safe value if
needed, and update the alignment computation to derive align_rows from both the
element width and the row stride constraints so that valid chunk sizes meeting
the safety invariant are not rejected. Include clear documentation explaining
the int32 safety constraint and why alignment from row stride matters.

Source: Coding guidelines

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant