Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""SM90+ CuTe DSL indexer top-K decode kernel."""

import math

import cuda.bindings.driver as cuda
import cutlass
import cutlass.cute as cute
Expand Down Expand Up @@ -687,6 +689,50 @@ def cute_dsl_topk_wrapper(
buffer_numbers = 2
else:
buffer_numbers = 1

# Decode-varlen IMA workaround.
elems_per_row = buffer_numbers * num_cols
int32_max = (1 << 31) - 1
if elems_per_row > 0:
max_chunk_rows = int32_max // elems_per_row + 1
else:
max_chunk_rows = num_rows
# Alignment
input_elem_bytes = dtype.width // 8
align_rows = 32 // input_elem_bytes if input_elem_bytes > 0 else 1
if align_rows < 1:
align_rows = 1
# chunk_rows must be a common multiple of next_n and align_rows.
row_step = (next_n * align_rows) // math.gcd(next_n, align_rows)
if max_chunk_rows < row_step:
chunk_rows = num_rows
else:
chunk_rows = (max_chunk_rows // row_step) * row_step
if 0 < chunk_rows < num_rows:
Comment on lines +700 to +711

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Enforce the int32-safe invariant before falling back to a full launch.

Line 708 sets chunk_rows = num_rows when max_chunk_rows < row_step; if num_rows > max_chunk_rows, this re-enters the unsafe single-launch path and can violate (chunk_rows - 1) * buffer_numbers * num_cols < 2**31. Also compute alignment from the row stride, not just element size, so valid safe chunks are not rejected unnecessarily.

🐛 Proposed fix
-    # Alignment
+    # Preserve the 32-byte assumed alignment of input row slices:
+    # row_lo * row_stride_bytes must be 32-byte aligned.
     input_elem_bytes = dtype.width // 8
-    align_rows = 32 // input_elem_bytes if input_elem_bytes > 0 else 1
-    if align_rows < 1:
-        align_rows = 1
-    # chunk_rows must be a common multiple of next_n and align_rows.
+    row_stride_bytes = num_cols * input_elem_bytes
+    align_rows = 32 // math.gcd(32, row_stride_bytes) if row_stride_bytes > 0 else 1
+    # chunk_rows must be a common multiple of next_n (whole seq_lens groups)
+    # and align_rows (input slice alignment), while preserving the int32 IMA bound.
     row_step = (next_n * align_rows) // math.gcd(next_n, align_rows)
     if max_chunk_rows < row_step:
+        if num_rows > max_chunk_rows:
+            raise NotImplementedError(
+                "Cannot choose an int32-safe row chunk aligned to next_n "
+                "and the input row stride "
+                f"(max_chunk_rows={max_chunk_rows}, row_step={row_step}, num_rows={num_rows})."
+            )
         chunk_rows = num_rows
     else:
         chunk_rows = (max_chunk_rows // row_step) * row_step

As per coding guidelines, for python/cudnn/**: “Focus on documentation.”

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@python/cudnn/deepseek_sparse_attention/indexer_top_k/indexer_top_k_decode_varlen.py`
around lines 700 - 711, The code sets chunk_rows = num_rows when max_chunk_rows
< row_step, but this fallback can violate the int32-safe invariant (chunk_rows -
1) * buffer_numbers * num_cols < 2**31 if num_rows is too large. Additionally,
the align_rows calculation only considers element size but should factor in row
stride to avoid unnecessarily rejecting valid safe chunks. When max_chunk_rows <
row_step, enforce the int32 invariant by clamping chunk_rows to a safe value if
needed, and update the alignment computation to derive align_rows from both the
element width and the row stride constraints so that valid chunk sizes meeting
the safety invariant are not rejected. Include clear documentation explaining
the int32 safety constraint and why alignment from row stride matters.

Source: Coding guidelines

g_global_counter_torch = None
for row_lo in range(0, num_rows, chunk_rows):
row_hi = min(row_lo + chunk_rows, num_rows)
batch_lo = row_lo // next_n
batch_hi = row_hi // next_n
chunk_input = input_values[row_lo:row_hi]
chunk_seqlens = seq_lens[batch_lo:batch_hi]
chunk_oi = output_indices_torch[row_lo:row_hi]
chunk_ov = output_values_torch[row_lo:row_hi] if return_val else None
chunk_extra = torch.empty(
row_hi - row_lo, buffer_numbers, num_cols,
dtype=torch.int32, device="cuda",
)
compiled_kernel(
chunk_input,
None,
chunk_extra,
g_global_counter_torch,
chunk_seqlens,
chunk_oi,
chunk_ov,
)
return output_indices_torch, output_values_torch

# Note: zeros will trigger an elementwise_add kernel.
buffer_torch = torch.empty(num_rows, buffer_numbers, num_cols, dtype=torch.int32, device="cuda")
g_global_counter_torch = None
Expand Down