Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docker/flash_mla/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
__version__ = "1.0.0"

from flash_mla.flash_mla_interface import (
get_mla_metadata,
flash_mla_with_kvcache,
flash_attn_varlen_func,
flash_attn_varlen_qkvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_mla_sparse_fwd
)

from flash_mla.txl_nsa_interface import txl_mla

from flash_mla.txl_mla_interface import mla_test
333 changes: 333 additions & 0 deletions docker/flash_mla/flash_mla_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,333 @@
from typing import Optional, Tuple

import torch

import flash_mla.cuda as flash_mla_cuda

def get_mla_metadata(
cache_seqlens: torch.Tensor,
num_q_tokens_per_head_k: int,
num_heads_k: int,
num_heads_q: Optional[int] = None,
is_fp8_kvcache: bool = False,
topk: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
num_q_tokens_per_head_k: Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
num_heads_k: The number of k heads.
num_heads_q: The number of q heads. This argument is optional when sparse attention is not enabled
is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
topk: If not None, sparse attention will be enabled, and only tokens in the `indices` array passed to `flash_mla_with_kvcache_sm90` will be attended to.

Returns:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
return flash_mla_cuda.get_mla_decoding_metadata(cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q, is_fp8_kvcache, topk)


def flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
is_fp8_kvcache: bool = False,
indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
is_fp8_kvcache: bool. Whether the k_cache and v_cache are in fp8 format. For the format of FP8 KV cache, please refer to README.md
indices: (batch_size, seq_len_q, topk), torch.int32. If not None, sparse attention will be enabled, and only tokens in the `indices` array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv. For details about how to set up `indices`, please refer to README.md.

Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if indices is not None:
assert causal == False, "causal must be `false` if sparse attention is enabled."
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q,
k_cache,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
is_fp8_kvcache,
indices
)
return out, softmax_lse


