Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions python/cudnn/deepseek_sparse_attention/indexer_backward/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from cudnn.api_base import APIBase, TupleDict
from cudnn.deepseek_sparse_attention.utils.runtime import (
torch_stream_context as _torch_stream_context,
validate_q_causal_offsets,
)

from .dense_indexer_backward_sm100 import dense_indexer_backward_sm100
Expand Down Expand Up @@ -275,6 +276,7 @@ def __init__(
batch: Optional[int] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
has_q_causal_offsets: bool = False,
):
super().__init__()
self.iq_desc = self._make_tensor_desc(sample_index_q, name="sample_index_q")
Expand All @@ -292,6 +294,7 @@ def __init__(
self.block_I = int(block_I)
self.ratio = int(ratio)
self.is_thd = bool(is_thd)
self.has_q_causal_offsets = bool(has_q_causal_offsets)

if self.is_thd:
total_q, heads, head_dim = sample_index_q.shape
Expand Down Expand Up @@ -333,10 +336,6 @@ def check_support(self) -> bool:
self._value_error_if(self.block_I <= 0, f"block_I must be positive, got {self.block_I}")
self._value_error_if(self.ratio < 1, f"ratio must be >= 1, got {self.ratio}")
self._value_error_if(self.heads < 64, f"DenseIndexerBackward requires heads >= 64, got {self.heads}")
self._value_error_if(
self.max_seqlen_q > self.max_seqlen_k * self.ratio,
"DenseIndexerBackward requires S_q <= S_k * ratio for bottom-right causal alignment",
)
self._is_supported = True
return True

Expand All @@ -357,6 +356,7 @@ def compile(self) -> None:
block_I=self.block_I,
ratio=self.ratio,
is_varlen=self.is_thd,
has_q_causal_offsets=self.has_q_causal_offsets,
)

