diff --git a/README.md b/README.md index 1945725f..e0ecd14e 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,26 @@ # FlashMLA +## SGLang Maintenance Rules + +- `main` tracks upstream updates and should stay clean. +- `rebase` is the active SGL integration branch and can be force-updated during rebases. +- `sgl` is the promoted stable branch for SGLang consumption. +- Always pin SGLang to an immutable commit SHA (or release tag), not a branch name. +- Only `upstream-*` tags are required for this workflow. + +### Rebase Workflow (Every Cycle) + +1. Update upstream on `main`. +2. Tag `main` snapshot as `upstream-` (example: `upstream-0217`). +3. Rebase and fix SGL compatibility on `rebase` (this branch may be force-pushed). +4. Run validation on `rebase` and finalize the stable commit SHA. +5. Promote `sgl` to the validated `rebase` commit (prefer fast-forward). +6. In SGLang, pin `GIT_TAG` to the promoted immutable SHA/tag. + +### Required Tags + +- `upstream-*`: records the upstream baseline on `main` before SGL rebase work. + ## Introduction FlashMLA is DeepSeek's library of optimized attention kernels, powering the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) and [DeepSeek-V3.2-Exp](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp) models. This repository contains the following implementations: diff --git a/csrc/api/api.cpp b/csrc/api/api.cpp index f43f2a09..569d2946 100644 --- a/csrc/api/api.cpp +++ b/csrc/api/api.cpp @@ -1,4 +1,6 @@ #include +#include +#include #include "sparse_fwd.h" #include "sparse_decode.h" @@ -10,6 +12,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("sparse_decode_fwd", &sparse_attn_decode_interface); m.def("dense_decode_fwd", &dense_attn_decode_interface); m.def("sparse_prefill_fwd", &sparse_attn_prefill_interface); + m.def("sparse_prefill_fwd", []( + const at::Tensor &q, + const at::Tensor &kv, + const at::Tensor &indices, + float sm_scale, + int d_v, + const at::Tensor &attn_sink, + const at::Tensor &topk_length) { + return sparse_attn_prefill_interface( + q, kv, indices, sm_scale, d_v, attn_sink, topk_length); + }); m.def("dense_prefill_fwd", &FMHACutlassSM100FwdRun); m.def("dense_prefill_bwd", &FMHACutlassSM100BwdRun); } diff --git a/csrc/api/common.h b/csrc/api/common.h index 2c930ed9..0e1c6f1e 100644 --- a/csrc/api/common.h +++ b/csrc/api/common.h @@ -2,7 +2,7 @@ #include -#include +#include #include #include #include @@ -228,4 +228,3 @@ class ImplBase { run_(params, required_features); } }; - diff --git a/csrc/extension/python_api.cpp b/csrc/extension/python_api.cpp new file mode 100644 index 00000000..1aa14b05 --- /dev/null +++ b/csrc/extension/python_api.cpp @@ -0,0 +1,31 @@ +#include + +#include +#include + +#include +#include + +extern +std::vector +fwd_kvcache_mla_fp8( + at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size (when is_fp8 is False) or num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True) + const int64_t head_size_v, + const at::Tensor &seqlens_k, // batch_size + const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq + const double softmax_scale, + bool is_causal, + const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize + const at::Tensor &num_splits, // batch_size + 1 + const std::optional &descale_q, // None or batch_size + const std::optional &descale_k // None or batch_size +); + +extern +std::vector +get_mla_decoding_metadata_dense_fp8( + at::Tensor &seqlens_k, + const int64_t num_heads_per_head_k, + const int64_t num_heads_k +); diff --git a/csrc/extension/sm90/dense_fp8/dense_fp8_python_api.cpp b/csrc/extension/sm90/dense_fp8/dense_fp8_python_api.cpp new file mode 100644 index 00000000..8ce97be5 --- /dev/null +++ b/csrc/extension/sm90/dense_fp8/dense_fp8_python_api.cpp @@ -0,0 +1,228 @@ +#include +#include +#include +#include +#include +#include + +#include "flash_mla.h" + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +std::vector +fwd_kvcache_mla_fp8( + at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &kcache, // num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True) + const int64_t head_size_v, + const at::Tensor &seqlens_k, // batch_size + const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq + const double softmax_scale, + bool is_causal, + const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize + const at::Tensor &num_splits, // batch_size + 1 + const std::optional &descale_q, // None or batch_size + const std::optional &descale_k // None or batch_size +) { + int head_size_v_int = static_cast(head_size_v); + + // Check the architecture + auto dprops = at::cuda::getCurrentDeviceProperties(); + TORCH_CHECK(dprops->major == 9 && dprops->minor == 0, "Dense FP8 MLA is only supported on SM90"); + + // Check data types + TORCH_CHECK(q.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(kcache.dtype() == q.dtype(), "query and key must have the same dtype"); + TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); + TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); + TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32"); + TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32"); + + // Check device + CHECK_DEVICE(q); + CHECK_DEVICE(kcache); + CHECK_DEVICE(seqlens_k); + CHECK_DEVICE(block_table); + CHECK_DEVICE(tile_scheduler_metadata); + CHECK_DEVICE(num_splits); + if (descale_q.has_value()) CHECK_DEVICE(descale_q.value()); + if (descale_k.has_value()) CHECK_DEVICE(descale_k.value()); + + // Check layout + TORCH_CHECK(q.stride(-1) == 1, "q must have contiguous last dimension"); + TORCH_CHECK(kcache.stride(-1) == 1, "kcache must have contiguous last dimension"); + CHECK_CONTIGUOUS(seqlens_k); + TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); + CHECK_CONTIGUOUS(tile_scheduler_metadata); + CHECK_CONTIGUOUS(num_splits); + + const auto sizes = q.sizes(); + const int batch_size = sizes[0]; + const int seqlen_q_ori = sizes[1]; + const int num_heads_q = sizes[2]; + const int head_size_k = sizes[3]; + TORCH_CHECK(head_size_k == 576, "Only head_size_k == 576 is supported"); + TORCH_CHECK(head_size_v_int == 512, "Only head_size_v == 512 is supported"); + + const int max_num_blocks_per_seq = block_table.size(1); + const int num_blocks = kcache.size(0); + const int page_block_size = kcache.size(1); + const int num_heads_k = kcache.size(2); + TORCH_CHECK(page_block_size == 64, "Currently page_block_size must be 64"); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + TORCH_CHECK(descale_q.has_value() && descale_k.has_value(), "descale is required when input dtype is fp8"); + auto descale_q_ = descale_q.value(); + auto descale_k_ = descale_k.value(); + CHECK_DEVICE(descale_q_); + CHECK_DEVICE(descale_k_); + TORCH_CHECK(descale_q_.stride(-1) == 1); + TORCH_CHECK(descale_k_.stride(-1) == 1); + TORCH_CHECK(descale_q_.dtype() == torch::kFloat); + TORCH_CHECK(descale_k_.dtype() == torch::kFloat); + CHECK_SHAPE(descale_q_, 1); + CHECK_SHAPE(descale_k_, 1); + + if (seqlen_q_ori == 1) { is_causal = false; } + + const int num_q_heads_per_hk = num_heads_q / num_heads_k; + const int q_seq_per_hk = seqlen_q_ori * num_q_heads_per_hk; + const int num_heads = num_heads_k; + q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3) + .reshape({batch_size, q_seq_per_hk, num_heads, head_size_k}); + + CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k); + CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k); + CHECK_SHAPE(seqlens_k, batch_size); + CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); + TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize); + CHECK_SHAPE(num_splits, batch_size+1); + + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + caffe2::TypeMeta out_type; + out_type = torch::kBFloat16; + at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v_int}, opts.dtype(out_type)); + at::Tensor softmax_lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); + CHECK_CONTIGUOUS(softmax_lse); + + // Set up parameters for the dense FP8 kernel + DecodingParams_fp8 params = {}; + // Set the sizes. + params.b = batch_size; + params.s_q = seqlen_q_ori; + params.q_seq_per_hk = q_seq_per_hk; + params.seqlens_k_ptr = seqlens_k.data_ptr(); + params.h_q = num_heads_q; + params.h_k = num_heads_k; + params.num_blocks = num_blocks; + params.q_head_per_hk = num_q_heads_per_hk; + params.is_causal = is_causal; + params.d = head_size_k; + params.d_v = head_size_v_int; + params.scale_softmax = static_cast(softmax_scale); + params.scale_softmax_log2 = float(static_cast(softmax_scale) * M_LOG2E); + params.topk = -1; // Dense attention + + // FP8-specific parameters + params.h_h_k_ratio = 1; + params.descale_q_ptr = reinterpret_cast(descale_q.value().data_ptr()); + params.descale_k_ptr = reinterpret_cast(descale_k.value().data_ptr()); + + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = kcache.data_ptr(); + params.o_ptr = out.data_ptr(); + params.indices_ptr = nullptr; + params.softmax_lse_ptr = softmax_lse.data_ptr(); + + // All stride are in elements, not bytes. + params.q_batch_stride = q.stride(0); + params.k_batch_stride = kcache.stride(0); + params.o_batch_stride = out.stride(0); + params.q_row_stride = q.stride(-3); + params.k_row_stride = kcache.stride(1); + params.o_row_stride = out.stride(-3); + params.q_head_stride = q.stride(-2); + params.k_head_stride = kcache.stride(2); + params.o_head_stride = out.stride(-2); + params.indices_batch_stride = 0; + params.indices_row_stride = 0; + + params.block_table = block_table.data_ptr(); + params.block_table_batch_stride = block_table.stride(0); + params.page_block_size = page_block_size; + + params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); + params.num_sm_parts = tile_scheduler_metadata.size(0); + params.num_splits_ptr = num_splits.data_ptr(); + + // Set up accumulation tensors + const int total_num_splits = batch_size + params.num_sm_parts; + at::Tensor softmax_lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); + at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v_int}, opts.dtype(at::kFloat)); + CHECK_CONTIGUOUS(softmax_lse_accum); + CHECK_CONTIGUOUS(out_accum); + params.total_num_splits = total_num_splits; + params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); + params.oaccum_ptr = out_accum.data_ptr(); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + // Call the actual kernel implementation +#ifdef FLASH_MLA_DISABLE_FP8 + TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP8. Please remove this flag from your environment and re-compile FlashMLA."); +#else + run_mha_fwd_splitkv_mla(params, stream); +#endif + + // Reshape outputs back to original format + out = out.view({batch_size, seqlen_q_ori, num_q_heads_per_hk, num_heads_k, head_size_v_int}).transpose(2, 3) + .reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v_int}); + softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk}).transpose(2, 3) + .reshape({batch_size, num_heads_q, seqlen_q_ori}); + + return {out, softmax_lse}; +} + +std::vector +get_mla_decoding_metadata_dense_fp8( + at::Tensor &seqlens_k, + const int64_t num_heads_per_head_k, + const int64_t num_heads_k +) { + int num_heads_per_head_k_int = static_cast(num_heads_per_head_k); + int num_heads_k_int = static_cast(num_heads_k); + // This should match the logic in the MLA kernel. + static constexpr int block_size_m = 64; + static constexpr int block_size_n = 64; + static constexpr int fixed_overhead_num_blocks = 5; + CHECK_DEVICE(seqlens_k); + TORCH_CHECK(seqlens_k.is_contiguous()); + TORCH_CHECK(seqlens_k.dtype() == torch::kInt32); + int batch_size = seqlens_k.size(0); + int *seqlens_k_ptr = seqlens_k.data_ptr(); + auto options = seqlens_k.options(); + auto dprops = at::cuda::getCurrentDeviceProperties(); + int sm_count = dprops->multiProcessorCount; + int num_sm_parts = sm_count / num_heads_k_int / cutlass::ceil_div(num_heads_per_head_k_int, block_size_m); + auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options); + auto num_splits = torch::empty({batch_size + 1}, options); + int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); + int *num_splits_ptr = num_splits.data_ptr(); + at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + Mla_metadata_params params = {}; + params.seqlens_k_ptr = seqlens_k_ptr; + params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr; + params.num_splits_ptr = num_splits_ptr; + params.batch_size = batch_size; + params.block_size_n = block_size_n; + params.fixed_overhead_num_blocks = fixed_overhead_num_blocks; + params.num_sm_parts = num_sm_parts; + get_mla_metadata_func(params, stream); + return {tile_scheduler_metadata, num_splits}; +} \ No newline at end of file diff --git a/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu b/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu new file mode 100644 index 00000000..b87902c4 --- /dev/null +++ b/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu @@ -0,0 +1,10 @@ +/* + * Taken from FlashMLA PR https://github.com/deepseek-ai/FlashMLA/pull/54 + * originally authored by @endurehero + */ + +#include "flash_fwd_mla_kernel.h" + +#ifndef FLASH_MLA_DISABLE_FP8 +template void run_mha_fwd_splitkv_mla(DecodingParams_fp8 ¶ms, cudaStream_t stream); +#endif \ No newline at end of file diff --git a/csrc/extension/sm90/dense_fp8/flash_fwd_mla_kernel.h b/csrc/extension/sm90/dense_fp8/flash_fwd_mla_kernel.h new file mode 100644 index 00000000..7aefe624 --- /dev/null +++ b/csrc/extension/sm90/dense_fp8/flash_fwd_mla_kernel.h @@ -0,0 +1,709 @@ +/* + * Taken from FlashMLA PR https://github.com/deepseek-ai/FlashMLA/pull/54 + * originally authored by @endurehero + */ + +#pragma once + +#include +#include +#include +#include + +using namespace cute; + +#include "named_barrier.h" +#include "utils.h" +#include "softmax.h" +#include "static_switch.h" +#include "flash_mla.h" +#include "fp8_transpose_v.h" + + +template +constexpr auto getSmemLayoutK() { + constexpr int headSizeBytes = sizeof(PrecType) * DIM; + constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2; + + if constexpr (major == GMMA::Major::K) { + if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) { + return GMMA::Layout_K_SW128_Atom{}; + } else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) { + return GMMA::Layout_K_SW64_Atom{}; + } else { + return GMMA::Layout_K_SW32_Atom{}; + } + } else { + if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) { + return GMMA::Layout_MN_SW128_Atom{}; + } else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) { + return GMMA::Layout_MN_SW64_Atom{}; + } else { + return GMMA::Layout_MN_SW32_Atom{}; + } + } + +} + +template +struct Flash_fwd_kernel_traits_mla { + using Element = elem_type; + using ElementO = elem_type_o; + using ElementAccum = float; + using index_t = int64_t; + + static constexpr bool Is_FP8 = cute::is_same_v; + + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + static constexpr int kNWarpsS = 4; + static constexpr int kNThreadsS = kNWarpsS * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim; + static_assert(kHeadDimV % 32 == 0); + static_assert(kHeadDimV <= kHeadDim); + + static constexpr int kBlockKSmem = Is_FP8 ? (kHeadDim % 128 == 0 ? 128 : 64) : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kBlockKSmemO = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kSwizzleO = kBlockKSmemO == 32 ? 2 : 3; + + static constexpr cute::GMMA::Major MmaMajorV = !Is_FP8 ? GMMA::Major::MN : GMMA::Major::K; + + using TiledMma = decltype(make_tiled_mma( + cute::GMMA::ss_op_selector, Int, Int>, + GMMA::Major::K, GMMA::Major::K>(), + Layout, _1, _1>>{})); + + static constexpr int AtomLayoutNO = kNThreads / kNThreadsS; + using TiledMmaO = decltype(make_tiled_mma( + cute::GMMA::rs_op_selector, Int, Int>, + GMMA::Major::K, MmaMajorV>(), + Layout, Int, _1>>{})); + + using SmemLayoutQ = decltype(tile_to_shape( + getSmemLayoutK(), + Shape, Int>{})); + + using SmemLayoutK = decltype(tile_to_shape( + getSmemLayoutK(), + Shape, Int>{})); + + using SmemLayoutV = decltype(tile_to_shape( + getSmemLayoutK(), + Shape, Int>{})); + using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + + using SmemLayoutP = std::conditional_t< + Is_FP8, + Layout, Int, _1, _2, Int>>, + Layout, Int, _1, _2, Int>> + >; + using SmemLayoutRow = Layout>, Stride<_1, _2>>; + + using SmemLayoutAtomO = decltype(composition( + Swizzle{}, + Layout, Int>, Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomOaccum = Copy_Atom, ElementAccum>; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + using Gmem_copy_struct = SM80_CP_ASYNC_CACHEGLOBAL; + static constexpr int kNThreadsLoad = kNThreads - kNThreadsS; + static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + + static constexpr int kGmemElemsPerLoadO = sizeof(cute::uint128_t) / sizeof(ElementO); + static_assert(kHeadDim % kGmemElemsPerLoadO == 0, "kHeadDim must be a multiple of kGmemElemsPerLoadO"); + static constexpr int kGmemThreadsPerRowO = kBlockKSmemO / kGmemElemsPerLoadO; + static_assert(kNThreadsLoad % kGmemThreadsPerRowO == 0, "kNThreads must be a multiple of kGmemThreadsPerRowO"); + + using GmemLayoutAtom = Layout< + Shape, Int>, + Stride, _1>>; + + + using GmemTiledCopy = decltype(make_tiled_copy( + Copy_Atom{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 8 vals per read + + using GmemLayoutAtomO = Layout< + Shape, Int>, + Stride, _1>>; + using GmemTiledCopyO = decltype(make_tiled_copy( + Copy_Atom, ElementO>{}, + GmemLayoutAtomO{}, + Layout>>{})); // Val layout, 8 vals per store + + static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum); + static constexpr int kGmemThreadsPerRowAccum = kBlockKSmemO / kGmemElemsPerLoadAccum; + using GmemLayoutAtomOaccum = Layout< + Shape, Int>, + Stride, _1>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy( + Copy_Atom, ElementAccum>{}, + GmemLayoutAtomOaccum{}, + Layout>>{})); // Val layout, 4 vals per store + + + // ------ for f8 ------ + using SmemFp8Tranpose = SmemTransposeFp8_64x64; + using SmemLayoutVtMMa = typename SmemFp8Tranpose::SmemLayoutVt; +}; + +namespace flash { + +using namespace cute; + +template +struct SharedStorageMLA { + using SmemV_t = std::conditional_t>, + cute::array_aligned>; + union { + struct { + cute::array_aligned> smem_q; + cute::array_aligned * 2> smem_k; // Double buffer + SmemV_t smem_vt; + cute::array_aligned> smem_p; + cute::array_aligned> smem_scale; + }; + struct { + cute::array_aligned> smem_max; + cute::array_aligned> smem_sum; + cute::array_aligned> smem_o; + }; + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void store(const DecodingParams_fp8 ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, + SharedStorage &shared_storage, AccO tOrO, Softmax softmax, float descale_k, float scale_softmax) { + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kHeadDimV = Kernel_traits::kHeadDimV; + constexpr int kNThreadsS = Kernel_traits::kNThreadsS; + using Element = typename Kernel_traits::ElementO; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + const int tidx = threadIdx.x; + + typename Kernel_traits::TiledMmaO tiled_mma_o; + auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); + + // Epilogue + + const int split_offset = __ldg(params.num_splits_ptr + bidb); + + Tensor lse = softmax.template normalize_softmax_lse(tOrO, scale_softmax, descale_k); + + using ElementO = std::conditional_t; + Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + using SmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::SmemCopyAtomO, + typename Kernel_traits::SmemCopyAtomOaccum + >; + auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o); + auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor rO = flash::convert_type(tOrO); + Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + __syncthreads(); + + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); + + const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h_k + bidh) * params.q_seq_per_hk + m_block * kBlockM) * params.d_v; + const index_t row_offset_lse = (bidb * params.h_k + bidh) * params.q_seq_per_hk + m_block * kBlockM; + const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h_k + bidh) * params.q_seq_per_hk + m_block * kBlockM; + + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)), + Shape>{}, Stride<_1>{}); + + using GmemTiledCopyO = std::conditional_t; + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + + __syncthreads(); + + if (tidx >= kNThreadsS) { return; } + + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1) + Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { +#pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < params.q_seq_per_hk - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.q_seq_per_hk - m_block * kBlockM + ); +} + +template +__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const DecodingParams_fp8 ¶ms, + const int bidb, const int bidh, const int m_block, + const int n_split_idx, const int seqlen_k, + const int n_block_min, const int n_block_max, const bool NoSplit, + SharedStorage &shared_storage, const float descale_k, const float scale_softmax, const float scale_softmax_log2) { + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kHeadDimV = Kernel_traits::kHeadDimV; + constexpr int kNThreads = Kernel_traits::kNThreads; + constexpr int kNThreadsS = Kernel_traits::kNThreadsS; + static_assert(kNThreads == 256 and kNThreadsS == 128); + using Element = typename Kernel_traits::Element; + using index_t = typename Kernel_traits::index_t; + + const int tidx = threadIdx.x; + int n_block = n_block_max - 1; + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{}); + + auto sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{}); + auto sVt = [&](){ + if constexpr(Kernel_traits::Is_FP8){ + return make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{}); + } else { + return make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}); + } + }(); + + Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); + Tensor tPsP = sP(_, tidx % kNThreadsS, _, _, _); + Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{}); + Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS); + Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{}); + Tensor tRow_maxsRow_max = sRow_max(_, tidx % kNThreadsS); + Tensor sRow_sum = make_tensor(make_smem_ptr(shared_storage.smem_sum.data()), typename Kernel_traits::SmemLayoutRow{}); + Tensor tRow_sumsRow_sum = sRow_sum(_, tidx % kNThreadsS); + + typename Kernel_traits::TiledMmaO tiled_mma_o; + auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); + Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt); // (MMA, MMA_K,MMA_N) + Tensor tOrO = partition_fragment_C(tiled_mma_o, Shape, Int>{}); // ((MMA=4, X), MMA_M, MMA_N=1) + clear(tOrO); + + flash::Softmax<2 * size<1>(tOrO)> softmax; + + int warp_group_idx = cutlass::canonical_warp_group_idx(); + if (warp_group_idx == 0) { + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + + if (n_block % 2 == 1) { + // Double buffer for sK + constexpr int sK_offset = size(sK); + + if constexpr (Kernel_traits::Is_FP8) { + tSrK.data() = tSrK.data() + sK_offset / 16; + } else { + tSrK.data() = tSrK.data() + sK_offset / 8; + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + } + + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; +#pragma unroll 1 + for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) { + __syncthreads(); + + Tensor tSrS = partition_fragment_C(tiled_mma, Shape, Int>{}); // ((MMA=4, X), MMA_M, MMA_N=1) + flash::gemm(tiled_mma, tSrQ, tSrK, tSrS); + + const bool is_masking_step = masking_step > 0; + const bool is_first_masking_step = masking_step == n_masking_steps; + + if (is_masking_step) { + Tensor cS = make_identity_tensor(Shape, Int>{}); + Tensor tScS = thr_mma.partition_C(cS); +#pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + if constexpr (!Is_causal) { // Just masking based on col + if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) tSrS(i) = -INFINITY; + } else { + // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / q_head_per_hk + // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / q_head_per_hk + int row = int(get<0>(tScS(i))); + int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.q_seq_per_hk - 1 - (m_block * kBlockM + row)) / params.q_head_per_hk; + if (int(get<1>(tScS(i))) > col_limit_right) tSrS(i) = -INFINITY; + } + } + } + + // We have key_padding_mask so we'll need to Check_inf + Tensor scale_o = is_first_masking_step + ? softmax.template softmax(tSrS, scale_softmax_log2) + : is_masking_step ? + softmax.template softmax(tSrS, scale_softmax_log2) + : softmax.template softmax(tSrS, scale_softmax_log2); + + if constexpr (Kernel_traits::Is_FP8) { flash::permute_Cregs_fp8(tSrS); } + Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); + Tensor tOrP = make_tensor_like(tOrP_acc); + convert_type_out(tOrP_acc, tOrP); + + cute::copy(tOrP, tPsP); // send Aregs of MMA1 instead of Cregs of MMA0 + cute::copy(scale_o, tScale_osScale_o); + + cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SReady)); + + flash::rescale_o(tOrO, scale_o); + + if constexpr (Kernel_traits::Is_FP8) { + cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::TransVReady)); + __syncthreads(); + } + flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); + + // Double buffer for sK + const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); + if constexpr (Kernel_traits::Is_FP8) { + tSrK.data() = tSrK.data() + sK_offset / 16; + } else { + tSrK.data() = tSrK.data() + sK_offset / 8; + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + } + + cute::copy(softmax.row_max, tRow_maxsRow_max); + cute::copy(softmax.row_sum, tRow_sumsRow_sum); + cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); + } else { + const int *block_table = params.block_table + bidb * params.block_table_batch_stride; + int cur_block_table = __ldg(&block_table[n_block]); + + const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q; + auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS); + Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, + params.q_seq_per_hk - m_block * kBlockM); + + const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride; + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_K; + auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx - kNThreadsS); + Tensor tKgK = gmem_thr_copy_K.partition_S(gK); + Tensor tKsK = gmem_thr_copy_K.partition_D(sK); + Tensor cK = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tKcK = gmem_thr_copy_K.partition_S(cK); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + Tensor tKpK = make_tensor(make_shape(size<2>(tKsK))); + + if (n_block % 2 == 1) { + // Double buffer for sK + constexpr int sK_offset = size(sK); + tKsK.data() = tKsK.data() + sK_offset; + if constexpr (!Kernel_traits::Is_FP8) { + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + } + + // We need to clear the sK smem tiles because K is V. + const index_t offset_k = cur_block_table * params.k_batch_stride; + tKgK.data() = tKgK.data() + offset_k; + flash::copy(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK, + seqlen_k - n_block * kBlockN); + tKgK.data() = tKgK.data() + -offset_k; + cute::cp_async_fence(); + + if (n_block - 1 >= n_block_min) { + cur_block_table = __ldg(&block_table[n_block - 1]); + } + +#pragma unroll 1 + for (; n_block >= n_block_min; --n_block) { + flash::cp_async_wait<0>(); + __syncthreads(); + + if (n_block - 1 >= n_block_min) { + // Double buffer for sK + const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); + tKsK.data() = tKsK.data() + sK_offset; + + const index_t offset_k = cur_block_table * params.k_batch_stride; + tKgK.data() = tKgK.data() + offset_k; + flash::copy(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK); + tKgK.data() = tKgK.data() + -offset_k; + cute::cp_async_fence(); + } + + if constexpr (Kernel_traits::Is_FP8) { + auto TransV = [&]() { + using SmemFp8Tranpose = typename Kernel_traits::SmemFp8Tranpose; + SmemFp8Tranpose smem_transpose_V; + Tensor sV_divide = as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename SmemFp8Tranpose::SmemLayoutTransposeV{})); + Tensor sVt_divide = as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename SmemFp8Tranpose::SmemLayoutTransposeVt{})); + + if (n_block % 2 == 1) { + sV_divide.data() = sV_divide.data() + size(sK); + } + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < shape<2>(typename SmemFp8Tranpose::SmemLayoutTransposeV{}); ++j) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < shape<1>(typename SmemFp8Tranpose::SmemLayoutTransposeV{}); ++i) { + smem_transpose_V.transpose(flatten(sV_divide(_, i, j)), flatten(sVt_divide(_, i, j))); + } + } + }; + + TransV(); + cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::TransVReady)); + } + + cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SReady)); + + if (n_block - 2 >= n_block_min) { + cur_block_table = __ldg(&block_table[n_block - 2]); + } + + typename Kernel_traits::TiledMma tiled_mma; + auto tSrS_layout = flash::convert_layout_acc_Aregs(partition_fragment_C(tiled_mma, Shape, Int>{}).layout()); + Tensor tOrP = make_tensor(tSrS_layout); + Tensor scale_o = make_tensor(Shape<_2>{}); + cute::copy(tScale_osScale_o, scale_o); + cute::copy(tPsP, tOrP); + + flash::rescale_o(tOrO, scale_o); + + if constexpr (Kernel_traits::Is_FP8) __syncthreads(); + flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); + + if constexpr (!Kernel_traits::Is_FP8) { + // Double buffer for sK + const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + } + + cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); + cute::copy(tRow_maxsRow_max, softmax.row_max); + cute::copy(tRow_sumsRow_sum, softmax.row_sum); + } + + if (NoSplit) + store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax, descale_k, scale_softmax); + else + store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax, descale_k, scale_softmax); +} + +template +__global__ void __launch_bounds__(Kernel_traits::kNThreads, 1, 1) +flash_fwd_splitkv_mla_kernel(__grid_constant__ const DecodingParams_fp8 params) { + constexpr int kBlockN = Kernel_traits::kBlockN; + const int m_block = blockIdx.x; + const int bidh = blockIdx.y; + const int partition_idx = blockIdx.z; + + extern __shared__ char shared_memory[]; + auto &shared_storage = *reinterpret_cast(shared_memory); + + int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; + int4 tile_scheduler_metadata = __ldg(reinterpret_cast(tile_scheduler_metadata_ptr)); + int begin_idx = tile_scheduler_metadata.x; + int begin_seqlen = tile_scheduler_metadata.y; + int end_idx = tile_scheduler_metadata.z; + int end_seqlen = tile_scheduler_metadata.w; + if (begin_idx >= params.b || begin_idx < 0) return; + int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); + + float descale_k = 1.f; + float scale_softmax = params.scale_softmax; + float scale_softmax_log2 = params.scale_softmax_log2; + if constexpr (Kernel_traits::Is_FP8) { + float descale_q = __ldg(params.descale_q_ptr); + descale_k = __ldg(params.descale_k_ptr); + scale_softmax = scale_softmax * descale_q * descale_k; + scale_softmax_log2 = scale_softmax_log2 * descale_q * descale_k; + } + +#pragma unroll 1 + for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) { + const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0; + const int seqlen_k = __ldg(params.seqlens_k_ptr + batch_id); + const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0; + const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN); + const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN); + if (batch_id > begin_idx) { + __syncthreads(); // Barrier between two tiles. + } + flash::compute_attn_1rowblock_splitkv_mla(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage, descale_k, scale_softmax, scale_softmax_log2); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void __launch_bounds__(256, 1, 1) +flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const DecodingParams_fp8 params) { + constexpr int kNThreads = 128; + + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + const int hs = params.h_k * params.q_seq_per_hk; + const int batch_idx = bidx / hs; + const int hs_idx = bidx % hs; + + const int split_offset = __ldg(params.num_splits_ptr + batch_idx); + const int actual_num_splits = __ldg(params.num_splits_ptr + batch_idx + 1) - split_offset; + FLASH_DEVICE_ASSERT(actual_num_splits <= kMaxSplits); + if (actual_num_splits <= 1) return; + + __shared__ ElementAccum sLseScale[kMaxSplits]; + + const index_t row_offset_lseaccum = split_offset * hs + hs_idx; + const index_t row_offset_lse = bidx; + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lseaccum), + Shape>{}, make_stride(hs)); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape<_1>{}, Stride<_1>{}); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + if (warp_idx == 0) { + constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32); + + float local_lse[kNLsePerThread]; + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + tidx; + local_lse[i] = split < actual_num_splits ? gLSEaccum(split) : -INFINITY; + } + + float max_lse = -INFINITY; + for (int i = 0; i < kNLsePerThread; ++i) max_lse = max(max_lse, local_lse[i]); + for (int offset = 16; offset >= 1; offset /= 2) max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset)); + max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf + + float sum_lse = 0; + for (int i = 0; i < kNLsePerThread; ++i) sum_lse = sum_lse + expf(local_lse[i] - max_lse); + for (int offset = 16; offset >= 1; offset /= 2) sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset); + + float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : logf(sum_lse) + max_lse; + if (tidx == 0) gLSE(0) = global_lse; + + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + tidx; + if (split < actual_num_splits) sLseScale[split] = expf(local_lse[i] - global_lse); + } + } + __syncthreads(); + + static_assert(kHeadDimV % kNThreads == 0); + constexpr int Elements = kHeadDimV / kNThreads; + const index_t row_offset_oaccum = (split_offset * hs + hs_idx) * kHeadDimV; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), + Shape>{}, Stride<_1>{}); + using GmemTiledCopyOaccum = decltype(make_tiled_copy( + Copy_Atom, ElementAccum>{}, + Layout>>{}, + Layout>>{})); + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + Tensor tOrO = make_tensor(shape(tOgOaccum)); + clear(tOrO); + + for (int split = 0; split < actual_num_splits; ++split) { + cute::copy(tOgOaccum, tOrOaccum); + ElementAccum lse_scale = sLseScale[split]; + for (int i = 0; i < size(tOrO); ++i) { + tOrO(i) += lse_scale * tOrOaccum(i); + } + tOgOaccum.data() = tOgOaccum.data() + hs * kHeadDimV; + } + + Tensor rO = flash::convert_type(tOrO); + const int head_idx = (bidx - batch_idx * hs) / params.q_seq_per_hk; + const int row = bidx - batch_idx * hs - head_idx * params.q_seq_per_hk; + auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride; + Tensor gO = make_tensor(make_gmem_ptr(o_ptr + tidx * Elements), Shape(rO))::value>>{}, Stride<_1>{}); + cute::copy(rO, gO); +} + +} // namespace flash + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_flash_splitkv_fwd_mla(DecodingParams_fp8 ¶ms, cudaStream_t stream) { + FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN); + const int num_m_block = cute::ceil_div(params.q_seq_per_hk, Kernel_traits::kBlockM); + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + auto kernel = &flash::flash_fwd_splitkv_mla_kernel; + constexpr size_t smem_size = sizeof(SharedStorage); + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + kernel<<>>(params); + }); + CHECK_CUDA_KERNEL_LAUNCH(); + + dim3 grid_combine(params.b * params.h_k * params.q_seq_per_hk); + MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] { + auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel< + typename Kernel_traits::ElementO, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>; + combine_kernel<<>>(params); + }); + CHECK_CUDA_KERNEL_LAUNCH(); +} + +template +void run_mha_fwd_splitkv_mla(DecodingParams_fp8 ¶ms, cudaStream_t stream) { + static_assert(Headdim == 576); + FLASH_ASSERT(params.d_v == 512); + using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, To, 512>; + run_flash_splitkv_fwd_mla>(params, stream); +} \ No newline at end of file diff --git a/csrc/extension/sm90/dense_fp8/flash_fwd_mla_metadata.cu b/csrc/extension/sm90/dense_fp8/flash_fwd_mla_metadata.cu new file mode 100644 index 00000000..96c2bd36 --- /dev/null +++ b/csrc/extension/sm90/dense_fp8/flash_fwd_mla_metadata.cu @@ -0,0 +1,77 @@ +#include "flash_fwd_mla_kernel.h" + +static constexpr int MaxBatchSize = 4096; + +__global__ void __launch_bounds__(256, 1, 1) +get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { + int *seqlens_k_ptr = params.seqlens_k_ptr; + int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; + int *num_splits_ptr = params.num_splits_ptr; + int batch_size = params.batch_size; + int block_size_n = params.block_size_n; + int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks; + int num_sm_parts = params.num_sm_parts; + + __shared__ int num_blocks_shared[MaxBatchSize]; + __shared__ int num_splits_shared[MaxBatchSize]; + + int total_num_blocks = 0; + for (int i = threadIdx.x; i < batch_size; i += 32) { + int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n); + total_num_blocks += num_blocks + fixed_overhead_num_blocks; + num_blocks_shared[i] = num_blocks; + } + for (int offset = 16; offset >= 1; offset /= 2) { + total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset); + } + __syncwarp(); + + if (threadIdx.x == 0) { + int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks; + + int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0; + num_splits_shared[0] = 0; + for (int i = 0; i < num_sm_parts; ++i) { + int tile_scheduler_metadata0[4], tile_scheduler_metadata1; + tile_scheduler_metadata0[0] = now_idx; + tile_scheduler_metadata0[1] = now_block * block_size_n; + tile_scheduler_metadata1 = now_n_split_idx; + int remain_payload = payload; + while (now_idx < batch_size) { + int num_blocks = num_blocks_shared[now_idx]; + int now_remain_blocks = num_blocks - now_block; + if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) { + cum_num_splits += now_n_split_idx + 1; + num_splits_shared[now_idx + 1] = cum_num_splits; + remain_payload -= now_remain_blocks + fixed_overhead_num_blocks; + ++now_idx; + now_block = 0; + now_n_split_idx = 0; + } else { + if (remain_payload - fixed_overhead_num_blocks > 0) { + now_block += remain_payload - fixed_overhead_num_blocks; + ++now_n_split_idx; + remain_payload = 0; + } + break; + } + } + tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1; + tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1]; + *reinterpret_cast(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast(tile_scheduler_metadata0); + tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1; + } + FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0); + } + __syncwarp(); + + for (int i = threadIdx.x; i <= batch_size; i += 32) { + num_splits_ptr[i] = num_splits_shared[i]; + } +} + +void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) { + FLASH_ASSERT(params.batch_size < MaxBatchSize); + get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params); + CHECK_CUDA_KERNEL_LAUNCH(); +} \ No newline at end of file diff --git a/csrc/extension/sm90/dense_fp8/flash_mla.h b/csrc/extension/sm90/dense_fp8/flash_mla.h new file mode 100644 index 00000000..2f4cff3d --- /dev/null +++ b/csrc/extension/sm90/dense_fp8/flash_mla.h @@ -0,0 +1,81 @@ +/* + * Taken from FlashMLA PR https://github.com/deepseek-ai/FlashMLA/pull/54 + * originally authored by @endurehero + */ + +#pragma once + +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Keep a self-contained fp8 decode params struct so this extension can stay +// compatible when upstream refactors csrc/params.h. +struct DecodingParams_fp8 { + using index_t = int64_t; + + int b; // batch size + int s_q; + int q_seq_per_hk; // Number of q(s) per KV head + int d, d_v; // K/V dimension + int h_q, h_k; // Number of Q/K heads + int num_blocks; // Number of blocks in total + int q_head_per_hk; // Number of q heads per KV head + bool is_causal; + float scale_softmax, scale_softmax_log2; + int topk; + + void* __restrict__ q_ptr; + void* __restrict__ k_ptr; + void* __restrict__ o_ptr; + void* __restrict__ softmax_lse_ptr; + int* __restrict__ indices_ptr; + + index_t q_batch_stride; + index_t k_batch_stride; + index_t o_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t o_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t o_head_stride; + index_t indices_batch_stride; + index_t indices_row_stride; + + int* __restrict__ block_table; + index_t block_table_batch_stride; + int page_block_size; + int* __restrict__ seqlens_k_ptr; + + int* __restrict__ tile_scheduler_metadata_ptr; + int num_sm_parts; + int* __restrict__ num_splits_ptr; + + int total_num_splits; + void* __restrict__ softmax_lseaccum_ptr; + void* __restrict__ oaccum_ptr; + + int h_h_k_ratio; + float* __restrict__ descale_q_ptr = nullptr; + float* __restrict__ descale_k_ptr = nullptr; +}; + +static constexpr int TileSchedulerMetaDataSize = 8; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_mha_fwd_splitkv_mla(DecodingParams_fp8 ¶ms, cudaStream_t stream); + +struct Mla_metadata_params { + int *__restrict__ seqlens_k_ptr; + int *__restrict__ tile_scheduler_metadata_ptr; + int *__restrict__ num_splits_ptr; + int batch_size; + int block_size_n; + int fixed_overhead_num_blocks; + int num_sm_parts; +}; +void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream); diff --git a/csrc/extension/sm90/dense_fp8/fp8_transpose_v.h b/csrc/extension/sm90/dense_fp8/fp8_transpose_v.h new file mode 100644 index 00000000..9001e1fd --- /dev/null +++ b/csrc/extension/sm90/dense_fp8/fp8_transpose_v.h @@ -0,0 +1,88 @@ +/* + * Taken from FlashMLA PR https://github.com/deepseek-ai/FlashMLA/pull/54 + * originally authored by @endurehero + */ + + +/** + * ref to Fa3's SmemTranspose64x64: + * https://github.com/Dao-AILab/flash-attention/blob/0823cf7b5d96499c1c79a4f64b1e256a035ba4b4/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp#L26 +*/ + +#pragma once + +template +struct SmemTransposeFp8_64x64 { + static_assert((kBlockN % 64 == 0) && (kHeadDim % 64 == 0)); + + using Element = cutlass::float_e4m3_t; + using TransposeShapeAtomV = Shape<_64, _64>; + using SmemLayoutAtomV = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom{}, TransposeShapeAtomV{})); + using SmemLayoutV = + decltype(tile_to_shape(SmemLayoutAtomV{}, + Shape, Int>{})); + + // for fp8 in-kernel transpose -- src layout + using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{})); + using SmemShapeLDSM = Shape, Shape<_16, _4>>; + using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{}, shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{}))); + using SmemLayoutTransposeV = decltype(composition(SmemLayoutDivideV{}, make_layout(FactoringShapeV{}))); + + // For fp8, this is the memory transpose. + using SmemLayoutAtomVt = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom{}, TransposeShapeAtomV{})); + using SmemLayoutVt = + decltype(tile_to_shape(SmemLayoutAtomVt{}, + Shape, Int>{})); + + // for fp8 in-kernel transpose -- dst layout + using SmemLayoutVtTrans = decltype(composition( + SmemLayoutVt{}, make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1>{}))); + using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{})); + using SmemShapeSTSM = Shape, Shape<_16, _4>>; + using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{}, shape<1>(SmemLayoutDivideVt{}), shape<2>(SmemLayoutDivideVt{}))); + using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{}))); + + + using ldsm_thread_shape = Shape<_4, _1, _8, _4>; + using ldsm_value_shape = Shape<_2, _8, _2, _1>; + using ldsm_value_stride = Stride<_2, _4, _1, _0>; + using TiledCopyLDSM = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, + Layout{})); + TiledCopyLDSM tiled_copy_ldsm; + + using stsm_thread_shape = Shape<_4, _1, _8, _4>; + // using stsm_thread_stride = Stride<_1, _0, _4, _32>; + using stsm_value_shape = Shape<_4, _4, _2, _1>; + using stsm_value_stride = Stride<_1, _8, _4, _0>; + + using TiledCopySTSM = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, + Layout{})); + TiledCopySTSM tiled_copy_stsm; + + template + CUTLASS_DEVICE void transpose(SmemTensor &&s_in, SmemTensorOut &&s_out) { + using namespace cute; + + auto tid = threadIdx.x % cutlass::NumThreadsPerWarpGroup; + auto thr_copy_ldsm = tiled_copy_ldsm.get_thread_slice(tid); + auto thr_copy_stsm = tiled_copy_stsm.get_thread_slice(tid); + + auto tXsX = thr_copy_ldsm.partition_S(s_in); + auto tXrX = make_tensor(shape(tXsX)); + auto tXsX_out = thr_copy_stsm.partition_D(s_out); + + cute::copy(tiled_copy_ldsm, tXsX, tXrX); + + auto data = tXrX.data(); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size(tXrX); n += 8) { + uint32_t *data_32bit = reinterpret_cast(&data[n]); + auto upper = data_32bit[0]; + auto lower = data_32bit[1]; + data_32bit[0] = __byte_perm(upper, lower, 0x6420); + data_32bit[1] = __byte_perm(upper, lower, 0x7531); + } + + cute::copy(tiled_copy_stsm, tXrX, tXsX_out); + } +}; diff --git a/csrc/extension/sm90/dense_fp8/named_barrier.h b/csrc/extension/sm90/dense_fp8/named_barrier.h new file mode 100644 index 00000000..8f2e546a --- /dev/null +++ b/csrc/extension/sm90/dense_fp8/named_barrier.h @@ -0,0 +1,21 @@ +/* + * Taken from FlashMLA PR https://github.com/deepseek-ai/FlashMLA/pull/54 + * originally authored by @endurehero + */ + +#pragma once + +#include "cutlass/barrier.h" + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Enumerates the reserved named barriers to avoid potential conflicts + +enum class NamedBarriers { + SReady = 1, + SoftmaxReady = 2, + TransVReady = 3, +}; + +} // flash \ No newline at end of file diff --git a/csrc/extension/sm90/dense_fp8/softmax.h b/csrc/extension/sm90/dense_fp8/softmax.h new file mode 100644 index 00000000..1996e850 --- /dev/null +++ b/csrc/extension/sm90/dense_fp8/softmax.h @@ -0,0 +1,202 @@ +/* + * Taken from FlashMLA PR https://github.com/deepseek-ai/FlashMLA/pull/54 + * originally authored by @endurehero + */ + +// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h + +#pragma once + +#include + +#include +#include + +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); + #pragma unroll + for (int i = 0; i < size(dst); i++){ + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ + SumOp sum_op; + thread_reduce_(tensor, sum, sum_op); +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ auto scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + // If we don't have float around M_LOG2E the multiplication is done in fp64. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + // The following macro will disable the use of fma. + // See: https://github.com/pytorch/pytorch/issues/121558 for more details + // This macro is set in PyTorch and not FlashAttention + #ifdef UNFUSE_FMA + tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled); + #else + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + #endif + } + } + return tensor; +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + MaxOp max_op; + max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + max(mi) = max_op(max(mi), tensor(mi, ni)); + } + max(mi) = Allreduce<4>::run(max(mi), max_op); + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + sum(mi) = 0; + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + sum(mi) += tensor(mi, ni); + } + SumOp sum_op; + sum(mi) = Allreduce<4>::run(sum(mi), sum_op); + } +} + +template +__forceinline__ __device__ void rescale_o(Tensor0 &acc_o, Tensor1 &scale_o) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + #pragma unroll + for (int mi = 0; mi < size(scale_o); ++mi) { + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale_o(mi); } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax { + + using TensorT = decltype(make_tensor(Shape>{})); + TensorT row_max, row_sum; + + __forceinline__ __device__ Softmax() {}; + + template + __forceinline__ __device__ TensorT softmax(Tensor0 &acc_s, float softmax_scale_log2) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + TensorT scale_o; + clear(scale_o); + if (Is_first) { + flash::template reduce_max(scores, row_max); + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + flash::reduce_sum(scores, row_sum); + } else { + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + flash::template reduce_max(scores, row_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + #pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = !Check_inf + ? row_max(mi) + : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + scale_o(mi) = scores_scale; + row_sum(mi) *= scores_scale; + } + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + // We don't do the reduce across threads here since we don't need to use the row_sum. + // We do that reduce at the end when we need to normalize the softmax. + flash::reduce_sum(scores, row_sum); + } + return scale_o; + }; + + template + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float descale_v, float rp_dropout=1.0) { + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); + TensorT lse = make_fragment_like(row_sum); + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); + #pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = row_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : descale_v / sum; + lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); + float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } + } + return lse; + }; +}; + +} // namespace flash \ No newline at end of file diff --git a/csrc/extension/sm90/dense_fp8/static_switch.h b/csrc/extension/sm90/dense_fp8/static_switch.h new file mode 100644 index 00000000..c0f73311 --- /dev/null +++ b/csrc/extension/sm90/dense_fp8/static_switch.h @@ -0,0 +1,70 @@ +/* + * Taken from FlashMLA PR https://github.com/deepseek-ai/FlashMLA/pull/54 + * originally authored by @endurehero + */ + +#pragma once + +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ + } while(0) + +#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) + + +#define FLASH_ASSERT(cond) \ + do { \ + if (not (cond)) { \ + fprintf(stderr, "Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \ + exit(1); \ + } \ + } while(0) + + +#define FLASH_DEVICE_ASSERT(cond) \ + do { \ + if (not (cond)) { \ + printf("Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \ + asm("trap;"); \ + } \ + } while(0) + + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + + +#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \ + [&] { \ + if (NUM_SPLITS <= 32) { \ + constexpr static int NAME = 32; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 64) { \ + constexpr static int NAME = 64; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 96) { \ + constexpr static int NAME = 96; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 128) { \ + constexpr static int NAME = 128; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 160) { \ + constexpr static int NAME = 160; \ + return __VA_ARGS__(); \ + } else { \ + FLASH_ASSERT(false); \ + } \ + }() \ No newline at end of file diff --git a/csrc/extension/sm90/dense_fp8/utils.h b/csrc/extension/sm90/dense_fp8/utils.h new file mode 100644 index 00000000..cd6f95bb --- /dev/null +++ b/csrc/extension/sm90/dense_fp8/utils.h @@ -0,0 +1,279 @@ +/* + * Taken from FlashMLA PR https://github.com/deepseek-ai/FlashMLA/pull/54 + * originally authored by @endurehero + */ + +// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/utils.h + +#pragma once + +#include +#include +#include + +#include + +#include + +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { +// This is slightly faster +__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ __forceinline__ T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Allreduce<2> { +template +static __device__ __forceinline__ T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { + constexpr bool Is_RS = !cute::is_base_of::value; + // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const + if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast(tCrA)); } + warpgroup_fence_operand(tCrC); + if constexpr (arrive) { + warpgroup_arrive(); + } + if constexpr (zero_init) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } else { + // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } + if constexpr (commit) { + warpgroup_commit_batch(); + } + if constexpr (wg_wait >= 0) { warpgroup_wait(); } + warpgroup_fence_operand(tCrC); + if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout0 acc_layout) { + if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = acc_layout; + if constexpr (!Transposed) { + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l))); + } else { + return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); + } + + } else { // SM80 + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + if constexpr (!Transposed) { + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + } else { + return make_layout(make_layout(get<0, 0>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. +// For SM90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) +// For SM90, FP8, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout0 acc_layout) { + using X = Underscore; + if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + static_assert(decltype(rank(get<0>(acc_layout)))::value == 3); + if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) { + auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{}); // ((2, N / 16)) + return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); + } else { + static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1); + static_assert(decltype(stride<0, 0>(acc_layout))::value == 1); + static_assert(decltype(stride<0, 1>(acc_layout))::value == 2); + auto l = logical_divide(get<0, 2>(acc_layout), Tile>>{}); // (((2, 2), N / 32)) + // This combines the first two modes (<0, 0> and <0, 1>) into one mode. + // Will require register shuffling later to be correct. + return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)), + get<1>(acc_layout), + coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); // ((4, 2, 2), MMA_M, N / 32 * MMA_N) + // This combination is right but doesn't work with register shuffling. + // return make_layout(make_layout(coalesce(make_layout(get<0, 0>(acc_layout), get<0, 0, 0>(l))), get<0, 1>(acc_layout), get<0, 0, 1>(l)), + // get<1>(acc_layout), + // coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); + } + } else { // SM80 + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + if constexpr (mma_shape_K == 8) { + return acc_layout; + } else { + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ auto convert_type(Tensor const &tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast *>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Blocks until all but N previous cp.async.commit_group operations have committed. +// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all +// (which is equivalent to commit_group then wait_group 0). +// Instead we just call cp.async.wait_group 0, which is slightly faster. +// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 +template +CUTE_HOST_DEVICE +void cp_async_wait() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const &S, + Tensor &D, Tensor const &identity_MN, + Tensor const &predicate_K, const int max_MN=0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void permute_Cregs_fp8(Fragment &frag) { + // frag has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits + static_assert(decltype(size<0, 0>(frag))::value == 2); + static_assert(decltype(size<0, 1>(frag))::value == 2); + static_assert(decltype(size<0, 2>(frag))::value % 2 == 0); + static_assert(decltype(stride<0, 0>(frag))::value == 1); + static_assert(sizeof(typename Fragment::value_type) == 4); + Tensor frag_64b = group_modes<1, 3>(recast(frag)); // ((1, 2, N / 8), (MMA_M, MMA_N)) + #pragma unroll + for (int mi = 0; mi < size<1>(frag_64b); ++mi) { + #pragma unroll + for (int i = 0; i < size<0, 2>(frag_64b) / 2; ++i) { + cutlass::swap(frag_64b(make_coord(_0{}, _1{}, 2 * i), mi), frag_64b(make_coord(_0{}, _0{}, 2 * i + 1), mi)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void convert_type_out(Tensor const &tensor, Tensor &out) { + // Somehow if we allocate out inside this function and return it, e2e is slower and the output can be wrong. + using From_type = typename Engine::value_type; + using To_type = typename EngineOut::value_type; + static constexpr int FragmentSize = std::max(sizeof(From_type) / sizeof(To_type), sizeof(To_type) / sizeof(From_type)); + static_assert(CUTE_STATIC_V(size(tensor)) % FragmentSize == 0, "Fragment size does not vectorize properly"); + Tensor frag = recast const>(tensor); + Tensor out_frg = recast>(out); + static_assert(size(frag) == size(out_frg)); + cutlass::NumericArrayConverter convert_op; + #pragma unroll + for (int i = 0; i < size(frag); ++i) { out_frg[i] = convert_op(frag[i]); } +} +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/csrc/flashmla_utils.h b/csrc/flashmla_utils.h new file mode 100644 index 00000000..5b8741ab --- /dev/null +++ b/csrc/flashmla_utils.h @@ -0,0 +1,3 @@ +#pragma once + +#include "utils.h" diff --git a/csrc/kerutils/include/kerutils/supplemental/torch_tensors.h b/csrc/kerutils/include/kerutils/supplemental/torch_tensors.h index 5b4e564c..0a0c57e7 100644 --- a/csrc/kerutils/include/kerutils/supplemental/torch_tensors.h +++ b/csrc/kerutils/include/kerutils/supplemental/torch_tensors.h @@ -2,7 +2,7 @@ #include -#include +#include #include "kerutils/common/common.h" diff --git a/csrc/python_api.cpp b/csrc/python_api.cpp new file mode 100644 index 00000000..2cda3be1 --- /dev/null +++ b/csrc/python_api.cpp @@ -0,0 +1,164 @@ +#include +#include +#include +#include + +#include +#include +#include + +#include + +#include "api/common.h" +#include "api/dense_decode.h" +#include "api/sparse_decode.h" +#include "api/sparse_fwd.h" + +std::vector get_mla_decoding_metadata( + at::Tensor& seqlens_k, + const int64_t num_q_tokens_per_head_k, + const int64_t h_k, + const std::optional h_q, + const bool is_fp8_kvcache, + const std::optional topk) { + TORCH_CHECK(seqlens_k.is_cuda(), "seqlens_k must be on CUDA device"); + TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); + TORCH_CHECK(seqlens_k.is_contiguous(), "seqlens_k must be contiguous"); + + const int batch_size = seqlens_k.size(0); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + + const int num_q_tokens_per_head_k_int = static_cast(num_q_tokens_per_head_k); + const int h_k_int = static_cast(h_k); + const std::optional h_q_int = h_q.has_value() ? std::make_optional(static_cast(*h_q)) : std::nullopt; + const std::optional topk_int = topk.has_value() ? std::make_optional(static_cast(*topk)) : std::nullopt; + + TORCH_CHECK(h_k_int > 0, "num_heads_k must be positive"); + if (topk_int.has_value()) { + TORCH_CHECK(h_q_int.has_value(), "num_heads_q must be provided when topk is provided"); + TORCH_CHECK(is_fp8_kvcache, "Sparse decoding requires is_fp8_kvcache=true"); + } + + // Keep dense FP8 metadata on the dedicated API path for compatibility. + TORCH_CHECK(!(is_fp8_kvcache && !topk_int.has_value()), + "Use get_mla_decoding_metadata_dense_fp8 for dense fp8 metadata"); + + const int num_heads_q = h_q_int.value_or(h_k_int * num_q_tokens_per_head_k_int); + TORCH_CHECK(num_heads_q > 0, "num_heads_q must be positive"); + + const int heads_ratio = std::max(1, num_heads_q / h_k_int); + const int s_q = std::max(1, num_q_tokens_per_head_k_int / heads_ratio); + + Arch arch; + int num_sm_parts; + if (topk_int.has_value()) { + if (arch.is_sm100f()) { + // sm100 sparse kernels use a larger split count envelope. + num_sm_parts = std::max(arch.num_sms / s_q, 1); + } else { + const int heads_per_64 = std::max(1, num_heads_q / 64); + num_sm_parts = std::max(arch.num_sms / s_q / heads_per_64, 1); + } + } else { + num_sm_parts = std::max( + arch.num_sms / h_k_int / cutlass::ceil_div(s_q * num_heads_q / h_k_int, 64), + 1); + } + + at::cuda::CUDAGuard device_guard{static_cast(seqlens_k.get_device())}; + auto opts = seqlens_k.options().dtype(torch::kInt32); + + at::Tensor tile_scheduler_metadata = + torch::empty({num_sm_parts, DecodingSchedMetaSize / static_cast(sizeof(int))}, opts); + at::Tensor num_splits = torch::empty({batch_size + 1}, opts); + + GetDecodeSchedMetaParams params = { + batch_size, + s_q, + 64, + 5, + topk_int.value_or(-1), + -1, + nullptr, + nullptr, + seqlens_k.data_ptr(), + reinterpret_cast(tile_scheduler_metadata.data_ptr()), + num_splits.data_ptr(), + num_sm_parts, + at::cuda::getCurrentCUDAStream().stream(), + }; + + smxx::decode::run_get_decoding_sched_meta_kernel(params); + return {tile_scheduler_metadata, num_splits}; +} + +std::vector fwd_kvcache_mla( + at::Tensor& q, + const at::Tensor& kcache, + const int64_t head_size_v, + const at::Tensor& seqlens_k, + const at::Tensor& block_table, + const double softmax_scale, + bool is_causal, + const at::Tensor& tile_scheduler_metadata, + const at::Tensor& num_splits, + const bool& is_fp8, + const std::optional& indices) { + const int head_size_v_int = static_cast(head_size_v); + const float softmax_scale_float = static_cast(softmax_scale); + + std::optional tile_scheduler_metadata_opt = tile_scheduler_metadata; + std::optional num_splits_opt = num_splits; + + if (indices.has_value()) { + TORCH_CHECK(is_fp8, "Sparse decode path requires is_fp8=true"); + auto result = sparse_attn_decode_interface( + q, + kcache, + indices.value(), + std::nullopt, + std::nullopt, + tile_scheduler_metadata_opt, + num_splits_opt, + std::nullopt, + std::nullopt, + std::nullopt, + head_size_v_int, + softmax_scale_float); + return {std::get<0>(result), std::get<1>(result)}; + } + + TORCH_CHECK(!is_fp8, + "Dense FP8 decode is exposed via fwd_kvcache_mla_fp8, not fwd_kvcache_mla"); + auto result = dense_attn_decode_interface( + q, + kcache, + head_size_v_int, + seqlens_k, + block_table, + softmax_scale_float, + is_causal, + tile_scheduler_metadata_opt, + num_splits_opt); + return {std::get<0>(result), std::get<1>(result)}; +} + +std::vector sparse_prefill_fwd( + const at::Tensor& q, + const at::Tensor& kv, + const at::Tensor& indices, + double sm_scale, + int64_t d_v) { + auto result = sparse_attn_prefill_interface( + q, + kv, + indices, + static_cast(sm_scale), + static_cast(d_v), + std::nullopt, + std::nullopt); + // Keep SGL compatibility: this API historically returns max_logits/lse in log2 space. + result[1].mul_(LOG_2_E); + result[2].mul_(LOG_2_E); + return result; +} diff --git a/csrc/sm100/decode/head64/config.h b/csrc/sm100/decode/head64/config.h index 401f3acf..a7ba1f65 100644 --- a/csrc/sm100/decode/head64/config.h +++ b/csrc/sm100/decode/head64/config.h @@ -3,6 +3,7 @@ #include "kernel.h" #include +#include #include #include @@ -38,6 +39,7 @@ static constexpr int D_ROPE = 64; static constexpr int QUANT_TILE_SIZE = MODEL_TYPE == ModelType::V32 ? 128 : 64; static constexpr bool V_HAVE_ROPE = MODEL_TYPE == ModelType::V32 ? false : true; static constexpr int NUM_SCALES_EACH_TOKEN = MODEL_TYPE == ModelType::V32 ? 4 : 8; // Padding is included +using scale_t = std::conditional_t; static constexpr int TMA_K_STRIDE = MODEL_TYPE == ModelType::V32 ? D_NOPE+2*D_ROPE+4*(D_NOPE/QUANT_TILE_SIZE) : D_NOPE+2*D_ROPE; // Stride of K's tensormap. This stride must 1) be a factor of the actual stride between tokens 2) large enough to cover the entire KV cache. Since TMA copy's coordinate can only be 32bit signed integers, this number must >= 128, perferrably >= 256. So we set this to 656 for V32 and 576 for MODEL1. Extra padding may be necessary for KV blocks. static_assert(D_NOPE + D_ROPE == D_Q); static_assert(V_HAVE_ROPE ? (D_NOPE + D_ROPE == D_V) : (D_NOPE == D_V)); @@ -45,7 +47,7 @@ static_assert(V_HAVE_ROPE ? (D_NOPE + D_ROPE == D_V) : (D_NOPE == D_V)); static constexpr int B_H = 64; static constexpr int B_TOPK = 64; static constexpr int NUM_BUFS = 2; -static constexpr int NUM_INDEX_BUFS = 4; // Number of buffers for indices (tma_coords) & is_token_valid & scales +static constexpr int NUM_INDEX_BUFS = MODEL_TYPE == ModelType::V32 ? 2 : 4; // Number of buffers for indices (tma_coords) & is_token_valid & scales static constexpr int NUM_THREADS = 128*3; // 128 exp + 1/32 utcmma + 1/32 raw KV producer + 1/32 rope producer + 32 index+scale+valid_mask producer + 128 dequant static constexpr float MAX_INIT_VAL = -1e30f; // To avoid (-inf) - (-inf) = NaN @@ -182,7 +184,7 @@ struct SharedMemoryPlan { CUTE_ALIGNAS(16) float rowwise_max_buf[128]; char is_token_valid[NUM_INDEX_BUFS][B_TOPK/8]; int tma_coord[NUM_INDEX_BUFS][B_TOPK]; - e8m0 scales[NUM_INDEX_BUFS][B_TOPK][NUM_SCALES_EACH_TOKEN]; + scale_t scales[NUM_INDEX_BUFS][B_TOPK][NUM_SCALES_EACH_TOKEN]; array_aligned tmem_start_addr; transac_bar_t bar_last_store_done; transac_bar_t bar_q_tma, bar_q_utccp; diff --git a/csrc/sm100/decode/head64/kernel.cuh b/csrc/sm100/decode/head64/kernel.cuh index 7c46921c..b1a26a7d 100644 --- a/csrc/sm100/decode/head64/kernel.cuh +++ b/csrc/sm100/decode/head64/kernel.cuh @@ -9,7 +9,7 @@ #include "kerutils/kerutils.cuh" -#include "utils.h" +#include "flashmla_utils.h" #include "sm100/helpers.h" #include "config.h" @@ -691,7 +691,7 @@ KernelTemplate plan.bar_valid_coord_scale_free[rs.index_buf_idx].wait(rs.index_bar_phase^1); int tma_coords[2]; - e8m0 scales[2*NUM_SCALES_EACH_TOKEN]; + scale_t scales[2*NUM_SCALES_EACH_TOKEN]; char valid_mask = 0; CUTE_UNROLL for (int i = 0; i < 2; ++i) { @@ -704,11 +704,13 @@ KernelTemplate if constexpr (MODEL_TYPE == ModelType::V32) { int64_t offset = is_token_valid ? block_idx*cur_k_block_stride + idx_in_block*cur_k_row_stride : 0; float4 cur_scale_fp32 = __ldg((float4*)(cur_k_scales_ptr + offset)); - e8m0 res[4]; - *(__nv_fp8x2_storage_t*)(res+0) = __nv_cvt_float2_to_e8m0x2(float2{cur_scale_fp32.x, cur_scale_fp32.y}, __NV_NOSAT, cudaRoundZero); - *(__nv_fp8x2_storage_t*)(res+2) = __nv_cvt_float2_to_e8m0x2(float2{cur_scale_fp32.z, cur_scale_fp32.w}, __NV_NOSAT, cudaRoundZero); - if (!is_token_valid) *(uint32_t*)res = (uint32_t)0; - *(uint32_t*)(scales+i*NUM_SCALES_EACH_TOKEN) = *(uint32_t*)(res); + __nv_bfloat16 res[4]; + res[0] = __float2bfloat16(cur_scale_fp32.x); + res[1] = __float2bfloat16(cur_scale_fp32.y); + res[2] = __float2bfloat16(cur_scale_fp32.z); + res[3] = __float2bfloat16(cur_scale_fp32.w); + if (!is_token_valid) *(uint64_t*)res = (uint64_t)0; + *(uint64_t*)(scales+i*NUM_SCALES_EACH_TOKEN) = *(uint64_t*)(res); } else { int64_t offset = block_idx*cur_k_block_stride + idx_in_block*8; // Each token has 7 scale factors with an extra 1B padding uint64_t scalesx8 = is_token_valid ? __ldg((uint64_t*)(cur_k_scales_ptr + offset)) : 0; @@ -719,7 +721,7 @@ KernelTemplate valid_mask |= __shfl_xor_sync(0xFFFFFFFF, valid_mask, 0x1); valid_mask |= __shfl_xor_sync(0xFFFFFFFF, valid_mask, 0x2); if constexpr (MODEL_TYPE == ModelType::V32) { - *(uint64_t*)(plan.scales[rs.index_buf_idx] + lane_idx*2) = *(uint64_t*)scales; + *(__int128_t*)(plan.scales[rs.index_buf_idx] + lane_idx*2) = *(__int128_t*)scales; } else { *(__int128_t*)(plan.scales[rs.index_buf_idx] + lane_idx*2) = *(__int128_t*)scales; } @@ -783,10 +785,7 @@ KernelTemplate for (int local_row_idx = 0; local_row_idx < ROWS_PER_GROUP; ++local_row_idx) { int row_idx = local_row_idx*NUM_GROUPS + group_idx; bf16 scales[4]; - e8m0 scales_e8m0[4]; - *(uint32_t*)scales_e8m0 = *(uint32_t*)plan.scales[rs.index_buf_idx][row_idx]; - *(__nv_bfloat162_raw*)(scales+0) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+0)); - *(__nv_bfloat162_raw*)(scales+2) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+2)); + *(uint64_t*)scales = *(uint64_t*)plan.scales[rs.index_buf_idx][row_idx]; uint64_t cur_data_fp8x8 = get_raw_fp8(local_row_idx, 0); CUTE_UNROLL diff --git a/csrc/sm100/prefill/dense/common/utils.hpp b/csrc/sm100/prefill/dense/common/utils.hpp index fdaeff08..f4b7767b 100644 --- a/csrc/sm100/prefill/dense/common/utils.hpp +++ b/csrc/sm100/prefill/dense/common/utils.hpp @@ -1,6 +1,6 @@ #pragma once -#include +#include #include "cutlass/numeric_types.h" #include "helper.h" @@ -30,4 +30,4 @@ struct cutlass_dtype<__nv_fp8_e5m2> { }; template -using cutlass_dtype_t = typename cutlass_dtype::type; \ No newline at end of file +using cutlass_dtype_t = typename cutlass_dtype::type; diff --git a/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu b/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu index ab66f0fd..5fd6303e 100644 --- a/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu +++ b/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu @@ -31,15 +31,19 @@ void call_run_fmha_fwd([[maybe_unused]] Mask mask, [[maybe_unused]] Varlen is_va void FMHACutlassSM100FwdRun(at::Tensor workspace_buffer, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, at::Tensor o, at::Tensor lse, - int mask_mode_code, float sm_scale, int max_seqlen_q, - int max_seqlen_kv, bool is_varlen) { + int64_t mask_mode_code, double sm_scale, int64_t max_seqlen_q, + int64_t max_seqlen_kv, bool is_varlen) { const c10::cuda::OptionalCUDAGuard device_guard(q.device()); CHECK(q.scalar_type() == k.scalar_type()); auto scalar_type_in = q.scalar_type(); auto scalar_type_out = o.scalar_type(); int head_dim_qk = q.size(-1); int head_dim_vo = v.size(-1); - MaskMode mask_mode = static_cast(mask_mode_code); + const int mask_mode_code_i32 = static_cast(mask_mode_code); + const float sm_scale_f32 = static_cast(sm_scale); + const int max_seqlen_q_i32 = static_cast(max_seqlen_q); + const int max_seqlen_kv_i32 = static_cast(max_seqlen_kv); + MaskMode mask_mode = static_cast(mask_mode_code_i32); if (scalar_type_in == at::ScalarType::BFloat16 && scalar_type_out == at::ScalarType::BFloat16) { @@ -65,12 +69,12 @@ void FMHACutlassSM100FwdRun(at::Tensor workspace_buffer, at::Tensor q, at::Tenso apply_config([&](auto mask, auto varlen, auto in, auto out) { if (head_dim_qk == 192 && head_dim_vo == 128) { call_run_fmha_fwd(mask, varlen, in, out, true_type{}, workspace_buffer, q, k, v, - cumulative_seqlen_q, cumulative_seqlen_kv, o, lse, sm_scale, - max_seqlen_q, max_seqlen_kv); + cumulative_seqlen_q, cumulative_seqlen_kv, o, lse, sm_scale_f32, + max_seqlen_q_i32, max_seqlen_kv_i32); } else if (head_dim_qk == 128 && head_dim_vo == 128) { call_run_fmha_fwd(mask, varlen, in, out, false_type{}, workspace_buffer, q, k, v, - cumulative_seqlen_q, cumulative_seqlen_kv, o, lse, sm_scale, - max_seqlen_q, max_seqlen_kv); + cumulative_seqlen_q, cumulative_seqlen_kv, o, lse, sm_scale_f32, + max_seqlen_q_i32, max_seqlen_kv_i32); } else { std::cout << "No kernel instantiated for head_dim_qk=" << head_dim_qk << " head_dim_vo=" << head_dim_vo << std::endl; diff --git a/csrc/sm100/prefill/dense/interface.h b/csrc/sm100/prefill/dense/interface.h index 80ef2bca..41163366 100644 --- a/csrc/sm100/prefill/dense/interface.h +++ b/csrc/sm100/prefill/dense/interface.h @@ -1,11 +1,14 @@ #pragma once +#include + #include void FMHACutlassSM100FwdRun(at::Tensor workspace_buffer, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, at::Tensor o, at::Tensor lse, - int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen); + int64_t mask_mode_code, double softmax_scale, int64_t max_seqlen_q, int64_t max_seqlen_kv, + bool is_varlen); void FMHACutlassSM100BwdRun(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor o, at::Tensor lse, diff --git a/csrc/sm100/prefill/sparse/fwd/head128/phase1.cuh b/csrc/sm100/prefill/sparse/fwd/head128/phase1.cuh index ec1192b9..c1a32314 100644 --- a/csrc/sm100/prefill/sparse/fwd/head128/phase1.cuh +++ b/csrc/sm100/prefill/sparse/fwd/head128/phase1.cuh @@ -9,7 +9,7 @@ #include #include "params.h" -#include "utils.h" +#include "flashmla_utils.h" #include "sm100/helpers.h" #include "config.h" diff --git a/csrc/sm100/prefill/sparse/fwd/head64/phase1.cuh b/csrc/sm100/prefill/sparse/fwd/head64/phase1.cuh index b510b27f..2831e6fd 100644 --- a/csrc/sm100/prefill/sparse/fwd/head64/phase1.cuh +++ b/csrc/sm100/prefill/sparse/fwd/head64/phase1.cuh @@ -10,7 +10,7 @@ #include #include "params.h" -#include "utils.h" +#include "flashmla_utils.h" #include "sm100/helpers.h" #include "sm100/prefill/sparse/common_subroutine.h" #include "config.h" diff --git a/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.cuh b/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.cuh index 388abc87..a550f3a7 100644 --- a/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.cuh +++ b/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.cuh @@ -8,7 +8,7 @@ #include #include "params.h" -#include "utils.h" +#include "flashmla_utils.h" #include "sm100/prefill/sparse/common_subroutine.h" #include "sm100/helpers.h" diff --git a/csrc/sm90/decode/dense/splitkv_mla.cuh b/csrc/sm90/decode/dense/splitkv_mla.cuh index cdd54413..2768b87b 100644 --- a/csrc/sm90/decode/dense/splitkv_mla.cuh +++ b/csrc/sm90/decode/dense/splitkv_mla.cuh @@ -1,6 +1,6 @@ #include -#include "utils.h" +#include "flashmla_utils.h" #include "params.h" #include "config.h" diff --git a/csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh b/csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh index 99945689..5509af15 100644 --- a/csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh +++ b/csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh @@ -11,7 +11,7 @@ #include -#include "utils.h" +#include "flashmla_utils.h" #include "components/dequant.h" #include "components/helpers.h" #include "config.h" diff --git a/csrc/sm90/prefill/sparse/config.h b/csrc/sm90/prefill/sparse/config.h index 75005664..a60a97a2 100644 --- a/csrc/sm90/prefill/sparse/config.h +++ b/csrc/sm90/prefill/sparse/config.h @@ -15,7 +15,11 @@ namespace sm90::fwd { using namespace cute; -template +template< + int D_QK, + bool HAVE_TOPK_LENGTH, + bool ENABLE_ODD_TAIL_SKIP +> class KernelTemplate { public: diff --git a/csrc/sm90/prefill/sparse/phase1.cuh b/csrc/sm90/prefill/sparse/phase1.cuh index bf2fff84..8c900288 100644 --- a/csrc/sm90/prefill/sparse/phase1.cuh +++ b/csrc/sm90/prefill/sparse/phase1.cuh @@ -2,7 +2,7 @@ #include "config.h" -#include "utils.h" +#include "flashmla_utils.h" #include "../../helpers.h" namespace sm90::fwd { @@ -39,9 +39,9 @@ void tma_bulk_reduce_add(void const* src_ptr, void* dst_ptr, int32_t store_bytes : "memory"); } -template +template template -__device__ void KernelTemplate::devfunc(const SparseAttnFwdParams ¶ms, const TMAParams &tma_params) { +__device__ void KernelTemplate::devfunc(const SparseAttnFwdParams ¶ms, const TMAParams &tma_params) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 900)) || (defined(__CLION_IDE__) || defined(__VSCODE_IDE__)) const int q_h_idx = blockIdx.x % (params.h_q/B_H); const int s_q_idx = blockIdx.x / (params.h_q/B_H); @@ -305,6 +305,10 @@ __device__ void KernelTemplate::devfunc(const SparseAttn CUTE_NO_UNROLL for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) { + bool has_peer_block = true; + if constexpr (ENABLE_ODD_TAIL_SKIP) { + has_peer_block = block_idx + 1 < num_topk_blocks; + } Tensor sV0l = make_tensor(make_smem_ptr(plan.k[0].data()), SmemLayoutKTilesTransposed<4>{}); Tensor sV1l = make_tensor(make_smem_ptr(plan.k[1].data()), SmemLayoutKTilesTransposed<4>{}); @@ -330,6 +334,14 @@ __device__ void KernelTemplate::devfunc(const SparseAttn warpgroup_wait<0>(); plan.bar_k0_free[0].arrive(); + if (!has_peer_block) { + save_rS_to_sS(rS, sS0, idx_in_warpgroup); + fence_view_async_shared(); + NamedBarrier::arrive(256, NamedBarriers::wg0_s0_ready); + cur_bar_wait_phase ^= 1; + continue; + } + // Wait for new sM, scale rS, save, inform WG1 NamedBarrier::arrive_and_wait(256, NamedBarriers::wg1_bunch_0_ready); float new_rM[2], scale_factors[2]; @@ -396,9 +408,43 @@ __device__ void KernelTemplate::devfunc(const SparseAttn CUTE_NO_UNROLL for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) { + bool has_peer_block = true; + if constexpr (ENABLE_ODD_TAIL_SKIP) { + has_peer_block = block_idx + 1 < num_topk_blocks; + } Tensor sV0r = make_tensor(make_smem_ptr(plan.k[0].data()+64*256), SmemLayoutKTilesTransposed<4>{}); Tensor sV1r = make_tensor(make_smem_ptr(plan.k[1].data()+64*256), SmemLayoutKTilesTransposed<4>{}); + if (!has_peer_block) { + NamedBarrier::arrive_and_wait(256, NamedBarriers::wg0_bunch_0_ready); + float new_rM[2], scale_factors[2]; + *(float2*)new_rM = plan.sM[idx_in_warpgroup/4]; + CUTE_UNROLL + for (int row = 0; row < 2; ++row) { + scale_factors[row] = exp2f(rM[row] - new_rM[row]); + rM[row] = new_rM[row]; + rL[row] *= scale_factors[row]; + } + CUTE_UNROLL + for (int row = 0; row < 2; ++row) { + CUTE_UNROLL + for (int i = row*2; i < size(rO); i += 4) { + rO(i) *= scale_factors[row]; + rO(i+1) *= scale_factors[row]; + } + } + + NamedBarrier::arrive_and_wait(256, NamedBarriers::wg0_s0_ready); + gemm_ss(false, TiledMMA_PV_RemoteP{}, sS0, sV0r, rO, idx_in_warpgroup); + warpgroup_commit_batch(); + + warpgroup_wait<0>(); + plan.bar_k0_free[1].arrive(); + + cur_bar_wait_phase ^= 1; + continue; + } + // Issue rP1 = sQ @ sK1, and wait pipelined_wait_and_qkt_gemm(); warpgroup_wait<0>(); @@ -475,23 +521,38 @@ __device__ void KernelTemplate::devfunc(const SparseAttn int64_t token_indices[2][NUM_ROWS_PER_GROUP]; bool is_token_valid[2][NUM_ROWS_PER_GROUP]; auto load_token_indices = [&](int block_idx) { + [[maybe_unused]] int src_lane = (idx_in_warpgroup & 31) & ~(GROUP_SIZE - 1); CUTE_UNROLL for (int buf_idx = 0; buf_idx < 2; ++buf_idx) { CUTE_UNROLL for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) { int offs = (block_idx+buf_idx)*B_TOPK + local_row*NUM_GROUPS + group_idx; - int t = __ldg(gIndices + offs); - token_indices[buf_idx][local_row] = t*(int64_t)params.stride_kv_s_kv; // We mult it with params.stride_kv_s_kv here since it's faster + int t; + if constexpr (ENABLE_ODD_TAIL_SKIP) { + t = -1; + if (idx_in_group == 0) { + t = __ldg(gIndices + offs); + } + t = __shfl_sync(0xffffffff, t, src_lane); + } else { + t = __ldg(gIndices + offs); + } bool is_cur_token_valid = t >= 0 && t < params.s_kv; if constexpr (HAVE_TOPK_LENGTH) { is_cur_token_valid &= offs < topk_length; } + token_indices[buf_idx][local_row] = t*(int64_t)params.stride_kv_s_kv; // We mult it with params.stride_kv_s_kv here since it's faster is_token_valid[buf_idx][local_row] = is_cur_token_valid; } } }; - int64_t cache_policy = createpolicy_evict_last(); + int64_t cache_policy; + if constexpr (ENABLE_ODD_TAIL_SKIP) { + cache_policy = createpolicy_evict_first(); + } else { + cache_policy = createpolicy_evict_last(); + } auto copy_tiles = [&](int block_idx, int buf_idx, int tile_start, int tile_end) { // Copy some K/V tiles from global memory to shared memory // A tile has a shape of 64 (B_TOPK) x 64 @@ -520,42 +581,55 @@ __device__ void KernelTemplate::devfunc(const SparseAttn CUTE_NO_UNROLL for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) { - load_token_indices(block_idx); + bool has_peer_block = true; + if constexpr (ENABLE_ODD_TAIL_SKIP) { + has_peer_block = block_idx + 1 < num_topk_blocks; + } + load_token_indices(block_idx); + + // V0L + plan.bar_k0_free[0].wait(cur_bar_wait_phase); + copy_tiles(block_idx+0, 0, 0, 4); + commit_to_mbar(plan.bar_k0_ready[0]); + + if (has_peer_block) { + // V1R + plan.bar_k1_free[1].wait(cur_bar_wait_phase); + copy_tiles(block_idx+1, 1, 4, D_K/64); + commit_to_mbar(plan.bar_k1_ready[1]); + } - // V0L - plan.bar_k0_free[0].wait(cur_bar_wait_phase); - copy_tiles(block_idx+0, 0, 0, 4); - commit_to_mbar(plan.bar_k0_ready[0]); + // V0R + plan.bar_k0_free[1].wait(cur_bar_wait_phase); + copy_tiles(block_idx+0, 0, 4, D_K/64); + commit_to_mbar(plan.bar_k0_ready[1]); - // V1R - plan.bar_k1_free[1].wait(cur_bar_wait_phase); - copy_tiles(block_idx+1, 1, 4, D_K/64); - commit_to_mbar(plan.bar_k1_ready[1]); - - // V0R - plan.bar_k0_free[1].wait(cur_bar_wait_phase); - copy_tiles(block_idx+0, 0, 4, D_K/64); - commit_to_mbar(plan.bar_k0_ready[1]); - - // V1L - plan.bar_k1_free[0].wait(cur_bar_wait_phase); - copy_tiles(block_idx+1, 1, 0, 4); - commit_to_mbar(plan.bar_k1_ready[0]); - - // Valid mask - // NOTE: V1R's finish implies maskings of the last round have finished - if (idx_in_group == 0) { - CUTE_UNROLL - for (int buf_idx = 0; buf_idx < 2; ++buf_idx) + if (has_peer_block) { + // V1L + plan.bar_k1_free[0].wait(cur_bar_wait_phase); + copy_tiles(block_idx+1, 1, 0, 4); + commit_to_mbar(plan.bar_k1_ready[0]); + } + + // Valid mask + // NOTE: V1R's finish implies maskings of the last round have finished + if (idx_in_group == 0) { CUTE_UNROLL - for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) - plan.is_kv_valid[buf_idx][local_row*NUM_GROUPS+group_idx] = is_token_valid[buf_idx][local_row]; - plan.bar_is_kv_valid_ready.arrive(); - } + for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) { + plan.is_kv_valid[0][local_row*NUM_GROUPS+group_idx] = is_token_valid[0][local_row]; + } + if (has_peer_block) { + CUTE_UNROLL + for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) { + plan.is_kv_valid[1][local_row*NUM_GROUPS+group_idx] = is_token_valid[1][local_row]; + } + } + plan.bar_is_kv_valid_ready.arrive(); + } - cur_bar_wait_phase ^= 1; + cur_bar_wait_phase ^= 1; + } } - } #else @@ -571,8 +645,8 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparseAttnFwdParams params, __gri Kernel::devfunc(params, tma_params); } -template -void KernelTemplate::run(const SparseAttnFwdParams ¶ms) { +template +void KernelTemplate::run(const SparseAttnFwdParams ¶ms) { KU_ASSERT(params.h_kv == 1); KU_ASSERT(params.topk % (2*B_TOPK) == 0); // To save some boundry checkings KU_ASSERT(params.topk > 0); @@ -620,7 +694,7 @@ void KernelTemplate::run(const SparseAttnFwdParams ¶ shape_Q, tma_Q, tensor_map_O }; - auto kernel = &sparse_attn_fwd_kernel, decltype(tma_params)>; + auto kernel = &sparse_attn_fwd_kernel, decltype(tma_params)>; constexpr size_t smem_size = sizeof(SharedMemoryPlan); KU_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -640,7 +714,14 @@ void KernelTemplate::run(const SparseAttnFwdParams ¶ template void run_fwd_phase1_kernel(const SparseAttnFwdParams& params) { - KernelTemplate::run(params); + // In the merged DSV4 layout, topk == 128 is the pure SWA case. + // Keep SWA on the original path. C128/C4 use the odd-tail-capable path; + // within that path, the peer block is still checked per query. + if (params.topk == 128) { + KernelTemplate::run(params); + } else { + KernelTemplate::run(params); + } } } diff --git a/csrc/smxx/decode/combine/combine.cu b/csrc/smxx/decode/combine/combine.cu index 283f9364..cbd906c1 100644 --- a/csrc/smxx/decode/combine/combine.cu +++ b/csrc/smxx/decode/combine/combine.cu @@ -9,7 +9,7 @@ #include #include "params.h" -#include "utils.h" +#include "flashmla_utils.h" using namespace cute; diff --git a/csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu b/csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu index 083da60c..925fa53f 100644 --- a/csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu +++ b/csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu @@ -4,10 +4,29 @@ #include #include -#include "utils.h" +#include "flashmla_utils.h" namespace smxx::decode { +__device__ __forceinline__ int get_effective_seq_len( + const GetDecodeSchedMetaParams& params, + int req_idx, + int block_size_n) { + if (params.topk == -1) { + return params.seqlens_k_ptr[req_idx]; + } + + int cur_s_k = params.topk_length ? params.topk_length[req_idx] : params.topk; + if (cur_s_k == 0) { + cur_s_k = 1; + } + if (params.extra_topk != -1) { + cur_s_k = ((cur_s_k + block_size_n - 1) / block_size_n) * block_size_n; + cur_s_k += params.extra_topk_length ? params.extra_topk_length[req_idx] : params.extra_topk; + } + return cur_s_k; +} + __global__ void __launch_bounds__(32, 1, 1) get_mla_metadata_kernel(__grid_constant__ const GetDecodeSchedMetaParams params) { int *seqlens_k_ptr = params.seqlens_k_ptr; @@ -106,10 +125,110 @@ get_mla_metadata_kernel(__grid_constant__ const GetDecodeSchedMetaParams params) } } +// Fallback path when dynamic shared memory requirement exceeds HW limit. +__global__ void __launch_bounds__(32, 1, 1) +get_mla_metadata_kernel_low_smem(__grid_constant__ const GetDecodeSchedMetaParams params) { + if (threadIdx.x != 0) { + return; + } + + int batch_size = params.b; + int block_size_n = params.block_size_n; + int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks; + int num_sm_parts = params.num_sm_parts; + + int total_num_blocks = 0; + for (int req_idx = 0; req_idx < batch_size; ++req_idx) { + int cur_s_k = get_effective_seq_len(params, req_idx, block_size_n); + int last_token_idx = max(cur_s_k - 1, 0); + int num_blocks = last_token_idx / block_size_n + 1; + total_num_blocks += num_blocks + fixed_overhead_num_blocks; + } + + int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks; + int now_req_idx = 0; + int now_block = 0; + int now_n_split_idx = 0; + int cum_num_splits = 0; + params.num_splits_ptr[0] = 0; + + for (int part = 0; part < num_sm_parts; ++part) { + DecodingSchedMeta cur_meta; + cur_meta.begin_req_idx = now_req_idx; + cur_meta.begin_block_idx = now_block; + cur_meta.begin_split_idx = now_n_split_idx; + cur_meta.is_first_req_splitted = (now_block != 0); + + int remain_payload = payload; + while (now_req_idx < batch_size) { + int cur_s_k = get_effective_seq_len(params, now_req_idx, block_size_n); + int last_token_idx = max(cur_s_k - 1, 0); + int num_blocks = last_token_idx / block_size_n + 1; + int now_remain_blocks = num_blocks - now_block; + + if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) { + cum_num_splits += now_n_split_idx + 1; + params.num_splits_ptr[now_req_idx + 1] = cum_num_splits; + remain_payload -= now_remain_blocks + fixed_overhead_num_blocks; + ++now_req_idx; + now_block = 0; + now_n_split_idx = 0; + } else { + if (remain_payload - fixed_overhead_num_blocks > 0) { + now_block += remain_payload - fixed_overhead_num_blocks; + ++now_n_split_idx; + remain_payload = 0; + } + break; + } + } + + cur_meta.end_req_idx = now_block > 0 ? now_req_idx : now_req_idx - 1; + cur_meta.end_block_idx = now_block; + if (now_block > 0) { + int cur_s_k = get_effective_seq_len(params, cur_meta.end_req_idx, block_size_n); + int cur_last_block_idx = max(cur_s_k - 1, 0) / block_size_n; + cur_meta.is_last_req_splitted = cur_meta.end_block_idx != cur_last_block_idx + 1 && cur_s_k != 0; + } else { + int prev_s_k = get_effective_seq_len(params, cur_meta.end_req_idx, block_size_n); + int prev_last_block_idx = max(prev_s_k - 1, 0) / block_size_n; + cur_meta.end_block_idx = prev_s_k == 0 ? 0 : prev_last_block_idx + 1; + cur_meta.is_last_req_splitted = false; + } + if (cur_meta.begin_req_idx == cur_meta.end_req_idx) { + cur_meta.is_first_req_splitted = cur_meta.is_last_req_splitted = + cur_meta.is_first_req_splitted || cur_meta.is_last_req_splitted; + } + params.tile_scheduler_metadata_ptr[part] = cur_meta; + } + + FLASH_DEVICE_ASSERT(now_req_idx == batch_size && now_block == 0 && now_n_split_idx == 0); +} + void run_get_decoding_sched_meta_kernel(GetDecodeSchedMetaParams ¶ms) { int smem_size = sizeof(int) * (params.b*5+1); - CHECK_CUDA(cudaFuncSetAttribute(get_mla_metadata_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - get_mla_metadata_kernel<<<1, 32, smem_size, params.stream>>>(params); + int max_smem = 0; + int dev = 0; + CHECK_CUDA(cudaGetDevice(&dev)); + CHECK_CUDA(cudaDeviceGetAttribute( + &max_smem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev)); + + if (smem_size <= max_smem) { + CHECK_CUDA(cudaFuncSetAttribute( + get_mla_metadata_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + get_mla_metadata_kernel<<<1, 32, smem_size, params.stream>>>(params); + } else { + printf("[WARNING] batch_size=%d requires %dB shared memory (max=%dB), using low-smem fallback kernel.\n", + params.b, smem_size, max_smem); + fflush(stdout); + CHECK_CUDA(cudaFuncSetAttribute( + get_mla_metadata_kernel_low_smem, + cudaFuncAttributeMaxDynamicSharedMemorySize, + 0)); + get_mla_metadata_kernel_low_smem<<<1, 32, 0, params.stream>>>(params); + } CHECK_CUDA_KERNEL_LAUNCH(); }