def flash_mla_sparse_fwd(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
sm_scale: float,
d_v: int = 512,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sparse attention prefill kernel

Args:
q: [s_q, h_q, d_qk], bfloat16
kv: [s_kv, h_kv, d_qk], bfloat16
indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv
sm_scale: float
d_v: The dimension of value vectors. Can only be 512

Returns:
(output, max_logits, lse)
About the definition of output, max_logits and lse, please refer to README.md
- output: [s_q, h_q, d_v], bfloat16
- max_logits: [s_q, h_q], float
- lse: [s_q, h_q], float, 2-based log-sum-exp
"""
results = flash_mla_cuda.sparse_prefill_fwd(
q, kv, indices, sm_scale, d_v
)
return results


def _flash_attn_varlen_forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_qo: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
max_seqlen_qo: int,
max_seqlen_kv: int,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
causal: bool = False,
softmax_scale: Optional[float] = None,
is_varlen: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
qo_total_len, num_qo_heads, head_dim_qk = q.shape
kv_total_len, num_kv_heads, head_dim_vo = v.shape

mask_mode_code = 1 if causal else 0
if softmax_scale is None:
softmax_scale = head_dim_qk ** (-0.5)

if out is None:
out = torch.empty(qo_total_len, num_qo_heads, head_dim_vo, device=q.device, dtype=q.dtype)
if lse is None:
# Make lse contiguous on seqlen dim
lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T

workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device)
flash_mla_cuda.dense_prefill_fwd(
workspace_buffer,
q,
k,
v,
cu_seqlens_qo,
cu_seqlens_kv,
out,
lse,
mask_mode_code,
softmax_scale,
max_seqlen_qo,
max_seqlen_kv,
is_varlen,
)

return out, lse


def _flash_attn_varlen_backward(
do: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
lse: torch.Tensor,
cu_seqlens_qo: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
max_seqlen_qo: int,
max_seqlen_kv: int,
dq: Optional[torch.Tensor] = None,
dk: Optional[torch.Tensor] = None,
dv: Optional[torch.Tensor] = None,
causal: bool = False,
softmax_scale: Optional[float] = None,
is_varlen: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
qo_total_len, num_qo_heads, head_dim_qk = q.shape
kv_total_len, num_kv_heads, head_dim_vo = v.shape

# TODO: fix bwd GQA
if num_qo_heads != num_kv_heads:
raise ValueError(f"SM100 bwd doesn't support GQA now. num_qo_heads: {num_qo_heads}, num_kv_heads: {num_kv_heads}.")

mask_mode_code = 1 if causal else 0
if softmax_scale is None:
softmax_scale = head_dim_qk ** (-0.5)

if dq is None:
dq = torch.empty(qo_total_len, num_qo_heads, head_dim_qk, device=q.device, dtype=q.dtype)
if dk is None:
dk = torch.empty(kv_total_len, num_kv_heads, head_dim_qk, device=q.device, dtype=q.dtype)
if dv is None:
dv = torch.empty(kv_total_len, num_kv_heads, head_dim_vo, device=q.device, dtype=q.dtype)

max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8
bs = cu_seqlens_qo.shape[0] - 1
workspace_bytes = 0
workspace_bytes += 4 * bs * max_seqlen_qo_aligned * num_qo_heads * head_dim_qk # dQ_acc
workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse
if num_qo_heads != num_kv_heads:
workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc
workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device)
flash_mla_cuda.dense_prefill_bwd(
workspace_buffer,
do,
q,
k,
v,
out,
lse,
cu_seqlens_qo,
cu_seqlens_kv,
dq,
dk,
dv,
mask_mode_code,
softmax_scale,
max_seqlen_qo,
max_seqlen_kv,
is_varlen,
)

return dq, dk, dv


class FlashAttnVarlenFunc(torch.autograd.Function):
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_qo: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
max_seqlen_qo: int,
max_seqlen_kv: int,
causal: bool = False,
softmax_scale: Optional[float] = None,
is_varlen: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
out, lse = _flash_attn_varlen_forward(
q, k, v,
cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv,
causal=causal, softmax_scale=softmax_scale,
is_varlen=is_varlen,
)
ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv)
ctx.max_seqlen_qo = max_seqlen_qo
ctx.max_seqlen_kv = max_seqlen_kv
ctx.causal = causal
ctx.softmax_scale = softmax_scale
ctx.is_varlen = is_varlen
return out, lse

def backward(
ctx,
do: torch.Tensor,
dlse: torch.Tensor,
):
del dlse # LSE doesn't support backward currently
q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv = ctx.saved_tensors
dq, dk, dv = _flash_attn_varlen_backward(
do, q, k, v, out, lse,
cu_seqlens_qo, cu_seqlens_kv, ctx.max_seqlen_qo, ctx.max_seqlen_kv,
causal=ctx.causal, softmax_scale=ctx.softmax_scale,
is_varlen=ctx.is_varlen,
)
return dq, dk, dv, None, None, None, None, None, None, None


def flash_attn_varlen_func(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_qo: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
max_seqlen_qo: int,
max_seqlen_kv: int,
dropout_p: float = 0.0,
softmax_scale: Optional[float] = None,
causal: bool = False,
deterministic: bool = False,
is_varlen: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert dropout_p == 0.0
assert not deterministic
return FlashAttnVarlenFunc.apply(
q, k, v,
cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv,
causal, softmax_scale, is_varlen,
)


def flash_attn_varlen_qkvpacked_func(
qkv: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: int,
head_dim_qk: int,
dropout_p: float = 0.0,
softmax_scale: Optional[float] = None,
causal: bool = False,
deterministic: bool = False,
is_varlen: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert dropout_p == 0.0
assert not deterministic
return FlashAttnVarlenFunc.apply(
qkv[:, :, :head_dim_qk], qkv[:, :, head_dim_qk:head_dim_qk * 2], qkv[:, :, head_dim_qk * 2:],
cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
causal, softmax_scale, is_varlen,
)


def flash_attn_varlen_kvpacked_func(
q: torch.Tensor,
kv: torch.Tensor,
cu_seqlens_qo: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
max_seqlen_qo: int,
max_seqlen_kv: int,
head_dim_qk: int,
dropout_p: float = 0.0,
softmax_scale: Optional[float] = None,
causal: bool = False,
deterministic: bool = False,
is_varlen: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert dropout_p == 0.0
assert not deterministic
return FlashAttnVarlenFunc.apply(
q, kv[:, :, :head_dim_qk], kv[:, :, head_dim_qk:],
cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv,
causal, softmax_scale, is_varlen,
)
Loading