diff --git a/python/cudnn/deepseek_sparse_attention/indexer_top_k/indexer_top_k_decode_varlen.py b/python/cudnn/deepseek_sparse_attention/indexer_top_k/indexer_top_k_decode_varlen.py index d8b21341..24d66447 100644 --- a/python/cudnn/deepseek_sparse_attention/indexer_top_k/indexer_top_k_decode_varlen.py +++ b/python/cudnn/deepseek_sparse_attention/indexer_top_k/indexer_top_k_decode_varlen.py @@ -14,6 +14,8 @@ """SM90+ CuTe DSL indexer top-K decode kernel.""" +import math + import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute @@ -687,6 +689,50 @@ def cute_dsl_topk_wrapper( buffer_numbers = 2 else: buffer_numbers = 1 + + # Decode-varlen IMA workaround. + elems_per_row = buffer_numbers * num_cols + int32_max = (1 << 31) - 1 + if elems_per_row > 0: + max_chunk_rows = int32_max // elems_per_row + 1 + else: + max_chunk_rows = num_rows + # Alignment + input_elem_bytes = dtype.width // 8 + align_rows = 32 // input_elem_bytes if input_elem_bytes > 0 else 1 + if align_rows < 1: + align_rows = 1 + # chunk_rows must be a common multiple of next_n and align_rows. + row_step = (next_n * align_rows) // math.gcd(next_n, align_rows) + if max_chunk_rows < row_step: + chunk_rows = num_rows + else: + chunk_rows = (max_chunk_rows // row_step) * row_step + if 0 < chunk_rows < num_rows: + g_global_counter_torch = None + for row_lo in range(0, num_rows, chunk_rows): + row_hi = min(row_lo + chunk_rows, num_rows) + batch_lo = row_lo // next_n + batch_hi = row_hi // next_n + chunk_input = input_values[row_lo:row_hi] + chunk_seqlens = seq_lens[batch_lo:batch_hi] + chunk_oi = output_indices_torch[row_lo:row_hi] + chunk_ov = output_values_torch[row_lo:row_hi] if return_val else None + chunk_extra = torch.empty( + row_hi - row_lo, buffer_numbers, num_cols, + dtype=torch.int32, device="cuda", + ) + compiled_kernel( + chunk_input, + None, + chunk_extra, + g_global_counter_torch, + chunk_seqlens, + chunk_oi, + chunk_ov, + ) + return output_indices_torch, output_values_torch + # Note: zeros will trigger an elementwise_add kernel. buffer_torch = torch.empty(num_rows, buffer_numbers, num_cols, dtype=torch.int32, device="cuda") g_global_counter_torch = None