def execute(
Expand All @@ -375,6 +375,7 @@ def execute(
grad_loss: Union[float, torch.Tensor] = 1.0,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
q_causal_offsets: Optional[torch.Tensor] = None,
current_stream: Optional[cuda.CUstream] = None,
) -> None:
backend_stream = None if self._uses_current_stream_pipeline else current_stream
Expand Down Expand Up @@ -404,6 +405,7 @@ def execute(
grad_scale,
cu_seqlens_q,
cu_seqlens_k,
q_causal_offsets,
backend_stream,
)

Expand Down Expand Up @@ -538,6 +540,7 @@ def dense_indexer_backward_wrapper(
cu_seqlens_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
q_causal_offsets: Optional[torch.Tensor] = None,
d_index_q: Optional[torch.Tensor] = None,
d_weights: Optional[torch.Tensor] = None,
d_index_k: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -587,6 +590,7 @@ def dense_indexer_backward_wrapper(
max_seqlen_q,
max_seqlen_k,
)
q_causal_offsets = validate_q_causal_offsets(q_causal_offsets, int(batch), index_q_exec.device)

if d_index_q is None:
d_index_q = torch.empty_like(index_q_exec)
Expand Down Expand Up @@ -618,6 +622,7 @@ def dense_indexer_backward_wrapper(
float(sm_scale),
int(block_I),
int(ratio),
q_causal_offsets is not None,
)
obj = _cache_of_DenseIndexerBackwardObjects.get(key)
if obj is None:
Expand All @@ -639,6 +644,7 @@ def dense_indexer_backward_wrapper(
batch=batch,
max_seqlen_q=max_q,
max_seqlen_k=max_k,
has_q_causal_offsets=q_causal_offsets is not None,
)
assert obj.check_support()
obj.compile()
Expand All @@ -659,6 +665,7 @@ def dense_indexer_backward_wrapper(
grad_loss=grad_loss,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
q_causal_offsets=q_causal_offsets,
current_stream=backend_stream,
)
with _torch_stream_context(backend_stream):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@
GradSignal / Score : (T_q, max_seqlen_k) [second dim is batch-local k]
LSE / L1Norm : (T_q,)

Bottom-right ratio causal:
q_start_b = seqlen_k_b * ratio - seqlen_q_b
col_limit = min(seqlen_k_b, (q_start_b + q_local + 1) // ratio)
When seqlen_q_b == seqlen_k_b * ratio, q_start_b == 0 (legacy behavior).
Constraint per batch: seqlen_q_b <= seqlen_k_b * ratio.
Ratio causal mask:
q_abs = q_causal_offset_b + q_local
col_limit = clamp((q_abs + 1) // ratio, 0, seqlen_k_b)
valid if k_local < col_limit
q_causal_offset_b defaults to 0 when no offset tensor is provided.
"""

from __future__ import annotations
Expand Down Expand Up @@ -144,6 +144,7 @@ def __call__(
mDenomTarget,
mCuSeqlensQ,
mCuSeqlensK,
mQCausalOffsets,
grad_scale: Float32,
max_seqlen_q: Int32,
max_seqlen_k: Int32,
Expand Down Expand Up @@ -188,6 +189,7 @@ def __call__(
mDenomTarget,
mCuSeqlensQ,
mCuSeqlensK,
mQCausalOffsets,
grad_scale,
seqlen_k_pad,
max_seqlen_q,
Expand All @@ -207,6 +209,7 @@ def kernel_score_grad(
mDenomTarget,
mCuSeqlensQ,
mCuSeqlensK,
mQCausalOffsets,
grad_scale: Float32,
seqlen_k_pad: Int32,
seqlen_q_static: Int32,
Expand All @@ -227,6 +230,7 @@ def kernel_score_grad(
seqlen_q_static,
seqlen_k_static,
)
q_causal_offset_b = Int32(0) if const_expr(mQCausalOffsets is None) else mQCausalOffsets[batch_idx]

# Out-of-range CTAs (seq_local >= seqlen_q_b in this batch): skip work.
# CuTe DSL forbids early `return`, so wrap the body in an `if` block.
Expand Down Expand Up @@ -261,14 +265,10 @@ class SharedStorage:

LOG2E = Float32(1.4426950408889634)

# Bottom-right ratio causal:
# q_start_b = seqlen_k_b * ratio - seqlen_q_b
# col_limit = min(seqlen_k_b, (q_start_b + q_local + 1) // ratio)
# Equals legacy ((q_local+1)//ratio) when seqlen_q_b == seqlen_k_b * ratio.
ratio = Int32(self.ratio)
q_start_b = seqlen_k_b * ratio - seqlen_q_b
col_limit_raw = (q_start_b + Int32(seq_local) + Int32(1)) // ratio
col_limit_raw = (q_causal_offset_b + Int32(seq_local) + Int32(1)) // ratio
col_limit = col_limit_raw if col_limit_raw < seqlen_k_b else seqlen_k_b
col_limit = col_limit if col_limit > Int32(0) else Int32(0)

# --- Phase 1: Accumulate sum_grad ---
local_sum = Float32(0.0)
Expand Down Expand Up @@ -420,6 +420,7 @@ def __call__(
mGradSignal: cute.Tensor,
mCuSeqlensQ,
mCuSeqlensK,
mQCausalOffsets,
sm_scale: Float32 | float,
max_seqlen_q: Int32,
max_seqlen_k: Int32,
Expand Down Expand Up @@ -592,6 +593,7 @@ def __call__(
mGradSignal,
mCuSeqlensQ,
mCuSeqlensK,
mQCausalOffsets,
sm_scale,
Int32(max_seqlen_q),
Int32(max_seqlen_k),
Expand Down Expand Up @@ -635,6 +637,7 @@ def kernel_gemm_dense_2q(
mGradSignal,
mCuSeqlensQ,
mCuSeqlensK,
mQCausalOffsets,
sm_scale: Float32 | float,
seqlen_q_static: Int32,
seqlen_k_static: Int32,
Expand Down Expand Up @@ -674,6 +677,7 @@ def kernel_gemm_dense_2q(
seqlen_q_static,
seqlen_k_static,
)
q_causal_offset_b = Int32(0) if const_expr(mQCausalOffsets is None) else mQCausalOffsets[batch_idx]

# 2Q pair indices are batch-local and scheduled from largest q to
# smallest q. For odd seqlen_q_b this puts the singleton CTA at the
Expand Down Expand Up @@ -713,11 +717,10 @@ def kernel_gemm_dense_2q(
mGS_b = mGradSignal[None, None, batch_idx]

# Causal-aware K-block bound for this 2Q CTA. GradSignal has already
# been zeroed past each q token's bottom-right causal column limit, but
# been zeroed past each q token's ratio-causal column limit, but
# using that limit here avoids TMA/MMA work for wholly future K blocks.
ratio = Int32(self.ratio)
q_start_b = seqlen_k_b * ratio - seqlen_q_b
max_kv_needed_raw = (q_start_b + q0_local + Int32(1)) // ratio
max_kv_needed_raw = (q_causal_offset_b + q0_local + Int32(1)) // ratio
max_kv_needed = max_kv_needed_raw if max_kv_needed_raw > Int32(0) else Int32(0)
max_kv_needed = max_kv_needed if max_kv_needed < seqlen_k_b else seqlen_k_b
num_kv_blocks = (max_kv_needed + self.block_I - 1) // self.block_I
Expand Down Expand Up @@ -1905,6 +1908,7 @@ def dense_indexer_backward_sm100(
block_I=128,
ratio=1,
is_varlen=False,
has_q_causal_offsets=False,
):
"""Build / fetch a compiled SM100 dense backward gradient kernel.

Expand All @@ -1915,10 +1919,10 @@ def dense_indexer_backward_sm100(
grid + row stride. Inputs are packed (T_q, ...) tensors plus
cu_seqlens_q / k.

``ratio`` is the indexer compression ratio. Bottom-right causal mask is
applied: kv_local < (seqlen_k_b * ratio - seqlen_q_b + q_local + 1) // ratio.
Per batch we require ``seqlen_q_b <= seqlen_k_b * ratio``. ``ratio`` must
be passed explicitly — auto-inferring from S_q / S_k is unsafe under THD.
``ratio`` is the indexer compression ratio. The causal rule is
``kv_local < (q_causal_offset_b + q_local + 1) // ratio``, clamped to the
batch-local K length. ``ratio`` must be passed explicitly — auto-inferring
from S_q / S_k is unsafe under THD.

``grad_scale`` is intentionally **not** an argument to this factory: it's
a host scalar consumed only as a multiplicative factor inside
Expand All @@ -1941,16 +1945,17 @@ def dense_indexer_backward_sm100(
block_I,
ratio,
is_varlen,
has_q_causal_offsets,
)


def _build_cute_dsl_kernel(batch, max_seqlen_q, max_seqlen_k, heads, dim, sm_scale, block_I, ratio, is_varlen):
def _build_cute_dsl_kernel(batch, max_seqlen_q, max_seqlen_k, heads, dim, sm_scale, block_I, ratio, is_varlen, has_q_causal_offsets):
from cudnn.deepseek_sparse_attention.utils.tensor_conversion import to_cute_tensor

if torch.cuda.get_device_capability()[0] < 10:
raise RuntimeError("Requires SM100+")

# Kernel 1: ScoreGradDense — applies bottom-right ratio causal mask so
# Kernel 1: ScoreGradDense — applies ratio causal mask so
# masked / padding columns produce grad_signal=0 (won't contaminate GEMM).
score_grad_obj = ScoreGradDense(ratio=ratio)

Expand All @@ -1961,7 +1966,7 @@ def _build_cute_dsl_kernel(batch, max_seqlen_q, max_seqlen_k, heads, dim, sm_sca
# Only params that change generated code. max_seqlen_q/k are now runtime
# Int32 args to ScoreGradDense (drive the launch grid + causal-mask bound)
# and to the GEMM, so neither is keyed; batch/sm_scale likewise runtime.
compile_key = (is_varlen, heads, dim, block_I, ratio)
compile_key = (is_varlen, heads, dim, block_I, ratio, bool(has_q_causal_offsets))

def _ensure_compiled_score_grad(
IdxScoreRaw,
Expand All @@ -1970,6 +1975,7 @@ def _ensure_compiled_score_grad(
AttnL1Norm,
CuSeqlensQ,
CuSeqlensK,
QCausalOffsets,
grad_scale,
current_stream=None,
):
Expand All @@ -1990,6 +1996,7 @@ def _ensure_compiled_score_grad(
s = _resolve_stream(current_stream)
cuq_arg = to_cute_tensor(CuSeqlensQ) if CuSeqlensQ is not None else None
cuk_arg = to_cute_tensor(CuSeqlensK) if CuSeqlensK is not None else None
q_offsets_arg = to_cute_tensor(QCausalOffsets) if QCausalOffsets is not None else None
_score_grad_compile_cache[compile_key] = cute.compile(
score_grad_obj,
to_cute_tensor(IdxScoreRaw),
Expand All @@ -1998,6 +2005,7 @@ def _ensure_compiled_score_grad(
to_cute_tensor(AttnL1Norm),
cuq_arg,
cuk_arg,
q_offsets_arg,
cutlass.Float32(float(grad_scale)),
cutlass.Int32(max_seqlen_q),
cutlass.Int32(max_seqlen_k),
Expand All @@ -2015,6 +2023,7 @@ def _ensure_compiled_gemm(
GradSignal,
CuSeqlensQ,
CuSeqlensK,
QCausalOffsets,
current_stream=None,
):
"""Lazy-compile kernel 2 (2Q GEMM).
Expand All @@ -2026,6 +2035,7 @@ def _ensure_compiled_gemm(
s = _resolve_stream(current_stream)
cuq_arg = to_cute_tensor(CuSeqlensQ) if CuSeqlensQ is not None else None
cuk_arg = to_cute_tensor(CuSeqlensK) if CuSeqlensK is not None else None
q_offsets_arg = to_cute_tensor(QCausalOffsets) if QCausalOffsets is not None else None
_gemm_compile_cache[compile_key] = cute.compile(
gemm_obj,
to_cute_tensor(IndexQ),
Expand All @@ -2037,6 +2047,7 @@ def _ensure_compiled_gemm(
to_cute_tensor(GradSignal),
cuq_arg,
cuk_arg,
q_offsets_arg,
cutlass.Float32(sm_scale),
cutlass.Int32(max_seqlen_q),
cutlass.Int32(max_seqlen_k),
Expand All @@ -2054,6 +2065,7 @@ def _run_gemm_only(
GradSignal,
CuSeqlensQ=None,
CuSeqlensK=None,
QCausalOffsets=None,
current_stream=None,
):
"""Run only kernel 2 (2Q GEMM). Caller must have run kernel 1 and zeroed dIndexK_f32."""
Expand All @@ -2063,6 +2075,10 @@ def _run_gemm_only(
assert CuSeqlensQ is not None and CuSeqlensK is not None, "THD-compiled kernel requires cu_seqlens_q/k at runtime"
else:
assert CuSeqlensQ is None and CuSeqlensK is None, "BSHD-compiled kernel must not receive cu_seqlens_q/k"
if has_q_causal_offsets:
assert QCausalOffsets is not None, "offset-compiled kernel requires q_causal_offsets at runtime"
else:
assert QCausalOffsets is None, "non-offset compiled kernel must not receive q_causal_offsets"
s = _resolve_stream(current_stream)
_ensure_compiled_gemm(
IndexQ,
Expand All @@ -2074,6 +2090,7 @@ def _run_gemm_only(
GradSignal,
CuSeqlensQ,
CuSeqlensK,
QCausalOffsets,
current_stream=current_stream,
)
with torch.cuda.nvtx.range("indexer_backward_dsl_dense_gemm_2q"):
Expand All @@ -2087,6 +2104,7 @@ def _run_gemm_only(
GradSignal,
CuSeqlensQ,
CuSeqlensK,
QCausalOffsets,
cutlass.Float32(sm_scale),
cutlass.Int32(max_seqlen_q),
cutlass.Int32(max_seqlen_k),
Expand All @@ -2107,6 +2125,7 @@ def _run(
grad_scale,
CuSeqlensQ=None,
CuSeqlensK=None,
QCausalOffsets=None,
current_stream=None,
):
"""Full dense backward: kernel 1 (score grad) + kernel 2 (2Q GEMM).
Expand All @@ -2123,6 +2142,10 @@ def _run(
assert CuSeqlensQ is not None and CuSeqlensK is not None, "THD-compiled kernel requires cu_seqlens_q/k at runtime"
else:
assert CuSeqlensQ is None and CuSeqlensK is None, "BSHD-compiled kernel must not receive cu_seqlens_q/k"
if has_q_causal_offsets:
assert QCausalOffsets is not None, "offset-compiled kernel requires q_causal_offsets at runtime"
else:
assert QCausalOffsets is None, "non-offset compiled kernel must not receive q_causal_offsets"
s = _resolve_stream(current_stream)

# Kernel 1 (CuTe DSL): in-place overwrites IdxScoreRaw with grad_signal.
Expand All @@ -2134,6 +2157,7 @@ def _run(
AttnL1Norm,
CuSeqlensQ,
CuSeqlensK,
QCausalOffsets,
grad_scale,
current_stream=current_stream,
)
Expand All @@ -2145,6 +2169,7 @@ def _run(
AttnL1Norm,
CuSeqlensQ,
CuSeqlensK,
QCausalOffsets,
cutlass.Float32(float(grad_scale)),
cutlass.Int32(max_seqlen_q),
cutlass.Int32(max_seqlen_k),
Expand All @@ -2163,6 +2188,7 @@ def _run(
IdxScoreRaw,
CuSeqlensQ,
CuSeqlensK,
QCausalOffsets,
current_stream=current_stream,
)

Expand Down
Loading