fix: IMA on indexer_topk_wrapper#312
Conversation
📝 WalkthroughWalkthroughAdds a ChangesChunked Decode-Varlen IMA Workaround
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
python/cudnn/deepseek_sparse_attention/indexer_top_k/indexer_top_k_decode_varlen.py (1)
711-734: Add FE test coverage for the chunked dispatch path (lines 711–734) or extract the planner for isolated unit testing.The chunked dispatch condition (
0 < chunk_rows < num_rows) triggers whennum_colsis large relative tonum_rowsdue to the int32 buffer constraint. The current FE test parametrization (b=1, s_q=1024, s_kv=1024, next_n=1) does not trigger this path becausemax_chunk_rows(2,097,152) far exceedsnum_rows(1,024).Chunking would activate with larger
num_cols(e.g., ≥ 262,144), which is realistic for KV cache scenarios. Add a parametrized FE test case that forces chunking, or extract the chunk planning logic (lines 693–710) into a helper function and unit test the boundary alignment and slicing math independently. Without this coverage, regressions in batch/row/seq_lens slicing could pass current tests.Chunking trigger analysis
With
num_cols=262144, num_rows=1024, buffer_numbers=1, next_n=1, align_rows=16:
max_chunk_rows = 8192chunk_rows = 8192(not < 1024, so no chunking)But with
num_cols=1000000:
max_chunk_rows = 2148chunk_rows = 2144(> 1024, but extraction of actual triggering configs from real KV cache sizes recommended)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@python/cudnn/deepseek_sparse_attention/indexer_top_k/indexer_top_k_decode_varlen.py` around lines 711 - 734, The chunked dispatch path (triggered when `0 < chunk_rows < num_rows`) lacks test coverage because the current FE test parametrization does not generate large enough `num_cols` values to activate chunking. Either add a parametrized FE test case with `num_cols` ≥ 262,144 that exercises the chunked dispatch loop with `for row_lo in range(0, num_rows, chunk_rows)` and its associated slicing of `chunk_input`, `chunk_seqlens`, `chunk_oi`, `chunk_ov`, and `chunk_extra`, or extract the chunk planning logic that calculates `max_chunk_rows` and `chunk_rows` into a separate helper function and write isolated unit tests to verify the boundary alignment and slicing math for `batch_lo`, `batch_hi`, `row_lo`, and `row_hi` calculations.Source: Coding guidelines
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In
`@python/cudnn/deepseek_sparse_attention/indexer_top_k/indexer_top_k_decode_varlen.py`:
- Around line 700-711: The code sets chunk_rows = num_rows when max_chunk_rows <
row_step, but this fallback can violate the int32-safe invariant (chunk_rows -
1) * buffer_numbers * num_cols < 2**31 if num_rows is too large. Additionally,
the align_rows calculation only considers element size but should factor in row
stride to avoid unnecessarily rejecting valid safe chunks. When max_chunk_rows <
row_step, enforce the int32 invariant by clamping chunk_rows to a safe value if
needed, and update the alignment computation to derive align_rows from both the
element width and the row stride constraints so that valid chunk sizes meeting
the safety invariant are not rejected. Include clear documentation explaining
the int32 safety constraint and why alignment from row stride matters.
---
Nitpick comments:
In
`@python/cudnn/deepseek_sparse_attention/indexer_top_k/indexer_top_k_decode_varlen.py`:
- Around line 711-734: The chunked dispatch path (triggered when `0 < chunk_rows
< num_rows`) lacks test coverage because the current FE test parametrization
does not generate large enough `num_cols` values to activate chunking. Either
add a parametrized FE test case with `num_cols` ≥ 262,144 that exercises the
chunked dispatch loop with `for row_lo in range(0, num_rows, chunk_rows)` and
its associated slicing of `chunk_input`, `chunk_seqlens`, `chunk_oi`,
`chunk_ov`, and `chunk_extra`, or extract the chunk planning logic that
calculates `max_chunk_rows` and `chunk_rows` into a separate helper function and
write isolated unit tests to verify the boundary alignment and slicing math for
`batch_lo`, `batch_hi`, `row_lo`, and `row_hi` calculations.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 61393c85-a95b-4a87-9bc5-1d6e125f3096
📒 Files selected for processing (1)
python/cudnn/deepseek_sparse_attention/indexer_top_k/indexer_top_k_decode_varlen.py
| # Alignment | ||
| input_elem_bytes = dtype.width // 8 | ||
| align_rows = 32 // input_elem_bytes if input_elem_bytes > 0 else 1 | ||
| if align_rows < 1: | ||
| align_rows = 1 | ||
| # chunk_rows must be a common multiple of next_n and align_rows. | ||
| row_step = (next_n * align_rows) // math.gcd(next_n, align_rows) | ||
| if max_chunk_rows < row_step: | ||
| chunk_rows = num_rows | ||
| else: | ||
| chunk_rows = (max_chunk_rows // row_step) * row_step | ||
| if 0 < chunk_rows < num_rows: |
There was a problem hiding this comment.
Enforce the int32-safe invariant before falling back to a full launch.
Line 708 sets chunk_rows = num_rows when max_chunk_rows < row_step; if num_rows > max_chunk_rows, this re-enters the unsafe single-launch path and can violate (chunk_rows - 1) * buffer_numbers * num_cols < 2**31. Also compute alignment from the row stride, not just element size, so valid safe chunks are not rejected unnecessarily.
🐛 Proposed fix
- # Alignment
+ # Preserve the 32-byte assumed alignment of input row slices:
+ # row_lo * row_stride_bytes must be 32-byte aligned.
input_elem_bytes = dtype.width // 8
- align_rows = 32 // input_elem_bytes if input_elem_bytes > 0 else 1
- if align_rows < 1:
- align_rows = 1
- # chunk_rows must be a common multiple of next_n and align_rows.
+ row_stride_bytes = num_cols * input_elem_bytes
+ align_rows = 32 // math.gcd(32, row_stride_bytes) if row_stride_bytes > 0 else 1
+ # chunk_rows must be a common multiple of next_n (whole seq_lens groups)
+ # and align_rows (input slice alignment), while preserving the int32 IMA bound.
row_step = (next_n * align_rows) // math.gcd(next_n, align_rows)
if max_chunk_rows < row_step:
+ if num_rows > max_chunk_rows:
+ raise NotImplementedError(
+ "Cannot choose an int32-safe row chunk aligned to next_n "
+ "and the input row stride "
+ f"(max_chunk_rows={max_chunk_rows}, row_step={row_step}, num_rows={num_rows})."
+ )
chunk_rows = num_rows
else:
chunk_rows = (max_chunk_rows // row_step) * row_stepAs per coding guidelines, for python/cudnn/**: “Focus on documentation.”
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In
`@python/cudnn/deepseek_sparse_attention/indexer_top_k/indexer_top_k_decode_varlen.py`
around lines 700 - 711, The code sets chunk_rows = num_rows when max_chunk_rows
< row_step, but this fallback can violate the int32-safe invariant (chunk_rows -
1) * buffer_numbers * num_cols < 2**31 if num_rows is too large. Additionally,
the align_rows calculation only considers element size but should factor in row
stride to avoid unnecessarily rejecting valid safe chunks. When max_chunk_rows <
row_step, enforce the int32 invariant by clamping chunk_rows to a safe value if
needed, and update the alignment computation to derive align_rows from both the
element width and the row stride constraints so that valid chunk sizes meeting
the safety invariant are not rejected. Include clear documentation explaining
the int32 safety constraint and why alignment from row stride matters.
Source: Coding guidelines
cute_dsl lowers the per-block slice buffer = extra_buffer[bidx, ...] into a 32-bit-by-32-bit -> 32-bit multiply followed by a sign-extending widening accumulate. This pr chunk the launch in cute_dsl_topk_wrapper along the row dimension so that within each chunk bidx < chunk_rows and
(chunk_rows − 1) * buffer_numbers * num_cols < 2**31.When run python repro_indexer_topk_ima.py, an IMA error does not exist.
No IMA occurs. And the output looks good.
When use compute-sanitizer to double check, the output show 0 errors.
Summary by CodeRabbit