diff --git a/ATTRIBUTIONS-Python.md b/ATTRIBUTIONS-Python.md index a5e5ab387733..c4cc08df553b 100644 --- a/ATTRIBUTIONS-Python.md +++ b/ATTRIBUTIONS-Python.md @@ -5260,8 +5260,7 @@ For more information, please refer to - `Source`: https://github.com/tox-dev/py-filelock - `Tracker`: https://github.com/tox-dev/py-filelock/issues - -## flashinfer-python (0.6.6) +## flashinfer-python (0.6.8) ### Licenses License: `Apache-2.0` @@ -33239,7 +33238,7 @@ License: `NVIDIA Proprietary Software` - `Homepage`: https://developer.nvidia.com/cusparselt -## nvidia-cutlass-dsl (4.2.1) +## nvidia-cutlass-dsl (4.4.2) ### Licenses License: `None` diff --git a/constraints.txt b/constraints.txt index 7dd0d6747765..6f4ae99d5d95 100644 --- a/constraints.txt +++ b/constraints.txt @@ -6,3 +6,5 @@ wheel>=0.46.2 tornado>=6.5.5 # WAR against https://github.com/advisories/GHSA-3936-cmfr-pm3m black>=26.3.1 +# Upgrade base image nvidia-cutlass-dsl 4.3.5 to 4.4.2 +nvidia-cutlass-dsl>=4.4.2 diff --git a/cpp/tensorrt_llm/batch_manager/rnnStateManager.cpp b/cpp/tensorrt_llm/batch_manager/rnnStateManager.cpp index 54ccdce82279..7608079fb396 100644 --- a/cpp/tensorrt_llm/batch_manager/rnnStateManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/rnnStateManager.cpp @@ -20,8 +20,6 @@ #include "tensorrt_llm/runtime/cudaStream.h" #include "tensorrt_llm/runtime/utils/runtimeUtils.h" -#include - using namespace tensorrt_llm::runtime; namespace tensorrt_llm::batch_manager::rnn_state_manager @@ -258,40 +256,16 @@ std::vector RnnStateManager::getStateIndices( std::vector const& requestIds, std::vector const& isPadding) { TLLM_CHECK_WITH_INFO(requestIds.size() == isPadding.size(), "requestIds and isPadding must have the same size"); - - std::unordered_set availableSlots; - availableSlots.reserve(mMaxNumSequences); - for (SizeType32 i = 0; i < mMaxNumSequences; ++i) - { - availableSlots.insert(i); - } - - for (size_t i = 0; i < requestIds.size(); ++i) - { - if (!isPadding[i]) - { - availableSlots.erase(getCacheIndex(requestIds[i])); - } - } - + // Every id (real or CUDA-graph padding sentinel) has a permanent slot + // allocated by allocateCacheBlocks; padding entries all share their + // sentinel's slot, so they never alias a live request and never + // consume free-pool slots. std::vector result; result.reserve(requestIds.size()); - auto availableIt = availableSlots.begin(); - - for (size_t i = 0; i < requestIds.size(); ++i) + for (auto const& rid : requestIds) { - if (isPadding[i]) - { - TLLM_CHECK_WITH_INFO(availableIt != availableSlots.end(), "Run out of available slots for padding"); - result.push_back(*availableIt); - ++availableIt; - } - else - { - result.push_back(getCacheIndex(requestIds[i])); - } + result.push_back(getCacheIndex(rid)); } - return result; } diff --git a/cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu b/cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu index faa1f2d9fcab..9cea8385c6ed 100644 --- a/cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu +++ b/cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu @@ -56,7 +56,7 @@ struct Causal_conv1d_fwd_kernel_traits static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize; }; -template +template __global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_fwd_kernel(ConvParamsBase params) { constexpr int kWidth = Ktraits::kWidth; @@ -94,13 +94,18 @@ __global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_fwd_kernel(C ? false : reinterpret_cast(params.has_initial_state_ptr)[batch_id]; - int* cache_indices - = params.cache_indices_ptr == nullptr ? nullptr : reinterpret_cast(params.cache_indices_ptr); - int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; - // cache_index == params.pad_slot_id is defined as padding, so we exit early - if (cache_index == params.pad_slot_id) + int cache_index; + if constexpr (kHasConvStateIndices) { - return; + cache_index = reinterpret_cast(params.cache_indices_ptr)[batch_id]; + if (cache_index == params.pad_slot_id) + { + return; + } + } + else + { + cache_index = batch_id; } input_t* conv_states = params.conv_states_ptr == nullptr ? nullptr : reinterpret_cast(params.conv_states_ptr) @@ -121,6 +126,35 @@ __global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_fwd_kernel(C smem_exchange[kNThreads - 1] = reinterpret_cast(initial_state)[0]; } + // Save final conv_state from the tail of x directly, instead of reconstructing it + // from smem_exchange after the main loop. + if (conv_states != nullptr && tidx == 0) + { + if (seqlen >= kWidth - 1) + { +#pragma unroll + for (int w = 0; w < kWidth - 1; ++w) + { + conv_states[w] = x[(seqlen - (kWidth - 1) + w) * params.x_l_stride]; + } + } + else + { +#pragma unroll + for (int w = 0; w < kWidth - 1; ++w) + { + if (w < (kWidth - 1) - seqlen) + { + conv_states[w] = has_initial_state ? conv_states[w + seqlen] : input_t(0.0f); + } + else + { + conv_states[w] = x[(w - ((kWidth - 1) - seqlen)) * params.x_l_stride]; + } + } + } + } + float weight_vals[kWidth]; #pragma unroll for (int i = 0; i < kWidth; ++i) @@ -208,7 +242,7 @@ __global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_fwd_kernel(C out_vals[i + 1] = acc1; } - if (params.silu_activation) + if constexpr (kSiluActivation) { #pragma unroll for (int i = 0; i < kNElts; i += 2) @@ -239,90 +273,6 @@ __global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_fwd_kernel(C typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, seqlen - chunk * kChunkSize); } out += kChunkSize; - - int final_state_position = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize); - // in case the final state is separated between the last "smem_exchange" and - // and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2), - // (which occurs when `final_state_position` is a non-positive index) - // we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it - if (conv_states != nullptr && final_state_position < 0 && seqlen > kWidth) - { - input_t vals_load[kNElts] = {0}; - if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)) - { - // chunk = n_chunks - 2, a segment of the final state sits in the last index - reinterpret_cast(vals_load)[0] = smem_exchange[kNThreads - 1]; -#pragma unroll - for (int w = 0; w < -final_state_position; ++w) - { - conv_states[w] = vals_load[kNElts + final_state_position + w]; - } - } - if ((chunk == n_chunks - 1) && tidx == 0) - { - // chunk = n_chunks - 1, the second segment of the final state first positions - reinterpret_cast(vals_load)[0] = smem_exchange[0]; - for (int w = -final_state_position; w < kWidth - 1; ++w) - { - conv_states[w] = vals_load[w + final_state_position]; - } - return; - } - } - } - // Final state is stored in the smem_exchange last token slot, - // in case seqlen < kWidth, we would need to take the final state from the - // initial state which is stored in conv_states - // in case seqlen > kWidth, we would need to load the last kWidth - 1 data - // and load it into conv_state accordingly - int last_thread = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize) / kNElts; - if (conv_states != nullptr && tidx == last_thread) - { - input_t x_vals_load[kNElts * 2] = {0}; - // in case we are on the first kWidth tokens - if (last_thread == 0 && seqlen < kWidth) - { - // Need to take the initial state - reinterpret_cast(x_vals_load)[0] = smem_exchange[0]; - int const offset = seqlen - (kWidth - 1); -#pragma unroll - for (int w = 0; w < kWidth - 1; ++w) - { - // pad the existing state - if ((w - seqlen) >= 0 && has_initial_state) - { - conv_states[w - seqlen] = conv_states[w]; - } - else if ((w - seqlen) >= 0 && !has_initial_state) - { - conv_states[w - seqlen] = input_t(0.0f); - } - } -#pragma unroll - for (int w = 0; w < kWidth - 1; ++w) - { - if (offset + w >= 0) - conv_states[w] = x_vals_load[offset + w]; - } - } - else - { - // in case the final state is in between the threads data - int const offset = ((seqlen - (kWidth - 1)) % (kNElts)); - if ((offset + kWidth - 2) >= kNElts && (last_thread + 1 < kNThreads)) - { - // In case last_thread == kNThreads - 1, accessing last_thread + 1 will result in a - // illegal access error on H100. - // Therefore, we access last_thread + 1, only if the final state data sits there - reinterpret_cast(x_vals_load)[1] = smem_exchange[last_thread + 1]; - } - reinterpret_cast(x_vals_load)[0] = smem_exchange[last_thread]; -#pragma unroll - for (int w = 0; w < kWidth - 1; ++w) - { - conv_states[w] = x_vals_load[offset + w]; - } - } } } @@ -331,20 +281,31 @@ void causal_conv1d_fwd_launch(ConvParamsBase& params, cudaStream_t stream) { static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; bool const kVarlen = params.query_start_loc_ptr != nullptr; - BOOL_SWITCH(params.seqlen % kNElts == 0 && !kVarlen, kIsVecLoad, + // Enable vectorized 128-bit loads when total tokens are aligned. For varlen with + // batch==1 (common prefill), seq_start is always 0 so alignment is guaranteed. + bool const canVecLoad = params.seqlen % kNElts == 0 && (!kVarlen || params.batch == 1); + BOOL_SWITCH(canVecLoad, kIsVecLoad, [&] { using Ktraits = Causal_conv1d_fwd_kernel_traits; constexpr int kSmemSize = Ktraits::kSmemSize; dim3 grid(params.batch, params.dim); - - auto kernel = &causal_conv1d_fwd_kernel; - - if (kSmemSize >= 48 * 1024) - { - TLLM_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - } - kernel<<>>(params); + bool const hasConvStateIdx = params.cache_indices_ptr != nullptr; + BOOL_SWITCH(hasConvStateIdx, kHasCSI, + [&] + { + BOOL_SWITCH(params.silu_activation, kSilu, + [&] + { + auto kernel = &causal_conv1d_fwd_kernel; + if (kSmemSize >= 48 * 1024) + { + TLLM_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + }); + }); TLLM_CUDA_KERNEL_LAUNCH_CHECK(); }); } @@ -357,12 +318,10 @@ void causal_conv1d_fwd_dispatch(ConvParamsBase& params, cudaStream_t stream) constexpr int kWideThreads = 128; constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; constexpr int kShortSeqThreshold = kNarrowThreads * kNElts; - // Varlen prefill launches one block per sequence/channel pair, so the per-sequence - // work is usually much smaller than params.seqlen suggests. That path also disables - // the wide vector-load specialization, so the 128-thread kernel tends to overprovision - // threads for many short chunks. Prefer the narrower launch for varlen and for short - // fixed-length inputs; keep the wider launch for long dense sequences. - bool const preferNarrowKernel = isVarlen || params.seqlen <= kShortSeqThreshold; + // Pick the wider 128-thread kernel when the average per-sequence length exceeds + // one chunk; otherwise the narrower 64-thread kernel avoids overprovisioning. + int const avgSeqlen = isVarlen ? (params.seqlen / max(params.batch, 1)) : params.seqlen; + bool const preferNarrowKernel = avgSeqlen <= kShortSeqThreshold; if (preferNarrowKernel) { @@ -406,7 +365,7 @@ struct Causal_conv1d_update_kernel_traits static_assert(kNBytes == 2 || kNBytes == 4); }; -template +template __global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_update_kernel(ConvParamsBase params) { constexpr int kWidth = Ktraits::kWidth; @@ -423,14 +382,18 @@ __global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_update_kerne input_t* x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + channel_id * params.x_c_stride; - // If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor - // along the batch axis. Otherwise, the conv state coordinate is the same as the batch id. - int const conv_state_batch_coord - = params.conv_state_indices_ptr == nullptr ? batch_id : params.conv_state_indices_ptr[batch_id]; - // conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early - if (conv_state_batch_coord == params.pad_slot_id) + int conv_state_batch_coord; + if constexpr (kHasConvStateIndices) { - return; + conv_state_batch_coord = params.conv_state_indices_ptr[batch_id]; + if (conv_state_batch_coord == params.pad_slot_id) + { + return; + } + } + else + { + conv_state_batch_coord = batch_id; } input_t* conv_state = reinterpret_cast(params.conv_state_ptr) + conv_state_batch_coord * params.conv_state_batch_stride + channel_id * params.conv_state_c_stride; @@ -506,7 +469,7 @@ __global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_update_kerne { out_val += weight_vals[j] * x_vals[j]; } - if (params.silu_activation) + if constexpr (kSiluActivation) { out_val = out_val / (1 + expf(-out_val)); } @@ -520,31 +483,119 @@ __global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_update_kerne } } +// Specialized kernel for the dominant decode case (seqlen=1, non-circular, silu). +// Drops the per-token loop and circular-buffer bookkeeping from the general kernel. +template +__global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_update_kernel_sl1(ConvParamsBase params) +{ + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + + int const tidx = threadIdx.x; + int const batch_id = blockIdx.x; + int const channel_id = blockIdx.y * kNThreads + tidx; + if (channel_id >= params.dim) + return; + + int conv_state_batch_coord; + if constexpr (kHasConvStateIndices) + { + conv_state_batch_coord = params.conv_state_indices_ptr[batch_id]; + if (conv_state_batch_coord == params.pad_slot_id) + return; + } + else + { + conv_state_batch_coord = batch_id; + } + + input_t* conv_state = reinterpret_cast(params.conv_state_ptr) + + conv_state_batch_coord * params.conv_state_batch_stride + channel_id * params.conv_state_c_stride; + weight_t* weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; + input_t* x + = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + channel_id * params.x_c_stride; + + float w[kWidth]; +#pragma unroll + for (int i = 0; i < kWidth; ++i) + w[i] = float(__ldg(&weight[i * params.weight_width_stride])); + + float s[kWidth]; +#pragma unroll + for (int i = 0; i < kWidth - 1; ++i) + s[i] = float(conv_state[i * params.conv_state_l_stride]); + s[kWidth - 1] = float(x[0]); + + float out_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); +#pragma unroll + for (int i = 0; i < kWidth; ++i) + out_val = __fmaf_rn(w[i], s[i], out_val); + out_val = out_val * __frcp_rn(1.0f + __expf(-out_val)); + x[0] = input_t(out_val); + + // Shift conv_state left by one and append the new token. +#pragma unroll + for (int i = 0; i < kWidth - 1; ++i) + conv_state[i * params.conv_state_l_stride] = input_t(s[i + 1]); +} + template void causal_conv1d_update_launch(ConvParamsBase& params, cudaStream_t stream) { using Ktraits = Causal_conv1d_update_kernel_traits; dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); - auto kernel = params.cache_seqlens == nullptr ? &causal_conv1d_update_kernel - : &causal_conv1d_update_kernel; - kernel<<>>(params); + bool const hasConvStateIndices = params.conv_state_indices_ptr != nullptr; + bool const isCircularBuffer = params.cache_seqlens != nullptr; + + // Fast path for the standard decode case (seqlen=1, non-circular, silu) when + // conv_state holds exactly width-1 elements (no extra trailing padding to shift). + if (params.seqlen == 1 && !isCircularBuffer && params.silu_activation && params.conv_state_len == params.width - 1) + { + BOOL_SWITCH(hasConvStateIndices, kHasCSI, + [&] + { + auto kernel = &causal_conv1d_update_kernel_sl1; + kernel<<>>(params); + }); + } + else + { + BOOL_SWITCH(isCircularBuffer, kIsCircBuf, + [&] + { + BOOL_SWITCH(hasConvStateIndices, kHasCSI, + [&] + { + BOOL_SWITCH(params.silu_activation, kSilu, + [&] + { + auto kernel = &causal_conv1d_update_kernel; + kernel<<>>(params); + }); + }); + }); + } TLLM_CUDA_KERNEL_LAUNCH_CHECK(); } template void causal_conv1d_update_cuda(ConvParamsBase& params, cudaStream_t stream) { + // Wider blocks (128 vs 64 threads) halve block count, reducing scheduling overhead. + constexpr int kNThreads = 128; if (params.width == 2) { - causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); + causal_conv1d_update_launch(params, stream); } else if (params.width == 3) { - causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); + causal_conv1d_update_launch(params, stream); } else if (params.width == 4) { - causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); + causal_conv1d_update_launch(params, stream); } } diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu index ad162658899a..e1fd9bc7f08a 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu @@ -60,12 +60,30 @@ using tensorrt_llm::common::launchWithPdlWhenEnabled; __VA_ARGS__; \ break; \ } \ + case 18: \ + { \ + constexpr int TOP_K = 18; \ + __VA_ARGS__; \ + break; \ + } \ case 16: \ { \ constexpr int TOP_K = 16; \ __VA_ARGS__; \ break; \ } \ + case 14: \ + { \ + constexpr int TOP_K = 14; \ + __VA_ARGS__; \ + break; \ + } \ + case 12: \ + { \ + constexpr int TOP_K = 12; \ + __VA_ARGS__; \ + break; \ + } \ case 10: \ { \ constexpr int TOP_K = 10; \ diff --git a/cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh b/cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh index c94ff267e5cf..efe0f396ae05 100644 --- a/cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh +++ b/cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh @@ -1,6 +1,6 @@ /* - * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2025-2026, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -123,8 +123,75 @@ struct TopKIdx topK[J].compValIdx = pairMin; \ } +template +struct IsPowerOf2 +{ + static constexpr bool value = (N > 0) && ((N & (N - 1)) == 0); +}; + template -struct Sort; +struct Sort +{ + static_assert(N > 0 && N <= 32, "Sort supports N in [1, 32]"); + + static __device__ void run(RedType* topK) + { + if constexpr (IsPowerOf2::value) + { +#pragma unroll + for (int k = 2; k <= N; k *= 2) + { +#pragma unroll + for (int j = k / 2; j > 0; j /= 2) + { +#pragma unroll + for (int i = 0; i < N; ++i) + { + int ixj = i ^ j; + if (ixj > i) + { + if ((i & k) == 0) + { + if (topK[i].compValIdx < topK[ixj].compValIdx) + { + auto tmp = topK[i].compValIdx; + topK[i].compValIdx = topK[ixj].compValIdx; + topK[ixj].compValIdx = tmp; + } + } + else + { + if (topK[i].compValIdx > topK[ixj].compValIdx) + { + auto tmp = topK[i].compValIdx; + topK[i].compValIdx = topK[ixj].compValIdx; + topK[ixj].compValIdx = tmp; + } + } + } + } + } + } + } + else + { +#pragma unroll + for (int pass = 0; pass < N; ++pass) + { +#pragma unroll + for (int i = 0; i < N - 1; i += 2) + { + TOPK_SWAP(i, i + 1); + } +#pragma unroll + for (int i = 1; i < N - 1; i += 2) + { + TOPK_SWAP(i, i + 1); + } + } + } + } +}; template struct Sort<1, RedType> @@ -170,28 +237,27 @@ __forceinline__ __device__ void reduceTopK(cg::thread_block_tile con int32_t (&outIdx)[K], Type value, int32_t idx, Type const minValue, int actualK = K) { static_assert(K > 0, "Top K must have K > 0"); - static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE"); + static_assert(K <= kWARP_SIZE, "Top K must have K <= kWARP_SIZE"); using RedType = TopKRedType; RedType topK{value, idx}; typename RedType::TypeCmp packedMax{}; #pragma unroll - for (int kk = 0; kk < actualK; ++kk) //@todo: check if actualK is correct + for (int kk = 0; kk < actualK; ++kk) { topK = kk > 0 && packedMax == topK.compValIdx ? RedType{minValue, idx} : topK; - // get the next largest value packedMax = topK.reduce(warp); RedType::unpack(out[kk], outIdx[kk], packedMax); } }; -template -__device__ void reduceTopKFunc(cg::thread_block_tile const& warp, Type (&out)[K], int32_t (&outIdx)[K], - Type (&value)[N], int32_t (&idx)[N], Type minValue, int actualK = K) +template +__forceinline__ __device__ void reduceTopK(cg::thread_block_tile const& warp, Type (&out)[K], + int32_t (&outIdx)[K], Type (&value)[N], int32_t (&idx)[N], Type const minValue, int actualK = K) { static_assert(K > 0, "Top K must have K > 0"); - static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE"); + static_assert(K <= kWARP_SIZE, "Top K must have K <= kWARP_SIZE"); static_assert(N > 0, "Top K must have N > 0"); - static_assert(N < 5, "Only support candidates number less than or equal to 128"); + static_assert(N <= 32, "Only support candidates number less than or equal to 32*32=1024"); using RedType = TopKRedType; RedType topK[N]; #pragma unroll @@ -200,12 +266,9 @@ __device__ void reduceTopKFunc(cg::thread_block_tile const& warp, Ty topK[nn] = RedType{value[nn], idx[nn]}; } - if constexpr (!IsSorted) - { - Sort::run(topK); - } + Sort::run(topK); + typename RedType::TypeCmp packedMax{}; -#pragma unroll for (int kk = 0; kk < actualK; ++kk) { bool update = kk > 0 && packedMax == topK[0].compValIdx; @@ -214,73 +277,11 @@ __device__ void reduceTopKFunc(cg::thread_block_tile const& warp, Ty { topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} : update ? topK[nn + 1] : topK[nn]; } - // get the next largest value packedMax = topK[0].reduce(warp); RedType::unpack(out[kk], outIdx[kk], packedMax); } }; -template -__forceinline__ __device__ void reduceTopK(cg::thread_block_tile const& warp, Type (&out)[K], - int32_t (&outIdx)[K], Type (&value)[N], int32_t (&idx)[N], Type const minValue, int actualK = K) -{ - static_assert(K > 0, "Top K must have K > 0"); - static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE"); - static_assert(N > 0, "Top K must have N > 0"); - static_assert(N <= 16, "Only support candidates number less than or equal to 16*32=512"); - static_assert( - N <= 4 || N % 4 == 0, "Only support candidates number is a multiple of 4*32=128 or less than or equal to 4"); - using RedType = TopKRedType; - - if constexpr (N <= 4) - { - reduceTopKFunc(warp, out, outIdx, value, idx, minValue, actualK); - } - else - { - - constexpr int numLoops = N / 4; - constexpr int numResults = (numLoops * K - 1) / kWARP_SIZE + 1; - - Type topKBufferValue[numResults]; - int32_t topKBufferIdx[numResults]; - int32_t laneIdx = threadIdx.x % kWARP_SIZE; - - for (int ii = 0; ii < numResults; ++ii) - { - topKBufferValue[ii] = minValue; - topKBufferIdx[ii] = ii * kWARP_SIZE - 1; //@todo: check if this is correct - } - for (int loop = 0; loop < numLoops; ++loop) - { - int start = loop * 4; - Type topKValue[K]; - int32_t topKIdx[K]; - Type inValue[4]; - int32_t inIdx[4]; - for (int i = 0; i < 4; ++i) - { - inValue[i] = value[start + i]; - inIdx[i] = idx[start + i]; - } - reduceTopKFunc(warp, topKValue, topKIdx, inValue, inIdx, minValue, actualK); - int inOffset = laneIdx % K; - if (laneIdx >= loop * K && laneIdx < (loop + 1) * K) - { - topKBufferValue[0] = topKValue[inOffset]; - topKBufferIdx[0] = topKIdx[inOffset]; - } - if (loop == numLoops - 1 && (laneIdx < (numLoops * K - kWARP_SIZE))) - { - topKBufferValue[1] = topKValue[inOffset]; - topKBufferIdx[1] = topKIdx[inOffset]; - } - } - - reduceTopKFunc(warp, out, outIdx, topKBufferValue, topKBufferIdx, minValue, actualK); - } -}; - #undef TOPK_SWAP } // namespace reduce_topk diff --git a/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu b/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu index 21f68c71824d..8256c4e6ca73 100644 --- a/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu +++ b/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2019-2026, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -32,15 +32,13 @@ TRTLLM_NAMESPACE_BEGIN namespace kernels { static constexpr int WARP_SIZE = 32; -static constexpr int NumNemotronExperts = 512; -static constexpr int NumKimiK2Experts = 384; static constexpr int NumDeepseekExperts = 256; -static constexpr int MaxSupportedExpertCount = std::max({NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts}); -static constexpr int MaxNumExpertsUnit = 128; +static constexpr int MaxSupportedExpertCount = 1024; static constexpr int NumTopGroupScores = 2; static constexpr int DefaultMaxNumTopExperts = 8; -static constexpr int MaxSupportedTopExperts = 22; -static constexpr int MaxNumTopGroups = 4; +static constexpr int MaxSupportedTopExperts = 32; +static constexpr int DefaultMaxNumTopGroups = 4; +static constexpr int LargeMaxNumTopGroups = 8; static __device__ inline float sigmoid_accurate(float x) { @@ -48,7 +46,7 @@ static __device__ inline float sigmoid_accurate(float x) } template + int MaxNumTopExperts = DefaultMaxNumTopExperts, int MaxNumTopGroups = DefaultMaxNumTopGroups> __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, IdxT* topkIndices, BiasT* routingBias, int64_t const numTokens, int64_t const numGroup, int64_t const topkGroup, int64_t const topk, int64_t const numExperts, int64_t const numExpertsPerGroup, double const routedScalingFactor) @@ -57,208 +55,137 @@ __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, Idx cudaGridDependencySynchronize(); #endif - // declare shared memory structure - // number of experts is bounded by number of threads __shared__ float __attribute((aligned(128))) smemScoreSigmoid[MaxNumExperts]; __shared__ float __attribute((aligned(128))) smemScoreBias[MaxNumExperts]; - // number of expert groups is bounded by number of warps - int constexpr NumWarps = MaxNumExperts / WARP_SIZE; - __shared__ float __attribute((aligned(128))) smemGroupScores[NumWarps]; - // needed for warp reduce auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); - // for the final reduction of weight norm, only some lanes need to participate int32_t laneIdx = threadIdx.x % WARP_SIZE; int32_t warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WARP_SIZE, 0); - if constexpr (UseGroups) - { - if (warpIdx >= numGroup) - { - return; - } - } - - // note that for invalid scores, we simply use a negative value: - // they work well even with the compacted format used in topK, and - // sigmoid / bias activated scores cannot be negative static constexpr float invalidScoreFloat = float{-INFINITY}; - const OutputT invalidScore = OutputT{invalidScoreFloat}; - // load bias already; each warp represents one expert group - auto threadExpert = threadIdx.x; - bool expertSelected = threadExpert < numExperts; - if constexpr (UseGroups) - { - threadExpert = warpIdx * numExpertsPerGroup + laneIdx; - expertSelected = laneIdx < numExpertsPerGroup; - } - - auto scoreIdx = int64_t{blockIdx.x} * int64_t{numExperts} + threadExpert; - auto biasVal = expertSelected ? static_cast(routingBias[threadExpert]) : invalidScoreFloat; topkValues += blockIdx.x * topk; topkIndices += blockIdx.x * topk; - // get our assigned thread score; each warp represents one expert group - float score = expertSelected ? static_cast(scores[scoreIdx]) : invalidScoreFloat; - auto scoreSigmoid = sigmoid_accurate(score); - // write the sigmoid score to shared for later use - if (expertSelected) + if constexpr (UseGroups) { - smemScoreSigmoid[threadExpert] = scoreSigmoid; - } + int constexpr NumWarps = MaxNumExperts / WARP_SIZE; + __shared__ float __attribute((aligned(128))) smemGroupScores[NumWarps]; - // get the score with bias - // note that with invalid values, because sigmoid is < 1 and bias is -1, - // we must get a negative value, which is smaller than any valid value - auto scoreBias = float{scoreSigmoid + float{biasVal}}; + if (warpIdx >= numGroup) + { + return; + } - if (expertSelected) - { - smemScoreBias[threadExpert] = scoreBias; - } + auto threadExpert = warpIdx * numExpertsPerGroup + laneIdx; + bool expertSelected = laneIdx < numExpertsPerGroup; - // registers for top group score reduction - float topExpGroupScores[NumTopGroupScores]; - [[maybe_unused]] int32_t topExpGroupIdx[NumTopGroupScores]; - float topGroups[MaxNumTopGroups]; // bound of numGroup - int32_t topGroupIdx[MaxNumTopGroups]; - float expertScoreGroup[MaxNumTopGroups]; - int32_t expertIdxGroup[MaxNumTopGroups]; - float topScores[MaxNumTopExperts]; // bound of topk - int32_t topExperts[MaxNumTopExperts]; + auto scoreIdx = int64_t{blockIdx.x} * int64_t{numExperts} + threadExpert; + auto biasVal = expertSelected ? static_cast(routingBias[threadExpert]) : invalidScoreFloat; + float score = expertSelected ? static_cast(scores[scoreIdx]) : invalidScoreFloat; + auto scoreSigmoid = sigmoid_accurate(score); + if (expertSelected) + { + smemScoreSigmoid[threadExpert] = scoreSigmoid; + } + auto scoreBias = float{scoreSigmoid + float{biasVal}}; + if (expertSelected) + { + smemScoreBias[threadExpert] = scoreBias; + } - if constexpr (UseGroups) - { + float topExpGroupScores[NumTopGroupScores]; + [[maybe_unused]] int32_t topExpGroupIdx[NumTopGroupScores]; reduce_topk::reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias, threadExpert, /* minValue */ invalidScoreFloat); - // get the final group score and write it to shared if (warp.thread_rank() == 0) { auto groupScore = topExpGroupScores[0] + topExpGroupScores[1]; smemGroupScores[warpIdx] = groupScore; } - } - // make group scores available to all warps - __syncthreads(); + __syncthreads(); + + float topScores[MaxNumTopExperts]; + int32_t topExperts[MaxNumTopExperts]; - if constexpr (UseGroups) - { if (warpIdx == 0) { - // a single warp performs the selection of top groups, and goes on to select the final experts + float topGroups[MaxNumTopGroups]; + int32_t topGroupIdx[MaxNumTopGroups]; float groupScore = laneIdx < numGroup ? smemGroupScores[laneIdx] : invalidScoreFloat; - reduce_topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx, /* minValue */ invalidScoreFloat); - // final expert selection: get relevant indexes and scores from shared + + float expertScoreGroup[MaxNumTopGroups]; + int32_t expertIdxGroup[MaxNumTopGroups]; #pragma unroll for (int ii = 0; ii < MaxNumTopGroups; ++ii) - { // bound of numGroup + { auto groupIdx = topGroupIdx[ii]; expertIdxGroup[ii] = groupIdx * numExpertsPerGroup + laneIdx; - expertScoreGroup[ii] = (ii < topkGroup) && expertSelected ? smemScoreBias[expertIdxGroup[ii]] : invalidScoreFloat; } - tensorrt_llm::kernels::reduce_topk::reduceTopK( + reduce_topk::reduceTopK( warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, /* minValue */ invalidScoreFloat, topk); - } - } - else if constexpr (MaxNumExperts > MaxNumExpertsUnit) - { - // without groups, and the expert number is larger than MaxNumExpertsUnit, - // we need to use multiple warps to calculate the intermediate topk results - - int constexpr NumExpertWarps = (MaxNumExperts - 1) / MaxNumExpertsUnit + 1; - int constexpr NumInterTopK = NumExpertWarps * MaxNumTopExperts; - __shared__ float __attribute((aligned(128))) smemInterTopScores[NumInterTopK]; - __shared__ int32_t __attribute((aligned(128))) smemInterTopExperts[NumInterTopK]; - if (warpIdx < NumExpertWarps) - { - int offset = warpIdx * WARP_SIZE * MaxNumTopGroups; -#pragma unroll - for (int ii = 0; ii < MaxNumTopGroups; ++ii) - { - auto expertIdx = ii * WARP_SIZE + laneIdx; - expertIdxGroup[ii] = offset + expertIdx; - expertScoreGroup[ii] - = offset + expertIdx < numExperts ? smemScoreBias[offset + expertIdx] : invalidScoreFloat; - } - reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, - /* minValue */ invalidScoreFloat, topk); + int32_t expertIdx = laneIdx < topk ? topExperts[laneIdx] : MaxNumExperts - 1; + float scoreNorm = laneIdx < topk ? smemScoreSigmoid[expertIdx] : 0.F; + auto redNorm = cg::reduce(warp, scoreNorm, cg::plus{}); + auto finalScore = static_cast(scoreNorm * routedScalingFactor / (redNorm + 1e-20)); if (laneIdx < topk) { - smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = topScores[laneIdx]; - smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = topExperts[laneIdx]; - } - else if (laneIdx >= topk && laneIdx < MaxNumTopExperts) - { - smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = invalidScoreFloat; - smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = MaxNumExperts - 1; - } - } - __syncthreads(); - if (warpIdx == 0) - { - int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WARP_SIZE + 1; - float intermediateScore[NumInterTopKPerThread]; - int32_t intermediateExpert[NumInterTopKPerThread]; - for (int i = laneIdx; i < NumInterTopKPerThread * WARP_SIZE; i += WARP_SIZE) - { - int ii = i / WARP_SIZE; - if (i < NumInterTopK) - { - intermediateScore[ii] = smemInterTopScores[i]; - intermediateExpert[ii] = smemInterTopExperts[i]; - } - else - { - intermediateScore[ii] = invalidScoreFloat; - intermediateExpert[ii] = MaxNumExperts - 1; - } + topkValues[laneIdx] = static_cast(finalScore); + topkIndices[laneIdx] = expertIdx; } - reduce_topk::reduceTopK(warp, topScores, topExperts, intermediateScore, intermediateExpert, - /* minValue */ invalidScoreFloat, topk); } } else { - // without groups, and the expert number is smaller than MaxNumExpertsUnit - // each thread just takes `MaxNumTopGroups` experts + for (int e = threadIdx.x; e < numExperts; e += blockDim.x) + { + auto scoreIdx = int64_t{blockIdx.x} * int64_t{numExperts} + e; + auto biasVal = static_cast(routingBias[e]); + float score = static_cast(scores[scoreIdx]); + auto scoreSigmoid = sigmoid_accurate(score); + smemScoreSigmoid[e] = scoreSigmoid; + smemScoreBias[e] = scoreSigmoid + biasVal; + } + + __syncthreads(); + + float topScores[MaxNumTopExperts]; + int32_t topExperts[MaxNumTopExperts]; + if (warpIdx == 0) { + constexpr int NumChunks = (MaxNumExperts + WARP_SIZE - 1) / WARP_SIZE; + float localScores[NumChunks]; + int32_t localIdx[NumChunks]; #pragma unroll - for (int ii = 0; ii < MaxNumTopGroups; ++ii) + for (int ii = 0; ii < NumChunks; ++ii) { auto expertIdx = ii * WARP_SIZE + laneIdx; - expertIdxGroup[ii] = expertIdx; - expertScoreGroup[ii] = expertIdx < numExperts ? smemScoreBias[expertIdx] : invalidScoreFloat; + localIdx[ii] = expertIdx; + localScores[ii] = expertIdx < numExperts ? smemScoreBias[expertIdx] : invalidScoreFloat; } - reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, + reduce_topk::reduceTopK(warp, topScores, topExperts, localScores, localIdx, /* minValue */ invalidScoreFloat, topk); - } - } - if (warpIdx == 0) - { - // determine our lane's expert index and write to output - int32_t expertIdx = laneIdx < topk ? topExperts[laneIdx] : MaxNumExperts - 1; - // norm the value - float scoreNorm = laneIdx < topk ? smemScoreSigmoid[expertIdx] : 0.F; - auto redNorm = cg::reduce(warp, scoreNorm, cg::plus{}); - auto finalScore = static_cast(scoreNorm * routedScalingFactor / (redNorm + 1e-20)); - // store the topk scores and experts to output - if (laneIdx < topk) - { - topkValues[laneIdx] = static_cast(finalScore); - topkIndices[laneIdx] = expertIdx; + int32_t expertIdx = laneIdx < topk ? topExperts[laneIdx] : MaxNumExperts - 1; + float scoreNorm = laneIdx < topk ? smemScoreSigmoid[expertIdx] : 0.F; + auto redNorm = cg::reduce(warp, scoreNorm, cg::plus{}); + auto finalScore = static_cast(scoreNorm * routedScalingFactor / (redNorm + 1e-20)); + if (laneIdx < topk) + { + topkValues[laneIdx] = static_cast(finalScore); + topkIndices[laneIdx] = expertIdx; + } } } @@ -272,43 +199,57 @@ void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk int64_t const num_experts, int64_t const n_group, int64_t const topk_group, int64_t const topk, double const routed_scaling_factor, cudaStream_t const stream) { - - // Check if we can use the optimized deepseek_v3_topk_kernel - bool const is_single_group = (n_group <= 1) && (num_experts <= MaxSupportedExpertCount); + bool const is_single_group + = (n_group <= 1) && (num_experts <= MaxSupportedExpertCount) && (topk <= MaxSupportedTopExperts); int64_t const experts_per_group = num_experts / n_group; bool const is_multi_group = (n_group > 1) && (num_experts <= NumDeepseekExperts) && (experts_per_group <= WARP_SIZE) - && (experts_per_group * topk_group <= MaxNumExpertsUnit); + && (topk <= DefaultMaxNumTopExperts) && (experts_per_group * topk_group <= LargeMaxNumTopGroups * WARP_SIZE); if (is_single_group || is_multi_group) { cudaLaunchConfig_t config; auto* kernel_instance = &deepseek_v3_topk_kernel; int num_threads = NumDeepseekExperts; - if (is_single_group) + + if (is_multi_group) { - // Special case for Nemotron, which selects top 22 from 512 experts, and 1 group only. - if (num_experts == NumNemotronExperts && n_group == 1 && topk == MaxSupportedTopExperts) + if (experts_per_group * topk_group <= DefaultMaxNumTopGroups * WARP_SIZE) { - kernel_instance = &deepseek_v3_topk_kernel; - num_threads = NumNemotronExperts; + kernel_instance = &deepseek_v3_topk_kernel; } - else if (num_experts > NumKimiK2Experts && num_experts <= MaxSupportedExpertCount) + else + { + kernel_instance = &deepseek_v3_topk_kernel; + } + num_threads = NumDeepseekExperts; + } + else if (is_single_group) + { + if (num_experts <= 128) { kernel_instance - = &deepseek_v3_topk_kernel; - num_threads = MaxSupportedExpertCount; + = &deepseek_v3_topk_kernel; + num_threads = 128; } - else if (num_experts > MaxNumExpertsUnit && num_experts <= NumKimiK2Experts) + else if (num_experts <= 256) { - kernel_instance = &deepseek_v3_topk_kernel; - num_threads = NumKimiK2Experts; + kernel_instance + = &deepseek_v3_topk_kernel; + num_threads = 256; + } + else if (num_experts <= 512) + { + kernel_instance + = &deepseek_v3_topk_kernel; + num_threads = 256; } else { - kernel_instance = &deepseek_v3_topk_kernel; - num_threads = MaxNumExpertsUnit; + kernel_instance + = &deepseek_v3_topk_kernel; + num_threads = 256; } } @@ -328,11 +269,10 @@ void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk } else { - // TODO: call the generic path (previous implementation) or signal unsupported config. TLLM_CHECK_WITH_INFO(false, - "invokeNoAuxTc: unsupported configuration (n_group=%ld, num_experts=%ld, topk_group=%ld). Please use " - "original pytorch implementation.", - n_group, num_experts, topk_group); + "invokeNoAuxTc: unsupported configuration (n_group=%ld, num_experts=%ld, topk_group=%ld, topk=%ld). " + "Please use original pytorch implementation.", + n_group, num_experts, topk_group, topk); } } diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.h index 7f8e2b06b00b..72996839037d 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.h +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.h @@ -233,100 +233,6 @@ namespace moe::dev TLLM_LOG_ERROR("Unsupported pair"); \ } -#define LAUNCH_TILEN(data, coopLaunch, types, kernel, numBlocks, numThreads, smemSize, stream) \ - if (data.mPaddingLog2 > 0) \ - { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(types, true), kernel, numBlocks, numThreads, smemSize, stream); \ - } \ - else \ - { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(types, false), kernel, numBlocks, numThreads, smemSize, stream); \ - } - -#define LAUNCH_ROUTING_LLAMA4(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \ - if (data.mDtypeExpW == tg::Dtype::Fp32) \ - { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(float, float, 128 /* Always 128 for llama4*/, 1 /* Always 1 for llama4*/), kernel, numBlocks, \ - numThreads, smemSize, stream); \ - } \ - else if (data.mDtypeExpW == tg::Dtype::Bfloat16) \ - { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, 128 /* Always 128 for llama4*/, 1 /* Always 1 for llama4*/), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } \ - else \ - { \ - TLLM_LOG_ERROR("Unsupported dtypeExpW"); \ - } - -#define LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, forceFloatInput, numExperts, numTopExperts) \ - if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag) \ - { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, true), kernel, numBlocks, \ - numThreads, smemSize, stream); \ - } \ - else if (data.mDtypeExpW == tg::Dtype::Fp32) \ - { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, false), kernel, numBlocks, \ - numThreads, smemSize, stream); \ - } \ - else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag && forceFloatInput) \ - { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, true), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } \ - else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag) \ - { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, true), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } \ - else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && forceFloatInput) \ - { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, false), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } \ - else if (data.mDtypeExpW == tg::Dtype::Bfloat16) \ - { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, false), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } \ - else \ - { \ - TLLM_LOG_ERROR("Unsupported dtypeExpW"); \ - } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define LAUNCH_ROUTING_WITH_NUM_EXPERTS( \ - data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, numExperts, numTopExperts) \ - if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag1) \ - { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, true), kernel, numBlocks, \ - numThreads, smemSize, stream); \ - } \ - else if (data.mDtypeExpW == tg::Dtype::Fp32) \ - { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, false), kernel, numBlocks, \ - numThreads, smemSize, stream); \ - } \ - else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag1) \ - { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, true), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } \ - else if (data.mDtypeExpW == tg::Dtype::Bfloat16) \ - { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, false), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } \ - else \ - { \ - TLLM_LOG_ERROR("Unsupported dtypeExpW"); \ - } - //////////////////////////////////////////////////////////////////////////////////////////////////// namespace activation { diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu deleted file mode 100644 index b59580f9f153..000000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "routingDeepSeek/RoutingDeepSeekCommon.cuh" - -namespace moe::dev::routing -{ -namespace routingDeepSeek -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// Forward declarations for split-compiled launch wrappers. -void launchMainKernel(Data& data, int numBlocks, int numThreadsMain, void* stream); -void launchInitExpertCounts(Data& data, int numThreadsHist, void* stream); -void launchClusterKernel(Data& data, int numThreadsHist, void* stream); -void launchCoopKernel(Data& data, int numBlocksCoop, int numThreadsHist, void* stream); -void launchHistogramKernel(Data& data, int numBlocksHistogram, int numThreadsHist, void* stream); -void launchOffsetsKernel(Data& data, int numBlocksOffsets, int numThreadsHist, void* stream); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run(Data& data, void* stream) -{ - TLLM_CHECK_WITH_INFO(data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr, - "Routing kernel requires at least one input parameter"); - if (data.mPtrTopKIds != nullptr) - { - TLLM_CHECK_WITH_INFO(data.mPtrTopKWeights != nullptr, - "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for DeepSeek routing."); - } - if (data.mPtrExpandedIdxToPermutedIdx != nullptr || data.mPtrPermutedIdxToExpandedIdx != nullptr - || data.mPtrPermutedIdxToTokenIdx != nullptr) - TLLM_CHECK_WITH_INFO( - (data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr) && data.mPtrPermutedIdxSize, - "If permuted index is required, `mPtrTopKPacked` or `mPtrTopKIds` is also required"); - TLLM_CHECK_WITH_INFO(!data.mUseRoutingSoftmax, "Routing with softmax not implemented yet"); - int const numBlocks = data.mNumTokens; - int const numThreadsHist = getMaxNumExperts(data.mNumExperts); - - bool const useSingleCluster = data.mNumTokens <= 1024; - if (!useSingleCluster) - { - // Reset the global histograms (not used in single-cluster code path). - // Cover both for the cooperative and two-kernel code paths. - TLLM_CHECK_WITH_INFO( - data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input."); - } - else - { - data.mPtrExpertCounts = nullptr; // Set it to nullptr for single-cluster code path, as it won't be used - } - - // Number of blocks we can use in the cooperative kernel - // The number of blocks must be: - // >= ⌈(numTokens * topK) / (MaxExpandedIdxPerThread * NumThreads)⌉ - // <= numSms, assuming an occupancy of 1 block/SM - // - // If too small for the given numTokens, fall back to the less performant two-step method. - // - // The upper bound is a strict requirement. The number of blocks should be determined by querying - // the device properties, or conservatively low. - static int const smCount = tensorrt_llm::common::getMultiProcessorCount(); - // WAR: Reserve 8 SMs for overlapping kernels. - int const numBlocksCoop = smCount - 8; - - // Maximum number of tokens supported by the kernel using a cooperative launch. - int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; - if (data.mPtrTopKIds == nullptr) - { - TLLM_CHECK_WITH_INFO(data.mNumExperts >= MaxSupportedTopExperts, - "Routing kernel expects %d to be at most #experts %d", MaxSupportedTopExperts, data.mNumExperts); - TLLM_CHECK_WITH_INFO(data.mNumExperts <= MaxSupportedExpertCount, - "Routing kernel expects #experts %d <= #threads %d", data.mNumExperts, MaxSupportedExpertCount); - TLLM_CHECK_WITH_INFO(data.mTopK <= MaxSupportedTopExperts, "Routing kernel expects topK experts <= %d, got %d", - MaxSupportedTopExperts, data.mTopK); - - // Routing needs to be executed - validate routing kernel constraints - if (data.mNumExpertGroups > 1) - { - // Note: Routing-specific constraints (experts per group, topK limits) are checked when routing is actually - // needed (data.mPtrTopKIds == nullptr) - TLLM_CHECK_WITH_INFO(data.mNumExpertGroups <= MaxNumGroups, - "Routing kernel expects #expert groups %d to be <= max groups %d", data.mNumExpertGroups, MaxNumGroups); - TLLM_CHECK_WITH_INFO(data.mNumExperts % data.mNumExpertGroups == 0, - "Routing kernel expects #experts %d to be a multiple of #expert groups %d", data.mNumExperts, - data.mNumExpertGroups); - TLLM_CHECK_WITH_INFO(data.mNumExperts / data.mNumExpertGroups <= WarpSize, - "Routing kernel expects #experts per group <= warp size (%d), got %d experts / %d groups = %d experts " - "per group", - WarpSize, data.mNumExperts, data.mNumExpertGroups, data.mNumExperts / data.mNumExpertGroups); - TLLM_CHECK_WITH_INFO(data.mNumLimitedGroups <= MaxNumTopGroups, - "Routing kernel expects <= %d top groups, got %d", MaxNumTopGroups, data.mNumLimitedGroups); - - TLLM_CHECK_WITH_INFO(data.mNumExpertGroups >= data.mNumLimitedGroups, - "Routing kernel expects top groups %d to be limited by #expert groups %d", data.mNumLimitedGroups, - data.mNumExpertGroups); - TLLM_CHECK_WITH_INFO(data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", - data.mNumExperts); - } - - int const numThreadsMain = max(data.mNumExpertGroups * WarpSize, getMaxNumExperts(data.mNumExperts)); - launchMainKernel(data, numBlocks, numThreadsMain, stream); - } - else - { - // Reset the global histograms. - launchInitExpertCounts(data, numThreadsHist, stream); - } - - if (data.mPtrPermutedIdxSize != nullptr) - { - if (useSingleCluster) - { - launchClusterKernel(data, numThreadsHist, stream); - } - else if (data.mNumTokens <= maxTokensCoop) - { - launchCoopKernel(data, numBlocksCoop, numThreadsHist, stream); - } - else - { - const int32_t expandedIdxSize = data.mNumTokens * data.mTopK; - const int32_t histogramEltsPerBlock = 8 * numThreadsHist; - const int32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist; - - // Limit grid size (both kernels use a grid-stride loop). - const int32_t maxNumBlocks = 1024; - - int const numBlocksHistogram - = std::min((expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks); - int const numBlocksOffsets - = std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); - - launchHistogramKernel(data, numBlocksHistogram, numThreadsHist, stream); - launchOffsetsKernel(data, numBlocksOffsets, numThreadsHist, stream); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingDeepSeek -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu deleted file mode 100644 index ff4bb808d92e..000000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "routingRenormalize/RoutingRenormalizeCommon.cuh" - -namespace moe::dev::routing -{ -namespace routingRenormalize -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// Forward declarations of per-kernel launch wrappers (defined in routingRenormalize/*.cu). -void launchBlockKernel(Data const& data, uint32_t numThreadsHist, void* stream); -void launchClusterKernel(Data const& data, void* stream); -void launchHistogramScoresKernel(Data const& data, uint32_t maxNumBlocks, uint32_t numThreadsHist, void* stream); -void launchInitExpertCounts(Data const& data, uint32_t numThreadsHist, void* stream); -void launchHistogramKernel(Data const& data, int numBlocksHistogram, uint32_t numThreadsHist, void* stream); -void launchOffsetsKernel(Data const& data, int numBlocksOffsets, uint32_t numThreadsHist, void* stream); - -//////////////////////////////////////////////////////////////////////////////////////////////////// -void run(Data const& data, void* stream) -{ - TLLM_CHECK_WITH_INFO(data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr, - "Routing kernel requires at least one input parameter"); - if (data.mPtrTopKIds != nullptr) - { - TLLM_CHECK_WITH_INFO(data.mPtrTopKWeights != nullptr, - "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for Renormalize routing."); - } - TLLM_CHECK_WITH_INFO(data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr - && data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr, - "Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers"); - TLLM_CHECK_WITH_INFO(data.mTopK <= MaxSupportedTopExperts, "Routing kernel expects topK experts <= %d, got %d", - MaxSupportedTopExperts, data.mTopK); - TLLM_CHECK_WITH_INFO(data.mNumExperts <= MaxSupportedExperts, - "Routing kernel expects #experts %d to be no more than %d", data.mNumExperts, MaxSupportedExperts); - // similar check - TLLM_CHECK_WITH_INFO( - data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); - - bool const useSingleBlock = data.mNumTokens <= BlockKernelMaxNumTokens - || (data.mNumTokens <= DynBlockKernelMaxNumTokens && data.mNumExperts <= DynBlockKernelMaxNumExperts); - - bool const useSingleCluster = data.mNumTokens <= ((data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr) - ? MaxNumTokensSingleClusterScores - : MaxNumTokensSingleCluster); - - if (!useSingleCluster && !useSingleBlock) - { - TLLM_CHECK_WITH_INFO((data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr), - "When #tokens is large, `mPtrTopKPacked` or `mPtrTopKIds` is a required input."); - TLLM_CHECK_WITH_INFO( - data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input."); - } - uint32_t const numThreadsHist = min(1024, getMaxNumExperts(data.mNumExperts)); - if (useSingleBlock) - { - //@TODO: For now we use the single block kernel for cases with token number no larger than 4. - // We will future tune this threshold based on the performance. - launchBlockKernel(data, numThreadsHist, stream); - } - else if (useSingleCluster) - { - launchClusterKernel(data, stream); - } - else - { - uint32_t const expandedIdxSize = data.mNumTokens * data.mTopK; - uint32_t const histogramEltsPerBlock = 8 * numThreadsHist; - uint32_t const offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist; - - // Limit grid size (all kernels use a grid-stride loop). - uint32_t const maxNumBlocks = 1024; - - int const numBlocksHistogram - = std::min((expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks); - int const numBlocksOffsets - = std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); - - if (data.mPtrScores != nullptr && data.mPtrTopKIds == nullptr) - { - launchHistogramScoresKernel(data, maxNumBlocks, numThreadsHist, stream); - } - else - { - // Reset the global histograms. - launchInitExpertCounts(data, numThreadsHist, stream); - } - launchHistogramKernel(data, numBlocksHistogram, numThreadsHist, stream); - launchOffsetsKernel(data, numBlocksOffsets, numThreadsHist, stream); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingRenormalize -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/IntFastDiv.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/IntFastDiv.h similarity index 100% rename from cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/IntFastDiv.h rename to cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/IntFastDiv.h diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchBlockKernel.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingCustom.cu similarity index 50% rename from cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchBlockKernel.cu rename to cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingCustom.cu index 2a4f9257aa9f..dc819f9a8d1c 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchBlockKernel.cu +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingCustom.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,13 +13,92 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "RoutingRenormalizeCommon.cuh" + +// Custom routing: entry point, kernel definitions, and launch wrappers. +// +// Kernel inventory: +// 1. routingIndicesBlockKernel — single-block fused kernel (≤4 tokens) +// 1b. routingIndicesDynBlockKernel — dynamic-block fused kernel (≤16 tokens, ≤512 experts) +// 2. routingIndicesClusterKernel — single-cluster fused kernel (≤256 tokens, SM90+) +// 3. routingIndicesHistogramScoresKernel — TopK + histogram from raw scores +// 4. routingIndicesCoopKernel — cooperative histogram + offsets (defined in RoutingKernel.cuh) +// 5. routingInitExpertCounts — zero expert counts (defined in RoutingKernel.cuh) +// 6. routingIndicesHistogramKernel — histogram from packed TopK (defined in RoutingKernel.cuh) +// 7. routingIndicesOffsetsKernel — prefix-scan + permutation (defined in RoutingKernel.cuh) + +#include "RoutingCustomPolicy.cuh" namespace moe::dev::routing { -namespace routingRenormalize +namespace routingCustom { +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Dual warp-level exclusive prefix scan over NumExpertWarps * 32 values. +// Scans val1 and val2 simultaneously while sharing the same two __syncthreads() barriers, +// reducing 4 barriers (two separate scans) to 2. +//////////////////////////////////////////////////////////////////////////////////////////////////// +template +__device__ __forceinline__ void warpExclusiveScan(int32_t val1, int32_t val2, int32_t laneIdx, int32_t warpIdx, + int32_t* warpTotals1, int32_t* warpTotals2, int32_t& prefix1, int32_t& prefix2, int32_t& totalSum1) +{ + static_assert(NumExpertWarps <= WarpSize, "NumExpertWarps must fit in one warp for the cross-warp scan"); + + int32_t inc1 = val1, inc2 = val2; +#pragma unroll + for (int j = 1; j < WarpSize; j *= 2) + { + int32_t n1 = __shfl_up_sync(0xffffffff, inc1, j); + int32_t n2 = __shfl_up_sync(0xffffffff, inc2, j); + if (laneIdx >= j) + { + inc1 += n1; + inc2 += n2; + } + } + + if (warpIdx < NumExpertWarps && laneIdx == WarpSize - 1) + { + warpTotals1[warpIdx] = inc1; + warpTotals2[warpIdx] = inc2; + } + __syncthreads(); + + if (warpIdx == 0) + { + int32_t wt1 = (laneIdx < NumExpertWarps) ? warpTotals1[laneIdx] : 0; + int32_t wt2 = (laneIdx < NumExpertWarps) ? warpTotals2[laneIdx] : 0; +#pragma unroll + for (int j = 1; j < NumExpertWarps; j *= 2) + { + int32_t n1 = __shfl_up_sync(0xffffffff, wt1, j); + int32_t n2 = __shfl_up_sync(0xffffffff, wt2, j); + if (laneIdx >= j) + { + wt1 += n1; + wt2 += n2; + } + } + if (laneIdx < NumExpertWarps) + { + warpTotals1[laneIdx] = wt1; + warpTotals2[laneIdx] = wt2; + } + } + __syncthreads(); + + totalSum1 = warpTotals1[NumExpertWarps - 1]; + int32_t wp1 = (warpIdx > 0 && warpIdx < NumExpertWarps) ? warpTotals1[warpIdx - 1] : 0; + int32_t wp2 = (warpIdx > 0 && warpIdx < NumExpertWarps) ? warpTotals2[warpIdx - 1] : 0; + prefix1 = inc1 - val1 + wp1; + prefix2 = inc2 - val2 + wp2; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// 1. Block kernel — single-block fused kernel for ≤4 tokens. +// Fuses TopK, histogram, prefix-scan, and permutation in one block. +// //////////////////////////////////////////////////////////////////////////////////////////////////// template @@ -29,7 +108,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa // types used in this kernel using OutputT = typename KernelParams::OutputT; using InputT = typename KernelParams::InputT; - using BaseType = std::conditional_t; + using BaseType = typename KernelParams::ExpertSelectPolicy::template BaseType; using TypePacked = PackedScoreIdx; static constexpr int MaxNumExperts = KernelParams::MaxNumExperts; // When MaxNumExperts > 1024, cap actual thread count at 1024 and let each thread handle @@ -63,7 +142,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // then wait on primary grid - if constexpr (KernelParams::UsePdl) + if (params.mUsePdl) { cudaGridDependencySynchronize(); } @@ -75,15 +154,16 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa { if (laneIdx < params.mTopK) { - auto expertIdx = params.mPtrTopKIds[warpIdx * params.mTopK + laneIdx]; - if (expertIdx != -1) + auto const expandedIdx = warpIdx * params.mTopK + laneIdx; + if (params.mPtrExpandedIdxToPermutedIdx != nullptr) { - int offset = warpIdx * MaxNumExperts + expertIdx; - smemKIdx[offset] = static_cast(laneIdx); + params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = int32_t{-1}; } - else + auto expertIdx = params.mPtrTopKIds[expandedIdx]; + if (expertIdx > -1 && expertIdx < params.mNumExperts) { - params.mPtrExpandedIdxToPermutedIdx[warpIdx * params.mTopK + laneIdx] = int32_t{-1}; + int offset = warpIdx * MaxNumExperts + expertIdx; + smemKIdx[offset] = static_cast(laneIdx); } } } @@ -91,18 +171,14 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa else if (params.mPtrScores != nullptr) { // in this case, each warp represents a token - BaseType score[VecSize]; - int32_t idx[VecSize]; - BaseType warpTopKScore[KernelParams::MaxNumTopExperts]; int32_t warpTopKExpertIdx[KernelParams::MaxNumTopExperts]; - BaseType minScore = BaseType{-INFINITY}; if (validToken) { - routingTopKExperts(warp, score, idx, warpTopKScore, warpTopKExpertIdx, laneIdx, - params.mNumExperts, params.mTopK, params.mPtrScores + scoreOffset, params.mNormTopkProb); + KernelParams::ExpertSelectPolicy::template apply( + warp, warpTopKScore, warpTopKExpertIdx, laneIdx, params.mNumExperts, params.mTopK, + params.mPtrScores + scoreOffset, params); if (laneIdx < params.mTopK) { @@ -115,6 +191,31 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa } } // end if (validToken) } + else if (params.mPtrTopKPacked != nullptr) + { + if (validToken) + { + if (laneIdx < params.mTopK) + { + auto const expandedIdx = warpIdx * params.mTopK + laneIdx; + if (params.mPtrExpandedIdxToPermutedIdx != nullptr) + { + params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = int32_t{-1}; + } + auto const scoreIdx = params.mPtrTopKPacked[expandedIdx]; + int const expertIdx = static_cast(scoreIdx.idx); + if (expertIdx >= 0 && expertIdx < params.mNumExperts) + { + int const offset = warpIdx * MaxNumExperts + expertIdx; + smemKIdx[offset] = static_cast(laneIdx); + if (params.mPtrTopKWeights != nullptr) + { + params.mPtrTopKWeights[expandedIdx] = static_cast(scoreIdx.score); + } + } + } + } + } __syncthreads(); // Each thread handles ExpertsPerThread contiguous experts. @@ -155,7 +256,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa #pragma unroll for (int e = 0; e < ExpertsPerThread; e++) { - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { numCtaPerExpert[e] = divUpLog2(accExpertCount[e], params.mPaddingLog2); } @@ -174,7 +275,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa #pragma unroll for (int e = 0; e < ExpertsPerThread; e++) { - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { tmpCountPerExpert[e] = divUpMulLog2(accExpertCount[e], params.mPaddingLog2); } @@ -205,7 +306,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa params.mPtrCtaIdxXyToBatchIdx[ctaOffsetPerExpert[e] + cta] = mappedLocalIdx; int32_t mnLimit1; int32_t mnLimit2; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { mnLimit1 = mulLog2(ctaOffsetPerExpert[e] + cta + 1, params.mPaddingLog2); mnLimit2 = mulLog2(ctaOffsetPerExpert[e], params.mPaddingLog2) + accExpertCount[e]; @@ -220,11 +321,10 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa } } - // at this point, we can write out padded count if (threadIdx.x == 0) { int32_t permutedIdxSize; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); } @@ -236,14 +336,6 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; } -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - // we can trigger the next kernel at this point - if constexpr (KernelParams::UsePdl) - { - cudaTriggerProgrammaticLaunchCompletion(); - } -#endif - for (int tokenIdx = 0; tokenIdx < params.mNumTokens; tokenIdx++) { #pragma unroll @@ -277,85 +369,43 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa } } } -} -//////////////////////////////////////////////////////////////////////////////////////////////////// -// Dual warp-level exclusive prefix scan over NumExpertWarps * 32 values. -// Scans val1 and val2 simultaneously while sharing the same two __syncthreads() barriers, -// reducing 4 barriers (two separate scans) to 2. -//////////////////////////////////////////////////////////////////////////////////////////////////// -template -__device__ __forceinline__ void warpExclusiveScan(int32_t val1, int32_t val2, int32_t laneIdx, int32_t warpIdx, - int32_t* warpTotals1, int32_t* warpTotals2, int32_t& prefix1, int32_t& prefix2, int32_t& totalSum1) -{ - static_assert(NumExpertWarps <= WarpSize, "NumExpertWarps must fit in one warp for the cross-warp scan"); - - int32_t inc1 = val1, inc2 = val2; -#pragma unroll - for (int j = 1; j < WarpSize; j *= 2) - { - int32_t n1 = __shfl_up_sync(0xffffffff, inc1, j); - int32_t n2 = __shfl_up_sync(0xffffffff, inc2, j); - if (laneIdx >= j) - { - inc1 += n1; - inc2 += n2; - } - } - - if (warpIdx < NumExpertWarps && laneIdx == WarpSize - 1) - { - warpTotals1[warpIdx] = inc1; - warpTotals2[warpIdx] = inc2; - } - __syncthreads(); - - if (warpIdx == 0) +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // Trigger the secondary kernel AFTER all global memory writes (including permutation indices). + // The downstream kernels depend on all routing outputs being visible. + if (params.mUsePdl) { - int32_t wt1 = (laneIdx < NumExpertWarps) ? warpTotals1[laneIdx] : 0; - int32_t wt2 = (laneIdx < NumExpertWarps) ? warpTotals2[laneIdx] : 0; -#pragma unroll - for (int j = 1; j < NumExpertWarps; j *= 2) - { - int32_t n1 = __shfl_up_sync(0xffffffff, wt1, j); - int32_t n2 = __shfl_up_sync(0xffffffff, wt2, j); - if (laneIdx >= j) - { - wt1 += n1; - wt2 += n2; - } - } - if (laneIdx < NumExpertWarps) - { - warpTotals1[laneIdx] = wt1; - warpTotals2[laneIdx] = wt2; - } + cudaTriggerProgrammaticLaunchCompletion(); } - __syncthreads(); +#endif +} - totalSum1 = warpTotals1[NumExpertWarps - 1]; - int32_t wp1 = (warpIdx > 0 && warpIdx < NumExpertWarps) ? warpTotals1[warpIdx - 1] : 0; - int32_t wp2 = (warpIdx > 0 && warpIdx < NumExpertWarps) ? warpTotals2[warpIdx - 1] : 0; - prefix1 = inc1 - val1 + wp1; - prefix2 = inc2 - val2 + wp2; +void launchBlockKernel(Data const& data, uint32_t numThreadsHist, void* stream) +{ + LAUNCH_ROUTING_CUSTOM(data, false, routingIndicesBlockKernel, 1, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream); } //////////////////////////////////////////////////////////////////////////////////////////////////// -// Dynamic-block routing kernel: uses a dynamic thread count and dynamic shared memory. // -// Compared to routingIndicesBlockKernel (which fixes blockDim = MaxExperts): -// 1. Thread count = min(max(numTokens*32, MaxExperts), 1024) so each token -// gets its own warp — eliminates the Phase-1 TopK batch loop for small batches. -// 2. Warp-level Hillis-Steele scan replaces CUB BlockScan, fusing two scans -// into one (2 barriers instead of 4) with no compile-time thread count dependency. -// 3. Dynamic shared memory enables flexible token counts (up to 16). +// 1b. Dynamic-block kernel — single-block with dynamic thread count and dynamic shared memory. +// +// Compared to routingIndicesBlockKernel (which fixes blockDim = MaxExperts): +// 1. Thread count = min(max(numTokens*32, MaxExperts), 1024) so each token +// gets its own warp — eliminates the Phase-1 TopK batch loop for small batches. +// 2. Warp-level Hillis-Steele scan replaces CUB BlockScan, fusing two scans +// into one (2 barriers instead of 4) with no compile-time thread count dependency. +// 3. Dynamic shared memory enables flexible token counts (up to 16). +// //////////////////////////////////////////////////////////////////////////////////////////////////// + template __global__ void routingIndicesDynBlockKernel(KernelParams params) { using OutputT = typename KernelParams::OutputT; using InputT = typename KernelParams::InputT; - using BaseType = std::conditional_t; + using BaseType = typename KernelParams::ExpertSelectPolicy::template BaseType; using TypePacked = PackedScoreIdx; static constexpr int MaxNumExperts = KernelParams::MaxNumExperts; static constexpr int NumThreadsExperts = MaxNumExperts <= 1024 ? MaxNumExperts : 1024; @@ -370,11 +420,6 @@ __global__ void routingIndicesDynBlockKernel(KernelParams params) int32_t const laneIdx = cutlass::arch::LaneId(); int32_t const numWarps = blockDim.x / WarpSize; - // Dynamic shared memory layout: - // [0 .. numSlots) : int8_t smemKIdx - // [numSlots .. numSlots*3) : int16_t smemOffset - // [aligned .. +NumExpertWarps] : int32_t warpTotals1 (scan: numCtaPerExpert) - // [+NumExpertWarps] : int32_t warpTotals2 (scan: tmpCountPerExpert) extern __shared__ char dynSmem[]; int const numSlots = params.mNumTokens * MaxNumExperts; int8_t* smemKIdx = reinterpret_cast(dynSmem); @@ -387,8 +432,6 @@ __global__ void routingIndicesDynBlockKernel(KernelParams params) auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); - // Initialize smemKIdx only — smemOffset is only read when kIdx >= 0, - // which implies Phase 2 has already written it (no init needed). for (int i = threadIdx.x; i < numSlots; i += blockDim.x) { smemKIdx[i] = int8_t{-1}; @@ -396,41 +439,40 @@ __global__ void routingIndicesDynBlockKernel(KernelParams params) __syncthreads(); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - if constexpr (KernelParams::UsePdl) + if (params.mUsePdl) { cudaGridDependencySynchronize(); } #endif - // ── Phase 1: TopK — one warp per token (loop only when numTokens > numWarps) ── + // Phase 1: TopK — one warp per token (loop only when numTokens > numWarps) for (int tokenIdx = warpIdx; tokenIdx < params.mNumTokens; tokenIdx += numWarps) { if (params.mPtrTopKIds != nullptr) { if (laneIdx < params.mTopK) { - auto expertIdx = params.mPtrTopKIds[tokenIdx * params.mTopK + laneIdx]; - if (expertIdx > -1 && expertIdx < params.mNumExperts) + auto const expandedIdx = tokenIdx * params.mTopK + laneIdx; + if (params.mPtrExpandedIdxToPermutedIdx != nullptr) { - smemKIdx[tokenIdx * MaxNumExperts + expertIdx] = static_cast(laneIdx); + params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = int32_t{-1}; } - else + auto expertIdx = params.mPtrTopKIds[expandedIdx]; + if (expertIdx > -1 && expertIdx < params.mNumExperts) { - params.mPtrExpandedIdxToPermutedIdx[tokenIdx * params.mTopK + laneIdx] = int32_t{-1}; + smemKIdx[tokenIdx * MaxNumExperts + expertIdx] = static_cast(laneIdx); } } } else if (params.mPtrScores != nullptr) { - BaseType score[VecSize]; - int32_t idx[VecSize]; BaseType warpTopKScore[KernelParams::MaxNumTopExperts]; int32_t warpTopKExpertIdx[KernelParams::MaxNumTopExperts]; auto scoreOff = tokenIdx * params.mNumExperts; - routingTopKExperts(warp, score, idx, warpTopKScore, warpTopKExpertIdx, laneIdx, - params.mNumExperts, params.mTopK, params.mPtrScores + scoreOff, params.mNormTopkProb); + KernelParams::ExpertSelectPolicy::template apply( + warp, warpTopKScore, warpTopKExpertIdx, laneIdx, params.mNumExperts, params.mTopK, + params.mPtrScores + scoreOff, params); if (laneIdx < params.mTopK) { @@ -445,19 +487,27 @@ __global__ void routingIndicesDynBlockKernel(KernelParams params) { if (laneIdx < params.mTopK) { - auto expandedIdx = tokenIdx * params.mTopK + laneIdx; + auto const expandedIdx = tokenIdx * params.mTopK + laneIdx; + if (params.mPtrExpandedIdxToPermutedIdx != nullptr) + { + params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = int32_t{-1}; + } auto scoreIdx = params.mPtrTopKPacked[expandedIdx]; - smemKIdx[tokenIdx * MaxNumExperts + static_cast(scoreIdx.idx)] = static_cast(laneIdx); - if (params.mPtrTopKWeights != nullptr) + int const expertIdx = static_cast(scoreIdx.idx); + if (expertIdx >= 0 && expertIdx < params.mNumExperts) { - params.mPtrTopKWeights[expandedIdx] = static_cast(scoreIdx.score); + smemKIdx[tokenIdx * MaxNumExperts + expertIdx] = static_cast(laneIdx); + if (params.mPtrTopKWeights != nullptr) + { + params.mPtrTopKWeights[expandedIdx] = static_cast(scoreIdx.score); + } } } } } __syncthreads(); - // ── Phase 2: Histogram — each expert-thread counts tokens assigned to its expert(s) ── + // Phase 2: Histogram int accExpertCount[ExpertsPerThread]; if (threadIdx.x < NumThreadsExperts) { @@ -493,7 +543,7 @@ __global__ void routingIndicesDynBlockKernel(KernelParams params) } } - // ── Phase 3: Prefix-scan (merged dual warp-level scan, 2 barriers instead of 4) ── + // Phase 3: Prefix-scan (merged dual warp-level scan, 2 barriers instead of 4) int32_t numCtaPerExpert[ExpertsPerThread]; int32_t tmpCountPerExpert[ExpertsPerThread]; int32_t ctaOffsetPerExpert[ExpertsPerThread]; @@ -505,7 +555,7 @@ __global__ void routingIndicesDynBlockKernel(KernelParams params) { if (threadIdx.x < NumThreadsExperts) { - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { numCtaPerExpert[e] = divUpLog2(accExpertCount[e], params.mPaddingLog2); tmpCountPerExpert[e] = divUpMulLog2(accExpertCount[e], params.mPaddingLog2); @@ -546,7 +596,7 @@ __global__ void routingIndicesDynBlockKernel(KernelParams params) } } - // ── Phase 4: CTA configs ── + // Phase 4: CTA configs if (threadIdx.x < NumThreadsExperts) { #pragma unroll @@ -564,7 +614,7 @@ __global__ void routingIndicesDynBlockKernel(KernelParams params) = (expert - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; params.mPtrCtaIdxXyToBatchIdx[ctaOffsetPerExpert[e] + cta] = mappedLocalIdx; int32_t mnLimit1, mnLimit2; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { mnLimit1 = mulLog2(ctaOffsetPerExpert[e] + cta + 1, params.mPaddingLog2); mnLimit2 = mulLog2(ctaOffsetPerExpert[e], params.mPaddingLog2) + accExpertCount[e]; @@ -583,7 +633,7 @@ __global__ void routingIndicesDynBlockKernel(KernelParams params) if (threadIdx.x == 0) { int32_t permutedIdxSize; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); } @@ -596,13 +646,13 @@ __global__ void routingIndicesDynBlockKernel(KernelParams params) } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - if constexpr (KernelParams::UsePdl) + if (params.mUsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } #endif - // ── Phase 5: Permutation ── + // Phase 5: Permutation if (threadIdx.x < NumThreadsExperts) { for (int tokenIdx = 0; tokenIdx < params.mNumTokens; tokenIdx++) @@ -641,30 +691,368 @@ __global__ void routingIndicesDynBlockKernel(KernelParams params) } } +void launchDynBlockKernel(Data const& data, uint32_t numThreadsHist, void* stream) +{ + int32_t const maxExperts = queryDispatchedMaxExperts(data); + int const numSlots = data.mNumTokens * maxExperts; + int const smemSize + = numSlots + numSlots * 2 + 128 + 2 * (maxExperts / WarpSize) * static_cast(sizeof(int32_t)); + int const threads = std::min(std::max(data.mNumTokens * static_cast(WarpSize), maxExperts), 1024); + + LAUNCH_ROUTING_CUSTOM(data, false, routingIndicesDynBlockKernel, 1, threads, smemSize, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// 2. Cluster kernel — single-cluster fused kernel for ≤256 tokens (SM90+). +// Uses distributed shared memory across 8 blocks in a cluster. +// //////////////////////////////////////////////////////////////////////////////////////////////////// -void launchBlockKernel(Data const& data, uint32_t numThreadsHist, void* stream) +template +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(NumThreads) + routingIndicesClusterKernel(KernelParams params) { - if (data.mNumTokens <= DynBlockKernelMaxNumTokens && data.mNumExperts <= DynBlockKernelMaxNumExperts) + using OutputT = typename KernelParams::OutputT; + using InputT = typename KernelParams::InputT; + using BaseType = typename KernelParams::ExpertSelectPolicy::template BaseType; + using TypePacked = PackedScoreIdx; + static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; + + __shared__ TypePacked __attribute((aligned(128))) smemPackedScoreIdx[NumWarps * KernelParams::MaxNumTopExperts]; + + uint32_t const clusterBlockRank = blockIdx.x; + int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); + int32_t const laneIdx = cutlass::arch::LaneId(); + auto warpTokenIdx = clusterBlockRank * NumWarps + warpIdx; + auto scoreOffset = warpTokenIdx * params.mNumExperts; + bool validToken = warpTokenIdx < params.mNumTokens; + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + if (params.mUsePdl) { - int32_t const maxExperts = getMaxNumExperts(data.mNumExperts); - int const numSlots = data.mNumTokens * maxExperts; - int const smemSize - = numSlots + numSlots * 2 + 128 + 2 * (maxExperts / WarpSize) * static_cast(sizeof(int32_t)); - int const threads = std::min(std::max(data.mNumTokens * static_cast(WarpSize), maxExperts), 1024); + cudaGridDependencySynchronize(); + } - LAUNCH_ROUTING_RENORMALIZE( - data, false, routingIndicesDynBlockKernel, 1, threads, smemSize, stream, data.mDoSoftmaxBeforeTopK); + if (params.mPtrScores != nullptr) + { + BaseType warpTopKScore[KernelParams::MaxNumTopExperts]; + int32_t warpTopKExpertIdx[KernelParams::MaxNumTopExperts]; + if (validToken) + { + KernelParams::ExpertSelectPolicy::template apply( + warp, warpTopKScore, warpTopKExpertIdx, laneIdx, params.mNumExperts, params.mTopK, + params.mPtrScores + scoreOffset, params); + if (laneIdx < params.mTopK) + { + smemPackedScoreIdx[warpIdx * params.mTopK + laneIdx] + = TypePacked{warpTopKScore[laneIdx], static_cast(warpTopKExpertIdx[laneIdx])}; + } + } + } + + __cluster_barrier_arrive(); + __cluster_barrier_wait(); + + if (params.mPtrScores != nullptr) + { + routingPermutation(params, smemPackedScoreIdx, warpIdx, clusterBlockRank); } else { - LAUNCH_ROUTING_RENORMALIZE(data, false, routingIndicesBlockKernel, 1, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK); + routingPermutation(params, smemPackedScoreIdx, warpIdx, clusterBlockRank); + } +} +#else +__global__ void __launch_bounds__(NumThreads) routingIndicesClusterKernel(KernelParams /* params */) +{ + assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures"); +} +#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + +void launchClusterKernel(Data const& data, void* stream) +{ + LAUNCH_ROUTING_CUSTOM(data, false, routingIndicesClusterKernel, NumBlocksPerCluster, NumThreads, + /*smemSize=*/0, // No dynamic smem + stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// 3. HistogramScores kernel — computes TopK from raw scores and initializes expert counts. +// Used as step 1 of the multi-kernel pipeline when input is raw logits. +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelParams::MaxNumExperts : 1024) + routingIndicesHistogramScoresKernel(KernelParams params) +{ + using OutputT = typename KernelParams::OutputT; + using InputT = typename KernelParams::InputT; + using BaseType = typename KernelParams::ExpertSelectPolicy::template BaseType; + // Cap actual thread count at 1024 when MaxNumExperts > 1024. + static constexpr int NumThreadsBlock = KernelParams::MaxNumExperts <= 1024 ? KernelParams::MaxNumExperts : 1024; + + // VecSize stays based on MaxNumExperts — each warp still processes all experts for one token. + static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; + + int32_t const laneIdx = cutlass::arch::LaneId(); + int32_t const warpIdx = threadIdx.x / WarpSize; + // Use NumThreadsBlock (actual thread count) for grid-stride warp/thread addressing + int32_t const globalWarpIdx = blockIdx.x * NumThreadsBlock / WarpSize + warpIdx; + int32_t const globalWarpStride = gridDim.x * NumThreadsBlock / WarpSize; + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // Wait on primary grid. + if (params.mUsePdl) + { + cudaGridDependencySynchronize(); + } +#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + + // initialize the mPtrExpertCounts — use NumThreadsBlock for grid-stride + int32_t expertCountsNum = 2 * params.mNumExperts; + int32_t globalThreadIdx = blockIdx.x * NumThreadsBlock + threadIdx.x; + int32_t globalThreadStride = gridDim.x * NumThreadsBlock; + initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0); + + // in this case, each warp represents a token, and we use a grid-stride loop + // over all warps/tokens + BaseType warpTopKScore[KernelParams::MaxNumTopExperts]; + int32_t warpTopKExpertIdx[KernelParams::MaxNumTopExperts]; + for (int tokenIdx = globalWarpIdx; tokenIdx < params.mNumTokens; tokenIdx += globalWarpStride) + { + auto scoreOffset = tokenIdx * params.mNumExperts; + + KernelParams::ExpertSelectPolicy::template apply( + warp, warpTopKScore, warpTopKExpertIdx, laneIdx, params.mNumExperts, params.mTopK, + params.mPtrScores + scoreOffset, params); + + if (laneIdx < params.mTopK) + { + PackedScoreIdx packedScore{ + static_cast(warpTopKScore[laneIdx]), static_cast(warpTopKExpertIdx[laneIdx])}; + params.mPtrTopKPacked[tokenIdx * params.mTopK + laneIdx] = packedScore; + } + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // Trigger secondary kernel AFTER writing all packed scores, so the next kernel + // (routingIndicesHistogramKernel) sees the completed mPtrTopKPacked writes. + if (params.mUsePdl) + { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +} + +static void launchHistogramScoresKernel(Data const& data, uint32_t maxNumBlocks, uint32_t numThreadsHist, void* stream) +{ + LAUNCH_ROUTING_CUSTOM(data, false, routingIndicesHistogramScoresKernel, maxNumBlocks, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// 4. Coop kernel — cooperative histogram + offsets via grid-sync. +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void launchCoopKernel(Data const& data, int numBlocksCoop, uint32_t numThreadsHist, void* stream) +{ + if (data.mNumExperts <= NumExperts128Experts) + { + LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, numThreadsHist, + /*smemSize=*/0, stream, NoOpPreprocess, NoOpPostprocess, NumExperts128Experts, NumTop8Experts); + } + else if (data.mNumExperts <= NumExperts160Experts) + { + LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, numThreadsHist, + /*smemSize=*/0, stream, NoOpPreprocess, NoOpPostprocess, NumExperts160Experts, NumTop8Experts); + } + else if (data.mNumExperts <= NumExperts256Experts) + { + LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, numThreadsHist, + /*smemSize=*/0, stream, NoOpPreprocess, NoOpPostprocess, NumExperts256Experts, NumTop8Experts); + } + else if (data.mNumExperts <= NumExperts384Experts) + { + LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, numThreadsHist, + /*smemSize=*/0, stream, NoOpPreprocess, NoOpPostprocess, NumExperts384Experts, NumTop8Experts); + } + else if (data.mNumExperts <= NumExperts512Experts) + { + LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, numThreadsHist, + /*smemSize=*/0, stream, NoOpPreprocess, NoOpPostprocess, NumExperts512Experts, NumTop8Experts); + } + else if (data.mNumExperts <= NumExperts576Experts) + { + LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, numThreadsHist, + /*smemSize=*/0, stream, NoOpPreprocess, NoOpPostprocess, NumExperts576Experts, NumTop8Experts); + } + else if (data.mNumExperts <= NumExperts1024Experts) + { + LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, numThreadsHist, + /*smemSize=*/0, stream, NoOpPreprocess, NoOpPostprocess, NumExperts1024Experts, NumTop8Experts); + } + else + { + TLLM_LOG_ERROR("Coop kernel does not support numExperts > %d", NumExperts1024Experts); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// 5-7. Launch wrappers for shared kernels (defined in RoutingKernel.cuh): +// - InitExpertCounts (zero expert counts) +// - Histogram kernel (histogram from packed TopK) +// - Offsets kernel (prefix-scan + permutation) +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void launchInitExpertCounts(Data const& data, uint32_t numThreadsHist, void* stream) +{ + LAUNCH_ROUTING_CUSTOM_NO_POLICY(data, false, routingInitExpertCounts, + (2 * data.mNumExperts - 1) / numThreadsHist + 1, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream); +} + +void launchHistogramKernel(Data const& data, int numBlocksHistogram, uint32_t numThreadsHist, void* stream) +{ + LAUNCH_ROUTING_CUSTOM_NO_POLICY(data, false, routingIndicesHistogramKernel, numBlocksHistogram, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream); +} + +void launchOffsetsKernel(Data const& data, int numBlocksOffsets, uint32_t numThreadsHist, void* stream) +{ + LAUNCH_ROUTING_CUSTOM_NO_POLICY(data, false, routingIndicesOffsetsKernel, numBlocksOffsets, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Entry point +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run(Data const& data, void* stream) +{ + TLLM_CHECK_WITH_INFO(data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr, + "Routing kernel requires at least one input parameter"); + + // When topK is already computed (mPtrTopKIds or mPtrTopKPacked without scores), + // delegate to the shared post-topK pipeline which handles all path selection + // (single-block, single-cluster, coop, multi-kernel) automatically. + // No routing-method-specific logic needed. + if (data.mPtrTopKIds != nullptr || (data.mPtrTopKPacked != nullptr && data.mPtrScores == nullptr)) + { + if (data.mPtrTopKIds != nullptr) + { + TLLM_CHECK_WITH_INFO(data.mPtrTopKWeights != nullptr, + "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for custom routing."); + } + uint32_t const numThreadsHist = min(1024, getMaxNumExperts(data.mNumExperts)); + runPostTopKPipeline(data, numThreadsHist, stream); + return; + } + + // After this point, input is mPtrScores (raw logits that need topK computation). + TLLM_CHECK_WITH_INFO(data.mPtrScores != nullptr, "Expected mPtrScores to be non-null at this point."); + TLLM_CHECK_WITH_INFO(data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr + && data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr, + "Custom routing kernel expects permuted idx and grouped Gemm launch config buffers"); + TLLM_CHECK_WITH_INFO(data.mTopK <= MaxSupportedTopExperts, "Routing kernel expects topK experts <= %d, got %d", + MaxSupportedTopExperts, data.mTopK); + TLLM_CHECK_WITH_INFO(data.mNumExperts <= MaxSupportedExperts, + "Routing kernel expects #experts %d to be no more than %d", data.mNumExperts, MaxSupportedExperts); + TLLM_CHECK_WITH_INFO( + data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); + + static int const smMajor = tensorrt_llm::common::getSMVersion() / 10; + + bool const useStaticBlock = data.mNumTokens <= BlockKernelMaxNumTokens; + int32_t const dispatchedMaxExperts = queryDispatchedMaxExperts(data); + bool const useDynBlock = !useStaticBlock && data.mNumTokens <= DynBlockKernelMaxNumTokens + && dispatchedMaxExperts <= DynBlockKernelMaxNumExperts; + bool const useSingleBlock = useStaticBlock || useDynBlock; + bool const useSingleCluster = (smMajor >= 9) && (data.mNumTokens <= MaxNumTokensSingleClusterScores); + + if (!useSingleCluster && !useSingleBlock) + { + TLLM_CHECK_WITH_INFO( + data.mPtrTopKPacked != nullptr, "When #tokens is large, `mPtrTopKPacked` is a required input."); + TLLM_CHECK_WITH_INFO( + data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input."); + } + + uint32_t const numThreadsHist = min(1024, getMaxNumExperts(data.mNumExperts)); + + Data lastKernelData = data; + + if (useDynBlock) + { + launchDynBlockKernel(lastKernelData, numThreadsHist, stream); + } + else if (useStaticBlock) + { + launchBlockKernel(lastKernelData, numThreadsHist, stream); + } + else if (useSingleCluster) + { + launchClusterKernel(lastKernelData, stream); + } + else + { + uint32_t const maxNumBlocks = 1024; + + launchHistogramScoresKernel(data, maxNumBlocks, numThreadsHist, stream); + + bool const canUseCoop = (smMajor >= 9) && (data.mNumExperts <= 1024) && (data.mPtrPermutedIdxSize != nullptr); + bool useCoop = false; + int numBlocksCoop = 0; + + if (canUseCoop) + { + static int const smCount = tensorrt_llm::common::getMultiProcessorCount(); + numBlocksCoop = smCount - 8; + int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; + useCoop = (data.mNumTokens <= maxTokensCoop); + } + + if (useCoop) + { + launchInitExpertCounts(data, numThreadsHist, stream); + launchCoopKernel(lastKernelData, numBlocksCoop, numThreadsHist, stream); + } + else + { + uint32_t const expandedIdxSize = data.mNumTokens * data.mTopK; + uint32_t const histogramEltsPerBlock = 8 * numThreadsHist; + uint32_t const offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist; + + int const numBlocksHistogram + = std::min((expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks); + int const numBlocksOffsets + = std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); + + launchHistogramKernel(data, numBlocksHistogram, numThreadsHist, stream); + launchOffsetsKernel(lastKernelData, numBlocksOffsets, numThreadsHist, stream); + } } } //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace routingRenormalize +} // namespace routingCustom } // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingCustomPolicy.cuh b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingCustomPolicy.cuh new file mode 100644 index 000000000000..ddea2e28f1e1 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingCustomPolicy.cuh @@ -0,0 +1,788 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "RoutingKernel.cuh" + +namespace moe::dev::routing +{ + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Preprocess policies: applied to all expert scores BEFORE topK selection. +// +// Each policy must provide: +// - template using BaseType +// The data type used for intermediate score computation. +// - template struct Params { void set(Data const&); } +// Policy-specific runtime data, populated from the host-side Data struct. +// Empty for policies that don't need extra data (zero register cost). +// - template +// static void apply(warp, score[VecSize], idx[VecSize], numExperts, params) +// Transforms scores in-place before topK selection. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// No-op: scores are passed through unchanged. +struct NoOpPreprocess +{ + /// BaseType: when no preprocess is applied, use the input type directly. + template + using BaseType = InputT; + + template + struct Params + { + void set(routingCustom::Data const& /*data*/) {} + }; + + template + __forceinline__ __device__ static void apply(cg::thread_block_tile const& /*warp*/, + DataType (& /*score*/)[VecSize], int32_t const (& /*idx*/)[VecSize], int32_t /*numExperts*/, + ParamsT const& /*params*/) + { + } +}; + +/// Softmax: applies softmax over all expert scores before topK selection. +struct SoftmaxPreprocess +{ + /// BaseType: softmax is always computed in float for numerical stability. + template + using BaseType = float; + + template + struct Params + { + void set(routingCustom::Data const& /*data*/) {} + }; + + template + __forceinline__ __device__ static void apply(cg::thread_block_tile const& warp, + DataType (&score)[VecSize], int32_t const (& /*idx*/)[VecSize], int32_t /*numExperts*/, + ParamsT const& /*params*/) + { + calcSoftmax(warp, score); + } +}; + +/// Sigmoid: applies sigmoid(score) for topK selection (no bias). +struct SigmoidPreprocess +{ + /// BaseType: sigmoid is computed in float for numerical stability. + template + using BaseType = float; + + template + struct Params + { + void set(routingCustom::Data const& /*data*/) {} + }; + + template + __forceinline__ __device__ static void apply(cg::thread_block_tile const& /*warp*/, + DataType (&score)[VecSize], int32_t const (&idx)[VecSize], int32_t numExperts, ParamsT const& /*params*/) + { +#pragma unroll + for (int i = 0; i < VecSize; i++) + { + float s = sigmoid_accurate(static_cast(score[i])); + score[i] = idx[i] < numExperts ? static_cast(s) : DataType{-INFINITY}; + } + } +}; + +/// SigmoidBias: applies sigmoid(score) + bias[expertIdx] for topK selection. +/// Used by DeepSeek-style routing where expert selection is based on biased sigmoid scores. +struct SigmoidBiasPreprocess +{ + /// BaseType: sigmoid is computed in float for numerical stability. + template + using BaseType = float; + + template + struct Params + { + // Store as void const* to support any bias dtype (float, bfloat16, etc.) without conversion. + void const* ptrRoutingBias = nullptr; + batchedGemm::trtllm::gen::Dtype dtypeBias = batchedGemm::trtllm::gen::Dtype::Bfloat16; + + void set(routingCustom::Data const& data) + { + ptrRoutingBias = data.mPtrRoutingBias; + dtypeBias = data.mDtypeBias; + } + }; + + template + __forceinline__ __device__ static void apply(cg::thread_block_tile const& /*warp*/, + DataType (&score)[VecSize], int32_t const (&idx)[VecSize], int32_t numExperts, ParamsT const& params) + { +#pragma unroll + for (int i = 0; i < VecSize; i++) + { + float s = sigmoid_accurate(static_cast(score[i])); + float bias + = idx[i] < numExperts ? loadScalar(params.ptrRoutingBias, idx[i], params.dtypeBias) : float{-INFINITY}; + score[i] = static_cast(s + bias); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Postprocess policies: applied to the top-K scores AFTER topK selection. +// +// Each policy must provide: +// - template struct Params { void set(Data const&); } +// Policy-specific runtime data. Empty when not needed. +// - template +// static void apply(warp, warpTopKScore[K], warpTopKExpertIdx[K], laneIdx, topK, params) +// Transforms top-K scores in-place after topK selection. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// No-op: top-K scores are left unchanged. +struct NoOpPostprocess +{ + template + struct Params + { + void set(routingCustom::Data const& /*data*/) {} + }; + + template + __forceinline__ __device__ static void apply(cg::thread_block_tile const& /*warp*/, + DataType (& /*warpTopKScore*/)[K], int32_t const (& /*warpTopKExpertIdx*/)[K], int32_t /*laneIdx*/, + int32_t /*topK*/, ParamsT const& /*params*/) + { + } +}; + +/// Softmax: applies softmax over the top-K scores. +struct SoftmaxPostprocess +{ + template + struct Params + { + void set(routingCustom::Data const& /*data*/) {} + }; + + template + __forceinline__ __device__ static void apply(cg::thread_block_tile const& warp, + DataType (&warpTopKScore)[K], int32_t const (& /*warpTopKExpertIdx*/)[K], int32_t laneIdx, int32_t topK, + ParamsT const& /*params*/) + { + DataType minScore = DataType{-INFINITY}; + auto softmaxScore = calcSoftmax(warp, laneIdx < topK ? warpTopKScore[laneIdx] : minScore, laneIdx, topK); + if (laneIdx < topK) + { + warpTopKScore[laneIdx] = softmaxScore; + } + } +}; + +/// SumNormalize: divides each top-K score by the sum of all top-K scores. +/// Used when softmax has already been applied before topK selection. +struct SumNormalizePostprocess +{ + template + struct Params + { + bool normTopkProb = true; + + void set(routingCustom::Data const& data) + { + normTopkProb = data.mNormTopkProb; + } + }; + + template + __forceinline__ __device__ static void apply(cg::thread_block_tile const& warp, + DataType (&warpTopKScore)[K], int32_t const (& /*warpTopKExpertIdx*/)[K], int32_t laneIdx, int32_t topK, + ParamsT const& params) + { + float sum = float{1.f}; + if (params.normTopkProb) + { + sum = static_cast(laneIdx < topK ? warpTopKScore[laneIdx] : 0); + sum = cg::reduce(warp, sum, cg::plus()); + } + if (laneIdx < topK) + { + warpTopKScore[laneIdx] = warpTopKScore[laneIdx] / sum; + } + } +}; + +/// ScaledSumNormalize: recovers un-biased sigmoid scores by subtracting per-expert bias from the +/// selection scores (sigmoid + bias), then normalizes by sum and applies routeScale. +/// Used by DeepSeek-style routing: final_weight = sigmoid(raw) * routeScale / (sum + epsilon). +/// DeepSeek uses epsilon=0 (no guard); MiniMax2 uses epsilon=1e-20 to prevent division by zero. +struct ScaledSumNormalizePostprocess +{ + template + struct Params + { + // Store as void const* to support any bias dtype (float, bfloat16, etc.) without conversion. + void const* ptrRoutingBias = nullptr; + batchedGemm::trtllm::gen::Dtype dtypeBias = batchedGemm::trtllm::gen::Dtype::Bfloat16; + float routeScale = 1.0f; + float sumEpsilon = 0.0f; + + void set(routingCustom::Data const& data) + { + ptrRoutingBias = data.mPtrRoutingBias; + dtypeBias = data.mDtypeBias; + routeScale = data.mRouteScale; + sumEpsilon = data.mSumEpsilon; + } + }; + + template + __forceinline__ __device__ static void apply(cg::thread_block_tile const& warp, + DataType (&warpTopKScore)[K], int32_t const (&warpTopKExpertIdx)[K], int32_t laneIdx, int32_t topK, + ParamsT const& params) + { + // Recover sigmoid score: selection_score = sigmoid(raw) + bias, so sigmoid = score - bias + float biasVal + = laneIdx < topK ? loadScalar(params.ptrRoutingBias, warpTopKExpertIdx[laneIdx], params.dtypeBias) : 0.f; + float sigmoidScore = laneIdx < topK ? (static_cast(warpTopKScore[laneIdx]) - biasVal) : 0.f; + float sum = cg::reduce(warp, sigmoidScore, cg::plus()); + if (laneIdx < topK) + { + warpTopKScore[laneIdx] + = static_cast(sigmoidScore * params.routeScale / (sum + params.sumEpsilon)); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// ExpertSelectPolicy: encapsulates the entire expert selection logic. +// +// Each policy must provide: +// - template using BaseType +// The data type used for intermediate score computation. +// - template struct Params { void set(Data const&); } +// Policy-specific runtime data, populated from the host-side Data struct. +// Empty for policies that don't need extra data (zero register cost). +// - template +// static void apply(warp, warpTopKScore[K], warpTopKExpertIdx[K], laneIdx, numExperts, topK, +// ptrScores, params) +// Selects the top-K experts and computes their weights. +// +// The default TopKExpertSelect wraps existing PreprocessPolicy + PostprocessPolicy, +// but users can write completely custom policies that bypass the preprocess+topK+postprocess +// pattern (e.g., lookup-table-based expert selection). +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Default ExpertSelectPolicy: preprocess + topK reduction + postprocess. +/// Wraps existing PreprocessPolicy and PostprocessPolicy as internal composition. +template +struct TopKExpertSelect +{ + /// BaseType: delegated to the preprocess policy. + template + using BaseType = typename PreprocessPolicy_::template BaseType; + + /// Params: combines preprocess and postprocess runtime parameters. + template + struct Params + { + typename PreprocessPolicy_::template Params mPreprocessParams; + typename PostprocessPolicy_::template Params mPostprocessParams; + + void set(routingCustom::Data const& data) + { + mPreprocessParams.set(data); + mPostprocessParams.set(data); + } + }; + + /// Selects top-K experts using preprocess → topK reduction → postprocess. + template + __forceinline__ __device__ static void apply(cg::thread_block_tile const& warp, + DataType (&warpTopKScore)[K], int32_t (&warpTopKExpertIdx)[K], int32_t const laneIdx, int32_t const numExperts, + int32_t topK, InputType const* ptrScores, KP const& params) + { + DataType minScore = DataType{-INFINITY}; + DataType score[VecSize]; + int32_t idx[VecSize]; + + for (int i = 0; i < VecSize; i++) + { + auto expertIdx = i * WarpSize + laneIdx; + auto newScore = expertIdx < numExperts ? static_cast(ptrScores[expertIdx]) : minScore; + score[i] = newScore; + idx[i] = expertIdx; + } + + // Apply preprocess (e.g. softmax over all scores, sigmoid + bias, ...) + PreprocessPolicy_::apply(warp, score, idx, numExperts, params.mExpertSelectParams.mPreprocessParams); + + // Get the top-k scores and their corresponding expert indices + topk::reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, score, idx, minScore, topK); + + // Apply postprocess (e.g. renormalize, softmax over top-K, scaled renormalize, ...) + PostprocessPolicy_::apply( + warp, warpTopKScore, warpTopKExpertIdx, laneIdx, topK, params.mExpertSelectParams.mPostprocessParams); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace routingCustom +{ +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Expert-count tiers (must be multiples of WarpSize=32 and of 4). +// Each tier covers all values ≤ the tier constant. +static constexpr int NumExperts128Experts = 128; +static constexpr int NumExperts160Experts = 160; +static constexpr int NumExperts256Experts = 256; +static constexpr int NumExperts384Experts = 384; +static constexpr int NumExperts512Experts = 512; +static constexpr int NumExperts576Experts = 576; +static constexpr int NumExperts1024Experts = 1024; +static constexpr int MaxSupportedExperts = 2048; + +// TopK tiers (must be ≤ WarpSize=32). +static constexpr int NumTop4Experts = 4; +static constexpr int NumTop8Experts = 8; +static constexpr int NumTop16Experts = 16; +static constexpr int NumTop22Experts = 22; +static constexpr int MaxSupportedTopExperts = 32; + +static constexpr int NumThreads = 1024; +static constexpr int NumWarps = NumThreads / WarpSize; + +static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads; +static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps; + +static constexpr int BlockKernelMaxNumTokens = 4; +static constexpr int DynBlockKernelMaxNumTokens = 16; +static constexpr int DynBlockKernelMaxNumExperts = 512; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +int32_t constexpr getMaxNumExperts(int32_t numExperts) +{ + if (numExperts <= NumExperts128Experts) + { + return NumExperts128Experts; + } + else if (numExperts <= NumExperts160Experts) + { + return NumExperts160Experts; + } + else if (numExperts <= NumExperts256Experts) + { + return NumExperts256Experts; + } + else if (numExperts <= NumExperts384Experts) + { + return NumExperts384Experts; + } + else if (numExperts <= NumExperts512Experts) + { + return NumExperts512Experts; + } + else if (numExperts <= NumExperts576Experts) + { + return NumExperts576Experts; + } + else if (numExperts <= NumExperts1024Experts) + { + return NumExperts1024Experts; + } + else if (numExperts <= MaxSupportedExperts) + { + return MaxSupportedExperts; + } + else + { + TLLM_LOG_ERROR("Unsupported numExperts"); + return 0; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// TIER PAIR TYPES — compile-time (MaxNumExperts, MaxNumTopExperts) configuration. +// +// Each Tier declares a supported kernel instantiation. +// TierList, ...> is an ordered list tried from first to last. +// The dispatch picks the FIRST pair where numExperts ≤ E AND topK ≤ K. +// +// Pairs must be sorted so that tighter tiers come first: +// - Sort by E ascending, then by K ascending within equal E. +// - A config (numExperts, topK) always matches the tightest available pair. +// - If the tightest expert tier doesn't have a topK that covers the runtime topK, +// the dispatch falls through to the next larger expert tier that does. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Tier +{ + static constexpr int kExperts = E_; + static constexpr int kTopK = K_; +}; + +template +struct TierList +{ +}; + +// Recursive dispatch: try each tier in order, call `fn` with the first match. +// fn receives (integral_constant, integral_constant) as compile-time args. +// Base case: empty list — no match. +template +inline bool dispatchTierPairs(TierList<>*, Data const& /*data*/, Fn&& /*fn*/) +{ + return false; +} + +// Recursive case: check First, then recurse on Rest... +template +inline bool dispatchTierPairs(TierList*, Data const& data, Fn&& fn) +{ + if (data.mNumExperts <= First::kExperts && data.mTopK <= First::kTopK) + { + fn(std::integral_constant{}, std::integral_constant{}); + return true; + } + return dispatchTierPairs(static_cast*>(nullptr), data, std::forward(fn)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// POLICY TIER CONFIGURATION +// +// ┌─────────────────────────────────────────────────────────────────────────┐ +// │ HOW TO ADD A NEW ROUTING POLICY │ +// │ │ +// │ 1. Define PreProc/PostProc structs (above in this file) │ +// │ 2. Add PolicyTraits with tier list (below) │ +// │ 3. Add enum value to RoutingPreprocessType/RoutingPostprocessType │ +// │ in RoutingKernel.h (if new enum needed) │ +// │ 4. Add an else-if branch to dispatchRoutingPolicy() (bottom of file) │ +// │ — LAUNCH_ROUTING_CUSTOM and queryDispatchedMaxExperts │ +// │ automatically pick it up │ +// │ 5. Set the policy in runner.cu for the routing method │ +// └─────────────────────────────────────────────────────────────────────────┘ +// +// PolicyTraits::Pairs declares the supported (expert, topK) pairs. +// Only these pairs are compiled as kernel instantiations. +// To add support for a new model config, add a Tier to the appropriate TierList. +// +// THREAD-COUNT SAFETY: LAUNCH_ROUTING_FOR_POLICY automatically clamps the launch thread +// count to at least min(MaxNumExperts, 1024) from the dispatched tier. This prevents +// mismatches when a policy's smallest tier is larger than getMaxNumExperts() returns for +// the same numExperts (e.g., 72 experts → getMaxNumExperts returns 128, but a policy +// whose smallest tier is 256 would produce MaxNumExperts=256). See the comment on +// LAUNCH_ROUTING_FOR_POLICY for details. +// +// ┌──────────────────────────────────────────────────────────────────────────────┐ +// │ Policy (PreProc + PostProc) Supported pairs │ +// ├──────────────────────────────────────────────────────────────────────────────┤ +// │ Softmax + None (Default) (128,8) │ +// │ None + Softmax (Renormalize) (128,4) (128,8) (160,8) (256,8) │ +// │ (256,16) (512,8) (512,16) │ +// │ (512,22) (576,8) (2048,32) │ +// │ Sigmoid + SumNorm (SigmoidRenorm) (128,8) │ +// │ SigmoidBias + ScaleS (DS nGroup≤1) (128,8) (256,8) (384,8) (512,8) │ +// │ (512,22) │ +// │ Softmax + SumNorm (RenormNaive) (128,4) (128,8) (256,8) (512,8) │ +// │ (2048,8) │ +// │ None + None (fallback) (128,8) │ +// └──────────────────────────────────────────────────────────────────────────────┘ +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Default: fallback for new/unknown policies. +/// Provides K8 (tight for most models) + K32 (catch-all for high topK) at each common expert tier. +/// Omits 160/384/576 — those are model-specific and handled by explicit specializations. +/// If a new policy needs a tighter tier, add a PolicyTraits specialization. +template +struct PolicyTraits +{ + using Pairs = TierList, Tier<128, 32>, Tier<256, 8>, Tier<256, 32>, Tier<512, 8>, Tier<512, 32>, + Tier<2048, 8>, Tier<2048, 32>>; +}; + +/// Softmax + None (Default): single config. +template <> +struct PolicyTraits +{ + using Pairs = TierList>; +}; + +/// None + Softmax (Renormalize): many model configs. +template <> +struct PolicyTraits +{ + using Pairs + = TierList, // Mixtral 8x7B (topK=2), Qwen2-MoE (topK=4), Arctic (topK=2), DBRX (topK=4), GPT-OSS + Tier<128, 8>, // DeepSeek-V2-Lite (topK=6), Mixtral 8x22B (topK=2) + Tier<160, 8>, // Qwen3-Coder-480B + Tier<256, 8>, // Mistral Large 3 (topK=8) + Tier<256, 16>, // Models with 256 experts and topK 9..16 + Tier<512, 8>, // Various 512-expert models + Tier<512, 16>, // Various 512-expert models with high topK + Tier<512, 22>, // Nemotron Super V3 (512 experts, topK=22) + Tier<576, 8>, // Customized model with 576 experts + Tier<2048, 32> // Large-expert fallback + >; +}; + +/// Sigmoid + SumNormalize (SigmoidRenorm): single config. +template <> +struct PolicyTraits +{ + using Pairs = TierList>; +}; + +/// SigmoidBias + ScaledSumNormalize (DeepSeek nGroup≤1 / MiniMax2 / Kimi-K2 / Nemotron SuperV3). +template <> +struct PolicyTraits +{ + using Pairs = TierList, // Small expert counts (≤128 experts, e.g. DeepSeek-V2-Lite) + Tier<256, 8>, // MiniMax M2 (256 experts, topK=6) + Tier<384, 8>, // Kimi K2 (384 experts) + Tier<512, 8>, // DeepSeek nGroup≤1 (256 experts → E512 fallback) + Tier<512, 22>, // Nemotron Super V3 (512 experts, topK=22, nGroup≤1) + Tier<1024, 32> // Default fallback (expert count may grow beyond 512) + >; +}; + +/// Softmax + SumNormalize (RenormalizeNaive): no specialization needed. +/// At runtime, RenormalizeNaive is always converted to the Renormalize path +/// (None + Softmax) by the runner, so this policy is never dispatched. +/// If it ever is, the default PolicyTraits provides broad fallback coverage. + +/// None + None (fallback for unknown preprocess/postprocess in LAUNCH_ROUTING_CUSTOM). +template <> +struct PolicyTraits +{ + using Pairs = TierList>; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// EXAMPLE: Custom ExpertSelectPolicy that bypasses the PreProc→topK→PostProc pattern. +// +// To enable it: +// 1. Uncomment the struct and PolicyTraits below. +// 2. Add an enum value (e.g., RoutingPreprocessType::FirstK) in RoutingKernel.h. +// 3. Add a branch in LAUNCH_ROUTING_CUSTOM that calls LAUNCH_ROUTING_FOR_EXPERT_SELECT. +// 4. Set the enum in runner.cu for the desired routing method type. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/* +struct FirstKExpertSelect +{ + template using BaseType = float; + template struct Params { void set(routingCustom::Data const&) {} }; + + template + __forceinline__ __device__ static void apply(cg::thread_block_tile const&, + DataType (&warpTopKScore)[K], int32_t (&warpTopKExpertIdx)[K], int32_t const laneIdx, + int32_t const, int32_t topK, InputType const*, KP const&) + { + if (laneIdx < topK) + { + warpTopKExpertIdx[laneIdx] = laneIdx; + warpTopKScore[laneIdx] = static_cast(1.0f / topK); + } + } +}; + +template <> struct PolicyTraits +{ + using Pairs = TierList>; +}; +*/ + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// GENERIC DISPATCH MACROS +// +// These macros are fixed infrastructure — they never need editing when adding new +// policies or changing tier support. All configuration lives in PolicyTraits above. +// +// The dispatch iterates PolicyTraits::Pairs (a TierList) via dispatchTierPairs. +// A generic lambda captures the kernel name (macro requirement) and receives +// (expert, topK) as compile-time integral_constants. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Generic per-policy dispatch. Iterates PolicyTraits::Pairs, +// picking the first (expert, topK) pair that covers the runtime values. +// +// IMPORTANT: numThreads is clamped to at least min(MaxNumExperts, 1024) from the dispatched tier. +// Many routing kernels derive their internal NumThreadsBlock from MaxNumExperts and use it for +// grid-stride addressing, initArr strides, and cub::BlockScan. If the caller's numThreads +// (typically getMaxNumExperts(mNumExperts)) is smaller than the tier's MaxNumExperts, the kernel +// would compute wrong indices, skip initialization, and corrupt memory. The max() below +// guarantees the launch thread count always matches or exceeds the kernel's NumThreadsBlock: +// - "derive from tier" kernels: numThreadsHist < MaxNumExperts → bumped to MaxNumExperts ✓ +// - "fixed 1024" kernels (cluster): numThreads=1024 ≥ MaxNumExperts → unchanged ✓ +#define LAUNCH_ROUTING_FOR_POLICY( \ + data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, PreProc, PostProc) \ + [&](auto pt_tag_) \ + { \ + using Pairs_ = typename decltype(pt_tag_)::Pairs; \ + bool dispatched_ = dispatchTierPairs(static_cast(nullptr), data, \ + [&](auto eTag_, auto kTag_) \ + { \ + constexpr int tierMaxExp_ = decltype(eTag_)::value; \ + constexpr int tierThreads_ = tierMaxExp_ <= 1024 ? tierMaxExp_ : 1024; \ + int const effectiveThreads_ = std::max(static_cast(numThreads), tierThreads_); \ + LAUNCH_ROUTING_WITH_POLICIES(data, coopLaunch, kernel, numBlocks, effectiveThreads_, smemSize, stream, \ + PreProc, PostProc, decltype(eTag_)::value, decltype(kTag_)::value); \ + }); \ + if (!dispatched_) \ + { \ + TLLM_LOG_ERROR("No tier covers numExperts=%d topK=%d", data.mNumExperts, data.mTopK); \ + } \ + }(PolicyTraits{}) + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// CUSTOM EXPERT SELECT DISPATCH +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Generic dispatch for custom ExpertSelectPolicy. PolicyTraits key is . +// Same numThreads clamping as LAUNCH_ROUTING_FOR_POLICY — see comment above. +#define LAUNCH_ROUTING_FOR_EXPERT_SELECT( \ + data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, ExpertSelect) \ + [&](auto pt_tag_) \ + { \ + using Pairs_ = typename decltype(pt_tag_)::Pairs; \ + bool dispatched_ = dispatchTierPairs(static_cast(nullptr), data, \ + [&](auto eTag_, auto kTag_) \ + { \ + constexpr int tierMaxExp_ = decltype(eTag_)::value; \ + constexpr int tierThreads_ = tierMaxExp_ <= 1024 ? tierMaxExp_ : 1024; \ + int const effectiveThreads_ = std::max(static_cast(numThreads), tierThreads_); \ + LAUNCH_ROUTING_WITH_EXPERT_SELECT(data, coopLaunch, kernel, numBlocks, effectiveThreads_, smemSize, \ + stream, ExpertSelect, decltype(eTag_)::value, decltype(kTag_)::value); \ + }); \ + if (!dispatched_) \ + { \ + TLLM_LOG_ERROR("No tier covers numExperts=%d topK=%d", data.mNumExperts, data.mTopK); \ + } \ + }(PolicyTraits{}) + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// PUBLIC DISPATCH MACROS +// +// These are the only macros that call sites use. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Lightweight dispatch for utility kernels (histogram, init-counts, offsets) that do NOT use +// expert select policies, InputT, or MaxNumTopExperts. +// - Always uses NoOp expert select (no policy dispatch). +// - Always uses a fixed NumTop8Experts (no topK-tier dispatch). +// - Dispatches only on expert tiers. +// This is intentionally NOT routed through LAUNCH_ROUTING_FOR_POLICY to avoid +// instantiating all topK tiers — utility kernels don't use topK at all. +#define LAUNCH_ROUTING_CUSTOM_NO_POLICY(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \ + if (data.mNumExperts <= NumExperts128Experts) \ + { \ + LAUNCH_ROUTING_WITH_POLICIES(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + NoOpPreprocess, NoOpPostprocess, NumExperts128Experts, NumTop8Experts); \ + } \ + else if (data.mNumExperts <= NumExperts160Experts) \ + { \ + LAUNCH_ROUTING_WITH_POLICIES(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + NoOpPreprocess, NoOpPostprocess, NumExperts160Experts, NumTop8Experts); \ + } \ + else if (data.mNumExperts <= NumExperts256Experts) \ + { \ + LAUNCH_ROUTING_WITH_POLICIES(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + NoOpPreprocess, NoOpPostprocess, NumExperts256Experts, NumTop8Experts); \ + } \ + else if (data.mNumExperts <= NumExperts384Experts) \ + { \ + LAUNCH_ROUTING_WITH_POLICIES(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + NoOpPreprocess, NoOpPostprocess, NumExperts384Experts, NumTop8Experts); \ + } \ + else if (data.mNumExperts <= NumExperts512Experts) \ + { \ + LAUNCH_ROUTING_WITH_POLICIES(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + NoOpPreprocess, NoOpPostprocess, NumExperts512Experts, NumTop8Experts); \ + } \ + else if (data.mNumExperts <= NumExperts576Experts) \ + { \ + LAUNCH_ROUTING_WITH_POLICIES(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + NoOpPreprocess, NoOpPostprocess, NumExperts576Experts, NumTop8Experts); \ + } \ + else if (data.mNumExperts <= NumExperts1024Experts) \ + { \ + LAUNCH_ROUTING_WITH_POLICIES(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + NoOpPreprocess, NoOpPostprocess, NumExperts1024Experts, NumTop8Experts); \ + } \ + else if (data.mNumExperts <= MaxSupportedExperts) \ + { \ + LAUNCH_ROUTING_WITH_POLICIES(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + NoOpPreprocess, NoOpPostprocess, MaxSupportedExperts, NumTop8Experts); \ + } \ + else \ + { \ + TLLM_LOG_ERROR("Unsupported numExperts"); \ + } + +// Single source of truth for runtime → compile-time policy dispatch. +// Maps (mPreprocessType, mPostprocessType) to compile-time (PreProc, PostProc) policy types. +// The callback `fn` receives instances of the policy types (e.g., SigmoidBiasPreprocess{}). +// Both LAUNCH_ROUTING_CUSTOM and queryDispatchedMaxExperts use this function, +// so they are always in sync. See "HOW TO ADD A NEW ROUTING POLICY" above. +template +inline void dispatchRoutingPolicy(Data const& data, Fn&& fn) +{ + if (data.mPreprocessType == RoutingPreprocessType::SigmoidBias) + fn(SigmoidBiasPreprocess{}, ScaledSumNormalizePostprocess{}); + else if (data.mPreprocessType == RoutingPreprocessType::Sigmoid) + fn(SigmoidPreprocess{}, SumNormalizePostprocess{}); + else if (data.mPreprocessType == RoutingPreprocessType::Softmax + && data.mPostprocessType == RoutingPostprocessType::None) + fn(SoftmaxPreprocess{}, NoOpPostprocess{}); + else if (data.mPreprocessType == RoutingPreprocessType::Softmax) + fn(SoftmaxPreprocess{}, SumNormalizePostprocess{}); + else if (data.mPostprocessType == RoutingPostprocessType::Softmax) + fn(NoOpPreprocess{}, SoftmaxPostprocess{}); + else + fn(NoOpPreprocess{}, NoOpPostprocess{}); +} + +// Query the MaxNumExperts that the policy tier dispatch would select for the given data. +inline int32_t queryDispatchedMaxExperts(Data const& data) +{ + int32_t result = getMaxNumExperts(data.mNumExperts); + dispatchRoutingPolicy(data, + [&](auto preProc, auto postProc) + { + using Pairs = typename PolicyTraits::Pairs; + dispatchTierPairs( + static_cast(nullptr), data, [&](auto eTag, auto /*kTag*/) { result = decltype(eTag)::value; }); + }); + return result; +} + +// Top-level dispatch: maps runtime preprocess/postprocess enums to compile-time policy types, +// then delegates to LAUNCH_ROUTING_FOR_POLICY which reads PolicyTraits for tier support. +#define LAUNCH_ROUTING_CUSTOM(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \ + dispatchRoutingPolicy(data, \ + [&](auto preProc_, auto postProc_) \ + { \ + LAUNCH_ROUTING_FOR_POLICY(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + decltype(preProc_), decltype(postProc_)); \ + }) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace routingCustom +} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingDeepSeek.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingDeepSeek.cu new file mode 100644 index 000000000000..60bf83030fd2 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingDeepSeek.cu @@ -0,0 +1,621 @@ +/* + * Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// DeepSeek routing: entry point, constants, dispatch macros, kernel definitions, and launch wrappers. +// +// Kernel inventory: +// 1. routingMainKernel — DeepSeek-specific main kernel (sigmoid + bias + group TopK) +// 2. routingIndicesClusterKernel — single-cluster fused kernel (SM90+) +// 3. launchCoopKernel — delegates to routingCustom's coop implementation +// 4. launchInitExpertCounts — zero expert counts +// 5. launchHistogramKernel — histogram from packed TopK +// 6. launchOffsetsKernel — prefix-scan + permutation + +#include "RoutingCustomPolicy.cuh" +#include "RoutingKernel.cuh" + +namespace moe::dev::routing +{ + +// Forward declaration of routingCustom's coop kernel (used by DeepSeek's coop path) +namespace routingCustom +{ +void launchCoopKernel(Data const& data, int numBlocksCoop, uint32_t numThreadsHist, void* stream); +} // namespace routingCustom + +namespace routingDeepSeek +{ + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Constants and dispatch macros +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static constexpr int NumNemotronExperts = 512; +static constexpr int NumKimiK2Experts = 384; +static constexpr int NumDeepseekExperts = 256; +static constexpr int NumExperts1024Experts = 1024; +static constexpr int MaxSupportedExpertCount + = std::max({NumExperts1024Experts, NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts}); +static constexpr int NumTopGroupScores = 2; +static constexpr int MaxNumTopGroups = 4; +static constexpr int MaxNumGroups = 8; + +static constexpr int NumTop8Experts = 8; +static constexpr int NumTop22Experts = 22; +static constexpr int MaxSupportedTopExperts = 32; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +int constexpr getMaxNumExperts(int32_t numExperts) +{ + if (numExperts <= topk::MaxNumExpertsUnit) + { + return topk::MaxNumExpertsUnit; + } + else if (numExperts <= NumDeepseekExperts) + { + return NumDeepseekExperts; + } + else if (numExperts <= NumKimiK2Experts) + { + return NumKimiK2Experts; + } + else if (numExperts <= NumNemotronExperts) + { + return NumNemotronExperts; + } + else if (numExperts <= NumExperts1024Experts) + { + return NumExperts1024Experts; + } + else + { + TLLM_LOG_ERROR("Unsupported numExperts"); + return 0; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Helper macro: dispatch on topK tier for a given numExperts tier. +#define LAUNCH_DEEPSEEK_WITH_TOPK( \ + data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, forceFloatInput, numExperts) \ + if (data.mTopK <= NumTop8Experts) \ + { \ + LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag1, forceFloatInput, numExperts, NumTop8Experts); \ + } \ + else if (data.mTopK <= NumTop22Experts) \ + { \ + LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag1, forceFloatInput, numExperts, NumTop22Experts); \ + } \ + else \ + { \ + LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag1, forceFloatInput, numExperts, MaxSupportedTopExperts); \ + } + +#define LAUNCH_ROUTING_DEEPSEEK( \ + data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, forceFloatInput) \ + if (data.mNumExperts <= topk::MaxNumExpertsUnit) \ + { \ + LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ + forceFloatInput, topk::MaxNumExpertsUnit); \ + } \ + else if (data.mNumExperts <= NumDeepseekExperts) \ + { \ + LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ + forceFloatInput, NumDeepseekExperts); \ + } \ + else if (data.mNumExperts <= NumKimiK2Experts) \ + { \ + LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ + forceFloatInput, NumKimiK2Experts); \ + } \ + else if (data.mNumExperts <= NumNemotronExperts) \ + { \ + LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ + forceFloatInput, NumNemotronExperts); \ + } \ + else if (data.mNumExperts <= NumExperts1024Experts) \ + { \ + LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ + forceFloatInput, NumExperts1024Experts); \ + } \ + else \ + { \ + TLLM_LOG_ERROR("Unsupported numExperts"); \ + } + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// 1. Main kernel — DeepSeek-specific routing with sigmoid activation, bias, and group TopK. +// Handles both grouped and non-grouped expert selection. +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void routingMainKernel(KernelParams params) +{ + // declare types + using OutputT = typename KernelParams::OutputT; + using InputT = typename KernelParams::InputT; + + // declare shared memory structure + // number of experts is bounded by number of threads + __shared__ float __attribute((aligned(128))) smemScoreSigmoid[KernelParams::MaxNumExperts]; + __shared__ float __attribute((aligned(128))) smemScoreBias[KernelParams::MaxNumExperts]; + // number of expert groups is bounded by number of warps + __shared__ float __attribute((aligned(128))) smemGroupScores[MaxNumGroups]; + + // needed for warp reduce + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + // for the final reduction of weight norm, only some lanes need to participate + int32_t laneIdx = threadIdx.x % WarpSize; + int32_t warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); + // note that for invalid scores, we simply use a negative value: + // they work well even with the compacted format used in topK, and + // sigmoid / bias activated scores cannot be negative + static constexpr float invalidScoreFloat = float{-INFINITY}; + const OutputT invalidScore = OutputT{invalidScoreFloat}; + + // load bias already; each warp represents one expert group + auto threadExpert = threadIdx.x; + bool expertSelected = threadExpert < params.mNumExperts; + if constexpr (KernelParams::UseGroups) + { + threadExpert = warpIdx * params.mNumExpertsPerGroup + laneIdx; + // Inactive warps (warpIdx >= mNumExpertGroups) must NOT return early because they + // still need to reach the __syncthreads() barriers below. Setting expertSelected + // to false is enough to keep them from doing any out-of-bounds reads or smem writes. + expertSelected = (warpIdx < params.mNumExpertGroups) && (laneIdx < params.mNumExpertsPerGroup); + } + auto scoreIdx = int64_t{blockIdx.x} * int64_t{params.mNumExperts} + threadExpert; + auto biasVal = expertSelected + ? static_cast(loadScalar(params.mPtrRoutingBias, threadExpert, params.mDtypeBias)) + : invalidScore; + + // initialize the mPtrExpertCounts + if (params.mPtrExpertCounts) + { + int32_t globalThreadIdx = blockIdx.x * blockDim.x + threadIdx.x; + int32_t globalThreadStride = gridDim.x * blockDim.x; + int32_t expertCountsNum = 2 * params.mNumExperts; + initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0); + } + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + // trigger the secondary kernel when using PDL, then wait on primary + if (params.mUsePdl) + { + cudaTriggerProgrammaticLaunchCompletion(); + cudaGridDependencySynchronize(); + } +#endif + + if (params.mPtrScores != nullptr) + { + // get our assigned thread score; each warp represents one expert group + float score = expertSelected ? static_cast(params.mPtrScores[scoreIdx]) : invalidScoreFloat; + // get the sigmoid score + // note that for invalid values, we simply use a negative value: + // sigmoig scores are always strictly positive + auto scoreSigmoid = sigmoid_accurate(score); + // write the sigmoid score to shared for later use + if (expertSelected) + { + smemScoreSigmoid[threadExpert] = scoreSigmoid; + } + // get the score with bias + // note that with invalid values, because sigmoid is < 1 and bias is -1, + // we must get a negative value, which is smaller than any valid value + auto scoreBias = float{scoreSigmoid + float{biasVal}}; + + if (expertSelected) + { + smemScoreBias[threadExpert] = scoreBias; + } + + // registers for top group score reduction + float topExpGroupScores[NumTopGroupScores]; + [[maybe_unused]] int32_t topExpGroupIdx[NumTopGroupScores]; + float topGroups[MaxNumTopGroups]; // bound of params.mNumLimitedGroups + int32_t topGroupIdx[MaxNumTopGroups]; + float expertScoreGroup[MaxNumTopGroups]; + int32_t expertIdxGroup[MaxNumTopGroups]; + float topScores[KernelParams::MaxNumTopExperts]; // bound of params.mTopK + int32_t topExperts[KernelParams::MaxNumTopExperts]; + + if constexpr (KernelParams::UseGroups) + { + topk::reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias, threadExpert, + /* minValue */ invalidScoreFloat); + // get the final group score and write it to shared + if (cute::elect_one_sync()) + { + auto groupScore = topExpGroupScores[0] + topExpGroupScores[1]; + smemGroupScores[warpIdx] = groupScore; + } + } + + // make group scores available to all warps + __syncthreads(); + + auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; + if constexpr (KernelParams::UseGroups) + { // a single warp performs the selection of top groups, and goes on to select the final experts + if (warpIdx == 0) + { + float groupScore = laneIdx < params.mNumExpertGroups ? smemGroupScores[laneIdx] : invalidScoreFloat; + topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx, + /* minValue */ invalidScoreFloat); + // final expert selection: get relevant indexes and scores from shared +#pragma unroll + for (int ii = 0; ii < MaxNumTopGroups; ++ii) + { // bound of params.mNumLimitedGroups + auto groupIdx = topGroupIdx[ii]; + expertIdxGroup[ii] = groupIdx * params.mNumExpertsPerGroup + laneIdx; + // note: expertSelected implies laneIdx < params.mNumExpertsPerGroup. + // we have params.mNumExpertsPerGroup == params.mNumExperts / params.mNumExpertGroups, + // thus groupIdx <= params.mNumExpertGroups - 1 => + // groupIdx * params.mNumExpertsPerGroup <= params.mNumExperts - params.mNumExpertsPerGroup + // => expertIdxGroup[ii] < params.mNumExperts <= NumThreads, + // so the access is safe here + expertScoreGroup[ii] + = (ii < params.mNumLimitedGroups) && (groupIdx < params.mNumExpertGroups) && expertSelected + ? smemScoreBias[expertIdxGroup[ii]] + : invalidScoreFloat; + } + + topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, + /* minValue */ invalidScoreFloat, params.mTopK); + } + } + else if constexpr (KernelParams::MaxNumExperts > topk::MaxNumExpertsUnit) + { + // without groups, each thread just takes `MaxNumTopGroups` experts + int constexpr NumExpertWarps = (KernelParams::MaxNumExperts - 1) / topk::MaxNumExpertsUnit + 1; + int constexpr NumInterTopK = NumExpertWarps * KernelParams::MaxNumTopExperts; + __shared__ float __attribute((aligned(128))) smemInterTopScores[NumInterTopK]; + __shared__ int32_t __attribute((aligned(128))) smemInterTopExperts[NumInterTopK]; + if (warpIdx < NumExpertWarps) + { + int offset = warpIdx * WarpSize * MaxNumTopGroups; +#pragma unroll + for (int ii = 0; ii < MaxNumTopGroups; ++ii) + { + auto expertIdx = ii * WarpSize + laneIdx; + expertIdxGroup[ii] = offset + expertIdx; + expertScoreGroup[ii] = offset + expertIdx < params.mNumExperts ? smemScoreBias[offset + expertIdx] + : invalidScoreFloat; + } + topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, + /* minValue */ invalidScoreFloat, params.mTopK); + + if (laneIdx < params.mTopK) + { + smemInterTopScores[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = topScores[laneIdx]; + smemInterTopExperts[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = topExperts[laneIdx]; + } + else if (laneIdx >= params.mTopK && laneIdx < KernelParams::MaxNumTopExperts) + { + smemInterTopScores[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = invalidScoreFloat; + smemInterTopExperts[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] + = MaxSupportedExpertCount - 1; + } + } + __syncthreads(); + if (warpIdx == 0) + { + int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WarpSize + 1; + float intermediateScore[NumInterTopKPerThread]; + int32_t intermediateExpert[NumInterTopKPerThread]; + for (int i = laneIdx; i < NumInterTopKPerThread * WarpSize; i += WarpSize) + { + int ii = i / WarpSize; + if (i < NumInterTopK) + { + intermediateScore[ii] = smemInterTopScores[i]; + intermediateExpert[ii] = smemInterTopExperts[i]; + } + else + { + intermediateScore[ii] = invalidScoreFloat; + intermediateExpert[ii] = KernelParams::MaxNumExperts - 1; + } + } + topk::reduceTopK(warp, topScores, topExperts, intermediateScore, intermediateExpert, + /* minValue */ invalidScoreFloat, params.mTopK); + } + } + else + { + if (warpIdx == 0) + { + // without groups, each thread just takes `MaxNumTopGroups` experts +#pragma unroll + for (int ii = 0; ii < MaxNumTopGroups; ++ii) + { + auto expertIdx = ii * WarpSize + laneIdx; + expertIdxGroup[ii] = expertIdx; + expertScoreGroup[ii] + = expertIdx < params.mNumExperts ? smemScoreBias[expertIdx] : invalidScoreFloat; + } + topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, + /* minValue */ invalidScoreFloat, params.mTopK); + } + } + + if (warpIdx == 0) + { + // determine our lane's expert index and write to output + int32_t expertIdx = 0; +#pragma unroll + for (int ii = 0; ii < params.mTopK; ++ii) + { // bound of params.mTopK + expertIdx = laneIdx == ii ? topExperts[ii] : expertIdx; + } + // determine whether our expert is local to this GPU + auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx; + auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent + && (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; + + float scoreNorm = laneIdx < params.mTopK ? smemScoreSigmoid[expertIdx] : 0.F; + auto redNorm = cg::reduce(warp, scoreNorm, cg::plus{}); + auto finalScore = OutputT{scoreNorm * params.mRouteScale / redNorm}; + + // write expert idx out already + auto idxTopK = blockIdx.x * params.mTopK + laneIdx; + if (laneIdx < params.mTopK && params.mPtrTopKPacked != nullptr) + { + PackedScoreIdx packedScore{static_cast(finalScore), static_cast(expertIdx)}; + params.mPtrTopKPacked[idxTopK] = packedScore; + } + + if (laneIdx < params.mTopK && params.mPtrTopKWeights != nullptr && params.mPtrTopKIds == nullptr) + { + params.mPtrTopKWeights[idxTopK] = finalScore; + } + } + } +} + +static void launchMainKernel(Data& data, int numBlocks, int numThreadsMain, void* stream) +{ + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingMainKernel, numBlocks, numThreadsMain, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// 2. Cluster kernel — single-cluster fused kernel (SM90+). +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(KernelParams::MaxNumExperts) + routingIndicesClusterKernel(KernelParams params) +{ + using OutputT = typename KernelParams::OutputT; + + int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); + int32_t const clusterBlockRank = blockIdx.x; + + //@todo: try to move it into routingPermutation + if (params.mUsePdl) + { + cudaGridDependencySynchronize(); + } + routingPermutation(params, nullptr, warpIdx, clusterBlockRank); +} +#else +__global__ void routingIndicesClusterKernel(KernelParams params) +{ + assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures"); +} +#endif + +static void launchClusterKernel(Data& data, int numThreadsHist, void* stream) +{ + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingIndicesClusterKernel, NumBlocksPerCluster, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// 3-6. Launch wrappers for shared kernels. +// Coop delegates to routingCustom; others use LAUNCH_ROUTING_DEEPSEEK macro. +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static void launchCoopKernel(Data& data, int numBlocksCoop, int /*numThreadsHist*/, void* stream) +{ + // Use routingCustom's coop kernel implementation (they are identical). + // Convert DeepSeek Data to Custom Data for launching. + routingCustom::Data customData; + // Copy base fields + static_cast(customData) = static_cast(data); + // Set routingCustom-specific defaults (not needed for coop kernel) + customData.mDtypeOutput = data.mDtypeOutput; + // The coop kernel doesn't read routing logits (mPtrInput), only mPtrTopKPacked. + // Set mDtypeInput = mDtypeOutput so the dispatched template is , + // avoiding an unnecessary mixed-type instantiation. + customData.mDtypeInput = data.mDtypeOutput; + customData.mPreprocessType = RoutingPreprocessType::None; + customData.mPostprocessType = RoutingPostprocessType::Softmax; + + // Recompute numThreadsHist using routingCustom's expert tiers (128, 512, 2048), + // since the custom coop kernel dispatch selects template parameters based on these tiers. + // DeepSeek's getMaxNumExperts uses different tiers (256, 384, 512) which would mismatch. + uint32_t const customNumThreadsHist + = std::min(1024u, static_cast(routingCustom::getMaxNumExperts(data.mNumExperts))); + routingCustom::launchCoopKernel(customData, numBlocksCoop, customNumThreadsHist, stream); +} + +static void launchHistogramKernel(Data& data, int numBlocksHistogram, int numThreadsHist, void* stream) +{ + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingIndicesHistogramKernel, numBlocksHistogram, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); +} + +static void launchOffsetsKernel(Data& data, int numBlocksOffsets, int numThreadsHist, void* stream) +{ + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Entry point +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run(Data& data, void* stream) +{ + TLLM_CHECK_WITH_INFO(data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr, + "Routing kernel requires at least one input parameter"); + + // When topK is already computed (mPtrTopKIds or mPtrTopKPacked without scores), + // delegate to the shared post-topK pipeline which handles all path selection + // (single-block, single-cluster, coop, multi-kernel) automatically. + // No routing-method-specific logic needed. + if (data.mPtrTopKIds != nullptr || (data.mPtrTopKPacked != nullptr && data.mPtrScores == nullptr)) + { + if (data.mPtrTopKIds != nullptr) + { + TLLM_CHECK_WITH_INFO(data.mPtrTopKWeights != nullptr, + "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for DeepSeek routing."); + } + int const numThreadsHist = getMaxNumExperts(data.mNumExperts); + runPostTopKPipeline(data, numThreadsHist, stream); + return; + } + + // After this point, input is mPtrScores (raw logits that need DeepSeek-specific routing). + TLLM_CHECK_WITH_INFO(!data.mUseRoutingSoftmax, "Routing with softmax not implemented yet"); + TLLM_CHECK_WITH_INFO(data.mNumExperts >= data.mTopK, "Routing kernel expects topK (%d) to be <= numExperts (%d)", + data.mTopK, data.mNumExperts); + TLLM_CHECK_WITH_INFO(data.mNumExperts <= MaxSupportedExpertCount, + "Routing kernel expects #experts %d <= #threads %d", data.mNumExperts, MaxSupportedExpertCount); + TLLM_CHECK_WITH_INFO(data.mTopK <= MaxSupportedTopExperts, "Routing kernel expects topK experts <= %d, got %d", + MaxSupportedTopExperts, data.mTopK); + + if (data.mPtrExpandedIdxToPermutedIdx != nullptr || data.mPtrPermutedIdxToExpandedIdx != nullptr + || data.mPtrPermutedIdxToTokenIdx != nullptr) + TLLM_CHECK_WITH_INFO(data.mPtrTopKPacked != nullptr && data.mPtrPermutedIdxSize, + "If permuted index is required, `mPtrTopKPacked` is also required"); + + // Routing needs to be executed - validate routing kernel constraints + if (data.mNumExpertGroups > 1) + { + TLLM_CHECK_WITH_INFO(data.mNumExpertGroups <= MaxNumGroups, + "Routing kernel expects #expert groups %d to be <= max groups %d", data.mNumExpertGroups, MaxNumGroups); + TLLM_CHECK_WITH_INFO(data.mNumExperts % data.mNumExpertGroups == 0, + "Routing kernel expects #experts %d to be a multiple of #expert groups %d", data.mNumExperts, + data.mNumExpertGroups); + TLLM_CHECK_WITH_INFO(data.mNumExperts / data.mNumExpertGroups <= WarpSize, + "Routing kernel expects #experts per group <= warp size (%d), got %d experts / %d groups = %d experts " + "per group", + WarpSize, data.mNumExperts, data.mNumExpertGroups, data.mNumExperts / data.mNumExpertGroups); + TLLM_CHECK_WITH_INFO(data.mNumLimitedGroups <= MaxNumTopGroups, + "Routing kernel expects <= %d top groups, got %d", MaxNumTopGroups, data.mNumLimitedGroups); + + TLLM_CHECK_WITH_INFO(data.mNumExpertGroups >= data.mNumLimitedGroups, + "Routing kernel expects top groups %d to be limited by #expert groups %d", data.mNumLimitedGroups, + data.mNumExpertGroups); + TLLM_CHECK_WITH_INFO( + data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); + } + + int const numBlocks = data.mNumTokens; + int const numThreadsHist = getMaxNumExperts(data.mNumExperts); + static int const smMajor = tensorrt_llm::common::getSMVersion() / 10; + // Step 1: Run DeepSeek-specific topK computation (writes to mPtrTopKPacked) + int const numThreadsMain = max(data.mNumExpertGroups * WarpSize, getMaxNumExperts(data.mNumExperts)); + launchMainKernel(data, numBlocks, numThreadsMain, stream); + + // Step 2: Permutation pipeline (reads from mPtrTopKPacked written by step 1) + if (data.mPtrPermutedIdxSize != nullptr) + { + TLLM_CHECK_WITH_INFO(data.mPtrCtaIdxXyToBatchIdx != nullptr && data.mPtrCtaIdxXyToMnLimit != nullptr + && data.mPtrNumNonExitingCtas != nullptr, + "DeepSeek routing step 2 requires grouped-GEMM launch config buffers " + "(mPtrCtaIdxXyToBatchIdx, mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas)"); + + bool const useSingleCluster = (smMajor >= 9) && (data.mNumTokens <= 1024); + if (!useSingleCluster) + { + TLLM_CHECK_WITH_INFO( + data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input."); + } + else + { + data.mPtrExpertCounts = nullptr; // Set it to nullptr for single-cluster code path, as it won't be used + } + + // Number of blocks we can use in the cooperative kernel + static int const smCount = tensorrt_llm::common::getMultiProcessorCount(); + // WAR: Reserve 8 SMs for overlapping kernels. + int const numBlocksCoop = smCount - 8; + // Maximum number of tokens supported by the kernel using a cooperative launch. + int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; + + if (useSingleCluster) + { + launchClusterKernel(data, numThreadsHist, stream); + } + else if ((smMajor >= 9) && (data.mNumTokens <= maxTokensCoop)) + { + launchCoopKernel(data, numBlocksCoop, numThreadsHist, stream); + } + else + { + const int32_t expandedIdxSize = data.mNumTokens * data.mTopK; + const int32_t histogramEltsPerBlock = 8 * numThreadsHist; + const int32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist; + const int32_t maxNumBlocks = 1024; + + int const numBlocksHistogram + = std::min((expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks); + int const numBlocksOffsets + = std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); + + launchHistogramKernel(data, numBlocksHistogram, numThreadsHist, stream); + launchOffsetsKernel(data, numBlocksOffsets, numThreadsHist, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#undef LAUNCH_DEEPSEEK_WITH_TOPK +#undef LAUNCH_ROUTING_DEEPSEEK + +} // namespace routingDeepSeek +} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingDevKernel.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingDevKernel.h new file mode 100644 index 000000000000..91ff1c19d40e --- /dev/null +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingDevKernel.h @@ -0,0 +1,169 @@ +/* + * Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../DevKernel.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Routing-specific launch macros. +// These macros build on top of LAUNCH_ESC from DevKernel.h. +// +// Unlike the generic LAUNCH_PDL (which instantiates 2 kernels for UsePdl=true/false), +// LAUNCH_PDL_ROUTING instantiates only 1 kernel and passes UsePdl as a runtime field +// in KernelParams. This halves routing kernel instantiations. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define LAUNCH_PDL_ROUTING(data, coopLaunch, types, kernel, numBlocks, numThreads, smemSize, stream) \ + do \ + { \ + cudaLaunchConfig_t config{}; \ + config.gridDim = numBlocks; \ + config.blockDim = numThreads; \ + config.dynamicSmemBytes = smemSize; \ + config.stream = (cudaStream_t) stream; \ + \ + cudaLaunchAttribute attributes[2] = {}; \ + attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; \ + attributes[0].val.programmaticStreamSerializationAllowed = int(data.mUsePdl); \ + attributes[1].id = cudaLaunchAttributeCooperative; \ + attributes[1].val.cooperative = int(coopLaunch); \ + config.attrs = attributes; \ + config.numAttrs = 2; \ + auto params = KernelParams::setKernelParams(data); \ + auto kernelTyped = kernel>; \ + if (smemSize > 48 * 1024) \ + TLLM_CUDA_CHECK(cudaFuncSetAttribute(kernelTyped, cudaFuncAttributeMaxDynamicSharedMemorySize, smemSize)); \ + TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernelTyped, params)); \ + } while (0) + +// Llama4 dispatch: uses data.mDtypeOutput. +#define LAUNCH_ROUTING_LLAMA4(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \ + if (data.mDtypeOutput == tg::Dtype::Fp32) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, \ + LAUNCH_ESC(float, float, 128 /* Always 128 for llama4*/, 1 /* Always 1 for llama4*/), kernel, numBlocks, \ + numThreads, smemSize, stream); \ + } \ + else if (data.mDtypeOutput == tg::Dtype::Bfloat16) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, 128 /* Always 128 for llama4*/, 1 /* Always 1 for llama4*/), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } \ + else \ + { \ + TLLM_LOG_ERROR("Unsupported dtypeExpW"); \ + } + +// DeepSeek dispatch: uses data.mDtypeOutput. +#define LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag, forceFloatInput, numExperts, numTopExperts) \ + if (data.mDtypeOutput == tg::Dtype::Fp32 && extraFlag) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, true), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } \ + else if (data.mDtypeOutput == tg::Dtype::Fp32) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, false), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } \ + else if (data.mDtypeOutput == tg::Dtype::Bfloat16 && extraFlag && forceFloatInput) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, true), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } \ + else if (data.mDtypeOutput == tg::Dtype::Bfloat16 && extraFlag) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, true), kernel, numBlocks, numThreads, \ + smemSize, stream); \ + } \ + else if (data.mDtypeOutput == tg::Dtype::Bfloat16 && forceFloatInput) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, false), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } \ + else if (data.mDtypeOutput == tg::Dtype::Bfloat16) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, false), kernel, numBlocks, numThreads, \ + smemSize, stream); \ + } \ + else \ + { \ + TLLM_LOG_ERROR("Unsupported dtypeExpW"); \ + } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// routingCustom dispatch: uses data.mDtypeOutput (OutputT) and data.mDtypeInput (InputT). +// These are routingCustom::Data fields, NOT used by DeepSeek/Llama4 macros. +// Wraps (PreProc, PostProc) into TopKExpertSelect for the standard preprocess→topK→postprocess flow. +#define LAUNCH_ROUTING_WITH_POLICIES( \ + data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, PreProc, PostProc, numExperts, numTopExperts) \ + if (data.mDtypeOutput == tg::Dtype::Fp32) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, \ + LAUNCH_ESC(float, float, numExperts, numTopExperts, TopKExpertSelect), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } \ + else if (data.mDtypeOutput == tg::Dtype::Bfloat16 && data.mDtypeInput == tg::Dtype::Fp32) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, \ + LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, TopKExpertSelect), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } \ + else if (data.mDtypeOutput == tg::Dtype::Bfloat16) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, TopKExpertSelect), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } \ + else \ + { \ + TLLM_LOG_ERROR("Unsupported dtypeOutput"); \ + } + +// routingCustom dispatch for custom ExpertSelectPolicy types that don't use PreProc/PostProc. +// Use this when the policy does NOT follow the standard preprocess→topK→postprocess pattern. +// ExpertSelect must satisfy the ExpertSelectPolicy concept (see RoutingCustomPolicy.cuh). +#define LAUNCH_ROUTING_WITH_EXPERT_SELECT( \ + data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, ExpertSelect, numExperts, numTopExperts) \ + if (data.mDtypeOutput == tg::Dtype::Fp32) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, ExpertSelect), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } \ + else if (data.mDtypeOutput == tg::Dtype::Bfloat16 && data.mDtypeInput == tg::Dtype::Fp32) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, \ + LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, ExpertSelect), kernel, numBlocks, numThreads, \ + smemSize, stream); \ + } \ + else if (data.mDtypeOutput == tg::Dtype::Bfloat16) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, ExpertSelect), kernel, numBlocks, \ + numThreads, smemSize, stream); \ + } \ + else \ + { \ + TLLM_LOG_ERROR("Unsupported dtypeOutput"); \ + } + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingFromTopKIds.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingFromTopKIds.cu new file mode 100644 index 000000000000..0fb7e9e520de --- /dev/null +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingFromTopKIds.cu @@ -0,0 +1,145 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "RoutingCustomPolicy.cuh" +#include "RoutingKernel.cuh" +#include "RoutingKernel.h" +#include + +namespace moe::dev::routing +{ +namespace routingCustom +{ +// Forward declarations of launch functions +void launchBlockKernel(Data const& data, uint32_t numThreadsHist, void* stream); +void launchDynBlockKernel(Data const& data, uint32_t numThreadsHist, void* stream); +void launchClusterKernel(Data const& data, void* stream); +void launchCoopKernel(Data const& data, int numBlocksCoop, uint32_t numThreadsHist, void* stream); +void launchInitExpertCounts(Data const& data, uint32_t numThreadsHist, void* stream); +void launchHistogramKernel(Data const& data, int numBlocksHistogram, uint32_t numThreadsHist, void* stream); +void launchOffsetsKernel(Data const& data, int numBlocksOffsets, uint32_t numThreadsHist, void* stream); +} // namespace routingCustom + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Implementation of shared post-topK pipeline for all routing methods. +// When topK is already computed (mPtrTopKIds or mPtrTopKPacked), we don't need +// routing-method-specific logic, so all methods can use the same workflow. +// This function handles all path selection: single-block, single-cluster, coop, multi-kernel. +template +void runPostTopKPipeline(DataType const& data, uint32_t /*numThreadsHist*/, void* stream) +{ + // Convert to routingCustom::Data for launching (kernels are shared) + routingCustom::Data customData; + // Copy base fields + static_cast(customData) = static_cast(data); + // Set routingCustom-specific defaults (not needed for utility kernels) + customData.mDtypeOutput = data.mDtypeOutput; + // The post-TopK kernels don't read routing logits (mPtrInput), only mPtrTopKPacked. + // Set mDtypeInput = mDtypeOutput so the dispatched template is , + // avoiding an unnecessary mixed-type instantiation. + customData.mDtypeInput = data.mDtypeOutput; + customData.mPreprocessType = RoutingPreprocessType::None; + // Softmax is chosen for its broad tier coverage, not because we need softmax. + // The TopKIds/TopKPacked branches never call ExpertSelectPolicy::apply(), + // so the postprocess is never executed. Using Softmax avoids extra template + // instantiations by reusing tiers already compiled for other models. + customData.mPostprocessType = RoutingPostprocessType::Softmax; + + // Recompute numThreadsHist using routingCustom's expert tiers, since we launch custom kernels. + // Different routing methods (DeepSeek, Llama4) may have different expert tier thresholds + // that don't match routingCustom's tiers (128, 512, 2048). + uint32_t const numThreadsHist + = std::min(1024u, static_cast(routingCustom::getMaxNumExperts(data.mNumExperts))); + + // Determine which path to use based on token count + static int const smMajor = tensorrt_llm::common::getSMVersion() / 10; + bool const useStaticBlock = data.mNumTokens <= routingCustom::BlockKernelMaxNumTokens; + int32_t const dispatchedMaxExperts = routingCustom::queryDispatchedMaxExperts(customData); + bool const useDynBlock = !useStaticBlock && data.mNumTokens <= routingCustom::DynBlockKernelMaxNumTokens + && dispatchedMaxExperts <= routingCustom::DynBlockKernelMaxNumExperts; + + // runPostTopKPipeline only handles pre-computed topK (mPtrTopKIds or mPtrTopKPacked), + // never raw scores. The cluster kernel's routingPermutation uses thread-per-expanded-index + // for both input types (LoadExpertIdxFromGlobal=true), so the capacity is + // NumBlocksPerCluster * NumThreads = 8192 tokens. + // (The smaller 256-token limit only applies to the mPtrScores path which does warp-per-token + // topK selection in the first phase, but that path is never taken here.) + bool const useSingleCluster = (smMajor >= 9) && (data.mNumTokens <= routingCustom::MaxNumTokensSingleCluster); + + routingCustom::Data lastKernelData = customData; + + if (useDynBlock) + { + routingCustom::launchDynBlockKernel(lastKernelData, numThreadsHist, stream); + } + else if (useStaticBlock) + { + routingCustom::launchBlockKernel(lastKernelData, numThreadsHist, stream); + } + else if (useSingleCluster) + { + routingCustom::launchClusterKernel(lastKernelData, stream); + } + else + { + bool const canUseCoop = (smMajor >= 9) && (data.mNumExperts <= 1024) && (data.mPtrPermutedIdxSize != nullptr); + bool useCoop = false; + int numBlocksCoop = 0; + + if (canUseCoop) + { + static int const smCount = tensorrt_llm::common::getMultiProcessorCount(); + numBlocksCoop = smCount - 8; + int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; + useCoop = (data.mNumTokens <= maxTokensCoop); + } + + TLLM_CHECK_WITH_INFO( + data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input."); + + if (useCoop) + { + routingCustom::launchInitExpertCounts(customData, numThreadsHist, stream); + routingCustom::launchCoopKernel(lastKernelData, numBlocksCoop, numThreadsHist, stream); + } + else + { + routingCustom::launchInitExpertCounts(customData, numThreadsHist, stream); + + int32_t const expandedIdxSize = data.mNumTokens * data.mTopK; + int32_t const histogramEltsPerBlock = 8 * numThreadsHist; + int32_t const offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist; + int32_t const maxNumBlocks = 1024; + + int const numBlocksHistogram + = std::min((expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks); + int const numBlocksOffsets + = std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); + + routingCustom::launchHistogramKernel(customData, numBlocksHistogram, numThreadsHist, stream); + routingCustom::launchOffsetsKernel(lastKernelData, numBlocksOffsets, numThreadsHist, stream); + } + } +} + +// Explicit instantiations for the three routing method Data types +template void runPostTopKPipeline(routingCustom::Data const&, uint32_t, void*); +template void runPostTopKPipeline(routingDeepSeek::Data const&, uint32_t, void*); +template void runPostTopKPipeline(routingLlama4::Data const&, uint32_t, void*); + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cuh b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingKernel.cuh similarity index 73% rename from cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cuh rename to cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingKernel.cuh index 4bc7b56aa18b..f5c57b1611ca 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cuh +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingKernel.cuh @@ -15,7 +15,7 @@ */ #pragma once -#include "DevKernel.h" +#include "RoutingDevKernel.h" #include "RoutingKernel.h" #include "RoutingKernelTopK.cuh" @@ -48,6 +48,21 @@ static constexpr int NumEltsPerOffsetTilePerThread = 8; //////////////////////////////////////////////////////////////////////////////////////////////////// +/// Dereference a type-erased pointer at the given index, reading the value in its native dtype. +/// Returns float since routing computations are done in float for numerical stability. +__forceinline__ __device__ float loadScalar(void const* ptr, int idx, batchedGemm::trtllm::gen::Dtype dtype) +{ + namespace tg = batchedGemm::trtllm::gen; + switch (dtype) + { + case tg::Dtype::Fp32: return static_cast(ptr)[idx]; + case tg::Dtype::Bfloat16: return static_cast(static_cast<__nv_bfloat16 const*>(ptr)[idx]); + default: return 0.f; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + static __device__ inline float sigmoid_accurate(float x) { return 0.5f * tanhf(0.5f * x) + 0.5f; @@ -391,7 +406,7 @@ __device__ void routingPermutation(KernelParams params, PackedScoreIdx #pragma unroll for (int e = 0; e < ExpertsPerThread; e++) { - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { numCta[e] = divUpLog2(count[e], params.mPaddingLog2); } @@ -411,7 +426,6 @@ __device__ void routingPermutation(KernelParams params, PackedScoreIdx int expert = threadIdx.x * ExpertsPerThread + e; if (expert < params.mNumExperts) { - // Strided loop to share this work between blocks. for (int32_t cta = clusterBlockRank; cta < numCta[e]; cta += NumBlocksPerCluster) { const int32_t localExpertIdx @@ -419,7 +433,7 @@ __device__ void routingPermutation(KernelParams params, PackedScoreIdx params.mPtrCtaIdxXyToBatchIdx[ctaOffset[e] + cta] = localExpertIdx; int32_t mnLimit1; int32_t mnLimit2; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { mnLimit1 = mulLog2(ctaOffset[e] + cta + 1, params.mPaddingLog2); mnLimit2 = mulLog2(ctaOffset[e], params.mPaddingLog2) + count[e]; @@ -432,9 +446,8 @@ __device__ void routingPermutation(KernelParams params, PackedScoreIdx params.mPtrCtaIdxXyToMnLimit[ctaOffset[e] + cta] = min(mnLimit1, mnLimit2); } - // get the padded offset associated with this expert int32_t offset; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { offset = mulLog2(ctaOffset[e], params.mPaddingLog2); } @@ -443,16 +456,14 @@ __device__ void routingPermutation(KernelParams params, PackedScoreIdx offset = mulTileN(ctaOffset[e], params.mTileTokensDim); } - // write expert offsets to shared smemExpertOffset[expert] = offset + blockExpertOffset[e]; } } - // write out padded count if (clusterBlockRank == 0 && warpIdx == NumWarps - 1 && cute::elect_one_sync()) { int32_t permutedIdxSize; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); } @@ -472,17 +483,6 @@ __device__ void routingPermutation(KernelParams params, PackedScoreIdx // implement break with EXIT. __cluster_barrier_wait(); - // trigger the secondary kernel when using PDL - // We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx, - // mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens - // TODO: this is not sufficient to ensure visibility in the next kernel! -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - if constexpr (KernelParams::UsePdl) - { - cudaTriggerProgrammaticLaunchCompletion(); - } -#endif - // each thread has the same "expanded indexes" assigned to it as above // at this point, we know the final offsets of experts and the offsets within // experts, which allows writing the final index values @@ -515,6 +515,18 @@ __device__ void routingPermutation(KernelParams params, PackedScoreIdx params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx; } } + + // Trigger the secondary kernel AFTER all global memory writes are complete. + // The downstream kernels (permute, FC1 GEMM) depend on mPtrCtaIdxXyToBatchIdx, + // mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas, mPtrPermutedIdxSize, AND + // mPtrExpandedIdxToPermutedIdx / mPtrPermutedIdxToTokenIdx. + // Triggering before the permutation writes causes the consumer to read stale data → NaN. +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + if (params.mUsePdl) + { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -556,11 +568,10 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa __syncthreads(); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - // Wait on primary grid and trigger secondary kernel. - if constexpr (KernelParams::UsePdl) + // Wait on primary grid (but do NOT trigger yet — trigger after atomicAdd to mPtrExpertCounts). + if (params.mUsePdl) { cudaGridDependencySynchronize(); - cudaTriggerProgrammaticLaunchCompletion(); } #endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) @@ -640,6 +651,15 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa atomicAdd(¶ms.mPtrExpertCounts[expert], localExpertCount); } } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // Trigger AFTER all atomicAdds to mPtrExpertCounts are done, so the next kernel + // (routingIndicesOffsetsKernel) sees the complete histogram. + if (params.mUsePdl) + { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -673,7 +693,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // Wait on primary grid. - if constexpr (KernelParams::UsePdl) + if (params.mUsePdl) { cudaGridDependencySynchronize(); } @@ -704,7 +724,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa #pragma unroll for (int e = 0; e < ExpertsPerThread; e++) { - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { numCta[e] = divUpLog2(count[e], params.mPaddingLog2); } @@ -723,9 +743,8 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa int expert = threadIdx.x * ExpertsPerThread + e; if (expert < params.mNumExperts) { - // Get the padded offset associated with this expert int32_t offset; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { offset = mulLog2(ctaOffset[e], params.mPaddingLog2); } @@ -734,19 +753,16 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa offset = mulTileN(ctaOffset[e], params.mTileTokensDim); } - // Write expert offsets to shared smemExpertOffset[expert] = offset; } } - // Sync to make expert offsets available to all threads. __syncthreads(); - // The first block writes out padded count (use last warp of actual thread count) if (blockIdx.x == 0 && warpIdx == NumThreadsBlock / WarpSize - 1 && cute::elect_one_sync()) { int32_t permutedIdxSize; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); } @@ -764,7 +780,6 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa int expert = threadIdx.x * ExpertsPerThread + e; if (expert < params.mNumExperts) { - // Strided loop to share this work between blocks. for (int32_t cta = blockIdx.x; cta < numCta[e]; cta += gridDim.x) { const int32_t localExpertIdx @@ -772,7 +787,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa params.mPtrCtaIdxXyToBatchIdx[ctaOffset[e] + cta] = localExpertIdx; int32_t mnLimit1; int32_t mnLimit2; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { mnLimit1 = mulLog2(ctaOffset[e] + cta + 1, params.mPaddingLog2); mnLimit2 = mulLog2(ctaOffset[e], params.mPaddingLog2) + count[e]; @@ -965,7 +980,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa // Trigger secondary kernel. // Note: this does not guarantee the visibility of prior writes unless the consumer executes a // dependency sync. - if constexpr (KernelParams::UsePdl) + if (params.mUsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } @@ -988,7 +1003,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // Wait on primary grid. - if constexpr (KernelParams::UsePdl) + if (params.mUsePdl) { cudaGridDependencySynchronize(); } @@ -998,11 +1013,307 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // Wait on primary grid. - if constexpr (KernelParams::UsePdl) + if (params.mUsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } #endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) } + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Cooperative launch kernel: fuses histogram + offsets computation for medium token counts. +// This kernel is shared by routingCustom, routingDeepSeek, and can be used by other routing methods. +// It uses cooperative groups to synchronize across multiple CTAs and compute expert counts, +// offsets, and permutation indices in a single kernel launch. +// +// Requirements: +// - MaxNumExperts <= 1024 (enforced by static_assert) +// - SM90+ architecture (cooperative groups) +// - mPtrPermutedIdxSize must be non-null (needed for permutation) +// +// The kernel handles both mPtrTopKIds and mPtrTopKPacked input formats. +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +template +__global__ void __launch_bounds__(KernelParams::MaxNumExperts) routingIndicesCoopKernel(KernelParams params) +{ + // number of experts is bounded by number of threads (coop kernel requires MaxNumExperts <= 1024) + using OutputT = typename KernelParams::OutputT; + static constexpr int MaxNumExperts = KernelParams::MaxNumExperts; + static constexpr int NumThreads = MaxNumExperts; + static_assert(MaxNumExperts <= 1024, "Coop kernel requires MaxNumExperts <= 1024"); + + __shared__ int32_t __attribute((aligned(128))) smemExpertCount[MaxNumExperts]; + __shared__ int32_t __attribute((aligned(128))) smemExpertOffset[MaxNumExperts]; + // needed for the exclusive sum of token offsets + using Scan = cub::BlockScan; + __shared__ typename Scan::TempStorage tempStorage; + // 64 elements -> 128+ registers. Above that we may start to see spilling to local memory. + static constexpr int MaxExpandedIdxPerThread = 64; + + // Initialize grid. + cg::grid_group grid = cg::this_grid(); + // Note: the following is more efficient than grid.block_index() because we don't use y and z. + int32_t const gridBlockIdx = blockIdx.x; + int32_t const gridThreadIdx = NumThreads * gridBlockIdx + threadIdx.x; + int32_t const numBlocks = gridDim.x; + int32_t const numThreadsPerGrid = numBlocks * NumThreads; + + int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); + + auto expandedIdxSize = params.mNumTokens * params.mTopK; + + // pre-fill the counts with 0 — each thread represents one expert + smemExpertCount[threadIdx.x] = 0; + __syncthreads(); + + // then wait on primary grid + if (params.mUsePdl) + { + cudaGridDependencySynchronize(); + } + + // each thread keeps has some number of "expanded indexes" assigned to it + // for each of these, we keep the associated expert and offset within expert in registers + int32_t expertIndexes[MaxExpandedIdxPerThread]; + int32_t expertOffsets[MaxExpandedIdxPerThread]; + auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; + // In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a + // time, and branch between a fast path without bound checks and a slow path with bound checks. + int constexpr IterStride = 4; + static_assert(MaxExpandedIdxPerThread % IterStride == 0); + + // Define a lambda to avoid code duplication in both branches. + // Use shared device function for expert index extraction. + auto loopBody = [&](int ii, int expandedIdx) + { + int32_t expertIdx = getExpertIdxFromInputWithWeights(params, expandedIdx, params.mPtrTopKWeights); + expertIndexes[ii] = expertIdx; + // check whether this expert is local to our GPU at all and ignore if not + auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx; + auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent + && (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; + expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + expertIdx, 1) : 0; + }; + +#pragma unroll + for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride) + { + // Whether it's safe to do multiple iterations without bound checks. + bool const takeFastPath = (ii0 + IterStride) * numThreadsPerGrid <= expandedIdxSize; + if (takeFastPath) + { +#pragma unroll + for (int32_t jj = 0; jj < IterStride; jj++) + { + int const ii = ii0 + jj; + auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; + loopBody(ii, expandedIdx); + } + } + else + { + bool doBreak = false; +#pragma unroll + for (int32_t jj = 0; jj < IterStride; jj++) + { + int const ii = ii0 + jj; + auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; + if (expandedIdx >= expandedIdxSize) + { + doBreak = true; + break; + } + loopBody(ii, expandedIdx); + } + if (doBreak) + { + break; + } + } + } + + // Make histogram (token counts per expert) available to all threads in the block. + __syncthreads(); + + // + // Each thread now represents one expert + // + + // Add the local bin count to the common bin count and get a per-CTA offset. + int32_t const localExpertCount = smemExpertCount[threadIdx.x]; + + int32_t blockExpertOffset = 0; + if (threadIdx.x < params.mNumExperts) + { + blockExpertOffset = atomicAdd(¶ms.mPtrExpertCounts[threadIdx.x], localExpertCount); + } + + // Sync to wait for completion of the histogram reduction. + grid.sync(); + + // Get total count for this expert. + int32_t count = (threadIdx.x < params.mNumExperts) ? params.mPtrExpertCounts[threadIdx.x] : 0; + + // Note: the scan is redundant in all CTAs, but doing it in only 1 CTA would be worse for latency. + + // Compute the runtime config for projections + // Whether or not an expert is local is taken into account when smemExpertCount is computed + // so we do not need to take it into account here. + + int32_t numCta; + if (params.mIsPow2) + { + numCta = divUpLog2(count, params.mPaddingLog2); + } + else + { + numCta = divUpTileN(count, params.mTileTokensDim); + } + + int32_t ctaOffset; + int32_t numNonExitingCtas; + Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); + + for (int32_t cta = gridBlockIdx; cta < numCta; cta += numBlocks) + { + const int32_t localExpertIdx = (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; + params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; + int32_t mnLimit1; + int32_t mnLimit2; + if (params.mIsPow2) + { + mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); + mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + count; + } + else + { + mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); + mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + count; + } + params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); + } + + int32_t offset; + if (params.mIsPow2) + { + offset = mulLog2(ctaOffset, params.mPaddingLog2); + } + else + { + offset = mulTileN(ctaOffset, params.mTileTokensDim); + } + int32_t permutedIdxSize; + if (params.mIsPow2) + { + permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + } + else + { + permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); + } + + // write out padded count + if (gridBlockIdx == 0 && warpIdx == NumThreads / WarpSize - 1 && cute::elect_one_sync()) + { + params.mPtrPermutedIdxSize[0] = permutedIdxSize; + params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; + } + + // write expert offsets to shared + smemExpertOffset[threadIdx.x] = offset + blockExpertOffset; + + // make expert offsets available to all threads + __syncthreads(); + + // each thread has the same "expanded indexes" assigned to it as above + // at this point, we know the final offsets of experts and the offsets within + // experts, which allows writing the final index values +#pragma unroll + for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ++ii) + { + auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; + if (expandedIdx >= expandedIdxSize) + { + break; + } + auto expertIdx = expertIndexes[ii]; + // check whether this expert is local to our GPU at all + auto localExpertIdx = static_cast(expertIdx) - params.mLocalExpertsStartIdx; + auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent + && (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; + auto tokenIdx = expandedIdx / params.mTopK; + auto permutedIdx = isLocalExpert ? int32_t{smemExpertOffset[expertIdx]} + expertOffsets[ii] : int32_t{-1}; + if (params.mPtrExpandedIdxToPermutedIdx != nullptr) + { + params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx; + } + if (params.mPtrPermutedIdxToExpandedIdx != nullptr && isLocalExpert) + { + params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = expandedIdx; + } + if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert) + { + params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx; + } + } + + // Trigger the secondary kernel AFTER all global memory writes (including permutation indices). + // The downstream kernels depend on all routing outputs being visible. + if (params.mUsePdl) + { + cudaTriggerProgrammaticLaunchCompletion(); + } +} +#else +template +__global__ void routingIndicesCoopKernel(KernelParams params) +{ + assert(false && "routingIndicesCoopKernel is only supported on SM90+ architectures"); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Shared device functions for coop kernel (used by both routingCustom and routingDeepSeek) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Device function to extract expert index from either mPtrTopKIds or mPtrTopKPacked. +// This is the only difference between routingCustom and routingDeepSeek coop kernels. +// For routingCustom: also writes to mPtrTopKWeights if provided. +// For routingDeepSeek: simpler version that doesn't write weights. +template +__forceinline__ __device__ int32_t getExpertIdxFromInput(KernelParams const& params, int32_t expandedIdx) +{ + if (params.mPtrTopKIds != nullptr) + { + return params.mPtrTopKIds[expandedIdx]; + } + else + { + return params.mPtrTopKPacked[expandedIdx].idx; + } +} + +// Overload for routingCustom that also writes topK weights if needed. +template +__forceinline__ __device__ int32_t getExpertIdxFromInputWithWeights( + KernelParams const& params, int32_t expandedIdx, typename KernelParams::OutputT* topKWeights) +{ + if (params.mPtrTopKIds != nullptr) + { + return params.mPtrTopKIds[expandedIdx]; + } + else + { + PackedScoreIdx scoreIdx = params.mPtrTopKPacked[expandedIdx]; + if (topKWeights != nullptr) + { + topKWeights[expandedIdx] = static_cast(scoreIdx.score); + } + return scoreIdx.idx; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace routing } // namespace moe::dev diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingKernel.h similarity index 69% rename from cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h rename to cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingKernel.h index 3daa1848e5d3..bea23942e49d 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -100,8 +100,10 @@ struct DataBase int32_t mNumTokens; int32_t mNumExperts; int32_t mTopK; - int32_t mPaddingLog2; + // Cluster-wide tile size in token dimension. int32_t mTileTokensDim; + // log2() of the padding size in cluster-wide tile. + int32_t mPaddingLog2; /// For expert parallelization int32_t mLocalExpertsStartIdx; @@ -109,15 +111,16 @@ struct DataBase int32_t mNumLocalExperts; }; -template +template struct KernelParamsBase { using InputT = InputT_; using OutputT = OutputT_; static constexpr int MaxNumExperts = MaxNumExperts_; static constexpr int MaxNumTopExperts = MaxNumTopExperts_; - static constexpr bool UsePdl = UsePdl_; - static constexpr bool isPow2 = isPow2_; + + bool mUsePdl = false; + bool mIsPow2 = false; // Public pointer members int32_t* mPtrExpertCounts = nullptr; @@ -146,6 +149,8 @@ struct KernelParamsBase template void setBaseParams(DataType const& data) { + mUsePdl = data.mUsePdl; + mIsPow2 = data.mPaddingLog2 > 0; mPtrExpertCounts = data.mPtrExpertCounts; mPtrPermutedIdxSize = data.mPtrPermutedIdxSize; mPtrExpandedIdxToPermutedIdx = data.mPtrExpandedIdxToPermutedIdx; @@ -175,12 +180,14 @@ namespace routingDeepSeek //////////////////////////////////////////////////////////////////////////////////////////////////// struct Data : public DataBase { - tg::Dtype mDtypeExpW{tg::Dtype::Bfloat16}; + tg::Dtype mDtypeOutput{tg::Dtype::Bfloat16}; // // Grouped Gemm Launch Config Buffers // void const* mPtrRoutingBias; + // Dtype of the routing bias buffer (Bfloat16 or Fp32). + tg::Dtype mDtypeBias{tg::Dtype::Bfloat16}; int32_t mHiddenDim; // not used int32_t mNumExpertGroups; @@ -190,9 +197,8 @@ struct Data : public DataBase bool mUseRoutingSoftmax; }; -template -struct KernelParams : public KernelParamsBase +template +struct KernelParams : public KernelParamsBase { using InputT = InputT_; using OutputT = OutputT_; @@ -205,7 +211,9 @@ struct KernelParams : public KernelParamsBase*) data.mPtrTopKPacked; - // params.mPtrTopKWeightsFull = static_cast(data.mPtrTopKWeightsFull); - params.mPtrRoutingBias = static_cast(data.mPtrRoutingBias); + params.mPtrRoutingBias = data.mPtrRoutingBias; + params.mDtypeBias = data.mDtypeBias; params.mNumExpertGroups = data.mNumExpertGroups; params.mNumExpertsPerGroup = data.mNumExperts / data.mNumExpertGroups; @@ -247,11 +255,11 @@ namespace routingLlama4 struct Data : public DataBase { - tg::Dtype mDtypeExpW{tg::Dtype::Bfloat16}; + tg::Dtype mDtypeOutput{tg::Dtype::Bfloat16}; }; -template -struct KernelParams : public KernelParamsBase +template +struct KernelParams : public KernelParamsBase { using InputT = InputT_; using OutputT = OutputT_; @@ -277,40 +285,69 @@ void run(Data const& data, void* stream); //////////////////////////////////////////////////////////////////////////////////////////////////// -namespace routingRenormalize +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Routing preprocess/postprocess policy type enums. +// These are used to select the compile-time policy at dispatch time. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +enum class RoutingPreprocessType +{ + None, // No preprocessing before topK + Softmax, // Apply softmax on all expert scores before topK + Sigmoid, // Apply sigmoid(score) for topK selection (no bias) + SigmoidBias, // Apply sigmoid(score) + bias for topK selection (DeepSeek-style) +}; + +enum class RoutingPostprocessType +{ + None, // No postprocessing after topK + Softmax, // Apply softmax on top-K scores + SumNormalize, // Normalize top-K scores by their sum + ScaledSumNormalize, // Recover sigmoid scores, normalize by sum and scale (DeepSeek-style) +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace routingCustom { //////////////////////////////////////////////////////////////////////////////////////////////////// struct Data : public DataBase { - tg::Dtype mDtypeExpW{tg::Dtype::Fp32}; - tg::Dtype mDtypeElt{tg::Dtype::Bfloat16}; + tg::Dtype mDtypeOutput{tg::Dtype::Fp32}; // OutputT: expert weights dtype (typically Bfloat16) + tg::Dtype mDtypeInput{tg::Dtype::Bfloat16}; // InputT: routing logits dtype (Bfloat16 or Fp32) - bool mDoSoftmaxBeforeTopK{false}; + RoutingPreprocessType mPreprocessType{RoutingPreprocessType::None}; + RoutingPostprocessType mPostprocessType{RoutingPostprocessType::Softmax}; bool mNormTopkProb{true}; // Default value is true for Qwen3 model - // If true, applies softmax normalization after selecting top-K experts. - // Use this for models that require post-selection normalization (e.g., specific Qwen variants). - // Mutually exclusive with mDoSoftmaxBeforeTopK when both normalization paths are active. - // NOTE: Don't need to use this variable for now. - bool mApplySoftmaxAfterTopK{true}; + + // Optional: per-expert routing bias (used by SigmoidBias preprocess). + void const* mPtrRoutingBias{nullptr}; + // Dtype of the routing bias buffer (Bfloat16 or Fp32). Used to read mPtrRoutingBias correctly. + tg::Dtype mDtypeBias{tg::Dtype::Bfloat16}; + // Optional: scaling factor applied to final scores (used by ScaledSumNormalize postprocess). + float mRouteScale{1.0f}; + // Optional: epsilon added to the sum before division to prevent division by zero. + // MiniMax2 uses 1e-20f; DeepSeek uses 0.0f (no epsilon). + float mSumEpsilon{0.0f}; }; -template -struct KernelParams : public KernelParamsBase +template +struct KernelParams : public KernelParamsBase { using InputT = InputT_; using OutputT = OutputT_; + using ExpertSelectPolicy = ExpertSelectPolicy_; - static constexpr bool DoSoftmaxBeforeTopK = DoSoftmaxBeforeTopK_; + // Expert select policy params — empty structs have zero register cost. + using ExpertSelectParams = typename ExpertSelectPolicy::template Params; PackedScoreIdx* mPtrTopKPacked = nullptr; int32_t mTopK = 0; - bool mNormTopkProb = true; - bool mApplySoftmaxAfterTopK = false; + ExpertSelectParams mExpertSelectParams; static KernelParams setKernelParams(Data const& data) { @@ -318,16 +355,33 @@ struct KernelParams : public KernelParamsBase*) data.mPtrTopKPacked; - params.mNormTopkProb = data.mNormTopkProb; - params.mApplySoftmaxAfterTopK = data.mApplySoftmaxAfterTopK; params.mTopK = data.mTopK; + + // Policy populates only the fields it needs from Data. + params.mExpertSelectParams.set(data); return params; } }; void run(Data const& data, void* stream); -} // namespace routingRenormalize +} // namespace routingCustom + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Shared utility for post-topK pipeline when mPtrTopKIds != nullptr. +// All routing methods (Custom, DeepSeek, Llama4) use the same workflow in this case: +// 1. Reset expert counts +// 2. Run histogram kernel +// 3. Run offsets kernel +// Since the kernels are shared and we don't need routing-method-specific logic, +// we can use routingCustom's launch mechanism. +// +// This function works with any Data type that inherits from DataBase. +// Implementation is in RoutingFromTopKIds.cu +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void runPostTopKPipeline(DataType const& data, uint32_t numThreadsHist, void* stream); //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernelTopK.cuh b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingKernelTopK.cuh similarity index 100% rename from cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernelTopK.cuh rename to cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingKernelTopK.cuh diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingLlama4.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingLlama4.cu similarity index 92% rename from cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingLlama4.cu rename to cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingLlama4.cu index 3362eb80c1b6..28435e548d2d 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingLlama4.cu +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingLlama4.cu @@ -106,7 +106,7 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // then wait on primary grid - if constexpr (KernelParams::UsePdl) + if (params.mUsePdl) { cudaGridDependencySynchronize(); } @@ -165,9 +165,9 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam static_cast(params.mPtrTopKPacked[threadIdx.x].idx)}; if (params.mPtrTopKWeights != nullptr) { - // we also compute the final score here and write it out if required - auto finalScore = OutputT{sigmoid_accurate(float{scoreIdx.score})}; - params.mPtrTopKWeights[threadIdx.x] = finalScore; + // mPtrTopKPacked already contains sigmoid scores (produced by the scores-path + // kernels), so we just pass them through — no need to apply sigmoid again. + params.mPtrTopKWeights[threadIdx.x] = scoreIdx.score; } } } @@ -208,7 +208,7 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam { auto count = getBits(expertCount, ii); int32_t num; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { num = divUpLog2(count, params.mPaddingLog2); } @@ -231,7 +231,7 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam { auto count = getBits(expertCount, ii); int32_t finalNumCta; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { finalNumCta = divUpLog2(count, params.mPaddingLog2); } @@ -240,14 +240,12 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam finalNumCta = divUpTileN(count, params.mTileTokensDim); } auto expertIdx = threadIdx.x * ExpertsPerThread + ii; - // during the scan for expert offsets, we can already write out - // both `mPtrCtaIdxXyToBatchIdx` and `mPtrCtaIdxXyToMnLimit` for (int cta = 0; cta < finalNumCta; ++cta) { params.mPtrCtaIdxXyToBatchIdx[ctaOffsetExp + cta] = expertIdx; int32_t mnLimit1; int32_t mnLimit2; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { mnLimit1 = mulLog2(ctaOffsetExp + cta + 1, params.mPaddingLog2); mnLimit2 = mulLog2(ctaOffsetExp, params.mPaddingLog2) + count; @@ -266,7 +264,7 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam if (cute::elect_one_sync()) { int32_t permutedIdxSize; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); } @@ -281,7 +279,7 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // we can trigger the next kernel at this point - if constexpr (KernelParams::UsePdl) + if (params.mUsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } @@ -294,7 +292,7 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam // of registers auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; int32_t finalExpertOffset[ExpertsPerThread]; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { finalExpertOffset[0] = mulLog2(ctaOffset, params.mPaddingLog2); } @@ -306,7 +304,7 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam for (int ii = 1; ii < ExpertsPerThread; ++ii) { int32_t tmp; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { tmp = divUpMulLog2(getBits(expertCount, ii - 1), params.mPaddingLog2); } @@ -387,7 +385,7 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu auto warp = cg::tiled_partition(block); // then wait on primary grid - if constexpr (KernelParams::UsePdl) + if (params.mUsePdl) { cudaGridDependencySynchronize(); } @@ -480,7 +478,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) routingIndicesHis #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // Wait on primary grid and trigger secondary kernel. - if constexpr (KernelParams::UsePdl) + if (params.mUsePdl) { cudaGridDependencySynchronize(); cudaTriggerProgrammaticLaunchCompletion(); @@ -546,10 +544,22 @@ void run(Data const& data, void* stream) { TLLM_CHECK_WITH_INFO(data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr, "Routing kernel requires at least one input parameter"); - if (data.mPtrTopKIds != nullptr) + // When topK is already computed (mPtrTopKIds or mPtrTopKPacked without scores), + // delegate to the shared post-topK pipeline. This avoids Llama4-specific issues: + // - The Llama4 cluster kernel loads one token per warp but useSingleCluster uses + // the thread-based capacity, causing unprocessed tokens for medium token counts. + // - The Llama4 device kernel applies sigmoid to packed scores that may already + // contain sigmoid values (produced by the scores-path kernels). + if (data.mPtrTopKIds != nullptr || (data.mPtrTopKPacked != nullptr && data.mPtrScores == nullptr)) { - TLLM_CHECK_WITH_INFO(data.mPtrTopKWeights != nullptr, - "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for Llama4 routing."); + if (data.mPtrTopKIds != nullptr) + { + TLLM_CHECK_WITH_INFO(data.mPtrTopKWeights != nullptr, + "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for Llama4 routing."); + } + int const numThreadsHist = getMaxNumExperts(data.mNumExperts); + runPostTopKPipeline(data, numThreadsHist, stream); + return; } TLLM_CHECK_WITH_INFO(data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr && data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr, @@ -563,15 +573,16 @@ void run(Data const& data, void* stream) TLLM_CHECK_WITH_INFO( data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); + // After this point, mPtrTopKIds is guaranteed to be nullptr. + // Input is either mPtrScores (raw logits) or mPtrTopKPacked (topK already computed, needs sigmoid). bool const useSingleWarp = (data.mPtrScores == nullptr && data.mNumTokens <= WarpKernelMaxNumTokens) || data.mNumTokens < WarpKernelMaxNumTokens; - bool const useSingleCluster = data.mNumTokens <= ((data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr) - ? MaxNumTokensSingleClusterScores - : MaxNumTokensSingleCluster); + bool const useSingleCluster = data.mNumTokens + <= ((data.mPtrScores != nullptr) ? MaxNumTokensSingleClusterScores : MaxNumTokensSingleCluster); if (!useSingleCluster) { - TLLM_CHECK_WITH_INFO((data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr), - "When #tokens is large, `mPtrTopKPacked` or `mPtrTopKIds` is a required input."); + TLLM_CHECK_WITH_INFO( + data.mPtrTopKPacked != nullptr, "When #tokens is large, `mPtrTopKPacked` is a required input."); TLLM_CHECK_WITH_INFO( data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input."); } @@ -606,7 +617,7 @@ void run(Data const& data, void* stream) int const numBlocksOffsets = std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); - if (data.mPtrScores != nullptr && data.mPtrTopKIds == nullptr) + if (data.mPtrScores != nullptr) { LAUNCH_ROUTING_LLAMA4(data, /*coopLaunch=*/false, routingIndicesHistogramScoresKernel, maxNumBlocks, numThreadsHist, diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/RoutingDeepSeekCommon.cuh b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/RoutingDeepSeekCommon.cuh deleted file mode 100644 index b9673be5efe2..000000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/RoutingDeepSeekCommon.cuh +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "../RoutingKernel.cuh" - -namespace moe::dev::routing -{ -namespace routingDeepSeek -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// -static constexpr int NumNemotronExperts = 512; -static constexpr int NumKimiK2Experts = 384; -static constexpr int NumDeepseekExperts = 256; -static constexpr int MaxSupportedExpertCount = std::max({NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts}); -static constexpr int NumTopGroupScores = 2; -static constexpr int MaxNumTopGroups = 4; -static constexpr int MaxNumGroups = 8; - -static constexpr int NumTop8Experts = 8; -static constexpr int NumTop22Experts = 22; -static constexpr int MaxSupportedTopExperts = 32; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -int constexpr getMaxNumExperts(int32_t numExperts) -{ - if (numExperts <= topk::MaxNumExpertsUnit) - { - return topk::MaxNumExpertsUnit; - } - else if (numExperts <= NumDeepseekExperts) - { - return NumDeepseekExperts; - } - else if (numExperts <= NumKimiK2Experts) - { - return NumKimiK2Experts; - } - else if (numExperts <= NumNemotronExperts) - { - return NumNemotronExperts; - } - else - { - TLLM_LOG_ERROR("Unsupported numExperts"); - return 0; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// Helper macro: dispatch on topK tier for a given numExperts tier. -#define LAUNCH_DEEPSEEK_WITH_TOPK( \ - data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, forceFloatInput, numExperts) \ - if (data.mTopK <= NumTop8Experts) \ - { \ - LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag1, forceFloatInput, numExperts, NumTop8Experts); \ - } \ - else if (data.mTopK <= NumTop22Experts) \ - { \ - LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag1, forceFloatInput, numExperts, NumTop22Experts); \ - } \ - else \ - { \ - LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag1, forceFloatInput, numExperts, MaxSupportedTopExperts); \ - } - -#define LAUNCH_ROUTING_DEEPSEEK( \ - data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, forceFloatInput) \ - if (data.mNumExperts <= topk::MaxNumExpertsUnit) \ - { \ - LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ - forceFloatInput, topk::MaxNumExpertsUnit); \ - } \ - else if (data.mNumExperts <= NumDeepseekExperts) \ - { \ - LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ - forceFloatInput, NumDeepseekExperts); \ - } \ - else if (data.mNumExperts <= NumKimiK2Experts) \ - { \ - LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ - forceFloatInput, NumKimiK2Experts); \ - } \ - else if (data.mNumExperts <= NumNemotronExperts) \ - { \ - LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ - forceFloatInput, NumNemotronExperts); \ - } \ - else \ - { \ - TLLM_LOG_ERROR("Unsupported numExperts"); \ - } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingDeepSeek -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchClusterKernel.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchClusterKernel.cu deleted file mode 100644 index 14fc591f4e57..000000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchClusterKernel.cu +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "RoutingDeepSeekCommon.cuh" - -namespace moe::dev::routing -{ -namespace routingDeepSeek -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(KernelParams::MaxNumExperts) - routingIndicesClusterKernel(KernelParams params) -{ - using OutputT = typename KernelParams::OutputT; - - int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); - int32_t const clusterBlockRank = blockIdx.x; - - //@todo: try to move it into routingPermutation - // then wait on primary grid - if constexpr (KernelParams::UsePdl) - { - cudaGridDependencySynchronize(); - } - routingPermutation(params, nullptr, warpIdx, clusterBlockRank); -} -#else -__global__ void routingIndicesClusterKernel(KernelParams params) -{ - assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures"); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void launchClusterKernel(Data& data, int numThreadsHist, void* stream) -{ - LAUNCH_ROUTING_DEEPSEEK(data, - /*coopLaunch=*/false, routingIndicesClusterKernel, NumBlocksPerCluster, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingDeepSeek -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchCoopKernel.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchCoopKernel.cu deleted file mode 100644 index a96db74865dc..000000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchCoopKernel.cu +++ /dev/null @@ -1,276 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "RoutingDeepSeekCommon.cuh" - -namespace moe::dev::routing -{ -namespace routingDeepSeek -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -__global__ void __launch_bounds__(KernelParams::MaxNumExperts) routingIndicesCoopKernel(KernelParams params) -{ - // number of experts is bounded by number of threads - int constexpr NumThreads = KernelParams::MaxNumExperts; - __shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreads]; - __shared__ int32_t __attribute((aligned(128))) smemExpertOffset[NumThreads]; - // needed for the exclusive sum of token offsets - using Scan = cub::BlockScan; - __shared__ typename Scan::TempStorage tempStorage; - // 64 elements -> 128+ registers. Above that we may start to see spilling to local memory. - static constexpr int MaxExpandedIdxPerThread = 64; - - // Initialize grid. - cg::grid_group grid = cg::this_grid(); - // Note: the following is more efficient than grid.block_index() because we don't use y and z. - int32_t const gridBlockIdx = blockIdx.x; - int32_t const gridThreadIdx = NumThreads * gridBlockIdx + threadIdx.x; - int32_t const numBlocks = gridDim.x; - int32_t const numThreadsPerGrid = numBlocks * NumThreads; - - int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); - - auto expandedIdxSize = params.mNumTokens * params.mTopK; - - // pre-fill the counts with 0 - smemExpertCount[threadIdx.x] = 0; - __syncthreads(); - - // then wait on primary grid - if constexpr (KernelParams::UsePdl) - { - cudaGridDependencySynchronize(); - } - - // each thread keeps has some number of "expanded indexes" assigned to it - // for each of these, we keep the associated expert and offset within expert in registers - int32_t expertIndexes[MaxExpandedIdxPerThread]; - int32_t expertOffsets[MaxExpandedIdxPerThread]; - auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; - // In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a - // time, and branch between a fast path without bound checks and a slow path with bound checks. - int constexpr IterStride = 4; - static_assert(MaxExpandedIdxPerThread % IterStride == 0); - - // Define a lambda to avoid code duplication in both branches. - auto loopBody = [&](int ii, int expandedIdx) - { - int32_t expertIdx - = params.mPtrTopKIds != nullptr ? params.mPtrTopKIds[expandedIdx] : params.mPtrTopKPacked[expandedIdx].idx; - expertIndexes[ii] = expertIdx; - // check whether this expert is local to our GPU at all and ignore if not - auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx; - auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent - && (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; - expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + expertIdx, 1) : 0; - }; - -#pragma unroll - for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride) - { - // Whether it's safe to do multiple iterations without bound checks. - bool const takeFastPath = (ii0 + IterStride) * numThreadsPerGrid <= expandedIdxSize; - if (takeFastPath) - { -#pragma unroll - for (int32_t jj = 0; jj < IterStride; jj++) - { - int const ii = ii0 + jj; - auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; - loopBody(ii, expandedIdx); - } - } - else - { - bool doBreak = false; -#pragma unroll - for (int32_t jj = 0; jj < IterStride; jj++) - { - int const ii = ii0 + jj; - auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; - if (expandedIdx >= expandedIdxSize) - { - doBreak = true; - break; - } - loopBody(ii, expandedIdx); - } - if (doBreak) - { - break; - } - } - } - - // Make histogram (token counts per expert) available to all threads in the block. - __syncthreads(); - - // - // Each thread now represents one expert - // - - // Add the local bin count to the common bin count and get a per-CTA offset. - int32_t const localExpertCount = smemExpertCount[threadIdx.x]; - - int32_t blockExpertOffset = 0; - if (threadIdx.x < params.mNumExperts) - { - blockExpertOffset = atomicAdd(¶ms.mPtrExpertCounts[threadIdx.x], localExpertCount); - } - - // Sync to wait for completion of the histogram reduction. - grid.sync(); - - // Get total count for this expert. - int32_t count = (threadIdx.x < params.mNumExperts) ? params.mPtrExpertCounts[threadIdx.x] : 0; - - // Note: the scan is redundant in all CTAs, but doing it in only 1 CTA would be worse for latency. - - // Compute the runtime config for projections - // Whether or not an expert is local is taken into account when smemExpertCount is computed - // so we do not need to take it into account here. - - int32_t numCta; - if constexpr (KernelParams::isPow2) - { - numCta = divUpLog2(count, params.mPaddingLog2); - } - else - { - numCta = divUpTileN(count, params.mTileTokensDim); - } - - int32_t ctaOffset; - int32_t numNonExitingCtas; - Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); - - for (int32_t cta = gridBlockIdx; cta < numCta; cta += numBlocks) - { - const int32_t localExpertIdx = (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; - params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; - int32_t mnLimit1; - int32_t mnLimit2; - if constexpr (KernelParams::isPow2) - { - mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); - mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + count; - } - else - { - mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); - mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + count; - } - params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); - } - - // get the padded offset associated with this expert - int32_t offset; - if constexpr (KernelParams::isPow2) - { - offset = mulLog2(ctaOffset, params.mPaddingLog2); - } - else - { - offset = mulTileN(ctaOffset, params.mTileTokensDim); - } - int32_t permutedIdxSize; - if constexpr (KernelParams::isPow2) - { - permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); - } - else - { - permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); - } - - // write out padded count - if (gridBlockIdx == 0 && warpIdx == NumThreads / WarpSize - 1 && cute::elect_one_sync()) - { - params.mPtrPermutedIdxSize[0] = permutedIdxSize; - params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; - } - - // write expert offsets to shared - smemExpertOffset[threadIdx.x] = offset + blockExpertOffset; - - // make expert offsets available to all threads - __syncthreads(); - - // trigger the secondary kernel when using PDL - // We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx, - // mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens - // TODO: this is not sufficient to ensure visibility in the next kernel! - if constexpr (KernelParams::UsePdl) - { - cudaTriggerProgrammaticLaunchCompletion(); - } - -// each thread has the same "expanded indexes" assigned to it as above -// at this point, we know the final offsets of experts and the offsets within -// experts, which allows writing the final index values -#pragma unroll - for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ++ii) - { - auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; - if (expandedIdx >= expandedIdxSize) - { - break; - } - auto expertIdx = expertIndexes[ii]; - // check whether this expert is local to our GPU at all - auto localExpertIdx = static_cast(expertIdx) - params.mLocalExpertsStartIdx; - auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent - && (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; - auto tokenIdx = expandedIdx / params.mTopK; - auto permutedIdx = isLocalExpert ? int32_t{smemExpertOffset[expertIdx]} + expertOffsets[ii] : int32_t{-1}; - if (params.mPtrExpandedIdxToPermutedIdx != nullptr) - { - params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx; - } - if (params.mPtrPermutedIdxToExpandedIdx != nullptr && isLocalExpert) - { - params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = expandedIdx; - } - if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert) - { - params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx; - } - } -} -#else -__global__ void routingIndicesCoopKernel(KernelParams params) -{ - assert(false && "routingIndicesCoopKernel is only supported on SM90+ architectures"); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void launchCoopKernel(Data& data, int numBlocksCoop, int numThreadsHist, void* stream) -{ - LAUNCH_ROUTING_DEEPSEEK(data, - /*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingDeepSeek -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchHistogramKernel.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchHistogramKernel.cu deleted file mode 100644 index 1263e289e134..000000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchHistogramKernel.cu +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "RoutingDeepSeekCommon.cuh" - -namespace moe::dev::routing -{ -namespace routingDeepSeek -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void launchHistogramKernel(Data& data, int numBlocksHistogram, int numThreadsHist, void* stream) -{ - LAUNCH_ROUTING_DEEPSEEK(data, - /*coopLaunch=*/false, routingIndicesHistogramKernel, numBlocksHistogram, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingDeepSeek -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchInitExpertCounts.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchInitExpertCounts.cu deleted file mode 100644 index 5f265878a388..000000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchInitExpertCounts.cu +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "RoutingDeepSeekCommon.cuh" - -namespace moe::dev::routing -{ -namespace routingDeepSeek -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void launchInitExpertCounts(Data& data, int numThreadsHist, void* stream) -{ - LAUNCH_ROUTING_DEEPSEEK(data, false, routingInitExpertCounts, (2 * data.mNumExperts - 1) / numThreadsHist + 1, - numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/false); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingDeepSeek -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchMainKernel.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchMainKernel.cu deleted file mode 100644 index 1edc469cf70b..000000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchMainKernel.cu +++ /dev/null @@ -1,289 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "RoutingDeepSeekCommon.cuh" - -namespace moe::dev::routing -{ -namespace routingDeepSeek -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__global__ void routingMainKernel(KernelParams params) -{ - // declare types - using OutputT = typename KernelParams::OutputT; - using InputT = typename KernelParams::InputT; - - // declare shared memory structure - // number of experts is bounded by number of threads - __shared__ float __attribute((aligned(128))) smemScoreSigmoid[KernelParams::MaxNumExperts]; - __shared__ float __attribute((aligned(128))) smemScoreBias[KernelParams::MaxNumExperts]; - // number of expert groups is bounded by number of warps - __shared__ float __attribute((aligned(128))) smemGroupScores[MaxNumGroups]; - - // needed for warp reduce - auto block = cg::this_thread_block(); - auto warp = cg::tiled_partition(block); - // for the final reduction of weight norm, only some lanes need to participate - int32_t laneIdx = threadIdx.x % WarpSize; - int32_t warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); - // warps outside the range of expert groups do not participate - if constexpr (KernelParams::UseGroups) - { - if (warpIdx >= params.mNumExpertGroups) - { - return; - } - } - - // note that for invalid scores, we simply use a negative value: - // they work well even with the compacted format used in topK, and - // sigmoid / bias activated scores cannot be negative - static constexpr float invalidScoreFloat = float{-INFINITY}; - const OutputT invalidScore = OutputT{invalidScoreFloat}; - - // load bias already; each warp represents one expert group - auto threadExpert = threadIdx.x; - bool expertSelected = threadExpert < params.mNumExperts; - if constexpr (KernelParams::UseGroups) - { - threadExpert = warpIdx * params.mNumExpertsPerGroup + laneIdx; - expertSelected = laneIdx < params.mNumExpertsPerGroup; - } - auto scoreIdx = int64_t{blockIdx.x} * int64_t{params.mNumExperts} + threadExpert; - auto biasVal = expertSelected ? params.mPtrRoutingBias[threadExpert] : invalidScore; - - // initialize the mPtrExpertCounts - if (params.mPtrExpertCounts) - { - int32_t globalThreadIdx = blockIdx.x * blockDim.x + threadIdx.x; - int32_t globalThreadStride = gridDim.x * blockDim.x; - int32_t expertCountsNum = 2 * params.mNumExperts; - initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0); - } - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - // trigger the secondary kernel when using PDL, then wait on primary - if constexpr (KernelParams::UsePdl) - { - cudaTriggerProgrammaticLaunchCompletion(); - cudaGridDependencySynchronize(); - } -#endif - - if (params.mPtrScores != nullptr) - { - // get our assigned thread score; each warp represents one expert group - float score = expertSelected ? static_cast(params.mPtrScores[scoreIdx]) : invalidScoreFloat; - // get the sigmoid score - // note that for invalid values, we simply use a negative value: - // sigmoig scores are always strictly positive - auto scoreSigmoid = sigmoid_accurate(score); - // write the sigmoid score to shared for later use - if (expertSelected) - { - smemScoreSigmoid[threadExpert] = scoreSigmoid; - } - // get the score with bias - // note that with invalid values, because sigmoid is < 1 and bias is -1, - // we must get a negative value, which is smaller than any valid value - auto scoreBias = float{scoreSigmoid + float{biasVal}}; - - if (expertSelected) - { - smemScoreBias[threadExpert] = scoreBias; - } - - // registers for top group score reduction - float topExpGroupScores[NumTopGroupScores]; - [[maybe_unused]] int32_t topExpGroupIdx[NumTopGroupScores]; - float topGroups[MaxNumTopGroups]; // bound of params.mNumLimitedGroups - int32_t topGroupIdx[MaxNumTopGroups]; - float expertScoreGroup[MaxNumTopGroups]; - int32_t expertIdxGroup[MaxNumTopGroups]; - float topScores[KernelParams::MaxNumTopExperts]; // bound of params.mTopK - int32_t topExperts[KernelParams::MaxNumTopExperts]; - - if constexpr (KernelParams::UseGroups) - { - topk::reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias, threadExpert, - /* minValue */ invalidScoreFloat); - // get the final group score and write it to shared - if (cute::elect_one_sync()) - { - auto groupScore = topExpGroupScores[0] + topExpGroupScores[1]; - smemGroupScores[warpIdx] = groupScore; - } - } - - // make group scores available to all warps - __syncthreads(); - - auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; - if constexpr (KernelParams::UseGroups) - { // a single warp performs the selection of top groups, and goes on to select the final experts - if (warpIdx == 0) - { - float groupScore = laneIdx < params.mNumExpertGroups ? smemGroupScores[laneIdx] : invalidScoreFloat; - topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx, - /* minValue */ invalidScoreFloat); - // final expert selection: get relevant indexes and scores from shared -#pragma unroll - for (int ii = 0; ii < MaxNumTopGroups; ++ii) - { // bound of params.mNumLimitedGroups - auto groupIdx = topGroupIdx[ii]; - expertIdxGroup[ii] = groupIdx * params.mNumExpertsPerGroup + laneIdx; - // note: expertSelected implies laneIdx < params.mNumExpertsPerGroup. - // we have params.mNumExpertsPerGroup == params.mNumExperts / params.mNumExpertGroups, - // thus groupIdx <= params.mNumExpertGroups - 1 => - // groupIdx * params.mNumExpertsPerGroup <= params.mNumExperts - params.mNumExpertsPerGroup - // => expertIdxGroup[ii] < params.mNumExperts <= NumThreads, - // so the access is safe here - expertScoreGroup[ii] - = (ii < params.mNumLimitedGroups) && (groupIdx < params.mNumExpertGroups) && expertSelected - ? smemScoreBias[expertIdxGroup[ii]] - : invalidScoreFloat; - } - - topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, - /* minValue */ invalidScoreFloat, params.mTopK); - } - } - else if constexpr (KernelParams::MaxNumExperts > topk::MaxNumExpertsUnit) - { - // without groups, each thread just takes `MaxNumTopGroups` experts - int constexpr NumExpertWarps = (KernelParams::MaxNumExperts - 1) / topk::MaxNumExpertsUnit + 1; - int constexpr NumInterTopK = NumExpertWarps * KernelParams::MaxNumTopExperts; - __shared__ float __attribute((aligned(128))) smemInterTopScores[NumInterTopK]; - __shared__ int32_t __attribute((aligned(128))) smemInterTopExperts[NumInterTopK]; - if (warpIdx < NumExpertWarps) - { - int offset = warpIdx * WarpSize * MaxNumTopGroups; -#pragma unroll - for (int ii = 0; ii < MaxNumTopGroups; ++ii) - { - auto expertIdx = ii * WarpSize + laneIdx; - expertIdxGroup[ii] = offset + expertIdx; - expertScoreGroup[ii] = offset + expertIdx < params.mNumExperts ? smemScoreBias[offset + expertIdx] - : invalidScoreFloat; - } - topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, - /* minValue */ invalidScoreFloat, params.mTopK); - - if (laneIdx < params.mTopK) - { - smemInterTopScores[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = topScores[laneIdx]; - smemInterTopExperts[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = topExperts[laneIdx]; - } - else if (laneIdx >= params.mTopK && laneIdx < KernelParams::MaxNumTopExperts) - { - smemInterTopScores[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = invalidScoreFloat; - smemInterTopExperts[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] - = MaxSupportedExpertCount - 1; - } - } - __syncthreads(); - if (warpIdx == 0) - { - int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WarpSize + 1; - float intermediateScore[NumInterTopKPerThread]; - int32_t intermediateExpert[NumInterTopKPerThread]; - for (int i = laneIdx; i < NumInterTopKPerThread * WarpSize; i += WarpSize) - { - int ii = i / WarpSize; - if (i < NumInterTopK) - { - intermediateScore[ii] = smemInterTopScores[i]; - intermediateExpert[ii] = smemInterTopExperts[i]; - } - else - { - intermediateScore[ii] = invalidScoreFloat; - intermediateExpert[ii] = KernelParams::MaxNumExperts - 1; - } - } - topk::reduceTopK(warp, topScores, topExperts, intermediateScore, intermediateExpert, - /* minValue */ invalidScoreFloat, params.mTopK); - } - } - else - { - if (warpIdx == 0) - { - // without groups, each thread just takes `MaxNumTopGroups` experts -#pragma unroll - for (int ii = 0; ii < MaxNumTopGroups; ++ii) - { - auto expertIdx = ii * WarpSize + laneIdx; - expertIdxGroup[ii] = expertIdx; - expertScoreGroup[ii] - = expertIdx < params.mNumExperts ? smemScoreBias[expertIdx] : invalidScoreFloat; - } - topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, - /* minValue */ invalidScoreFloat, params.mTopK); - } - } - - if (warpIdx == 0) - { - // determine our lane's expert index and write to output - int32_t expertIdx = 0; -#pragma unroll - for (int ii = 0; ii < params.mTopK; ++ii) - { // bound of params.mTopK - expertIdx = laneIdx == ii ? topExperts[ii] : expertIdx; - } - // determine whether our expert is local to this GPU - auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx; - auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent - && (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; - - float scoreNorm = laneIdx < params.mTopK ? smemScoreSigmoid[expertIdx] : 0.F; - auto redNorm = cg::reduce(warp, scoreNorm, cg::plus{}); - auto finalScore = OutputT{scoreNorm * params.mRouteScale / redNorm}; - - // write expert idx out already - auto idxTopK = blockIdx.x * params.mTopK + laneIdx; - if (laneIdx < params.mTopK && params.mPtrTopKPacked != nullptr) - { - PackedScoreIdx packedScore{static_cast(finalScore), static_cast(expertIdx)}; - params.mPtrTopKPacked[idxTopK] = packedScore; - } - - if (laneIdx < params.mTopK && params.mPtrTopKWeights != nullptr && params.mPtrTopKIds == nullptr) - { - params.mPtrTopKWeights[idxTopK] = finalScore; - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void launchMainKernel(Data& data, int numBlocks, int numThreadsMain, void* stream) -{ - LAUNCH_ROUTING_DEEPSEEK(data, - /*coopLaunch=*/false, routingMainKernel, numBlocks, numThreadsMain, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingDeepSeek -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchOffsetsKernel.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchOffsetsKernel.cu deleted file mode 100644 index 0836c21aa5cf..000000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchOffsetsKernel.cu +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "RoutingDeepSeekCommon.cuh" - -namespace moe::dev::routing -{ -namespace routingDeepSeek -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void launchOffsetsKernel(Data& data, int numBlocksOffsets, int numThreadsHist, void* stream) -{ - LAUNCH_ROUTING_DEEPSEEK(data, - /*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingDeepSeek -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/RoutingRenormalizeCommon.cuh b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/RoutingRenormalizeCommon.cuh deleted file mode 100644 index 31bfb399547b..000000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/RoutingRenormalizeCommon.cuh +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "../RoutingKernel.cuh" - -namespace moe::dev::routing -{ -namespace routingRenormalize -{ -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static constexpr int NumExperts128Experts = 128; -static constexpr int NumExperts512Experts = 512; -static constexpr int MaxSupportedExperts = 2048; - -static constexpr int NumTop8Experts = 8; -static constexpr int NumTop16Experts = 16; -static constexpr int MaxSupportedTopExperts = 32; - -static constexpr int NumThreads = 1024; -static constexpr int NumWarps = NumThreads / WarpSize; - -static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads; -static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps; - -static constexpr int BlockKernelMaxNumTokens = 4; -static constexpr int DynBlockKernelMaxNumTokens = 16; -static constexpr int DynBlockKernelMaxNumExperts = 512; - -template -__forceinline__ __device__ void routingTopKExperts(cg::thread_block_tile const& warp, - DataType (&score)[VecSize], int32_t (&idx)[VecSize], DataType (&warpTopKScore)[K], int32_t (&warpTopKExpertIdx)[K], - int32_t const laneIdx, int32_t const numExperts, int32_t topK, InputType const* ptrScores, bool const normTopkProb, - bool const applySoftmaxAfterTopK = true) -{ - DataType minScore = DataType{-INFINITY}; - - for (int i = 0; i < VecSize; i++) - { - auto expertIdx = i * WarpSize + laneIdx; - auto newScore = expertIdx < numExperts ? static_cast(ptrScores[expertIdx]) : minScore; - score[i] = newScore; - idx[i] = expertIdx; - } - if constexpr (DoSoftmaxBeforeTopK) - { - calcSoftmax(warp, score); - } - - // Get the top-k scores and their corresponding expert indices - topk::reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, score, idx, minScore, topK); - - // Normalize the scores - if constexpr (DoSoftmaxBeforeTopK) - { - float sum = float{1.f}; - if (normTopkProb) - { - sum = static_cast(laneIdx < topK ? warpTopKScore[laneIdx] : 0); - sum = cg::reduce(warp, sum, cg::plus()); - } - if (laneIdx < topK) - { - warpTopKScore[laneIdx] = warpTopKScore[laneIdx] / sum; - } - } - else - { - if (applySoftmaxAfterTopK) - { - auto softmaxScore = calcSoftmax(warp, laneIdx < topK ? warpTopKScore[laneIdx] : minScore, laneIdx, topK); - if (laneIdx < topK) - { - warpTopKScore[laneIdx] = softmaxScore; - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -int32_t constexpr getMaxNumExperts(int32_t numExperts) -{ - if (numExperts <= NumExperts128Experts) - { - return NumExperts128Experts; - } - else if (numExperts <= NumExperts512Experts) - { - return NumExperts512Experts; - } - else if (numExperts <= MaxSupportedExperts) - { - return MaxSupportedExperts; - } - else - { - TLLM_LOG_ERROR("Unsupported numExperts"); - return 0; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Helper macro: dispatch on topK tier for a given numExperts tier. -#define LAUNCH_ROUTING_WITH_TOPK( \ - data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, numExperts) \ - if (data.mTopK <= NumTop8Experts) \ - { \ - LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ - numExperts, NumTop8Experts); \ - } \ - else if (data.mTopK <= NumTop16Experts) \ - { \ - LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ - numExperts, NumTop16Experts); \ - } \ - else \ - { \ - LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ - numExperts, MaxSupportedTopExperts); \ - } - -#define LAUNCH_ROUTING_RENORMALIZE(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1) \ - if (data.mNumExperts <= NumExperts128Experts) \ - { \ - LAUNCH_ROUTING_WITH_TOPK( \ - data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, NumExperts128Experts); \ - } \ - else if (data.mNumExperts <= NumExperts512Experts) \ - { \ - LAUNCH_ROUTING_WITH_TOPK( \ - data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, NumExperts512Experts); \ - } \ - else if (data.mNumExperts <= MaxSupportedExperts) \ - { \ - LAUNCH_ROUTING_WITH_TOPK( \ - data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, MaxSupportedExperts); \ - } \ - else \ - { \ - TLLM_LOG_ERROR("Unsupported numExperts"); \ - } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingRenormalize -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchClusterKernel.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchClusterKernel.cu deleted file mode 100644 index b8d7f8b91186..000000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchClusterKernel.cu +++ /dev/null @@ -1,117 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "RoutingRenormalizeCommon.cuh" - -namespace moe::dev::routing -{ -namespace routingRenormalize -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(NumThreads) - routingIndicesClusterKernel(KernelParams params) -{ - // number of tokens/expanded idx is bounded by total number of warps - using OutputT = typename KernelParams::OutputT; - using InputT = typename KernelParams::InputT; - - using BaseType = std::conditional_t; - using TypePacked = PackedScoreIdx; - - static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; - - __shared__ TypePacked __attribute((aligned(128))) smemPackedScoreIdx[NumWarps * KernelParams::MaxNumTopExperts]; - - uint32_t const clusterBlockRank = blockIdx.x; - - int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); - int32_t const laneIdx = cutlass::arch::LaneId(); - - auto warpTokenIdx = clusterBlockRank * NumWarps + warpIdx; - auto scoreOffset = warpTokenIdx * params.mNumExperts; - bool validToken = warpTokenIdx < params.mNumTokens; - - auto block = cg::this_thread_block(); - auto warp = cg::tiled_partition(block); - - // then wait on primary grid - if constexpr (KernelParams::UsePdl) - { - cudaGridDependencySynchronize(); - } - - if (params.mPtrScores != nullptr) - { - // in this case, each warp represents a token - BaseType score[VecSize]; - int32_t idx[VecSize]; - - BaseType warpTopKScore[KernelParams::MaxNumTopExperts]; - int32_t warpTopKExpertIdx[KernelParams::MaxNumTopExperts]; - - BaseType minScore = BaseType{-INFINITY}; - if (validToken) - { - routingTopKExperts(warp, score, idx, warpTopKScore, warpTopKExpertIdx, laneIdx, - params.mNumExperts, params.mTopK, params.mPtrScores + scoreOffset, params.mNormTopkProb); - - if (laneIdx < params.mTopK) - { - smemPackedScoreIdx[warpIdx * params.mTopK + laneIdx] - = TypePacked{warpTopKScore[laneIdx], static_cast(warpTopKExpertIdx[laneIdx])}; - } - } // end if (validToken) - } - - // make packed scores available to all threads in cluster - __cluster_barrier_arrive(); - __cluster_barrier_wait(); - - if (params.mPtrScores != nullptr) - { - routingPermutation(params, smemPackedScoreIdx, warpIdx, clusterBlockRank); - } - else - { - routingPermutation(params, smemPackedScoreIdx, warpIdx, clusterBlockRank); - } -} -#else -__global__ void __launch_bounds__(NumThreads) routingIndicesClusterKernel(KernelParams /* params */) -{ - assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures"); -} -#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void launchClusterKernel(Data const& data, void* stream) -{ - LAUNCH_ROUTING_RENORMALIZE(data, false, routingIndicesClusterKernel, NumBlocksPerCluster, NumThreads, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingRenormalize -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchHistogramKernel.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchHistogramKernel.cu deleted file mode 100644 index 7d6f6177a56e..000000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchHistogramKernel.cu +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "RoutingRenormalizeCommon.cuh" - -namespace moe::dev::routing -{ -namespace routingRenormalize -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void launchHistogramKernel(Data const& data, int numBlocksHistogram, uint32_t numThreadsHist, void* stream) -{ - LAUNCH_ROUTING_RENORMALIZE(data, false, routingIndicesHistogramKernel, numBlocksHistogram, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingRenormalize -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchHistogramScoresKernel.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchHistogramScoresKernel.cu deleted file mode 100644 index 03bec526f823..000000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchHistogramScoresKernel.cu +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "RoutingRenormalizeCommon.cuh" - -namespace moe::dev::routing -{ -namespace routingRenormalize -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// this kernel is needed in case we have scores as input for the histogram kernel -template -__global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelParams::MaxNumExperts : 1024) - routingIndicesHistogramScoresKernel(KernelParams params) -{ - using OutputT = typename KernelParams::OutputT; - using InputT = typename KernelParams::InputT; - using BaseType = std::conditional_t; - // Cap actual thread count at 1024 when MaxNumExperts > 1024. - static constexpr int NumThreadsBlock = KernelParams::MaxNumExperts <= 1024 ? KernelParams::MaxNumExperts : 1024; - - // VecSize stays based on MaxNumExperts — each warp still processes all experts for one token. - static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; - - int32_t const laneIdx = cutlass::arch::LaneId(); - int32_t const warpIdx = threadIdx.x / WarpSize; - // Use NumThreadsBlock (actual thread count) for grid-stride warp/thread addressing - int32_t const globalWarpIdx = blockIdx.x * NumThreadsBlock / WarpSize + warpIdx; - int32_t const globalWarpStride = gridDim.x * NumThreadsBlock / WarpSize; - BaseType minScore = BaseType{-INFINITY}; - auto block = cg::this_thread_block(); - auto warp = cg::tiled_partition(block); - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - // Wait on primary grid. - if constexpr (KernelParams::UsePdl) - { - cudaGridDependencySynchronize(); - } -#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - - // initialize the mPtrExpertCounts — use NumThreadsBlock for grid-stride - int32_t expertCountsNum = 2 * params.mNumExperts; - int32_t globalThreadIdx = blockIdx.x * NumThreadsBlock + threadIdx.x; - int32_t globalThreadStride = gridDim.x * NumThreadsBlock; - initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0); - - // in this case, each warp represents a token, and we use a grid-stride loop - // over all warps/tokens - BaseType allScores[VecSize]; - int32_t allExpertIdx[VecSize]; - BaseType warpTopKScore[KernelParams::MaxNumTopExperts]; - int32_t warpTopKExpertIdx[KernelParams::MaxNumTopExperts]; - for (int tokenIdx = globalWarpIdx; tokenIdx < params.mNumTokens; tokenIdx += globalWarpStride) - { - auto scoreOffset = tokenIdx * params.mNumExperts; - - routingTopKExperts(warp, allScores, allExpertIdx, warpTopKScore, warpTopKExpertIdx, laneIdx, - params.mNumExperts, params.mTopK, params.mPtrScores + scoreOffset, params.mNormTopkProb); - - if (laneIdx < params.mTopK) - { - PackedScoreIdx packedScore{ - static_cast(warpTopKScore[laneIdx]), static_cast(warpTopKExpertIdx[laneIdx])}; - params.mPtrTopKPacked[tokenIdx * params.mTopK + laneIdx] = packedScore; - } - } - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - // Trigger secondary kernel AFTER writing all packed scores, so the next kernel - // (routingIndicesHistogramKernel) sees the completed mPtrTopKPacked writes. - if constexpr (KernelParams::UsePdl) - { - cudaTriggerProgrammaticLaunchCompletion(); - } -#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void launchHistogramScoresKernel(Data const& data, uint32_t maxNumBlocks, uint32_t numThreadsHist, void* stream) -{ - LAUNCH_ROUTING_RENORMALIZE(data, false, routingIndicesHistogramScoresKernel, maxNumBlocks, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingRenormalize -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchInitExpertCounts.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchInitExpertCounts.cu deleted file mode 100644 index 807fc89e8978..000000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchInitExpertCounts.cu +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "RoutingRenormalizeCommon.cuh" - -namespace moe::dev::routing -{ -namespace routingRenormalize -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void launchInitExpertCounts(Data const& data, uint32_t numThreadsHist, void* stream) -{ - LAUNCH_ROUTING_RENORMALIZE(data, false, routingInitExpertCounts, (2 * data.mNumExperts - 1) / numThreadsHist + 1, - numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingRenormalize -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchOffsetsKernel.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchOffsetsKernel.cu deleted file mode 100644 index fe398c80cdb1..000000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchOffsetsKernel.cu +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "RoutingRenormalizeCommon.cuh" - -namespace moe::dev::routing -{ -namespace routingRenormalize -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void launchOffsetsKernel(Data const& data, int numBlocksOffsets, uint32_t numThreadsHist, void* stream) -{ - LAUNCH_ROUTING_RENORMALIZE(data, false, routingIndicesOffsetsKernel, numBlocksOffsets, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingRenormalize -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu index 467bca9318ac..19ac15b18ad1 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu @@ -15,7 +15,7 @@ */ #include "DevKernel.h" -#include "RoutingKernel.h" +#include "routing/RoutingKernel.h" #include "runner.h" #include "tensorrt_llm/common/config.h" #include "tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.h" @@ -56,7 +56,7 @@ inline int32_t computeLog2(int32_t val, std::string const& name = "") Runner::Runner() {} -Runner::Runner(int32_t tileTokensDim) +Runner::Runner(int32_t tileTokensDim, int32_t clusterSizeInBatchDim) : mTileTokensDim(tileTokensDim) { } @@ -67,15 +67,175 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 int32_t* expandedIdxToPermutedIdx, int32_t* permutedIdxToExpandedIdx, int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* expertIds, int32_t* numTokensPerExpert, int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit, int32_t* numNonExitingCtas, btg::Dtype dtypeElt, bool useRoutingScalesOnInput, - bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream) + bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream, btg::Dtype dtypeRoutingLogits) { - if (routingMethodType == RoutingMethodType::DeepSeekV3) + if (routingMethodType == RoutingMethodType::DeepSeekV3 && nGroup <= 1) + { + // DeepSeek no-groups case: use routingCustom with SigmoidBias preprocess + // and ScaledSumNormalize postprocess. This is more efficient than the full DeepSeek + // kernel because it uses the warp-level routingTopKExperts flow. + moe::dev::routing::routingCustom::Data routingData; + + // + // Config + // + routingData.mDtypeOutput = btg::Dtype::Bfloat16; + routingData.mDtypeInput = dtypeRoutingLogits; + routingData.mUsePdl = tensorrt_llm::common::getEnvEnablePDL(); + routingData.mPreprocessType = moe::dev::routing::RoutingPreprocessType::SigmoidBias; + routingData.mPostprocessType = moe::dev::routing::RoutingPostprocessType::ScaledSumNormalize; + routingData.mPtrRoutingBias = routingBias; + // Bias is always bfloat16 in the current Runner::run() API (no separate bias dtype param). + // The bias buffer dtype is determined by the caller (thop), not by the routing logits dtype. + routingData.mDtypeBias = btg::Dtype::Bfloat16; + routingData.mRouteScale = routedScalingFactor; + + // Pass-through raw pointer; kernels will cast to the proper InputT based on routing method + routingData.mPtrScores = expertIds == nullptr ? routingLogits : nullptr; + // + // Outputs + // + routingData.mPtrTopKPacked = routingExpertIndexes; + routingData.mPtrExpertCounts = expertCountHistogram; + routingData.mPtrPermutedIdxSize = permutedIdxSize; + routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx; + routingData.mPtrPermutedIdxToExpandedIdx = permutedIdxToExpandedIdx; + routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx; + routingData.mPtrTopKWeights = expertWeights; + routingData.mPtrTopKIds = expertIds; + // + // Grouped Gemm Launch Config Buffers + // + routingData.mPtrCtaIdxXyToBatchIdx = ctaIdxXyToBatchIdx; + routingData.mPtrCtaIdxXyToMnLimit = ctaIdxXyToMnLimit; + routingData.mPtrNumNonExitingCtas = numNonExitingCtas; + + // + // Inputs + // + routingData.mNumTokens = numTokens; + routingData.mNumExperts = numExperts; + routingData.mTopK = topK; + routingData.mPaddingLog2 = computeLog2(mTileTokensDim); + routingData.mTileTokensDim = mTileTokensDim; + routingData.mLocalExpertsStartIdx = localExpertOffset; + routingData.mLocalExpertsStrideLog2 = 0; + routingData.mNumLocalExperts = localNumExperts; + + moe::dev::routing::routingCustom::run(routingData, stream); + } + else if (routingMethodType == RoutingMethodType::SigmoidRenorm) + { + // SigmoidRenorm: sigmoid(logit) → topK → renormalize. + // No bias, no scaling factor — pure sigmoid activation with top-K renormalization. + moe::dev::routing::routingCustom::Data routingData; + + // + // Config + // + routingData.mDtypeOutput = btg::Dtype::Bfloat16; + routingData.mDtypeInput = dtypeRoutingLogits; + routingData.mUsePdl = tensorrt_llm::common::getEnvEnablePDL(); + routingData.mPreprocessType = moe::dev::routing::RoutingPreprocessType::Sigmoid; + routingData.mPostprocessType = moe::dev::routing::RoutingPostprocessType::SumNormalize; + routingData.mNormTopkProb = true; + + // Pass-through raw pointer; kernels will cast to the proper InputT based on routing method + routingData.mPtrScores = expertIds == nullptr ? routingLogits : nullptr; + // + // Outputs + // + routingData.mPtrTopKPacked = routingExpertIndexes; + routingData.mPtrExpertCounts = expertCountHistogram; + routingData.mPtrPermutedIdxSize = permutedIdxSize; + routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx; + routingData.mPtrPermutedIdxToExpandedIdx = permutedIdxToExpandedIdx; + routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx; + routingData.mPtrTopKWeights = expertWeights; + routingData.mPtrTopKIds = expertIds; + // + // Grouped Gemm Launch Config Buffers + // + routingData.mPtrCtaIdxXyToBatchIdx = ctaIdxXyToBatchIdx; + routingData.mPtrCtaIdxXyToMnLimit = ctaIdxXyToMnLimit; + routingData.mPtrNumNonExitingCtas = numNonExitingCtas; + + // + // Inputs + // + routingData.mNumTokens = numTokens; + routingData.mNumExperts = numExperts; + routingData.mTopK = topK; + routingData.mPaddingLog2 = computeLog2(mTileTokensDim); + routingData.mTileTokensDim = mTileTokensDim; + routingData.mLocalExpertsStartIdx = localExpertOffset; + routingData.mLocalExpertsStrideLog2 = 0; + routingData.mNumLocalExperts = localNumExperts; + + moe::dev::routing::routingCustom::run(routingData, stream); + } + else if (routingMethodType == RoutingMethodType::MiniMax2) + { + // MiniMaxM2: sigmoid(logit) + bias → topK → renormalize un-biased sigmoid scores. + // Similar to DeepSeek no-groups but with routeScale = 1.0 and epsilon = 1e-20 + // to match the Python reference: weight / (sum + 1e-20). + moe::dev::routing::routingCustom::Data routingData; + + // + // Config + // + routingData.mDtypeOutput = btg::Dtype::Bfloat16; + routingData.mDtypeInput = dtypeRoutingLogits; + routingData.mUsePdl = tensorrt_llm::common::getEnvEnablePDL(); + routingData.mPreprocessType = moe::dev::routing::RoutingPreprocessType::SigmoidBias; + routingData.mPostprocessType = moe::dev::routing::RoutingPostprocessType::ScaledSumNormalize; + routingData.mPtrRoutingBias = routingBias; + // Bias is always bfloat16 in the current Runner::run() API (no separate bias dtype param). + routingData.mDtypeBias = btg::Dtype::Bfloat16; + routingData.mRouteScale = 1.0f; + routingData.mSumEpsilon = 1e-20f; + + // Pass-through raw pointer; kernels will cast to the proper InputT based on routing method + routingData.mPtrScores = expertIds == nullptr ? routingLogits : nullptr; + // + // Outputs + // + routingData.mPtrTopKPacked = routingExpertIndexes; + routingData.mPtrExpertCounts = expertCountHistogram; + routingData.mPtrPermutedIdxSize = permutedIdxSize; + routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx; + routingData.mPtrPermutedIdxToExpandedIdx = permutedIdxToExpandedIdx; + routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx; + routingData.mPtrTopKWeights = expertWeights; + routingData.mPtrTopKIds = expertIds; + // + // Grouped Gemm Launch Config Buffers + // + routingData.mPtrCtaIdxXyToBatchIdx = ctaIdxXyToBatchIdx; + routingData.mPtrCtaIdxXyToMnLimit = ctaIdxXyToMnLimit; + routingData.mPtrNumNonExitingCtas = numNonExitingCtas; + + // + // Inputs + // + routingData.mNumTokens = numTokens; + routingData.mNumExperts = numExperts; + routingData.mTopK = topK; + routingData.mPaddingLog2 = computeLog2(mTileTokensDim); + routingData.mTileTokensDim = mTileTokensDim; + routingData.mLocalExpertsStartIdx = localExpertOffset; + routingData.mLocalExpertsStrideLog2 = 0; + routingData.mNumLocalExperts = localNumExperts; + + moe::dev::routing::routingCustom::run(routingData, stream); + } + else if (routingMethodType == RoutingMethodType::DeepSeekV3) { TLLM_CHECK_WITH_INFO(topK <= 22, "For DeepSeek routing method, must have topK <= 22"); TLLM_CHECK_WITH_INFO(topkGroup <= 4, "For DeepSeek routing method, must have topkGroup <= 4"); moe::dev::routing::routingDeepSeek::Data routingData; - routingData.mDtypeExpW = btg::Dtype::Bfloat16; - routingData.mUsePdl = true; + routingData.mDtypeOutput = btg::Dtype::Bfloat16; + routingData.mUsePdl = tensorrt_llm::common::getEnvEnablePDL(); // output: routingData.mPtrTopKPacked = routingExpertIndexes; @@ -92,6 +252,8 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 // input: routingData.mPtrRoutingBias = routingBias; + // Bias is always bfloat16 in the current Runner::run() API (no separate bias dtype param). + routingData.mDtypeBias = btg::Dtype::Bfloat16; // Pass-through raw pointer; kernels will cast to the proper InputT based on routing method routingData.mPtrScores = expertIds == nullptr ? routingLogits : nullptr; routingData.mPtrTopKIds = expertIds; @@ -117,8 +279,8 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 TLLM_LOG_WARNING("For Llama routing method, nGroup/topkGroup is ignored, got %d/%d.", nGroup, topkGroup); } moe::dev::routing::routingLlama4::Data routingData; - routingData.mDtypeExpW = btg::Dtype::Bfloat16; - routingData.mUsePdl = true; + routingData.mDtypeOutput = btg::Dtype::Bfloat16; + routingData.mUsePdl = tensorrt_llm::common::getEnvEnablePDL(); // output: routingData.mPtrTopKPacked = routingExpertIndexes; @@ -157,20 +319,43 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 // routingData.mUseRoutingSoftmax = false; moe::dev::routing::routingLlama4::run(routingData, stream); } - else if (routingMethodType == RoutingMethodType::Renormalize /* default */ - || routingMethodType == RoutingMethodType::RenormalizeNaive /* Softmax -> TopK */) + else if (routingMethodType == RoutingMethodType::Renormalize + || routingMethodType == RoutingMethodType::RenormalizeNaive || routingMethodType == RoutingMethodType::Default) { - moe::dev::routing::routingRenormalize::Data routingData; + moe::dev::routing::routingCustom::Data routingData; // // Config // - routingData.mDtypeExpW = btg::Dtype::Bfloat16; + routingData.mDtypeOutput = btg::Dtype::Bfloat16; + routingData.mDtypeInput = dtypeRoutingLogits; // routingData.mDtypeElt = dtypeElt; // no-op for now as hidden_state is not input - routingData.mUsePdl = tensorrt_llm::common::getEnvEnableTrtllmgenMoeRoutingRenormPDL(); - routingData.mDoSoftmaxBeforeTopK = routingMethodType == RoutingMethodType::RenormalizeNaive; - routingData.mNormTopkProb = routingMethodType == RoutingMethodType::RenormalizeNaive; + routingData.mUsePdl = tensorrt_llm::common::getEnvEnablePDL(); + if (routingMethodType == RoutingMethodType::Default) + { + // Default: Softmax -> TopK (no postprocessing) + routingData.mPreprocessType = moe::dev::routing::RoutingPreprocessType::Softmax; + routingData.mPostprocessType = moe::dev::routing::RoutingPostprocessType::None; + } + else + { + // Renormalize and RenormalizeNaive are mathematically equivalent: + // RenormalizeNaive: softmax(all N experts) → topK → divide by sum of topK + // Renormalize: topK(raw scores) → softmax(K experts) + // + // Both produce identical output because: + // 1. softmax is monotonic, so topK selection yields the same experts + // 2. softmax(topK raw scores) = softmax(topK softmax scores) after renormalization, + // since softmax(x_i) / Σ softmax(x_j) = exp(x_i) / Σ exp(x_j) for the topK subset + // + // We always use the Renormalize path (NoOp preprocess + Softmax postprocess) + // because it only computes softmax over K experts instead of all N, which is faster + // — especially for large expert counts (e.g., 256 experts with topK=8). + routingData.mPreprocessType = moe::dev::routing::RoutingPreprocessType::None; + routingData.mPostprocessType = moe::dev::routing::RoutingPostprocessType::Softmax; + routingData.mNormTopkProb = true; + } // Pass-through raw pointer; kernels will cast to the proper InputT based on routing method routingData.mPtrScores = expertIds == nullptr ? routingLogits : nullptr; @@ -204,7 +389,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mLocalExpertsStrideLog2 = 0; routingData.mNumLocalExperts = localNumExperts; - moe::dev::routing::routingRenormalize::run(routingData, stream); + moe::dev::routing::routingCustom::run(routingData, stream); } else { @@ -291,12 +476,12 @@ void Runner::run(void* hiddenState, void* hiddenStateScale, void* weights, void* // tensorrt_llm/_torch/modules/fused_moe/quantization.py:MXFP4WeightTRTLLMGenFusedMoEMethod.input_hidden_alignment validHiddenSize = tensorrt_llm::common::roundUp(validHiddenSize, 512); } - auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + auto maxNumCgasInBatchDim = Routing::getMaxNumCgasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); bool is_gated_activation = mActType == ActType::SwiGlu; int32_t intermediateSizeFactor = (is_gated_activation ? 2 : 1); mRunner.run(numTokens, intermediateSizeFactor * intermediateSize, hiddenSize, numTokens, intermediateSizeFactor * validIntermediateSize, validHiddenSize, {}, numTokens, numExperts, - maxNumCtasInBatchDim, hiddenState, hiddenStateScale, weights, weightsScale, + maxNumCgasInBatchDim, hiddenState, hiddenStateScale, weights, weightsScale, useRoutingScalesOnInput ? expertWeights : nullptr, /* perTokensSfB */ nullptr, outputScalesScalar, outputScalesGateScalar, ptrBias, ptrAlpha, ptrBeta, ptrClampLimit, output, outputScale, permutedIdxToTokenIdx, ptrTotalNumPaddedTokens, ptrCtaIdxXyToBatchIdx, ptrCtaIdxXyToMnLimit, ptrNumNonExitingCtas, bmm1Workspace, @@ -306,31 +491,31 @@ void Runner::run(void* hiddenState, void* hiddenStateScale, void* weights, void* size_t Runner::getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens, int32_t configIndex) const { - auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + auto maxNumCgasInBatchDim = Routing::getMaxNumCgasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); int32_t const intermediateSizeFactor = mActType == ActType::SwiGlu ? 2 : 1; return mRunner.getWorkspaceSizeInBytes(numTokens, intermediateSizeFactor * intermediateSize, hiddenSize, {}, - numTokens, numExperts, maxNumCtasInBatchDim, configIndex); + numTokens, numExperts, maxNumCgasInBatchDim, configIndex); } int32_t Runner::getDefaultValidConfigIndex(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens, int32_t validHiddenSize, int32_t validIntermediateSize) const { - auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + auto maxNumCgasInBatchDim = Routing::getMaxNumCgasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); bool is_gated_activation = mActType == ActType::SwiGlu; return mRunner.getDefaultValidConfigIndex(numTokens, is_gated_activation ? 2 * intermediateSize : intermediateSize, - hiddenSize, {}, numTokens, numExperts, maxNumCtasInBatchDim, numTokens, 2 * validIntermediateSize, + hiddenSize, {}, numTokens, numExperts, maxNumCgasInBatchDim, numTokens, 2 * validIntermediateSize, validHiddenSize); } bool Runner::isValidConfigIndex(int32_t configIndex, int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens, int32_t validHiddenSize, int32_t validIntermediateSize) const { - auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + auto maxNumCgasInBatchDim = Routing::getMaxNumCgasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); bool is_gated_activation = mActType == ActType::SwiGlu; auto const isValid = mRunner.isValidConfigIndex(configIndex, numTokens, is_gated_activation ? 2 * intermediateSize : intermediateSize, hiddenSize, {}, numTokens, numExperts, - maxNumCtasInBatchDim, numTokens, 2 * validIntermediateSize, validHiddenSize); + maxNumCgasInBatchDim, numTokens, 2 * validIntermediateSize, validHiddenSize); return isValid; } @@ -391,9 +576,9 @@ void Runner::run(void* permutedHiddenState, void* permutedHiddenStateScale, void // The multiple is no less than 128 as TMA requires it for CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B types validIntermediateSize = tensorrt_llm::common::roundUp(validIntermediateSize, 128); } - auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + auto maxNumCgasInBatchDim = Routing::getMaxNumCgasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); mRunner.run(numTokens, hiddenSize, intermediateSize, numTokens, validHiddenSize, validIntermediateSize, {}, - numTokens, numExperts, maxNumCtasInBatchDim, permutedHiddenState, permutedHiddenStateScale, weights, + numTokens, numExperts, maxNumCgasInBatchDim, permutedHiddenState, permutedHiddenStateScale, weights, weightsScale, /* perTokensSfA */ nullptr, /* perTokensSfB */ nullptr, outputScalesScalar, /* outputScalesGateScalar */ nullptr, ptrBias, /* ptrAlpha */ nullptr, /* ptrBeta */ nullptr, /* clampLimit */ nullptr, output, outputScale, @@ -404,27 +589,27 @@ void Runner::run(void* permutedHiddenState, void* permutedHiddenStateScale, void size_t Runner::getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens, int32_t configIndex) const { - auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + auto maxNumCgasInBatchDim = Routing::getMaxNumCgasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); return mRunner.getWorkspaceSizeInBytes( - numTokens, hiddenSize, intermediateSize, {}, numTokens, numExperts, maxNumCtasInBatchDim, configIndex); + numTokens, hiddenSize, intermediateSize, {}, numTokens, numExperts, maxNumCgasInBatchDim, configIndex); } int32_t Runner::getDefaultValidConfigIndex(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens, int32_t validHiddenSize, int32_t validIntermediateSize) const { - auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + auto maxNumCgasInBatchDim = Routing::getMaxNumCgasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); return mRunner.getDefaultValidConfigIndex(numTokens, hiddenSize, intermediateSize, {}, numTokens, numExperts, - maxNumCtasInBatchDim, numTokens, validHiddenSize, validIntermediateSize); + maxNumCgasInBatchDim, numTokens, validHiddenSize, validIntermediateSize); } bool Runner::isValidConfigIndex(int32_t configIndex, int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens, int32_t validHiddenSize, int32_t validIntermediateSize) const { - auto const maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + auto const maxNumCgasInBatchDim = Routing::getMaxNumCgasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); auto const isValid = mRunner.isValidConfigIndex(configIndex, numTokens, hiddenSize, intermediateSize, {}, numTokens, - numExperts, maxNumCtasInBatchDim, numTokens, validHiddenSize, validIntermediateSize); + numExperts, maxNumCgasInBatchDim, numTokens, validHiddenSize, validIntermediateSize); return isValid; } @@ -482,11 +667,11 @@ void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace convertSfData.numTokens = args.num_tokens; convertSfData.sfLayoutSrc = btg::SfLayout::R128c4; convertSfData.sfLayoutDst = btg::SfLayout::Linear; - convertSfData.mUsePdl = true; + convertSfData.mUsePdl = tensorrt_llm::common::getEnvEnablePDL(); // Setup activation data activationData.mDtypeElt = args.mDtypeElt; - activationData.mUsePdl = true; + activationData.mUsePdl = tensorrt_llm::common::getEnvEnablePDL(); activationData.mUseDeepSeekFp8 = true; activationData.inPtr = workspace.gemm1_output; activationData.outPtr = workspace.activation_output; @@ -504,7 +689,7 @@ void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace // Setup finalize data finalizeData.mDtypeElt = args.mDtypeOut; finalizeData.mDtypeExpW = args.mDtypeExpW; - finalizeData.mUsePdl = true; + finalizeData.mUsePdl = tensorrt_llm::common::getEnvEnablePDL(); finalizeData.mUseDeepSeekFp8 = false; finalizeData.inPtr = workspace.gemm2_output; finalizeData.outPtr = args.output; diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h index e922764061de..ba20fca4daf9 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h @@ -17,7 +17,7 @@ #pragma once #include "DevKernel.h" -#include "RoutingKernel.h" +#include "routing/RoutingKernel.h" #include "tensorrt_llm/common/config.h" #include "tensorrt_llm/common/cudaDriverWrapper.h" #include "tensorrt_llm/common/cudaUtils.h" @@ -76,12 +76,17 @@ enum class RoutingMethodType : int64_t DeepSeekV3 = 2, // Llama4: Top1 -> Sigmoid Llama4 = 3, - // RenormalizeNaive: Softmax -> TopK -> Renormalize + // RenormalizeNaive: Softmax -> TopK -> Renormalize. + // Mathematically equivalent to Renormalize (TopK -> Softmax), but conceptually applies + // softmax over all N experts first. At runtime, we use the Renormalize kernel path + // (TopK -> Softmax over K) which is faster since softmax is only over K selected experts. RenormalizeNaive = 4, // MiniMaxM2: Sigmoid -> RoutingBiasAdd -> TopK -> Renormalize(without bias) MiniMax2 = 5, + // SigmoidRenorm: Sigmoid -> TopK -> Renormalize + SigmoidRenorm = 6, // Unspecified - Unspecified = 6, + Unspecified = 7, }; inline int32_t maybeGetMinTokenCount(int32_t numPaddedTokens, int32_t hiddenSize, int32_t dtypeSizeBits) @@ -101,45 +106,52 @@ inline std::string serializeMoeRoutingMethodType(RoutingMethodType routingMethod case RoutingMethodType::Llama4: return "Llama4"; case RoutingMethodType::RenormalizeNaive: return "RenormalizeNaive"; case RoutingMethodType::MiniMax2: return "MiniMax2"; + case RoutingMethodType::SigmoidRenorm: return "SigmoidRenorm"; default: TLLM_CHECK_WITH_INFO(false, "Invalid routing method"); return ""; }; } -inline int32_t getMaxNumCtasInBatchDim(int32_t numTokens, int32_t topK, int32_t numExperts, int32_t tileTokensDim) +inline int32_t getMaxNumCgasInBatchDim(int32_t numTokens, int32_t topK, int32_t numExperts, int32_t cgaTileTokensDim) { - // For MoE, mNumTokens != 0 and the number of CTAs is known only at runtime. - // We launch maximally possible number of CTAs and use ptrNumNonExitingCtas to determine - // the actual number of CTAs to run. + // For MoE, mNumTokens != 0 and the number of CGAs is known only at runtime. + // We launch maximally possible number of CGAs and use ptrNumNonExitingCtas to determine + // the actual number of CGAs to run. // Initialize number of tokens with the number of expanded tokens after routing. - int32_t numRemainingTokens = numTokens * topK; - int32_t maxNumCtasInBatchDim = 0; - // First, distribute one token each expert until token depletion to maximize CTA tile count. - int32_t numExpertsFilled = std::min(numExperts, numRemainingTokens); - maxNumCtasInBatchDim += numExpertsFilled; + auto numRemainingTokens = numTokens * topK; + int32_t maxNumCgasInBatchDim = 0; + // First, distribute one token each expert until token depletion to maximize CGA tile count. + auto numExpertsFilled = std::min(numExperts, numRemainingTokens); + maxNumCgasInBatchDim += numExpertsFilled; numRemainingTokens -= numExpertsFilled; - // Next, greedily pour all remaining tokens to one expert to maximize CTA tile count. + // Next, greedily pour all remaining tokens to one expert to maximize CGA tile count. // E.g., at this point tokens over 4 experts are [1, 1, 1, 1], and we have 4 tokens left. - // If each CTA handles 4 tokens/expert, the greedy strategy is to pour all remaining tokens - // to any one expert to get to the 5th CTA tile. Otherwise, we can only get 4 tiles in total. + // If each CGA handles 4 tokens/expert, the greedy strategy is to pour all remaining tokens + // to any one expert to get to the 5th CGA tile. Otherwise, we can only get 4 tiles in total. // // Another way to reason about this is to pour the remaining tokens into buckets of some fixed // capacity. These buckets, if full, can then be attributed to any expert; it does not have to // belong to the same expert every time. if (numRemainingTokens > 0) { - // For every tileTokenDim tokens, we add an extra CTA tile in the token dimension. - // The number of CTA tiles is given by divDown(numRemainingTokens, tokenTileDim). - maxNumCtasInBatchDim += (numRemainingTokens / tileTokensDim); + // For every tileTokenDim tokens, we add an extra CGA tile in the token dimension. + // The number of CGA tiles is given by divDown(numRemainingTokens, tokenTileDim). + maxNumCgasInBatchDim += (numRemainingTokens / cgaTileTokensDim); } - return maxNumCtasInBatchDim; + return maxNumCgasInBatchDim; +} + +// Backward-compatible alias — callers outside routing may still use the old name. +inline int32_t getMaxNumCtasInBatchDim(int32_t numTokens, int32_t topK, int32_t numExperts, int32_t tileTokensDim) +{ + return getMaxNumCgasInBatchDim(numTokens, topK, numExperts, tileTokensDim); } inline int32_t getMaxPermutedPaddedCount( int32_t numTokens, int32_t expertsPerToken, int32_t numExperts, int32_t padding) { - int32_t maxCtas = getMaxNumCtasInBatchDim(numTokens, expertsPerToken, numExperts, padding); - return maxCtas * padding; + int32_t maxCgas = getMaxNumCgasInBatchDim(numTokens, expertsPerToken, numExperts, padding); + return maxCgas * padding; } class Runner @@ -147,7 +159,7 @@ class Runner public: explicit Runner(); - explicit Runner(int32_t tileTokensDim); + explicit Runner(int32_t tileTokensDim, int32_t clusterSizeInBatchDim = 1); void run(void* routingLogits, void* routingBias, int32_t numTokens, int32_t numExperts, int32_t topK, int32_t nGroups, int32_t topkGroups, int32_t localExpertOffset, int32_t localNumExperts, @@ -156,7 +168,8 @@ class Runner int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* expertIds, int32_t* numTokensPerExpert, int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit, int32_t* numNonExitingCtas, batchedGemm::trtllm::gen::Dtype dtypeElt, bool useRoutingScalesOnInput, bool useDeepSeekFp8, - RoutingMethodType routingMethodType, cudaStream_t stream); + RoutingMethodType routingMethodType, cudaStream_t stream, + batchedGemm::trtllm::gen::Dtype dtypeRoutingLogits = batchedGemm::trtllm::gen::Dtype::Bfloat16); private: int32_t mTileTokensDim; diff --git a/cpp/tensorrt_llm/thop/cuteDslMoeUtilsOp.cpp b/cpp/tensorrt_llm/thop/cuteDslMoeUtilsOp.cpp index 5f17e2372b62..cb6765ac6a56 100644 --- a/cpp/tensorrt_llm/thop/cuteDslMoeUtilsOp.cpp +++ b/cpp/tensorrt_llm/thop/cuteDslMoeUtilsOp.cpp @@ -20,6 +20,8 @@ #include +namespace btg = batchedGemm::trtllm::gen; + TRTLLM_NAMESPACE_BEGIN namespace torch_ext @@ -74,6 +76,9 @@ std::vector moe_topk_sort_impl(torch::optional con tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::Runner routing_runner(tile_tokens_dim); auto const& stream = at::cuda::getCurrentCUDAStream( routing_logits.has_value() ? routing_logits->get_device() : token_selected_experts->get_device()); + auto const dtypeRoutingLogits = routing_logits.has_value() + ? (routing_logits->scalar_type() == at::ScalarType::Float ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16) + : btg::Dtype::Bfloat16; routing_runner.run(routing_logits_ptr, routing_bias_ptr, num_tokens, num_experts, top_k, n_group.value_or(0), topk_group.value_or(0), local_expert_offset, local_num_experts, routed_scaling_factor.value_or(1.0), expert_indexes.data_ptr(), expert_count_histogram.data_ptr(), total_num_padded_tokens.data_ptr(), @@ -82,7 +87,7 @@ std::vector moe_topk_sort_impl(torch::optional con num_tokens_per_expert.data_ptr(), tile_idx_to_expert_idx.data_ptr(), tile_idx_to_mn_limit.data_ptr(), num_non_exiting_tiles.data_ptr(), batchedGemm::trtllm::gen::Dtype::Void /* dtypeElt */, false /* use_routing_scales_on_input */, - false /* use_deep_seek_fp8 */, routing_method_type, stream); + false /* use_deep_seek_fp8 */, routing_method_type, stream, dtypeRoutingLogits); std::vector results{tile_idx_to_expert_idx, tile_idx_to_mn_limit, expanded_idx_to_permuted_idx, permuted_idx_to_expanded_idx, total_num_padded_tokens, num_non_exiting_tiles}; diff --git a/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp b/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp index 35486c3be937..e090d816d973 100644 --- a/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp +++ b/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp @@ -68,15 +68,9 @@ std::vector run_fp4_block_scale_moe_runner(torch::optional(routing_method_type) == RoutingMethodType::DeepSeekV3) - { - TORCH_CHECK(routing_logits.value().scalar_type() == at::ScalarType::Float, "routing_logits must be float"); - } - else - { - TORCH_CHECK( - routing_logits.value().scalar_type() == at::ScalarType::BFloat16, "routing_logits must be bfloat16"); - } + TORCH_CHECK(routing_logits.value().scalar_type() == at::ScalarType::BFloat16 + || routing_logits.value().scalar_type() == at::ScalarType::Float, + "routing_logits must be bfloat16 or float32"); TORCH_CHECK(routing_logits.value().dim() == 2, "routing_logits must be 2D."); TORCH_CHECK(routing_logits.value().sizes()[1] == num_experts, "routing_logits has incorrect shape."); } @@ -264,6 +258,9 @@ std::vector run_fp4_block_scale_moe_runner(torch::optional(), expert_count_histogram.data_ptr(), total_num_padded_tokens.data_ptr(), @@ -272,7 +269,7 @@ std::vector run_fp4_block_scale_moe_runner(torch::optional(), cta_idx_xy_to_batch_idx.data_ptr(), cta_idx_xy_to_mn_limit.data_ptr(), num_non_exiting_ctas.data_ptr(), args.mDtypeElt, false /* use_routing_scales_on_input */, false /* use_deep_seek_fp8 */, - static_cast(routing_method_type), stream); + static_cast(routing_method_type), stream, dtypeRoutingLogits); // // FC13 (gemm1) + FC2 (gemm2) diff --git a/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp b/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp index 3c13c695991c..a0b857b24f6a 100644 --- a/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp +++ b/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp @@ -63,15 +63,9 @@ at::Tensor run_fp8_block_scale_moe(at::optional const& routing_logit } else if (routing_logits.has_value()) { - if (static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3) - { - TORCH_CHECK(routing_logits.value().scalar_type() == at::ScalarType::Float, "routing_logits must be float"); - } - else - { - TORCH_CHECK( - routing_logits.value().scalar_type() == at::ScalarType::BFloat16, "routing_logits must be bfloat16"); - } + TORCH_CHECK(routing_logits.value().scalar_type() == at::ScalarType::BFloat16 + || routing_logits.value().scalar_type() == at::ScalarType::Float, + "routing_logits must be bfloat16 or float32"); TORCH_CHECK(routing_logits.value().dim() == 2, "routing_logits must be 2D."); TORCH_CHECK(routing_logits.value().sizes()[1] == num_experts, "routing_logits dim1 must match num_experts."); } @@ -232,6 +226,9 @@ at::Tensor run_fp8_block_scale_moe(at::optional const& routing_logit tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::Runner routing_runner(tile_tokens_dim); auto const& stream = at::cuda::getCurrentCUDAStream( routing_logits.has_value() ? routing_logits.value().get_device() : topk_ids.value().get_device()); + auto const dtypeRoutingLogits = routing_logits.has_value() + ? (routing_logits.value().scalar_type() == at::ScalarType::Float ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16) + : btg::Dtype::Bfloat16; routing_runner.run(args.routing_logits, args.routing_bias, args.num_tokens, args.num_experts, args.top_k, args.n_group, args.topk_group, args.local_expert_offset, args.local_num_experts, args.routed_scaling_factor, expert_indexes.data_ptr(), expert_count_histogram.data_ptr(), total_num_padded_tokens.data_ptr(), @@ -239,7 +236,7 @@ at::Tensor run_fp8_block_scale_moe(at::optional const& routing_logit permuted_idx_to_token_idx.data_ptr(), expert_weights_ptr, args.topk_ids, num_tokens_per_expert.data_ptr(), cta_idx_xy_to_batch_idx.data_ptr(), cta_idx_xy_to_mn_limit.data_ptr(), num_non_exiting_ctas.data_ptr(), args.mDtypeElt, false, true, - static_cast(routing_method_type), stream); + static_cast(routing_method_type), stream, dtypeRoutingLogits); // MoE kernel except routing TORCH_CHECK(hidden_states.scalar_type() == at::ScalarType::Float8_e4m3fn, "hidden_states must be fp8."); diff --git a/cpp/tensorrt_llm/thop/fp8PerTensorScaleMoe.cpp b/cpp/tensorrt_llm/thop/fp8PerTensorScaleMoe.cpp index 092f8f013620..183a91721690 100644 --- a/cpp/tensorrt_llm/thop/fp8PerTensorScaleMoe.cpp +++ b/cpp/tensorrt_llm/thop/fp8PerTensorScaleMoe.cpp @@ -57,24 +57,9 @@ torch::Tensor fp8_per_tensor_scale_moe_runner(torch::optional con } else if (routing_logits.has_value()) { - if (use_routing_scales_on_input) - { - TORCH_CHECK( - routing_logits.value().scalar_type() == at::ScalarType::BFloat16, "routing_logits must be bfloat16."); - } - else - { - if (static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3) - { - TORCH_CHECK( - routing_logits.value().scalar_type() == at::ScalarType::Float, "routing_logits must be float"); - } - else - { - TORCH_CHECK(routing_logits.value().scalar_type() == at::ScalarType::BFloat16, - "routing_logits must be bfloat16"); - } - } + TORCH_CHECK(routing_logits.value().scalar_type() == at::ScalarType::BFloat16 + || routing_logits.value().scalar_type() == at::ScalarType::Float, + "routing_logits must be bfloat16 or float32"); TORCH_CHECK(routing_logits.value().dim() == 2, "routing_logits must be 2D."); TORCH_CHECK(routing_logits.value().sizes()[1] == num_experts, "routing_logits has incorrect shape."); } @@ -230,6 +215,9 @@ torch::Tensor fp8_per_tensor_scale_moe_runner(torch::optional con tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::Runner routing_runner(tile_tokens_dim); auto const& stream = at::cuda::getCurrentCUDAStream( routing_logits.has_value() ? routing_logits.value().get_device() : topk_ids.value().get_device()); + auto const dtypeRoutingLogits = routing_logits.has_value() + ? (routing_logits.value().scalar_type() == at::ScalarType::Float ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16) + : btg::Dtype::Bfloat16; routing_runner.run(args.routing_logits, args.routing_bias, args.num_tokens, args.num_experts, args.top_k, args.n_group, args.topk_group, args.local_expert_offset, args.local_num_experts, args.routed_scaling_factor, expert_indexes.data_ptr(), expert_count_histogram.data_ptr(), total_num_padded_tokens.data_ptr(), @@ -238,7 +226,7 @@ torch::Tensor fp8_per_tensor_scale_moe_runner(torch::optional con num_tokens_per_expert.data_ptr(), cta_idx_xy_to_batch_idx.data_ptr(), cta_idx_xy_to_mn_limit.data_ptr(), num_non_exiting_ctas.data_ptr(), args.mDtypeElt, use_routing_scales_on_input, false /* use_deep_seek_fp8 */, static_cast(routing_method_type), - stream); + stream, dtypeRoutingLogits); // MoE kernel except routing TORCH_CHECK(hidden_states.scalar_type() == at::ScalarType::Float8_e4m3fn, "hidden_states must be fp8."); diff --git a/cpp/tensorrt_llm/thop/mxFp4BlockScaleMoe.cpp b/cpp/tensorrt_llm/thop/mxFp4BlockScaleMoe.cpp index 5e8331b77c3f..41788fc3a842 100644 --- a/cpp/tensorrt_llm/thop/mxFp4BlockScaleMoe.cpp +++ b/cpp/tensorrt_llm/thop/mxFp4BlockScaleMoe.cpp @@ -72,15 +72,9 @@ torch::Tensor dtype_mxe2m1_block_scale_moe_runner(torch::optional } else if (routing_logits.has_value()) { - if (static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3) - { - TORCH_CHECK(routing_logits.value().scalar_type() == at::ScalarType::Float, "routing_logits must be float"); - } - else - { - TORCH_CHECK( - routing_logits.value().scalar_type() == at::ScalarType::BFloat16, "routing_logits must be bfloat16"); - } + TORCH_CHECK(routing_logits.value().scalar_type() == at::ScalarType::BFloat16 + || routing_logits.value().scalar_type() == at::ScalarType::Float, + "routing_logits must be bfloat16 or float32"); TORCH_CHECK(routing_logits.value().dim() == 2, "routing_logits must be 2D."); TORCH_CHECK(routing_logits.value().sizes()[1] == num_experts, "routing_logits dim1 must match num_experts."); @@ -274,6 +268,9 @@ torch::Tensor dtype_mxe2m1_block_scale_moe_runner(torch::optional tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::Runner routing_runner(tile_tokens_dim); auto const& stream = at::cuda::getCurrentCUDAStream( routing_logits.has_value() ? routing_logits.value().get_device() : topk_ids.value().get_device()); + auto const dtypeRoutingLogits = routing_logits.has_value() + ? (routing_logits.value().scalar_type() == at::ScalarType::Float ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16) + : btg::Dtype::Bfloat16; routing_runner.run(args.routing_logits, args.routing_bias, args.num_tokens, args.num_experts, args.top_k, args.n_group, args.topk_group, args.local_expert_offset, args.local_num_experts, args.routed_scaling_factor, expert_indexes.data_ptr(), expert_count_histogram.data_ptr(), total_num_padded_tokens.data_ptr(), @@ -282,7 +279,7 @@ torch::Tensor dtype_mxe2m1_block_scale_moe_runner(torch::optional num_tokens_per_expert.data_ptr(), cta_idx_xy_to_batch_idx.data_ptr(), cta_idx_xy_to_mn_limit.data_ptr(), num_non_exiting_ctas.data_ptr(), args.mDtypeElt, false /* use_routing_scales_on_input */, false /* use_deep_seek_fp8 */, - static_cast(routing_method_type), stream); + static_cast(routing_method_type), stream, dtypeRoutingLogits); // // FC13 (gemm1) + FC2 (gemm2) diff --git a/cpp/tests/unit_tests/kernels/CMakeLists.txt b/cpp/tests/unit_tests/kernels/CMakeLists.txt index ab8280498e5e..95d33e421050 100644 --- a/cpp/tests/unit_tests/kernels/CMakeLists.txt +++ b/cpp/tests/unit_tests/kernels/CMakeLists.txt @@ -89,7 +89,7 @@ add_gtest(samplingKernelsTest "${SAMPLING_KERNEL_TEST_SRC}") set(ROUTING_KERNEL_TEST_SRC routing/routingTest.cpp routing/routingLlama4Test.cpp - routing/routingRenormalizeTest.cpp routing/routingDeepSeekTest.cpp) + routing/routingCustomTest.cpp routing/routingDeepSeekTest.cpp) add_gtest(routingKernelsTest "${ROUTING_KERNEL_TEST_SRC}") target_link_libraries(routingKernelsTest PRIVATE Python3::Python) diff --git a/cpp/tests/unit_tests/kernels/routing/routingCustomTest.cpp b/cpp/tests/unit_tests/kernels/routing/routingCustomTest.cpp new file mode 100644 index 000000000000..4aa7e0762d81 --- /dev/null +++ b/cpp/tests/unit_tests/kernels/routing/routingCustomTest.cpp @@ -0,0 +1,1549 @@ +/* + * Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tests/unit_tests/kernels/routing/routingTest.h" + +#include + +namespace tk = tensorrt_llm::kernels; +namespace btg = batchedGemm::trtllm::gen; +using namespace tensorrt_llm::runtime; +using namespace tensorrt_llm::tests::kernels::routing; + +namespace +{ + +template +class RoutingCustomKernelTest : public RoutingKernelTest +{ + +protected: + using RoutingKernelTest::mSeed; + using RoutingKernelTest::mStream; + using RoutingKernelTest::mBufferManager; + using typename RoutingKernelTest::PackedType; + +private: + // Routing bias buffers (used by SigmoidBias preprocess) + TensorPtr mPtrRoutingBiasHost; + TensorPtr mPtrRoutingBiasDevice; + + static float sigmoid_accurate(float x) + { + return 0.5f * std::tanh(0.5f * x) + 0.5f; + } + + // Reference implementation for all policy combinations: + // 1. Softmax + NoOp (Default: softmax before topK, raw scores) + // 2. NoOp + Softmax (Renormalize: topK first, then softmax) + // 3. Softmax + SumNormalize (RenormalizeNaive path) + // 4. SigmoidBias + ScaledSumNormalize (DeepSeek-style path) + // 5. Sigmoid + SumNormalize (SigmoidRenorm path) + // 6. NoOp + NoOp (raw topK, no transformation) + void computeTopKExperts(RoutingKernelTestParam const& param) override + { + for (int it = 0; it < param.numTokens; ++it) + { + std::vector expWeightsIdx(param.numExperts); + std::vector expIdx(param.topK); + + // Per-expert sigmoid scores — only populated for SigmoidBias preprocess. + std::vector sigmoidScores(param.numExperts, 0.f); + + // --- Read raw scores and apply preprocess --- + for (int ie = 0; ie < param.numExperts; ++ie) + { + float score = static_cast(bufferCast(*this->mPtrScoresHost)[it * param.numExperts + ie]); + + if (param.preprocessType == RoutingPreprocessType::Sigmoid) + { + float sig = sigmoid_accurate(score); + score = ie < param.numExperts ? sig : -std::numeric_limits::infinity(); + } + else if (param.preprocessType == RoutingPreprocessType::SigmoidBias) + { + float sig = sigmoid_accurate(score); + sigmoidScores[ie] = sig; + float bias = static_cast(bufferCast(*mPtrRoutingBiasHost)[ie]); + score = sig + bias; + } + + expWeightsIdx[ie] = PackedFloat{score, static_cast(ie)}; + } + + // Apply softmax preprocess (over all experts) when requested + if (param.preprocessType == RoutingPreprocessType::Softmax) + { + float maxScore = -std::numeric_limits::infinity(); + for (int ie = 0; ie < param.numExperts; ++ie) + { + maxScore = std::max(maxScore, expWeightsIdx[ie].score); + } + float sum = 0.f; + for (int ie = 0; ie < param.numExperts; ++ie) + { + expWeightsIdx[ie].score = std::exp(expWeightsIdx[ie].score - maxScore); + sum += expWeightsIdx[ie].score; + } + for (int ie = 0; ie < param.numExperts; ++ie) + { + expWeightsIdx[ie].score /= sum; + } + } + + // --- TopK selection --- + std::partial_sort_copy(expWeightsIdx.begin(), expWeightsIdx.end(), expIdx.begin(), expIdx.end(), comp); + + // --- Apply postprocess --- + if (param.postprocessType == RoutingPostprocessType::Softmax) + { + // Softmax over top-K scores + float maxScore = -std::numeric_limits::infinity(); + for (uint32_t ie = 0; ie < param.topK; ++ie) + { + maxScore = std::max(maxScore, expIdx[ie].score); + } + float sum = 0.f; + for (uint32_t ie = 0; ie < param.topK; ++ie) + { + sum += std::exp(expIdx[ie].score - maxScore); + } + for (uint32_t ie = 0; ie < param.topK; ++ie) + { + expIdx[ie].score = std::exp(expIdx[ie].score - maxScore) / sum; + } + } + else if (param.postprocessType == RoutingPostprocessType::SumNormalize) + { + // SumNormalize: divide top-K scores by their sum + if (param.normTopkProb) + { + float sum = 0.f; + for (uint32_t ie = 0; ie < param.topK; ++ie) + { + sum += expIdx[ie].score; + } + for (uint32_t ie = 0; ie < param.topK; ++ie) + { + expIdx[ie].score /= sum; + } + } + } + else if (param.postprocessType == RoutingPostprocessType::ScaledSumNormalize) + { + // Recover sigmoid scores, renormalize by their sum, and scale + float sumSigmoid = 0.f; + for (uint32_t ie = 0; ie < param.topK; ++ie) + { + sumSigmoid += sigmoidScores[expIdx[ie].idx]; + } + for (uint32_t ie = 0; ie < param.topK; ++ie) + { + expIdx[ie].score = sigmoidScores[expIdx[ie].idx] * param.routedScalingFactor / sumSigmoid; + } + } + // For NoOp postprocess: scores are left unchanged. + + // --- Store results --- + for (uint32_t ie = 0; ie < param.topK; ++ie) + { + // Set invalid topk indices for the first half of the topk + if (param.hasInvalidTopKInput && ie < param.topK / 2 + 1) + { + expIdx[ie].idx = static_cast(param.invalidExpertIdValue); + } + + PackedType si{static_cast(expIdx[ie].score), expIdx[ie].idx}; + reinterpret_cast(bufferCast(*this->mPtrTopKPackedHost))[it * param.topK + ie] = si; + if (param.useTopKAsInput) + { + bufferCast(*this->mPtrTopKIdsHost)[it * param.topK + ie] + = static_cast(expIdx[ie].idx); + bufferCast(*this->mPtrTopKWeightsHost)[it * param.topK + ie] = static_cast(expIdx[ie].score); + } + else if (param.getExpWeights) + { + bufferCast(*this->mPtrTopKWeightsHost)[it * param.topK + ie] = static_cast(expIdx[ie].score); + } + } + } + } + +protected: + void allocateBuffers(RoutingKernelTestParam const& param) override + { + RoutingKernelTest::allocateBuffers(param); + int64_t scoresSize = param.numTokens * param.numExperts; + this->mPtrScoresHost = mBufferManager->pinned(ITensor::makeShape({scoresSize}), TRTDataType::value); + this->mPtrScoresDevice = mBufferManager->gpu(ITensor::makeShape({scoresSize}), TRTDataType::value); + + // Allocate routing bias buffers when needed + if (param.preprocessType == RoutingPreprocessType::SigmoidBias) + { + mPtrRoutingBiasHost = mBufferManager->pinned(ITensor::makeShape({param.numExperts}), TRTDataType::value); + mPtrRoutingBiasDevice = mBufferManager->gpu(ITensor::makeShape({param.numExperts}), TRTDataType::value); + } + } + + void setupBuffers(RoutingKernelTestParam const& param) override + { + RoutingKernelTest::setupBuffers(param); + + // Initialize routing bias with small random values + if (param.preprocessType == RoutingPreprocessType::SigmoidBias) + { + T* biasPtr = bufferCast(*mPtrRoutingBiasHost); + initData(biasPtr, param.numExperts, mSeed + 7); + mBufferManager->copy(*mPtrRoutingBiasHost, *mPtrRoutingBiasDevice); + } + } + + template + void setParams(RoutingKernelTestParam const& param, RoutingData& routingData) + { + RoutingKernelTest::setCommonParams(param, routingData); + + if (sizeof(T) == 4) + { + routingData.mDtypeOutput = btg::Dtype::Fp32; + routingData.mDtypeInput = btg::Dtype::Fp32; + } + else + { + routingData.mDtypeOutput = btg::Dtype::Bfloat16; + routingData.mDtypeInput = btg::Dtype::Bfloat16; + } + + // Set policy types from test param (already derived by build()) + routingData.mPreprocessType = param.preprocessType; + routingData.mPostprocessType = param.postprocessType; + routingData.mNormTopkProb = param.normTopkProb; + + // Set routing bias and scale when using SigmoidBias preprocess + if (param.preprocessType == RoutingPreprocessType::SigmoidBias) + { + routingData.mPtrRoutingBias = bufferCast(*mPtrRoutingBiasDevice); + // Bias dtype matches T (the test's type parameter) + routingData.mDtypeBias = (sizeof(T) == 4) ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16; + routingData.mRouteScale = param.routedScalingFactor; + } + + if (param.useTopKAsInput) + { + routingData.mPtrTopKIds = bufferCast(*this->mPtrTopKIdsDevice); + routingData.mPtrScores = nullptr; + } + else if (param.useTopKPackedAsInput) + { + // mPtrTopKPacked is already set by setCommonParams; just clear scores and topKIds + routingData.mPtrTopKIds = nullptr; + routingData.mPtrScores = nullptr; + } + else + { + routingData.mPtrTopKIds = nullptr; + routingData.mPtrScores = bufferCast(*this->mPtrScoresDevice); + } + } + + void callTestedFunction( + RoutingKernelTestParam const& param, tensorrt_llm::runtime::ITensor::SharedPtr& workspaceDevice) override + { + moe::dev::routing::routingCustom::Data routingData; + setParams(param, routingData); + moe::dev::routing::routingCustom::run(routingData, mStream->get()); + } +}; + +TYPED_TEST_SUITE(RoutingCustomKernelTest, FloatAndBf16Types); + +TYPED_TEST(RoutingCustomKernelTest, BlockLevelParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(4) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, BlockLevelParallelizationWithExpertParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(4) + .withNumExperts(128) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(192) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, BlockLevelParallelizationWithInvalidTopKInput) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(4) + .withNumExperts(128) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(192) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .build(); + this->runTest(param); +}; + +// --- Tests for useTopKAsInput (mPtrTopKIds + mPtrTopKWeights as input) --- +// These test the runPostTopKPipeline path at block, cluster, and coop levels. + +TYPED_TEST(RoutingCustomKernelTest, BlockLevelTopKAsInput) +{ + // Small token count -> single-block path in runPostTopKPipeline + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(4) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ClusterLevelTopKAsInput) +{ + // Medium token count -> single-cluster path in runPostTopKPipeline + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(100) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, CoopLevelTopKAsInput) +{ + // Large token count -> coop path in runPostTopKPipeline + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ClusterLevelParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(10) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(192) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ClusterLevelParallelizationWithExpertParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(100) + .withNumExperts(128) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ClusterLevelParallelizationWithInvalidTopKInput) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(100) + .withNumExperts(128) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ClusterLevelParallelizationWithRenormalizeNaive) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::RenormalizeNaive) + .withNumTokens(10) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +// --- Tests for useTopKPackedAsInput (mPtrTopKPacked without mPtrScores) --- +// These test the runPostTopKPipeline path for the packed input format. + +TYPED_TEST(RoutingCustomKernelTest, BlockLevelTopKPackedAsInput) +{ + // Small token count -> single-block path in runPostTopKPipeline + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(4) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKPackedAsInput(true) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ClusterLevelTopKPackedAsInput) +{ + // Medium token count -> single-cluster path in runPostTopKPipeline + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(100) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKPackedAsInput(true) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DeviceLevelTopKPackedAsInput) +{ + // Large token count -> coop or multi-kernel path in runPostTopKPipeline + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(128) + .withTopK(4) + .withTileTokensDim(256) + .withUseTopKPackedAsInput(true) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DeviceLevelParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(512) + .withTopK(10) + .withTileTokensDim(8) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ClusterLevelParallelizationTop4) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(200) + .withNumExperts(128) + .withTopK(4) + .withTileTokensDim(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ClusterLevelParallelizationWithInvalidTopKInputTop4) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(200) + .withNumExperts(128) + .withTopK(4) + .withTileTokensDim(8) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ClusterLevelParallelizationWithExpertParallelizationTop4) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(256) + .withNumExperts(128) + .withTopK(4) + .withExpertParallelization(2, 1) + .withTileTokensDim(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ClusterLevelParallelizationWithRenormalizeNaiveTop4) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::RenormalizeNaive) + .withNumTokens(10) + .withNumExperts(128) + .withTopK(4) + .withTileTokensDim(8) + .build(); + this->runTest(param); +}; + +// --- Tests for Default (Softmax + NoOp postprocess) --- + +TYPED_TEST(RoutingCustomKernelTest, DefaultBlockLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Default) + .withNumTokens(4) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DefaultClusterLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Default) + .withNumTokens(10) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DefaultDeviceLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Default) + .withNumTokens(1000) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DefaultWithExpertParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Default) + .withNumTokens(10) + .withNumExperts(128) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +// --- Tests for RenormalizeNaive at block and device levels --- + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeNaiveBlockLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::RenormalizeNaive) + .withNumTokens(4) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeNaiveDeviceLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::RenormalizeNaive) + .withNumTokens(1000) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeNaiveWithExpertParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::RenormalizeNaive) + .withNumTokens(10) + .withNumExperts(128) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DeviceLevelParallelizationTop4) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(128) + .withTopK(4) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, BlockLevelParallelizationLargeN) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(4) + .withNumExperts(2048) + .withTopK(32) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ClusterLevelParallelizationLargeN) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(100) + .withNumExperts(2048) + .withTopK(32) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DeviceLevelParallelizationLargeN) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(2048) + .withTopK(32) + .withTileTokensDim(256) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DeviceLevelParallelizationLargeNWithInvalidTopKInput) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(2048) + .withTopK(32) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Tests for invalid expertId = numExperts (instead of -1). +// Some frameworks use expertId == numExperts to mark unassigned slots. +// The kernel must handle this without illegal memory access. +// These tests exercise the block, cluster, and device-level paths with topKIds input +// where some expert IDs are set to numExperts. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TYPED_TEST(RoutingCustomKernelTest, BlockLevelInvalidExpertIdEqualsNumExperts) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(4) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .withInvalidExpertIdValue(128) // numExperts as invalid marker + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ClusterLevelInvalidExpertIdEqualsNumExperts) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(100) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .withInvalidExpertIdValue(128) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DeviceLevelInvalidExpertIdEqualsNumExperts) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .withInvalidExpertIdValue(128) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +// Test with numExperts < MaxNumExperts (fall-through tier) — expertId=numExperts +// passes the `< MaxNumExperts` check but should still be treated as invalid. +TYPED_TEST(RoutingCustomKernelTest, BlockLevelInvalidExpertIdFallThroughTier) +{ + // numExperts=100 → dispatches to E128 tier (MaxNumExperts=128). + // expertId=100 passes `100 < 128` but is invalid (only 0..99 are valid). + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(4) + .withNumExperts(100) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .withInvalidExpertIdValue(100) // numExperts as invalid marker + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ClusterLevelInvalidExpertIdFallThroughTier) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(100) + .withNumExperts(100) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .withInvalidExpertIdValue(100) + .build(); + this->runTest(param); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Tests for Renormalize with new expert/topK tiers (E160, E576, K22) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// --- E576 experts, K22 topK (exercises the new E576 and K22 tiers) --- + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeBlockLevelE576K22) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(4) + .withNumExperts(576) + .withTopK(22) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeClusterLevelE576K22) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(100) + .withNumExperts(576) + .withTopK(22) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeDeviceLevelE576K22) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(576) + .withTopK(22) + .withTileTokensDim(256) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeDeviceLevelE576K22TopKAsInput) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(576) + .withTopK(22) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +// --- E160 experts, K8 topK (exercises the new E160 tier) --- + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeBlockLevelE160) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(4) + .withNumExperts(160) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeClusterLevelE160) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(100) + .withNumExperts(160) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeDeviceLevelE160) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(160) + .withTopK(8) + .withTileTokensDim(256) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeClusterLevelE160WithExpertParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(100) + .withNumExperts(160) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Tests for 384-expert tier (Renormalize + SigmoidBias policies). +// 384 is in getMaxNumExperts() tiers but was previously missing from some PolicyTraits, +// causing thread-count mismatch bugs. These tests cover block, cluster, and device paths. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// --- Renormalize with E384 (exercises Tier<384,8> in None+Softmax policy) --- + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeBlockLevelE384) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1) + .withNumExperts(384) + .withTopK(8) + .withTileTokensDim(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeClusterLevelE384) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(100) + .withNumExperts(384) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeDeviceLevelE384) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(384) + .withTopK(8) + .withTileTokensDim(256) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeBlockLevelE384WithEP) +{ + // Mirrors the failing multi-GPU test: e384, topK=8, seq=1, EP=4 + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1) + .withNumExperts(384) + .withTopK(8) + .withExpertParallelization(4, 1) + .withTileTokensDim(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeBlockLevelE384TopKAsInput) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1) + .withNumExperts(384) + .withTopK(8) + .withTileTokensDim(8) + .withUseTopKAsInput(true) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeBlockLevelE384TopKAsInputWithEP) +{ + // Mirrors the failing multi-GPU test with pre-computed topK: e384, topK=8, seq=1, EP=4 + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1) + .withNumExperts(384) + .withTopK(8) + .withExpertParallelization(4, 1) + .withTileTokensDim(8) + .withUseTopKAsInput(true) + .build(); + this->runTest(param); +}; + +// --- SigmoidBias with E384 (exercises Tier<384,8> in SigmoidBias+ScaledSumNormalize policy) --- + +TYPED_TEST(RoutingCustomKernelTest, SigmoidBiasBlockLevelE384) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withPreprocessType(RoutingPreprocessType::SigmoidBias) + .withPostprocessType(RoutingPostprocessType::ScaledSumNormalize) + .withRoutedScalingFactor(2.5f) + .withNumTokens(4) + .withNumExperts(384) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, SigmoidBiasClusterLevelE384) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withPreprocessType(RoutingPreprocessType::SigmoidBias) + .withPostprocessType(RoutingPostprocessType::ScaledSumNormalize) + .withRoutedScalingFactor(2.5f) + .withNumTokens(100) + .withNumExperts(384) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Tests for scores input + cooperative kernel path (scores→topK kernel + coop histogram+offsets). +// These verify the coop fast-path when input is raw mPtrScores (not pre-computed topK). +// Triggered when numTokens > cluster capacity (256) and within coop capacity. +// Requires SM90+ (coop kernel uses grid-sync). +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TYPED_TEST(RoutingCustomKernelTest, ScoresCoopRenormalize) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ScoresCoopRenormalizeE160) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(160) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ScoresCoopRenormalizeE256K4) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(256) + .withTopK(4) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ScoresCoopRenormalizeE576K22) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(500) + .withNumExperts(576) + .withTopK(22) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ScoresCoopRenormalizeWithExpertParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(128) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ScoresCoopDefault) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Default) + .withNumTokens(1000) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ScoresCoopRenormalizeNaive) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::RenormalizeNaive) + .withNumTokens(1000) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ScoresCoopSigmoidBias) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withPreprocessType(RoutingPreprocessType::SigmoidBias) + .withPostprocessType(RoutingPostprocessType::ScaledSumNormalize) + .withRoutedScalingFactor(2.5f) + .withNumTokens(1000) + .withNumExperts(512) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Tests for SigmoidBias + ScaledSumNormalize (DeepSeek-style routing via routingCustom) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SigmoidBias PolicyTraits: only E512 × K8. + +TYPED_TEST(RoutingCustomKernelTest, SigmoidBiasBlockLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withPreprocessType(RoutingPreprocessType::SigmoidBias) + .withPostprocessType(RoutingPostprocessType::ScaledSumNormalize) + .withRoutedScalingFactor(2.5f) + .withNumTokens(4) + .withNumExperts(512) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, SigmoidBiasClusterLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withPreprocessType(RoutingPreprocessType::SigmoidBias) + .withPostprocessType(RoutingPostprocessType::ScaledSumNormalize) + .withRoutedScalingFactor(2.5f) + .withNumTokens(100) + .withNumExperts(512) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, SigmoidBiasDeviceLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withPreprocessType(RoutingPreprocessType::SigmoidBias) + .withPostprocessType(RoutingPostprocessType::ScaledSumNormalize) + .withRoutedScalingFactor(2.5f) + .withNumTokens(1000) + .withNumExperts(512) + .withTopK(8) + .withTileTokensDim(8) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, SigmoidBiasWithExpertParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withPreprocessType(RoutingPreprocessType::SigmoidBias) + .withPostprocessType(RoutingPostprocessType::ScaledSumNormalize) + .withRoutedScalingFactor(2.5f) + .withNumTokens(10) + .withNumExperts(512) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Tests for MiniMax2 (SigmoidBias + ScaledSumNormalize with routeScale=1.0) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MiniMax2 PolicyTraits: same as SigmoidBias, only E512 × K8. + +TYPED_TEST(RoutingCustomKernelTest, MiniMax2BlockLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::MiniMax2) + .withNumTokens(4) + .withNumExperts(512) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, MiniMax2ClusterLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::MiniMax2) + .withNumTokens(100) + .withNumExperts(512) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, MiniMax2DeviceLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::MiniMax2) + .withNumTokens(1000) + .withNumExperts(512) + .withTopK(8) + .withTileTokensDim(8) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, MiniMax2WithExpertParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::MiniMax2) + .withNumTokens(10) + .withNumExperts(512) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Tests for mixed input/bias dtypes (SigmoidBias with float32 scores + bfloat16 bias, and vice versa). +// These test the loadScalar + mDtypeBias dispatch for cross-dtype bias reading. +// The test allocates bias in the "opposite" dtype from T. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TYPED_TEST(RoutingCustomKernelTest, MiniMax2MixedBiasDtype) +{ + using OtherT = std::conditional_t, __nv_bfloat16, float>; + + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::MiniMax2) + .withNumTokens(100) + .withNumExperts(512) + .withTopK(8) + .withTileTokensDim(256) + .build(); + + // Allocate and setup normal buffers + this->allocateBuffers(param); + this->setupBuffers(param); + + // Allocate bias in the "opposite" dtype from T + auto otherBiasHost + = this->mBufferManager->pinned(ITensor::makeShape({param.numExperts}), TRTDataType::value); + auto otherBiasDevice + = this->mBufferManager->gpu(ITensor::makeShape({param.numExperts}), TRTDataType::value); + auto biasPtr = bufferCast(*otherBiasHost); + for (int i = 0; i < param.numExperts; i++) + { + biasPtr[i] = static_cast(0.01f * (i % 100)); + } + this->mBufferManager->copy(*otherBiasHost, *otherBiasDevice); + this->mStream->synchronize(); + + // Setup routing data with mixed dtypes + moe::dev::routing::routingCustom::Data routingData; + this->setCommonParams(param, routingData); + routingData.mDtypeOutput = (sizeof(TypeParam) == 4) ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16; + routingData.mDtypeInput = routingData.mDtypeOutput; + routingData.mPreprocessType = param.preprocessType; + routingData.mPostprocessType = param.postprocessType; + routingData.mNormTopkProb = param.normTopkProb; + routingData.mPtrScores = bufferCast(*this->mPtrScoresDevice); + routingData.mPtrRoutingBias = bufferCast(*otherBiasDevice); + // Bias dtype is intentionally different from scores dtype (T) to test mixed-precision support. + // e.g. T=float → OtherT=bfloat16, T=bfloat16 → OtherT=float. + routingData.mDtypeBias = (sizeof(OtherT) == 4) ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16; + routingData.mRouteScale = param.routedScalingFactor; + + // Run kernel — verifies it doesn't crash with mixed bias dtype + moe::dev::routing::routingCustom::run(routingData, this->mStream->get()); + this->mStream->synchronize(); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Tests for DeepSeek nGroup=1 via routingCustom (SigmoidBias + ScaledSumNormalize with routeScale != 1.0) +// When nGroup <= 1, DeepSeek routing is equivalent to SigmoidBias + ScaledSumNormalize, +// and production code routes through routingCustom (not routingDeepSeek). +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// DeepSeek nGroup=1: uses SigmoidBias policy (E512 × K8). + +TYPED_TEST(RoutingCustomKernelTest, DeepSeekNoGroupBlockLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::MiniMax2) + .withNumTokens(4) + .withNumExperts(512) + .withTopK(8) + .withTileTokensDim(256) + .withRoutedScalingFactor(2.5f) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DeepSeekNoGroupClusterLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::MiniMax2) + .withNumTokens(100) + .withNumExperts(512) + .withTopK(8) + .withTileTokensDim(256) + .withRoutedScalingFactor(2.5f) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DeepSeekNoGroupDeviceLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::MiniMax2) + .withNumTokens(1000) + .withNumExperts(512) + .withTopK(8) + .withTileTokensDim(256) + .withRoutedScalingFactor(2.5f) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Tests for SigmoidRenorm (Sigmoid + SumNormalize) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TYPED_TEST(RoutingCustomKernelTest, SigmoidRenormBlockLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::SigmoidRenorm) + .withNumTokens(4) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, SigmoidRenormClusterLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::SigmoidRenorm) + .withNumTokens(100) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +// SigmoidRenorm PolicyTraits: only E128 × K8. + +TYPED_TEST(RoutingCustomKernelTest, SigmoidRenormDeviceLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::SigmoidRenorm) + .withNumTokens(1000) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(8) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, SigmoidRenormWithExpertParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::SigmoidRenorm) + .withNumTokens(10) + .withNumExperts(128) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Tests for NoOp + NoOp (raw topK, no score transformation) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TYPED_TEST(RoutingCustomKernelTest, NoOpBlockLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withPreprocessType(RoutingPreprocessType::None) + .withPostprocessType(RoutingPostprocessType::None) + .withNumTokens(4) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, NoOpClusterLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withPreprocessType(RoutingPreprocessType::None) + .withPostprocessType(RoutingPostprocessType::None) + .withNumTokens(100) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +// NoOp PolicyTraits: only E128 × K8. + +TYPED_TEST(RoutingCustomKernelTest, NoOpDeviceLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withPreprocessType(RoutingPreprocessType::None) + .withPostprocessType(RoutingPostprocessType::None) + .withNumTokens(1000) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(8) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, NoOpWithExpertParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withPreprocessType(RoutingPreprocessType::None) + .withPostprocessType(RoutingPostprocessType::None) + .withNumTokens(10) + .withNumExperts(128) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic-block kernel tests (5-16 tokens, ≤512 experts) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TYPED_TEST(RoutingCustomKernelTest, DynBlockBasic) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(8) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .withUsePdl(true) + .withGetExpWeights(true) + .withRequiredComputeCapability(9) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DynBlockMaxTokens) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(16) + .withNumExperts(512) + .withTopK(8) + .withTileTokensDim(256) + .withUsePdl(true) + .withGetExpWeights(true) + .withRequiredComputeCapability(9) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DynBlockWithExpertParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(12) + .withNumExperts(512) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(192) + .withUsePdl(true) + .withGetExpWeights(true) + .withRequiredComputeCapability(9) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DynBlockWithTopKAsInput) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(8) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .withUsePdl(true) + .withGetExpWeights(true) + .withUseTopKAsInput(true) + .withRequiredComputeCapability(9) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DynBlockWithInvalidTopKInput) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(10) + .withNumExperts(512) + .withTopK(8) + .withExpertParallelization(2, 0) + .withTileTokensDim(256) + .withUsePdl(true) + .withGetExpWeights(true) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .withRequiredComputeCapability(9) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DynBlockWithRenormalizeNaive) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::RenormalizeNaive) + .withNumTokens(16) + .withNumExperts(512) + .withTopK(8) + .withTileTokensDim(256) + .withUsePdl(true) + .withGetExpWeights(true) + .withRequiredComputeCapability(9) + .build(); + this->runTest(param); +}; + +} // namespace diff --git a/cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp b/cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp index 0467c1749654..e29ed24adbfc 100644 --- a/cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp +++ b/cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp @@ -151,6 +151,7 @@ class RoutingDeepSeekKernelTest : public RoutingKernelTest } } +protected: void allocateBuffers(RoutingKernelTestParam const& param) { RoutingKernelTest::allocateBuffers(param); @@ -179,9 +180,11 @@ class RoutingDeepSeekKernelTest : public RoutingKernelTest void setParams(RoutingKernelTestParam const& param, RoutingData& routingData) { RoutingKernelTest::setCommonParams(param, routingData); - routingData.mDtypeExpW = btg::Dtype::Bfloat16; + routingData.mDtypeOutput = btg::Dtype::Bfloat16; routingData.mPtrRoutingBias = bufferCast(*this->mPtrRoutingBiasDevice); + // Bias dtype matches T (the test's type parameter) + routingData.mDtypeBias = (sizeof(T) == 4) ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16; routingData.mNumExpertGroups = param.nGroup; routingData.mNumLimitedGroups = param.topkGroup; @@ -193,6 +196,12 @@ class RoutingDeepSeekKernelTest : public RoutingKernelTest routingData.mPtrTopKIds = bufferCast(*this->mPtrTopKIdsDevice); routingData.mPtrScores = nullptr; } + else if (param.useTopKPackedAsInput) + { + // mPtrTopKPacked is already set by setCommonParams; just clear scores and topKIds + routingData.mPtrTopKIds = nullptr; + routingData.mPtrScores = nullptr; + } else { routingData.mPtrTopKIds = nullptr; @@ -213,199 +222,377 @@ TYPED_TEST_SUITE(RoutingDeepSeekKernelTest, Bf16Types); TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization32) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/4, // 1024 - /*numExperts=*/32, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(4) + .withNumExperts(32) + .withTopK(8) + .withTileTokensDim(256) + .withNGroup(8) + .withTopkGroup(4) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization72) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/4, // 1024 - /*numExperts=*/72, /*topK=*/6, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(4) + .withNumExperts(72) + .withTopK(6) + .withTileTokensDim(256) + .withNGroup(1) + .withTopkGroup(1) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization384) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/4, // 1024 - /*numExperts=*/384, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(4) + .withNumExperts(384) + .withTopK(8) + .withTileTokensDim(256) + .withNGroup(1) + .withTopkGroup(1) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization512) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/4, // 1024 - /*numExperts=*/512, /*topK=*/22, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(4) + .withNumExperts(512) + .withTopK(22) + .withTileTokensDim(256) + .withNGroup(1) + .withTopkGroup(1) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1024, // 10 - /*numExperts=*/256, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(1024) + .withNumExperts(256) + .withTopK(8) + .withTileTokensDim(256) + .withNGroup(8) + .withTopkGroup(4) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingDeepSeekKernelTest, BlockLevelTopKAsInput) +{ + // Small token count -> single-block path in runPostTopKPipeline + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(4) + .withNumExperts(256) + .withTopK(8) + .withTileTokensDim(192) + .withUseTopKAsInput(true) + .withNGroup(8) + .withTopkGroup(4) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithTopKAsInput) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1024, // 10 - /*numExperts=*/256, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/192, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true, - /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(1024) + .withNumExperts(256) + .withTopK(8) + .withTileTokensDim(192) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .withNGroup(8) + .withTopkGroup(4) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithTopKAsInput384) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1024, // 10 - /*numExperts=*/384, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(1024) + .withNumExperts(384) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .withNGroup(1) + .withTopkGroup(1) + .build(); + this->runTest(param); +}; + +// --- Tests for useTopKPackedAsInput (mPtrTopKPacked without mPtrScores) --- +// These test the runPostTopKPipeline path for the packed input format. + +TYPED_TEST(RoutingDeepSeekKernelTest, BlockLevelTopKPackedAsInput) +{ + // Small token count -> single-block path in runPostTopKPipeline + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(4) + .withNumExperts(256) + .withTopK(8) + .withTileTokensDim(192) + .withUseTopKPackedAsInput(true) + .withNGroup(8) + .withTopkGroup(4) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelTopKPackedAsInput) +{ + // Medium token count -> single-cluster path in runPostTopKPipeline + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(100) + .withNumExperts(256) + .withTopK(8) + .withTileTokensDim(192) + .withUseTopKPackedAsInput(true) + .withNGroup(8) + .withTopkGroup(4) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelTopKPackedAsInput) +{ + // Large token count -> coop or multi-kernel path in runPostTopKPipeline + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(2048) + .withNumExperts(256) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKPackedAsInput(true) + .withNGroup(8) + .withTopkGroup(4) + .withRequiredComputeCapability(8) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithExpertParallelization) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/100, - /*numExperts=*/256, /*topK=*/8, - /*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/192, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(100) + .withNumExperts(256) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(192) + .withNGroup(8) + .withTopkGroup(4) + .build(); this->runTest(param); }; +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Test DeepSeek main kernel with float32 bias (T=bf16 for scores output, but bias is float32). +// This exercises the loadScalar path with mismatched bias dtype. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelWithFloat32Bias) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(100) + .withNumExperts(256) + .withTopK(8) + .withTileTokensDim(192) + .withNGroup(8) + .withTopkGroup(4) + .build(); + + this->allocateBuffers(param); + + // Override: allocate bias as float32 instead of T (bf16) + auto float32BiasHost + = this->mBufferManager->pinned(ITensor::makeShape({param.numExperts}), nvinfer1::DataType::kFLOAT); + auto float32BiasDevice + = this->mBufferManager->gpu(ITensor::makeShape({param.numExperts}), nvinfer1::DataType::kFLOAT); + auto biasPtr = bufferCast(*float32BiasHost); + for (int i = 0; i < param.numExperts; i++) + { + biasPtr[i] = 0.01f * (i % 100); + } + this->mBufferManager->copy(*float32BiasHost, *float32BiasDevice); + + // Setup normal buffers (scores, etc.) + float* scoresHostPtr = bufferCast(*this->mPtrScoresHost); + initData(scoresHostPtr, param.numTokens * param.numExperts, 42); + this->mBufferManager->copy(*this->mPtrScoresHost, *this->mPtrScoresDevice); + this->mStream->synchronize(); + + // Setup routing data with float32 bias + moe::dev::routing::routingDeepSeek::Data routingData; + this->setCommonParams(param, routingData); + routingData.mDtypeOutput = btg::Dtype::Bfloat16; + routingData.mPtrScores = bufferCast(*this->mPtrScoresDevice); + routingData.mPtrRoutingBias = bufferCast(*float32BiasDevice); + routingData.mDtypeBias = btg::Dtype::Fp32; // Float32 bias with bf16 output + routingData.mNumExpertGroups = param.nGroup; + routingData.mNumLimitedGroups = param.topkGroup; + routingData.mRouteScale = param.routedScalingFactor; + routingData.mUseRoutingSoftmax = false; + + // Run kernel — verifies it doesn't crash with float32 bias + moe::dev::routing::routingDeepSeek::run(routingData, this->mStream->get()); + this->mStream->synchronize(); +}; + TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelization) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1030, - /*numExperts=*/256, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(1030) + .withNumExperts(256) + .withTopK(8) + .withTileTokensDim(256) + .withNGroup(8) + .withTopkGroup(4) + .withRequiredComputeCapability(10) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelization384) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1030, - /*numExperts=*/384, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(1030) + .withNumExperts(384) + .withTopK(8) + .withTileTokensDim(256) + .withNGroup(1) + .withTopkGroup(1) + .withRequiredComputeCapability(10) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelization512) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1030, - /*numExperts=*/512, /*topK=*/22, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(1030) + .withNumExperts(512) + .withTopK(22) + .withTileTokensDim(256) + .withNGroup(1) + .withTopkGroup(1) + .withRequiredComputeCapability(10) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/20300, - /*numExperts=*/256, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true, - /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(20300) + .withNumExperts(256) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .withNGroup(8) + .withTopkGroup(4) + .withRequiredComputeCapability(10) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization384) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/20300, - /*numExperts=*/384, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10); - this->runTest(param); -}; - -TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization512) -{ - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/20300, - /*numExperts=*/512, /*topK=*/22, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(20300) + .withNumExperts(384) + .withTopK(8) + .withTileTokensDim(256) + .withNGroup(1) + .withTopkGroup(1) + .withRequiredComputeCapability(10) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationTop2) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/10, - /*numExperts=*/256, /*topK=*/2, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(10) + .withNumExperts(256) + .withTopK(2) + .withTileTokensDim(256) + .withNGroup(8) + .withTopkGroup(4) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithExpertParallelizationTop2) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/100, - /*numExperts=*/256, /*topK=*/2, - /*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/192, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(100) + .withNumExperts(256) + .withTopK(2) + .withExpertParallelization(2, 1) + .withTileTokensDim(192) + .withNGroup(8) + .withTopkGroup(4) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelizationTop2) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1030, - /*numExperts=*/256, /*topK=*/2, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(1030) + .withNumExperts(256) + .withTopK(2) + .withTileTokensDim(256) + .withNGroup(8) + .withTopkGroup(4) + .withRequiredComputeCapability(10) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelizationTop8) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1030, - /*numExperts=*/32, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true, - /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(1030) + .withNumExperts(32) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .withNGroup(8) + .withTopkGroup(4) + .withRequiredComputeCapability(10) + .build(); this->runTest(param); }; } // namespace diff --git a/cpp/tests/unit_tests/kernels/routing/routingLlama4Test.cpp b/cpp/tests/unit_tests/kernels/routing/routingLlama4Test.cpp index 6c4b5032c665..f889a7b79db6 100644 --- a/cpp/tests/unit_tests/kernels/routing/routingLlama4Test.cpp +++ b/cpp/tests/unit_tests/kernels/routing/routingLlama4Test.cpp @@ -63,28 +63,25 @@ class RoutingLlama4KernelTest : public RoutingKernelTest (a.score > b.score) || (a.score == b.score && a.idx < b.idx)); //@TODO: check if this is correct }); - // Apply sigmoid to the top-k scores + // Apply sigmoid to top-K scores, then store results. + // mPtrTopKPacked stores SIGMOID scores (matching what the scores-path kernels produce). + // The cluster/device kernels pass these through as-is to mPtrTopKWeights. for (int ie = 0; ie < param.topK; ++ie) { auto finalScore = 1.F / (1.F + std::exp(-expIdx[ie].score)); - expIdx[ie].score = static_cast(finalScore); - } - // convert back to io_dtype and store the topk expert results in hostData.mPtrTopKPacked - for (int ie = 0; ie < param.topK; ++ie) - { - PackedType si{static_cast(expIdx[ie].score), expIdx[ie].idx}; + PackedType si{static_cast(finalScore), expIdx[ie].idx}; reinterpret_cast(bufferCast(*this->mPtrTopKPackedHost))[it * param.topK + ie] = si; if (param.useTopKAsInput) { bufferCast(*this->mPtrTopKIdsHost)[it * param.topK + ie] = static_cast(expIdx[ie].idx); - bufferCast(*this->mPtrTopKWeightsHost)[it * param.topK + ie] = static_cast(expIdx[ie].score); + bufferCast(*this->mPtrTopKWeightsHost)[it * param.topK + ie] = static_cast(finalScore); } else if (param.getExpWeights) { - bufferCast(*this->mPtrTopKWeightsHost)[it * param.topK + ie] = static_cast(expIdx[ie].score); + bufferCast(*this->mPtrTopKWeightsHost)[it * param.topK + ie] = static_cast(finalScore); } } } @@ -102,7 +99,7 @@ class RoutingLlama4KernelTest : public RoutingKernelTest void setParams(RoutingKernelTestParam const& param, RoutingData& routingData) { RoutingKernelTest::setCommonParams(param, routingData); - routingData.mDtypeExpW = btg::Dtype::Bfloat16; + routingData.mDtypeOutput = btg::Dtype::Bfloat16; routingData.mPtrTopKPacked = reinterpret_cast(bufferCast(*this->mPtrTopKPackedDevice)); if (param.useTopKAsInput) @@ -110,6 +107,12 @@ class RoutingLlama4KernelTest : public RoutingKernelTest routingData.mPtrTopKIds = bufferCast(*this->mPtrTopKIdsDevice); routingData.mPtrScores = nullptr; } + else if (param.useTopKPackedAsInput) + { + // mPtrTopKPacked is already set above; just clear scores and topKIds + routingData.mPtrTopKIds = nullptr; + routingData.mPtrScores = nullptr; + } else { routingData.mPtrTopKIds = nullptr; @@ -130,69 +133,128 @@ TYPED_TEST_SUITE(RoutingLlama4KernelTest, Bf16Types); TYPED_TEST(RoutingLlama4KernelTest, WarpLevelParallelization) { - RoutingKernelTestParam param(RoutingMethodType::Llama4, /*numTokens=*/3, - /*numExperts=*/128, /*topK=*/1, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 0.0f, - /*requiredComputeCapability*/ 8); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Llama4) + .withNumTokens(3) + .withNumExperts(128) + .withTopK(1) + .withTileTokensDim(8) + .withRequiredComputeCapability(8) + .build(); this->runTest(param); }; TYPED_TEST(RoutingLlama4KernelTest, ClusterLevelParallelization) { - RoutingKernelTestParam param(RoutingMethodType::Llama4, /*numTokens=*/10, - /*numExperts=*/128, /*topK=*/1, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Llama4) + .withNumTokens(10) + .withNumExperts(128) + .withTopK(1) + .withTileTokensDim(8) + .build(); this->runTest(param); }; TYPED_TEST(RoutingLlama4KernelTest, DeviceLevelParallelization) { - RoutingKernelTestParam param(RoutingMethodType::Llama4, /*numTokens=*/300, - /*numExperts=*/128, /*topK=*/1, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Llama4) + .withNumTokens(300) + .withNumExperts(128) + .withTopK(1) + .withTileTokensDim(8) + .withRequiredComputeCapability(8) + .build(); this->runTest(param); }; TYPED_TEST(RoutingLlama4KernelTest, WarpLevelParallelizationTopKAsInput) { - RoutingKernelTestParam param(RoutingMethodType::Llama4, /*numTokens=*/3, - /*numExperts=*/128, /*topK=*/1, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 0.0f, - /*requiredComputeCapability*/ 8); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Llama4) + .withNumTokens(3) + .withNumExperts(128) + .withTopK(1) + .withTileTokensDim(8) + .withUseTopKAsInput(true) + .withRequiredComputeCapability(8) + .build(); this->runTest(param); }; TYPED_TEST(RoutingLlama4KernelTest, ClusterLevelParallelizationTopKAsInput) { - RoutingKernelTestParam param(RoutingMethodType::Llama4, /*numTokens=*/10, - /*numExperts=*/128, /*topK=*/1, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Llama4) + .withNumTokens(10) + .withNumExperts(128) + .withTopK(1) + .withTileTokensDim(8) + .withUseTopKAsInput(true) + .build(); this->runTest(param); }; TYPED_TEST(RoutingLlama4KernelTest, DeviceLevelParallelizationTopKAsInput) { - RoutingKernelTestParam param(RoutingMethodType::Llama4, /*numTokens=*/300, - /*numExperts=*/128, /*topK=*/1, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Llama4) + .withNumTokens(300) + .withNumExperts(128) + .withTopK(1) + .withTileTokensDim(8) + .withUseTopKAsInput(true) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +// --- Tests for useTopKPackedAsInput (mPtrTopKPacked without mPtrScores) --- +// For Llama4, the kernels apply sigmoid_accurate to packed scores, +// so the packed input path goes through Llama4-specific kernels (not runPostTopKPipeline). + +TYPED_TEST(RoutingLlama4KernelTest, WarpLevelTopKPackedAsInput) +{ + // Small token count -> warp-level kernel + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Llama4) + .withNumTokens(3) + .withNumExperts(128) + .withTopK(1) + .withTileTokensDim(8) + .withUseTopKPackedAsInput(true) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingLlama4KernelTest, ClusterLevelTopKPackedAsInput) +{ + // Medium token count -> cluster-level kernel + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Llama4) + .withNumTokens(10) + .withNumExperts(128) + .withTopK(1) + .withTileTokensDim(8) + .withUseTopKPackedAsInput(true) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingLlama4KernelTest, DeviceLevelTopKPackedAsInput) +{ + // Large token count -> multi-kernel pipeline + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Llama4) + .withNumTokens(300) + .withNumExperts(128) + .withTopK(1) + .withTileTokensDim(8) + .withUseTopKPackedAsInput(true) + .withRequiredComputeCapability(8) + .build(); this->runTest(param); }; diff --git a/cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp b/cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp deleted file mode 100644 index a6fce8ce49c6..000000000000 --- a/cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp +++ /dev/null @@ -1,453 +0,0 @@ -/* - * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tests/unit_tests/kernels/routing/routingTest.h" - -namespace tk = tensorrt_llm::kernels; -namespace btg = batchedGemm::trtllm::gen; -using namespace tensorrt_llm::runtime; -using namespace tensorrt_llm::tests::kernels::routing; - -namespace -{ - -template -class RoutingRenormalizeKernelTest : public RoutingKernelTest -{ - -protected: - using RoutingKernelTest::mSeed; - using RoutingKernelTest::mStream; - using RoutingKernelTest::mBufferManager; - using typename RoutingKernelTest::PackedType; - -private: - // private methods - void computeTopKExperts(RoutingKernelTestParam const& param) override - { - for (int it = 0; it < param.numTokens; ++it) - { - PackedFloat expWeightsIdx[param.numExperts]; - PackedFloat expIdx[param.topK]; - float sum = float{0.0f}; - float maxScore = -std::numeric_limits::infinity(); - for (int ie = 0; ie < param.numExperts; ++ie) - { - float score; - int16_t newIdx = static_cast(ie); - score = static_cast(bufferCast(*this->mPtrScoresHost)[it * param.numExperts + ie]); - - if (param.doSoftmaxBeforeTopK && score > maxScore) - { - maxScore = score; - } - - PackedFloat si{static_cast(score), newIdx}; - expWeightsIdx[ie] = si; - } - - if (param.doSoftmaxBeforeTopK) - { - // Run softmax before topk - for (int ie = 0; ie < param.numExperts; ++ie) - { - expWeightsIdx[ie].score - = static_cast(std::exp(static_cast(expWeightsIdx[ie].score) - maxScore)); - sum += expWeightsIdx[ie].score; - } - - for (int ie = 0; ie < param.numExperts; ++ie) - { - float score = static_cast(expWeightsIdx[ie].score); - score /= sum; - expWeightsIdx[ie].score = static_cast(score); - } - } - - // Calculate the top-k scores and indices - std::partial_sort_copy(expWeightsIdx, expWeightsIdx + param.numExperts, expIdx, expIdx + param.topK, comp); - - if (param.doSoftmaxBeforeTopK) - { - // Normalize the value after the topk - if (param.normTopkProb) - { - float sum = float{0.0f}; - for (int ie = 0; ie < param.topK; ++ie) - { - sum += static_cast(expIdx[ie].score); - } - for (int ie = 0; ie < param.topK; ++ie) - { - float score = static_cast(expIdx[ie].score); - score /= sum; - expIdx[ie].score = static_cast(score); - } - } - } - else - { - // Perform softmax after topk - float sum = float{0.0f}; - float maxScore = -std::numeric_limits::infinity(); - float score; - for (int ie = 0; ie < param.topK; ++ie) - { - score = static_cast(expIdx[ie].score); - maxScore = score >= maxScore ? score : maxScore; - } - for (int ie = 0; ie < param.topK; ++ie) - { - score = static_cast(expIdx[ie].score) - maxScore; - score = std::exp(score); - sum += score; - } - for (int ie = 0; ie < param.topK; ++ie) - { - score = static_cast(expIdx[ie].score) - maxScore; - score = static_cast(std::exp(score)); - score /= sum; - expIdx[ie].score = static_cast(score); - } - } - - // convert back to io_dtype and store the topk expert results in hostData.mPtrTopKPacked - for (int ie = 0; ie < param.topK; ++ie) - { - // Set invalid topk indices for the first half of the topk - if (param.hasInvalidTopKInput && ie < param.topK / 2 + 1) - { - expIdx[ie].idx = -1; - } - - PackedType si{static_cast(expIdx[ie].score), expIdx[ie].idx}; - reinterpret_cast(bufferCast(*this->mPtrTopKPackedHost))[it * param.topK + ie] = si; - if (param.useTopKAsInput) - { - bufferCast(*this->mPtrTopKIdsHost)[it * param.topK + ie] - = static_cast(expIdx[ie].idx); - bufferCast(*this->mPtrTopKWeightsHost)[it * param.topK + ie] = static_cast(expIdx[ie].score); - } - else if (param.getExpWeights) - { - bufferCast(*this->mPtrTopKWeightsHost)[it * param.topK + ie] = static_cast(expIdx[ie].score); - } - } - } - } - - void allocateBuffers(RoutingKernelTestParam const& param) override - { - RoutingKernelTest::allocateBuffers(param); - int64_t scoresSize = param.numTokens * param.numExperts; - this->mPtrScoresHost = mBufferManager->pinned(ITensor::makeShape({scoresSize}), TRTDataType::value); - this->mPtrScoresDevice = mBufferManager->gpu(ITensor::makeShape({scoresSize}), TRTDataType::value); - } - - template - void setParams(RoutingKernelTestParam const& param, RoutingData& routingData) - { - RoutingKernelTest::setCommonParams(param, routingData); - - if (sizeof(T) == 4) - { - routingData.mDtypeExpW = btg::Dtype::Fp32; - } - else - { - routingData.mDtypeExpW = btg::Dtype::Bfloat16; - } - - // Special case for RenormalizeNaive - routingData.mDoSoftmaxBeforeTopK = param.routingMethod == RoutingMethodType::RenormalizeNaive; - routingData.mNormTopkProb = param.routingMethod == RoutingMethodType::RenormalizeNaive; - - if (param.useTopKAsInput) - { - routingData.mPtrTopKIds = bufferCast(*this->mPtrTopKIdsDevice); - routingData.mPtrScores = nullptr; - } - else - { - routingData.mPtrTopKIds = nullptr; - routingData.mPtrScores = bufferCast(*this->mPtrScoresDevice); - } - } - - void callTestedFunction( - RoutingKernelTestParam const& param, tensorrt_llm::runtime::ITensor::SharedPtr& workspaceDevice) override - { - moe::dev::routing::routingRenormalize::Data routingData; - setParams(param, routingData); - moe::dev::routing::routingRenormalize::run(routingData, mStream->get()); - } -}; - -TYPED_TEST_SUITE(RoutingRenormalizeKernelTest, FloatAndBf16Types); - -TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelization) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/4, - /*numExperts=*/128, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelizationWithExpertParallelization) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/4, - /*numExperts=*/128, /*topK=*/8, - /*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/192, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelizationWithInvalidTopKInput) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/4, - /*numExperts=*/128, /*topK=*/8, - /*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/192, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelization) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/10, - /*numExperts=*/128, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/192, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithExpertParallelization) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/100, - /*numExperts=*/128, /*topK=*/8, - /*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithInvalidTopKInput) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/100, - /*numExperts=*/128, /*topK=*/8, - /*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithRenormalizeNaive) -{ - RoutingKernelTestParam param(RoutingMethodType::RenormalizeNaive, /*numTokens=*/10, - /*numExperts=*/128, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelization) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/1000, - /*numExperts=*/512, /*topK=*/10, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationTop4) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/200, - /*numExperts=*/128, /*topK=*/4, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithInvalidTopKInputTop4) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/200, - /*numExperts=*/128, /*topK=*/4, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithExpertParallelizationTop4) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/256, - /*numExperts=*/128, /*topK=*/4, - /*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/8, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithRenormalizeNaiveTop4) -{ - RoutingKernelTestParam param(RoutingMethodType::RenormalizeNaive, /*numTokens=*/10, - /*numExperts=*/128, /*topK=*/4, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelizationTop4) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/1000, - /*numExperts=*/128, /*topK=*/4, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, DynBlockBasic) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/8, - /*numExperts=*/128, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, DynBlockMaxTokens) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/16, - /*numExperts=*/512, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, DynBlockWithExpertParallelization) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/12, - /*numExperts=*/512, /*topK=*/8, - /*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/192, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, DynBlockWithTopKAsInput) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/8, - /*numExperts=*/128, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, DynBlockWithInvalidTopKInput) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/10, - /*numExperts=*/512, /*topK=*/8, - /*expertParallelization=*/2, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, DynBlockWithRenormalizeNaive) -{ - RoutingKernelTestParam param(RoutingMethodType::RenormalizeNaive, /*numTokens=*/16, - /*numExperts=*/512, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelizationLargeN) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/4, - /*numExperts=*/2048, /*topK=*/32, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationLargeN) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/100, - /*numExperts=*/2048, /*topK=*/32, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelizationLargeN) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/1000, - /*numExperts=*/2048, /*topK=*/32, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelizationLargeNWithInvalidTopKInput) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/1000, - /*numExperts=*/2048, /*topK=*/32, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8); - this->runTest(param); -}; -} // end namespace diff --git a/cpp/tests/unit_tests/kernels/routing/routingTest.cpp b/cpp/tests/unit_tests/kernels/routing/routingTest.cpp index a7da5a6d7281..3b5747f28936 100644 --- a/cpp/tests/unit_tests/kernels/routing/routingTest.cpp +++ b/cpp/tests/unit_tests/kernels/routing/routingTest.cpp @@ -297,6 +297,24 @@ void RoutingKernelTest::verifyExpertRoutingIndices(RoutingKernelTestParam con EXPECT_EQ(checkSetEqual(ie, permutedIdx, permutedIdxTest, "permuted idx"), true); EXPECT_EQ(checkSetEqual(ie, tokenIdx, tokenIdxTest, "token idx"), true); } + + // Verify that invalid expert entries produce expandedIdxToPermutedIdx == -1. + // The loop above only checks valid experts (0..numExperts-1) and skips invalid entries. + if (param.hasInvalidTopKInput) + { + for (int it = 0; it < param.numTokens * param.topK; ++it) + { + int16_t const expertIdx = expIdxHostPtr[it].idx; + bool const isInvalid = (expertIdx < 0) || (expertIdx >= param.numExperts); + if (isInvalid) + { + int32_t const permIdxTest = hostExpToPermTest[it]; + EXPECT_EQ(permIdxTest, -1) + << "expandedIdxToPermutedIdx[" << it << "] should be -1 for invalid expertId=" << expertIdx + << " but got " << permIdxTest; + } + } + } } template @@ -326,9 +344,18 @@ void RoutingKernelTest::verifyResult(RoutingKernelTestParam const& param) } // expert counts aren't always used, but if tokens > 8 * 1024, we are sure they are used if (param.numTokens > param.singleClusterTokenNum) - { //@Todo: check if this is always true + { assertEqual(bufferCast(*mPtrExpertCountsHost), expertCountsPtr, param.numExperts, "expert counts"); - if (param.routingMethod != RoutingMethodType::DeepSeekV3) + // The second half of mPtrExpertCounts is only filled by the multi-kernel offsets pipeline + // (routingIndicesOffsetsKernel). It is NOT filled by the coop kernel or cluster kernel. + // On SM90+, both the scores path (RoutingCustom.cu) and the post-topK path + // (RoutingFromTopKIds.cu) may use the coop kernel instead of multi-kernel for medium + // token counts. Skip this check whenever the coop path could have been taken. + // The coop path requires SM90+ and numExperts <= 1024. + bool const coopMayBeUsed = (mDeviceProp.major >= 9) && (param.numExperts <= 1024); + bool const useMultiKernelPath = !param.useTopKAsInput && !param.useTopKPackedAsInput + && param.routingMethod != RoutingMethodType::DeepSeekV3 && !coopMayBeUsed; + if (useMultiKernelPath) { assertEqual(bufferCast(*mPtrExpertCountsHost), expertCountsPtr + param.numExperts, param.numExperts, "expert counts (2)"); @@ -370,6 +397,13 @@ void RoutingKernelTest::runTest(RoutingKernelTestParam const& param) mBufferManager->copy(*mPtrTopKWeightsHost, *mPtrTopKWeightsDevice); mStream->synchronize(); } + else if (param.useTopKPackedAsInput) + { + // Set the topk_packed as input (computed by host reference, no scores) + mBufferManager->copy(*mPtrTopKPackedHost, *mPtrTopKPackedDevice); + mBufferManager->copy(*mPtrTopKWeightsHost, *mPtrTopKWeightsDevice); + mStream->synchronize(); + } // Retrieve the workspace size of the routing kernel. auto const workspaceSize = getDeviceWorkspaceSize(param); TensorPtr workspaceDevice diff --git a/cpp/tests/unit_tests/kernels/routing/routingTest.h b/cpp/tests/unit_tests/kernels/routing/routingTest.h index 4a50cbbc7950..aed280232927 100644 --- a/cpp/tests/unit_tests/kernels/routing/routingTest.h +++ b/cpp/tests/unit_tests/kernels/routing/routingTest.h @@ -36,6 +36,8 @@ typedef testing::Types FloatAndBf16Types; typedef testing::Types<__nv_bfloat16> Bf16Types; using RoutingMethodType = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::RoutingMethodType; +using RoutingPreprocessType = moe::dev::routing::RoutingPreprocessType; +using RoutingPostprocessType = moe::dev::routing::RoutingPostprocessType; using TensorPtr = tensorrt_llm::runtime::ITensor::SharedPtr; using namespace tensorrt_llm::runtime; @@ -226,14 +228,13 @@ inline auto comp = [](PackedFloat const& a, PackedFloat const& b) struct RoutingKernelTestParam { RoutingMethodType routingMethod{RoutingMethodType::Renormalize}; - int32_t numTokens; - int32_t numExperts; + int32_t numTokens{0}; + int32_t numExperts{0}; uint32_t topK{1}; int32_t localExpertsStartIdx{0}; int32_t localExpertsStrideLog2{0}; - // we don't use any special striding, and we always test the GPU at logical idx 0 - int32_t numLocalExperts{128}; + int32_t numLocalExperts{0}; int32_t paddingLog2{3}; int32_t tileTokensDim{1}; @@ -244,13 +245,19 @@ struct RoutingKernelTestParam int requiredComputeCapability{9}; // Check the input parameters - bool useTopKAsInput{false}; + bool useTopKAsInput{false}; // When true, mPtrTopKIds + mPtrTopKWeights are provided as input + bool useTopKPackedAsInput{false}; // When true, mPtrTopKPacked is provided as input (without mPtrScores) bool hasInvalidTopKInput{false}; + int32_t invalidExpertIdValue{-1}; // Value used to mark invalid topK entries: -1 or numExperts // Special for renormalize routing method bool doSoftmaxBeforeTopK{false}; bool normTopkProb{true}; + // Policy type selection for routingCustom (set automatically by build() if not overridden) + RoutingPreprocessType preprocessType{RoutingPreprocessType::None}; + RoutingPostprocessType postprocessType{RoutingPostprocessType::Softmax}; + // Special for deepseek routing method int32_t nGroup{0}; int32_t topkGroup{0}; @@ -259,59 +266,239 @@ struct RoutingKernelTestParam // Default constructor RoutingKernelTestParam() = default; - // Constructor with required parameters - RoutingKernelTestParam(int32_t nt, int32_t ne, uint32_t tk = 1) - : numTokens(nt) - , numExperts(ne) - , topK(tk) - { - } - - // Constructor with all parameters - RoutingKernelTestParam(RoutingMethodType routingMethod, int32_t numTokens, int32_t numExperts, uint32_t topK, - int32_t expertParallelization = 1, int32_t expertParallelizationId = 0, int32_t tileTokensDim = 1, - int32_t paddingLog2 = 3, int32_t localExpertsStrideLog2 = 0, bool usePdl = true, bool getExpWeights = true, - bool useTopKAsInput = false, bool hasInvalidTopKInput = false, int32_t nGroup = 1, int32_t topkGroup = 1, - float routedScalingFactor = 1.0f, int requiredComputeCapability = 9) - : routingMethod(routingMethod) - , numTokens(numTokens) - , numExperts(numExperts) - , topK(topK) - , tileTokensDim(tileTokensDim) - , paddingLog2(paddingLog2) - , localExpertsStrideLog2(localExpertsStrideLog2) - , usePdl(usePdl) - , getExpWeights(getExpWeights) - , useTopKAsInput(useTopKAsInput) - , hasInvalidTopKInput(hasInvalidTopKInput) - , nGroup(nGroup) - , topkGroup(topkGroup) - , routedScalingFactor(routedScalingFactor) - , requiredComputeCapability(requiredComputeCapability) - { - // Check the routing method - if (routingMethod != RoutingMethodType::Renormalize && routingMethod != RoutingMethodType::RenormalizeNaive - && routingMethod != RoutingMethodType::Llama4 && routingMethod != RoutingMethodType::DeepSeekV3) + // Copy / move constructors and assignment operators + RoutingKernelTestParam(RoutingKernelTestParam const& other) = default; + RoutingKernelTestParam(RoutingKernelTestParam&& other) = default; + RoutingKernelTestParam& operator=(RoutingKernelTestParam const& other) = default; + RoutingKernelTestParam& operator=(RoutingKernelTestParam&& other) = default; + ~RoutingKernelTestParam() = default; + + // + // Fluent builder methods — each returns *this so calls can be chained. + // Usage: + // auto param = RoutingKernelTestParam() + // .withRoutingMethod(RoutingMethodType::Renormalize) + // .withNumTokens(4) + // .withNumExperts(128) + // .withTopK(8) + // .build(); + // + + RoutingKernelTestParam& withRoutingMethod(RoutingMethodType val) + { + routingMethod = val; + return *this; + } + + RoutingKernelTestParam& withNumTokens(int32_t val) + { + numTokens = val; + return *this; + } + + RoutingKernelTestParam& withNumExperts(int32_t val) + { + numExperts = val; + return *this; + } + + RoutingKernelTestParam& withTopK(uint32_t val) + { + topK = val; + return *this; + } + + RoutingKernelTestParam& withExpertParallelization(int32_t ep, int32_t epId = 0) + { + mExpertParallelization = ep; + mExpertParallelizationId = epId; + return *this; + } + + RoutingKernelTestParam& withTileTokensDim(int32_t val) + { + tileTokensDim = val; + return *this; + } + + RoutingKernelTestParam& withPaddingLog2(int32_t val) + { + paddingLog2 = val; + return *this; + } + + RoutingKernelTestParam& withLocalExpertsStrideLog2(int32_t val) + { + localExpertsStrideLog2 = val; + return *this; + } + + RoutingKernelTestParam& withUsePdl(bool val) + { + usePdl = val; + return *this; + } + + RoutingKernelTestParam& withGetExpWeights(bool val) + { + getExpWeights = val; + return *this; + } + + RoutingKernelTestParam& withUseTopKAsInput(bool val) + { + useTopKAsInput = val; + return *this; + } + + RoutingKernelTestParam& withUseTopKPackedAsInput(bool val) + { + useTopKPackedAsInput = val; + return *this; + } + + RoutingKernelTestParam& withHasInvalidTopKInput(bool val) + { + hasInvalidTopKInput = val; + return *this; + } + + RoutingKernelTestParam& withInvalidExpertIdValue(int32_t val) + { + invalidExpertIdValue = val; + return *this; + } + + RoutingKernelTestParam& withNGroup(int32_t val) + { + nGroup = val; + return *this; + } + + RoutingKernelTestParam& withTopkGroup(int32_t val) + { + topkGroup = val; + return *this; + } + + RoutingKernelTestParam& withRoutedScalingFactor(float val) + { + routedScalingFactor = val; + return *this; + } + + RoutingKernelTestParam& withPreprocessType(RoutingPreprocessType val) + { + preprocessType = val; + mPreprocessTypeOverridden = true; + return *this; + } + + RoutingKernelTestParam& withPostprocessType(RoutingPostprocessType val) + { + postprocessType = val; + mPostprocessTypeOverridden = true; + return *this; + } + + RoutingKernelTestParam& withNormTopkProb(bool val) + { + normTopkProb = val; + mNormTopkProbOverridden = true; + return *this; + } + + RoutingKernelTestParam& withRequiredComputeCapability(int val) + { + requiredComputeCapability = val; + return *this; + } + + /// Finalize and validate. Must be called after all `with*()` setters. + RoutingKernelTestParam& build() + { + // Validate routing method + if (routingMethod != RoutingMethodType::Default && routingMethod != RoutingMethodType::Renormalize + && routingMethod != RoutingMethodType::RenormalizeNaive && routingMethod != RoutingMethodType::Llama4 + && routingMethod != RoutingMethodType::DeepSeekV3 && routingMethod != RoutingMethodType::MiniMax2 + && routingMethod != RoutingMethodType::SigmoidRenorm) { throw std::invalid_argument("Invalid routing method"); } - // Set about the expert parallelization - numLocalExperts = numExperts / expertParallelization; - localExpertsStartIdx = numLocalExperts * expertParallelizationId; + // Derive expert parallelization parameters + numLocalExperts = numExperts / mExpertParallelization; + localExpertsStartIdx = numLocalExperts * mExpertParallelizationId; - // Apply routing method specific settings - if (routingMethod == RoutingMethodType::RenormalizeNaive) + // Apply routing-method-specific settings + if (routingMethod == RoutingMethodType::Default) + { + doSoftmaxBeforeTopK = true; + normTopkProb = false; + } + else if (routingMethod == RoutingMethodType::RenormalizeNaive) { doSoftmaxBeforeTopK = true; normTopkProb = true; } + else if (routingMethod == RoutingMethodType::SigmoidRenorm) + { + doSoftmaxBeforeTopK = false; + if (!mNormTopkProbOverridden) + { + normTopkProb = true; + } + } else { doSoftmaxBeforeTopK = false; normTopkProb = false; } + // Derive policy types from routing method when not explicitly set + if (!mPreprocessTypeOverridden) + { + if (routingMethod == RoutingMethodType::Default || routingMethod == RoutingMethodType::RenormalizeNaive) + { + preprocessType = RoutingPreprocessType::Softmax; + } + else if (routingMethod == RoutingMethodType::MiniMax2) + { + preprocessType = RoutingPreprocessType::SigmoidBias; + } + else if (routingMethod == RoutingMethodType::SigmoidRenorm) + { + preprocessType = RoutingPreprocessType::Sigmoid; + } + else + { + preprocessType = RoutingPreprocessType::None; + } + } + if (!mPostprocessTypeOverridden) + { + if (routingMethod == RoutingMethodType::Default) + { + postprocessType = RoutingPostprocessType::None; + } + else if (routingMethod == RoutingMethodType::RenormalizeNaive) + { + postprocessType = RoutingPostprocessType::SumNormalize; + } + else if (routingMethod == RoutingMethodType::MiniMax2) + { + postprocessType = RoutingPostprocessType::ScaledSumNormalize; + } + else if (routingMethod == RoutingMethodType::SigmoidRenorm) + { + postprocessType = RoutingPostprocessType::SumNormalize; + } + else + { + postprocessType = RoutingPostprocessType::Softmax; + } + } + // Set singleClusterTokenNum if (routingMethod == RoutingMethodType::DeepSeekV3) { @@ -322,36 +509,36 @@ struct RoutingKernelTestParam singleClusterTokenNum = 256; } + // Cross-field validation if (hasInvalidTopKInput && !useTopKAsInput) { throw std::invalid_argument("hasInvalidTopKInput is only supported when useTopKAsInput is true"); } - } - - // Copy constructor - RoutingKernelTestParam(RoutingKernelTestParam const& other) = default; - - // Move constructor - RoutingKernelTestParam(RoutingKernelTestParam&& other) = default; - - // Copy assignment operator - RoutingKernelTestParam& operator=(RoutingKernelTestParam const& other) = default; - - // Move assignment operator - RoutingKernelTestParam& operator=(RoutingKernelTestParam&& other) = default; + if (useTopKAsInput && useTopKPackedAsInput) + { + throw std::invalid_argument("useTopKAsInput and useTopKPackedAsInput are mutually exclusive"); + } - // Destructor - ~RoutingKernelTestParam() = default; + return *this; + } std::string toString() const { return tensorrt_llm::common::fmtstr( "RoutingKernelTestParam[num_tokens=%d, num_experts=%d, topK=%u, doSoftmaxBeforeTopK=%d, normTopkProb=%d, " "localExpertsStartIdx=%d, localExpertsStrideLog2=%d, numLocalExperts=%d, usePdl=%d, useTopKAsInput=%d, " - "hasInvalidTopKInput=%d]", + "useTopKPackedAsInput=%d, hasInvalidTopKInput=%d]", numTokens, numExperts, topK, doSoftmaxBeforeTopK, normTopkProb, localExpertsStartIdx, - localExpertsStrideLog2, numLocalExperts, usePdl, useTopKAsInput, hasInvalidTopKInput); + localExpertsStrideLog2, numLocalExperts, usePdl, useTopKAsInput, useTopKPackedAsInput, hasInvalidTopKInput); } + +private: + // Builder state — used by build() to derive public fields. + int32_t mExpertParallelization{1}; + int32_t mExpertParallelizationId{0}; + bool mPreprocessTypeOverridden{false}; + bool mPostprocessTypeOverridden{false}; + bool mNormTopkProbOverridden{false}; }; template diff --git a/docker/Dockerfile.multi b/docker/Dockerfile.multi index 96b75cd96a43..3fbcc0b5659e 100644 --- a/docker/Dockerfile.multi +++ b/docker/Dockerfile.multi @@ -52,8 +52,10 @@ RUN --mount=type=bind,source=docker/common,target=/opt/docker/common \ # Install constraints after install.sh so cleanup() doesn't delete the file mid-RUN COPY constraints.txt /tmp/constraints.txt RUN --mount=type=cache,target=/root/.cache/pip \ - # WAR: uninstall dependencies that has vulnerability - pip3 uninstall -y tornado black nbconvert || true && \ + # WAR: uninstall dependencies that has vulnerability or need upgrading + pip3 uninstall -y tornado black nbconvert nvidia-cutlass-dsl nvidia-cutlass-dsl-libs-base || true && \ + # Remove any leftover namespace dirs or dist-info that pip missed + rm -rf $(python3 -c "import site; print(site.getsitepackages()[0])")/nvidia_cutlass_dsl* && \ pip3 install --ignore-installed --no-cache-dir -r /tmp/constraints.txt && \ rm /tmp/constraints.txt diff --git a/jenkins/L0_Test.groovy b/jenkins/L0_Test.groovy index 43844da06d22..e5c353877b1f 100644 --- a/jenkins/L0_Test.groovy +++ b/jenkins/L0_Test.groovy @@ -3620,6 +3620,10 @@ def launchTestJobs(pipeline, testFilter) trtllm_utils.llmExecStepWithRetry(pipeline, script: "[ -f /etc/pip/constraint.txt ] && : > /etc/pip/constraint.txt || true") // Remove the python3-pygments pip package because the dlfw image already includes a Debian pygments package, which conflicts with the pip-installed version. trtllm_utils.llmExecStepWithRetry(pipeline, script: "apt-get remove -y python3-pygments") + // Remove stale nvidia-cutlass-dsl from the base image to prevent namespace + // directory corruption when pip upgrades to the version required by tensorrt_llm. + trtllm_utils.llmExecStepWithRetry(pipeline, script: "pip3 uninstall -y nvidia-cutlass-dsl nvidia-cutlass-dsl-libs-base || true") + trtllm_utils.llmExecStepWithRetry(pipeline, script: 'rm -rf $(python3 -c "import site; print(site.getsitepackages()[0])")/nvidia_cutlass_dsl*') } trtllm_utils.llmExecStepWithRetry(pipeline, script: "apt-get update && apt-get install -y python3-pip git rsync curl wget") trtllm_utils.checkoutSource(LLM_REPO, env.gitlabCommit, LLM_ROOT, false, true) diff --git a/jenkins/current_image_tags.properties b/jenkins/current_image_tags.properties index 8d751664640b..d8d16c56ee5d 100644 --- a/jenkins/current_image_tags.properties +++ b/jenkins/current_image_tags.properties @@ -13,7 +13,7 @@ # images are adopted from PostMerge pipelines, the abbreviated commit hash is used instead. IMAGE_NAME=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm -LLM_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-26.02-py3-x86_64-ubuntu24.04-trt10.15.1.29-skip-tritondevel-202604011104-12600 -LLM_SBSA_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-26.02-py3-sbsa-ubuntu24.04-trt10.15.1.29-skip-tritondevel-202604011104-12600 -LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-13.1.1-devel-rocky8-x86_64-rocky8-py310-trt10.15.1.29-skip-tritondevel-202604011104-12600 -LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-13.1.1-devel-rocky8-x86_64-rocky8-py312-trt10.15.1.29-skip-tritondevel-202604011104-12600 +LLM_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-26.02-py3-x86_64-ubuntu24.04-trt10.15.1.29-skip-tritondevel-202604200956-13064 +LLM_SBSA_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-26.02-py3-sbsa-ubuntu24.04-trt10.15.1.29-skip-tritondevel-202604200956-13064 +LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-13.1.0-devel-rocky8-x86_64-rocky8-py310-trt10.15.1.29-skip-tritondevel-202604200956-13064 +LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-13.1.0-devel-rocky8-x86_64-rocky8-py312-trt10.15.1.29-skip-tritondevel-202604200956-13064 diff --git a/requirements.txt b/requirements.txt index 110cf75aa64f..9d9b584bc1c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -55,7 +55,7 @@ ordered-set peft>=0.18.1,<0.19.0 patchelf einops -flashinfer-python==0.6.6 +flashinfer-python==0.6.8 opencv-python-headless xgrammar==0.1.32 llguidance==0.7.29 @@ -72,7 +72,7 @@ xdsl>=0.59.0 # Optional: required for MLIR-based elementwise fusion in AutoDeplo tiktoken blobfile openai-harmony==0.0.4 -nvidia-cutlass-dsl==4.3.4; python_version >= "3.10" +nvidia-cutlass-dsl[cu13]==4.4.2; python_version >= "3.10" plotly numexpr partial_json_parser diff --git a/security_scanning/pyproject.toml b/security_scanning/pyproject.toml index 54e70cca5869..042c236dfb53 100644 --- a/security_scanning/pyproject.toml +++ b/security_scanning/pyproject.toml @@ -55,7 +55,7 @@ dependencies = [ "peft (>=0.18.1,<0.19.0)", "patchelf (>=0.17.2.4,<0.18.0.0)", "einops (>=0.8.2,<0.9.0)", - "flashinfer-python (==0.6.6)", + "flashinfer-python (==0.6.8)", "opencv-python-headless (>=4.13.0.92,<5.0.0.0)", "xgrammar (==0.1.32)", "llguidance (==0.7.29)", @@ -72,7 +72,7 @@ dependencies = [ "tiktoken (>=0.12.0,<0.13.0)", "blobfile (>=3.2.0,<4.0.0)", "openai-harmony (==0.0.4)", - "nvidia-cutlass-dsl (==4.3.4)", + "nvidia-cutlass-dsl (==4.4.2)", "plotly (>=6.7.0,<7.0.0)", "numexpr (>=2.14.1,<3.0.0)", "partial-json-parser (>=0.2.1.1.post7,<0.3.0.0)", diff --git a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py index f91d331dbe52..95e210879cc3 100644 --- a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py @@ -250,7 +250,7 @@ class GatherGroupedGemmInputsHelper(GroupedGemmInputsHelper): IDX_SHAPE_INFER = IDX_PERMUTED_IDX_TO_EXPANDED_IDX def inputs_pre_hook(self, inputs: List) -> List: - """Pre-hook for gather-based SwiGLU fusion kernel. + """Pre-hook for gather-based activation fusion kernel. Generates: - tile_idx_to_group_idx @@ -324,7 +324,7 @@ def get_dense_gemm_approximate_cta_nums( import cutlass import cutlass.cute as cute - from ..cute_dsl_kernels.blackwell.blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion import \ + from ..cute_dsl_kernels.blackwell.blockscaled_contiguous_gather_grouped_gemm_act_fusion import \ BlockScaledContiguousGatherGroupedGemmKernel from ..cute_dsl_kernels.blackwell.blockscaled_contiguous_grouped_gemm import \ Sm100BlockScaledContiguousGroupedGemmKernel @@ -1883,7 +1883,7 @@ def _( device=input_scale.device) return output, output_scale - class Sm100BlockScaledContiguousGatherGroupedGemmSwigluFusionRunner( + class Sm100BlockScaledContiguousGatherGroupedGemmActFusionRunner( TunableRunner): kernel_class = BlockScaledContiguousGatherGroupedGemmKernel kernel_cache = dict() @@ -1899,7 +1899,8 @@ def __init__(self, local_expert_offset: int, tile_size: int, scaling_vector_size: int = 16, - b_tensor_l_sizes: Optional[Tuple[int, ...]] = None): + b_tensor_l_sizes: Optional[Tuple[int, ...]] = None, + is_gated: bool = True): """Initialize the runner. Args: @@ -1907,6 +1908,7 @@ def __init__(self, None for single-B mode. Used for kernel cache key. """ super().__init__() + self.is_gated = is_gated self.num_experts = num_experts self.top_k = top_k self.num_local_experts = num_local_experts @@ -1938,6 +1940,7 @@ def unique_id(self): self.tile_size, self.scaling_vector_size, self.b_tensor_l_sizes, + self.is_gated, ) def get_valid_tactics( @@ -2061,11 +2064,14 @@ def forward(self, inputs: List, n = b0.size(1) sum(bi.size(0) for bi in b_list) scale_k = k // self.scaling_vector_size - interm_size = n // 2 + interm_size = n // 2 if self.is_gated else n assert m % self.tile_size == 0 assert k % (self.scaling_vector_size * 4) == 0 - assert n % (self.scaling_vector_size * 4 * 2) == 0 + if self.is_gated: + assert n % (self.scaling_vector_size * 4 * 2) == 0 + else: + assert n % (self.scaling_vector_size * 4) == 0 assert b0.size(2) * 2 == k assert a_sf.size(0) == orig_m assert a_sf.size(1) == scale_k @@ -2148,7 +2154,7 @@ def forward(self, inputs: List, cache_key = (self.scaling_vector_size, self.tile_size, self.top_k, mma_tiler_mn, cluster_shape_mn, raster_along_m, - b_tensor_l_sizes) + b_tensor_l_sizes, self.is_gated) if cache_key not in self.__class__.kernel_cache: gemm = self.__class__.kernel_class( @@ -2159,6 +2165,7 @@ def forward(self, inputs: List, topk=self.top_k, raster_along_m=raster_along_m, b_tensor_l_sizes=b_tensor_l_sizes, + is_gated=self.is_gated, ) hardware_info = cutlass.utils.HardwareInfo() max_active_clusters = hardware_info.get_max_active_clusters( @@ -2190,6 +2197,7 @@ def forward(self, inputs: List, scaling_vector_size=self.scaling_vector_size, max_active_clusters=max_active_clusters, stream=stream, + is_gated=self.is_gated, ) self.__class__.kernel_cache[cache_key] = compiled_gemm else: @@ -2219,10 +2227,10 @@ def forward(self, inputs: List, return c, c_sf @torch.library.custom_op( - "trtllm::cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b", + "trtllm::cute_dsl_nvfp4_gather_grouped_gemm_act_fusion_blackwell_multi_b", mutates_args=(), device_types="cuda") - def cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b( + def cute_dsl_nvfp4_gather_grouped_gemm_act_fusion_blackwell_multi_b( input: torch.Tensor, weight: List[torch.Tensor], input_scale: torch.Tensor, @@ -2239,8 +2247,11 @@ def cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b( local_expert_offset: int, tile_size: int, scaling_vector_size: int = 16, + is_gated: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: - """CuteDSL-based NVFP4 gather grouped GEMM with SwiGLU fusion (multi-B list interface). + """CuteDSL-based NVFP4 gather grouped GEMM with activation fusion (multi-B list interface). + + Supports SwiGLU (is_gated=True) and Relu2 (is_gated=False) epilogue. Args: weight: List of B tensors. Single-B mode: [b], multi-B mode: [b0, b1, ...]. @@ -2251,9 +2262,9 @@ def cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b( b_tensor_l_sizes = tuple(w.size(0) for w in weight) - runner = Sm100BlockScaledContiguousGatherGroupedGemmSwigluFusionRunner( + runner = Sm100BlockScaledContiguousGatherGroupedGemmActFusionRunner( num_experts, top_k, num_local_experts, local_expert_offset, - tile_size, scaling_vector_size, b_tensor_l_sizes) + tile_size, scaling_vector_size, b_tensor_l_sizes, is_gated) inputs = [ input, weight, input_scale, weight_scale, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, @@ -2261,7 +2272,7 @@ def cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b( ] _, best_tactic = tuner.choose_one( - "trtllm::cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b", + "trtllm::cute_dsl_nvfp4_gather_grouped_gemm_act_fusion_blackwell_multi_b", [runner], runner.get_tuning_config(), inputs, @@ -2272,7 +2283,8 @@ def cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b( return output @torch.library.register_fake( - "trtllm::cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b") + "trtllm::cute_dsl_nvfp4_gather_grouped_gemm_act_fusion_blackwell_multi_b" + ) def _fake_multi_b( input: torch.Tensor, weight: List[torch.Tensor], @@ -2290,10 +2302,11 @@ def _fake_multi_b( local_expert_offset: int, tile_size: int, scaling_vector_size: int = 16, + is_gated: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: m = permuted_idx_to_expanded_idx.size(0) n = weight[0].size(1) - interm_size = n // 2 + interm_size = n // 2 if is_gated else n output = torch.empty(m, interm_size // 2, dtype=input.dtype, @@ -2304,10 +2317,10 @@ def _fake_multi_b( return output, output_scale @torch.library.custom_op( - "trtllm::cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell", + "trtllm::cute_dsl_nvfp4_gather_grouped_gemm_act_fusion_blackwell", mutates_args=(), device_types="cuda") - def cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell( + def cute_dsl_nvfp4_gather_grouped_gemm_act_fusion_blackwell( input: torch.Tensor, weight: torch.Tensor, input_scale: torch.Tensor, @@ -2324,13 +2337,14 @@ def cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell( local_expert_offset: int, tile_size: int, scaling_vector_size: int = 16, + is_gated: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: - """CuteDSL-based NVFP4 gather grouped GEMM with SwiGLU fusion (single-B tensor interface). + """CuteDSL-based NVFP4 gather grouped GEMM with activation fusion (single-B tensor interface). Thin wrapper: wraps single tensors into lists and calls - cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b. + cute_dsl_nvfp4_gather_grouped_gemm_act_fusion_blackwell_multi_b. """ - return torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b( + return torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_act_fusion_blackwell_multi_b( input, [weight], input_scale, @@ -2347,10 +2361,11 @@ def cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell( local_expert_offset, tile_size, scaling_vector_size, + is_gated, ) @torch.library.register_fake( - "trtllm::cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell") + "trtllm::cute_dsl_nvfp4_gather_grouped_gemm_act_fusion_blackwell") def _fake_single_b( input: torch.Tensor, weight: torch.Tensor, @@ -2368,10 +2383,11 @@ def _fake_single_b( local_expert_offset: int, tile_size: int, scaling_vector_size: int = 16, + is_gated: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: m = permuted_idx_to_expanded_idx.size(0) n = weight.size(1) - interm_size = n // 2 + interm_size = n // 2 if is_gated else n output = torch.empty(m, interm_size // 2, dtype=input.dtype, diff --git a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py index 8a0b544722cb..7b113dad0e0e 100644 --- a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py @@ -89,6 +89,18 @@ def prepare_dummy_topk_and_hook( lambda: torch.randn( num_experts, dtype=torch.bfloat16, device=hidden_states.device) }) + if routing_method_type == RoutingMethodType.MiniMax2: + routing_cls_kwargs.update({ + 'callable_e_score_correction_bias': + lambda: torch.randn( + num_experts, dtype=torch.bfloat16, device=hidden_states.device), + 'num_experts': + num_experts, + }) + if routing_method_type == RoutingMethodType.SigmoidRenorm: + routing_cls_kwargs.update({ + 'num_experts': num_experts, + }) routing_method = ROUTING_METHOD_TYPE_TO_CLASS[routing_method_type]( top_k=top_k, **routing_cls_kwargs) diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_gather_grouped_gemm_act_fusion.py similarity index 92% rename from tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py rename to tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_gather_grouped_gemm_act_fusion.py index c339787301e6..5fb4ecbc9175 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_gather_grouped_gemm_act_fusion.py @@ -37,7 +37,6 @@ import cutlass.utils.blockscaled_layout as blockscaled_utils from cutlass._mlir.dialects import math from cutlass.cute.nvgpu import cpasync, tcgen05 -from cutlass.cutlass_dsl import Int32 from .custom_pipeline import PipelineCpAsyncUmma from .utils import ( @@ -50,19 +49,25 @@ ) """ -High-performance persistent blockscaled contiguous grouped dense GEMM with gather and SwiGLU fusion -(C = up * silu(gate), where up and gate come from interleaved weight matrix B) -example for the NVIDIA Blackwell architecture using CUTE DSL. +High-performance persistent blockscaled contiguous grouped dense GEMM with gather and activation +fusion example for the NVIDIA Blackwell architecture using CUTE DSL. -This kernel performs FC1 layer computation with SwiGLU activation fusion: +Supported fused activations (selected at construction via ``is_gated``): + - Gated (SwiGLU): C = up * silu(gate), where up/gate come from interleaved weight matrix B + - Non-gated (Relu2): C = relu(alpha * x)^2 + +This kernel performs FC1 layer computation with fused activation: 1. GEMM: acc = alpha * (SFA * A[token_ids]) * (SFB * B) -2. SwiGLU: C = up * silu(gate), where up/gate are extracted from interleaved acc (granularity=64) +2. Activation: + - Gated: C = up * silu(gate), up/gate extracted from interleaved acc (granularity=64) + - Non-gated: C = relu(acc)^2 3. Optional Quant: When c_dtype is Float4E2M1FN, generates scale factor C and quantizes output - Matrix A is MxKx1, A can be row-major("K"), ValidM is composed of valid m in different groups - Matrix B is NxKxL, B can be column-major("K"), L is grouped dimension (number of experts) - - B weights are interleaved: [up_0:64, gate_64:128, up_128:192, gate_192:256, ...] -- Matrix C is Mx(N/2)x1, C can be row-major("N"), N is halved due to SwiGLU fusion + - Gated: B weights are interleaved: [up_0:64, gate_64:128, up_128:192, gate_192:256, ...] + - Non-gated: B weights are plain +- Matrix C is MxN_out x1, C can be row-major("N"). N_out = N/2 for gated, N for non-gated. - Matrix SFA layout is filled internally according to A shape and BlockScaledBasicChunk, which has M×ceil_div(K, sf_vec_size)×1 elements - Matrix SFB layout is filled internally according to B shape and BlockScaledBasicChunk, @@ -104,10 +109,10 @@ - Load scale factor A/B from shared memory (SMEM) to tensor memory (TMEM) using tcgen05.cp instruction. - Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. 5. EPILOGUE warps (warps 0-3): - - Load two accumulator subtiles (up and gate) from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. - - Apply alpha scaling: up_scaled = alpha * up, gate_scaled = alpha * gate - - Compute SwiGLU activation: output = up_scaled * silu(gate_scaled), where silu(x) = x * sigmoid(x) - - If c_dtype is Float4E2M1FN: generate scale factor C (SFC) and quantize output + - Gated: load two accumulator subtiles (up, gate) from TMEM to RMEM via tcgen05.ld, then + apply alpha scaling and compute output = (alpha*up) * silu(alpha*gate). + - Non-gated: load one accumulator subtile per iteration and compute output = relu(alpha*acc)^2. + - If c_dtype is Float4E2M1FN: generate scale factor C (SFC) and quantize output. - Type convert output to c_dtype. - Store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations. @@ -155,160 +160,24 @@ """ -# TODO: Remove this hook helper function after nvidia-cutlass-dsl 4.4 is released. -def hooked_PersistentTileSchedulerParams_init( - self, - problem_shape_ntile_mnl: cute.Shape, - cluster_shape_mnk: cute.Shape, - swizzle_size: int = 1, - raster_along_m: bool = True, - *, - loc=None, - ip=None, -): - if cluster_shape_mnk[2] != 1: - raise ValueError(f"unsupported cluster_shape_k {cluster_shape_mnk[2]}") - if swizzle_size < 1: - raise ValueError(f"expect swizzle_size >= 1, but get {swizzle_size}") - - self.problem_shape_ntile_mnl = problem_shape_ntile_mnl - # cluster_shape_mnk is kept for reconstruction - self._cluster_shape_mnk = cluster_shape_mnk - self.cluster_shape_mn = cluster_shape_mnk[:2] - self.swizzle_size = swizzle_size - self._raster_along_m = raster_along_m - self._loc = loc - - # Apply swizzle if swizzle_size > 1 - if swizzle_size > 1: - problem_shape_ncluster_mnl = cute.round_up( - self.problem_layout_ncluster_mnl.shape, - (1, swizzle_size, 1) if raster_along_m else (swizzle_size, 1, 1), - ) - - if raster_along_m: - self.problem_layout_ncluster_mnl = cute.make_layout( - ( - problem_shape_ncluster_mnl[0], - (swizzle_size, problem_shape_ncluster_mnl[1] // swizzle_size), - problem_shape_ncluster_mnl[2], - ), - stride=( - swizzle_size, - (1, swizzle_size * problem_shape_ncluster_mnl[0]), - problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1], - ), - loc=loc, - ip=ip, - ) - else: - self.problem_layout_ncluster_mnl = cute.make_layout( - ( - (swizzle_size, problem_shape_ncluster_mnl[0] // swizzle_size), - problem_shape_ncluster_mnl[1], - problem_shape_ncluster_mnl[2], - ), - stride=( - (1, swizzle_size * problem_shape_ncluster_mnl[1]), - swizzle_size, - problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1], - ), - loc=loc, - ip=ip, - ) - - # Create FastDivmod divisors (only when swizzle_size == 1 for correctness) - # FastDivmod assumes simple col-major/row-major layout, incompatible with swizzled layouts - if swizzle_size == 1: - problem_shape_ncluster_mnl = cute.ceil_div( - self.problem_shape_ntile_mnl, cluster_shape_mnk[:2], loc=loc, ip=ip - ) - if raster_along_m: - self.problem_layout_ncluster_mnl = cute.make_layout( - problem_shape_ncluster_mnl, - stride=( - 1, - problem_shape_ncluster_mnl[0], - problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1], - ), - loc=loc, - ip=ip, - ) - else: - self.problem_layout_ncluster_mnl = cute.make_layout( - problem_shape_ncluster_mnl, - stride=( - problem_shape_ncluster_mnl[1], - 1, - problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1], - ), - loc=loc, - ip=ip, - ) - problem_layout_size = cute.size(self.problem_layout_ncluster_mnl, loc=loc, ip=ip) - cluster_count_m = self.problem_layout_ncluster_mnl.shape[0] - cluster_count_n = self.problem_layout_ncluster_mnl.shape[1] - - # batch_fdd: Used to map linear_idx to work_unit_id (handles persistent scheduling) - self.batch_fdd = cute.fast_divmod_create_divisor(problem_layout_size, loc=loc, ip=ip) - - # cluster_shape_m_fdd: Used to decode work_unit_id to cluster coordinates - self.cluster_shape_m_fdd = cute.fast_divmod_create_divisor(cluster_count_m, loc=loc, ip=ip) - - # cluster_shape_n_fdd: Used for the second level decomposition - self.cluster_shape_n_fdd = cute.fast_divmod_create_divisor(cluster_count_n, loc=loc, ip=ip) - else: - # FastDivmod not applicable with swizzling, set to None - self.batch_fdd = None - self.cluster_shape_m_fdd = None - self.cluster_shape_n_fdd = None - - -def hooked_get_cluster_work_idx_with_fastdivmod( - self, current_work_linear_idx: Int32, *, loc=None, ip=None -) -> Tuple[Int32, Int32, Int32]: - work_iteration, work_unit_id = divmod(current_work_linear_idx, self.params.batch_fdd) - - if self.params._raster_along_m: - # raster_along_m=True means column major (m is fastest) - # First, get cluster_m using cluster_shape_m_fdd - cluster_n_batch, cluster_m = divmod(work_unit_id, self.params.cluster_shape_m_fdd) - - # Then decode cluster_n_batch to get cluster_n and batch_l using FastDivmod - batch_l, cluster_n = divmod(cluster_n_batch, self.params.cluster_shape_n_fdd) - else: - # raster_along_m=False means row major (n is fastest) - # First, get cluster_n using cluster_shape_n_fdd - cluster_m_batch, cluster_n = divmod(work_unit_id, self.params.cluster_shape_n_fdd) - - # Then decode cluster_m_batch to get cluster_m and batch_l using FastDivmod - batch_l, cluster_m = divmod(cluster_m_batch, self.params.cluster_shape_m_fdd) - - return (cluster_m, cluster_n, batch_l) - - -cutlass.utils.PersistentTileSchedulerParams.__init__ = hooked_PersistentTileSchedulerParams_init -cutlass.utils.StaticPersistentTileScheduler._get_cluster_work_idx_with_fastdivmod = ( - hooked_get_cluster_work_idx_with_fastdivmod -) - - class BlockScaledContiguousGatherGroupedGemmKernel: - """This class implements contiguous grouped matrix multiplication with gather operation and SwiGLU fusion - for FC1 layer computation (C = up * silu(gate), where up/gate come from interleaved GEMM result). + """This class implements contiguous grouped matrix multiplication with gather operation and a + fused activation (SwiGLU or Relu2) for FC1 layer computation. The computation flow: 1. GEMM: acc = alpha * (SFA * A[token_ids]) * (SFB * B) - 2. SwiGLU: C = up * silu(gate), extracted from interleaved acc with granularity=64 + 2. Activation (selected via ``is_gated``): + - Gated (SwiGLU): C = up * silu(gate), from interleaved acc with granularity=64 + - Non-gated (Relu2): C = relu(acc)^2 3. Optional Quant: When c_dtype is Float4E2M1FN, generates SFC and quantizes output - Note: Output C has N/2 columns since pairs of (up, gate) are combined by SwiGLU. + Note: Output C has N/2 columns for gated (pairs of up/gate collapsed), N columns for non-gated. Key Features: - Uses LDGSTS instructions for loading A and SFA matrices with gather/permutation capability - Uses TMA (Tensor Memory Access) for loading B and SFB matrices with multicast - Token ID mapping enables efficient gather operation during A/SFA load - - SwiGLU activation fusion in epilogue (up * silu(gate) with interleaved weights) + - Activation fusion in epilogue (gated uses interleaved weights; non-gated uses plain weights) - Optional quantization fusion for Float4E2M1FN output with scale factor generation - Warp specialization: Scheduler (warp 10), A Sync Transform (warp 11, only used when use_2cta_instrs is True), LDGSTS A/SFA (warps 4-7), TMA B/SFB (warp 9), MMA (warp 8), @@ -389,9 +258,10 @@ def __init__( topk: cutlass.Int64, raster_along_m: bool = False, b_tensor_l_sizes: Optional[Tuple[int, ...]] = None, + is_gated: bool = True, ): """Initializes the configuration for a Blackwell blockscaled dense GEMM kernel with - gather operation and SwiGLU fusion. + gather operation and fused activation (SwiGLU when is_gated=True, Relu2 otherwise). This configuration includes several key aspects: @@ -432,6 +302,7 @@ def __init__( self.sf_vec_size = sf_vec_size self.topk = topk + self.is_gated = is_gated self.acc_dtype = cutlass.Float32 self.use_2cta_instrs = mma_tiler_mn[0] == 256 self.cluster_shape_mn = cluster_shape_mn @@ -599,7 +470,7 @@ def _setup_attributes(self): self.mma_tiler_c = ( self.mma_inst_shape_mn[0], - self.mma_inst_shape_mn[1] // 2, + self.mma_inst_shape_mn[1] // 2 if self.is_gated else self.mma_inst_shape_mn[1], mma_inst_shape_k * mma_inst_tile_k, ) @@ -716,7 +587,7 @@ def _setup_attributes(self): else self.cta_tile_shape_mnk[1] * 2 - self.num_sf_tmem_cols ) - self.epi_tile_n_required = 2 * cute.size(self.epi_tile[1]) + self.epi_tile_n_required = (2 if self.is_gated else 1) * cute.size(self.epi_tile[1]) # Only when overlapping_accum is enabled, we need to release accumulator buffer early in epilogue self.iter_acc_early_release_in_epilogue = self.num_sf_tmem_cols // self.epi_tile_n_required @@ -739,17 +610,20 @@ def __call__( stream: cuda.CUstream, epilogue_op: cutlass.Constexpr = lambda x: x, ): - """Execute the contiguous grouped GEMM with gather operation and SwiGLU fusion. + """Execute the contiguous grouped GEMM with gather operation and fused activation. This method performs FC1 layer computation: 1. GEMM: acc = alpha * (SFA * A[token_ids]) * (SFB * B) - 2. SwiGLU: C = up * silu(gate), where up/gate are extracted from interleaved acc (granularity=64) + 2. Activation (selected by ``is_gated`` at construction): + - Gated (SwiGLU): C = up * silu(gate), up/gate from interleaved acc (granularity=64) + - Non-gated (Relu2): C = relu(acc)^2 3. Optional Quant: When c_dtype is Float4E2M1FN, generates SFC and quantizes output Data loading: - A and SFA are loaded using LDGSTS instructions with token-based gather - B and SFB are loaded using TMA instructions with multicast - - B weights are interleaved: [up_0:64, gate_64:128, up_128:192, gate_192:256, ...] + - Gated: B weights are interleaved [up_0:64, gate_64:128, up_128:192, gate_192:256, ...]; + non-gated: B weights are plain. Execution steps: 1. Setup static attributes before smem/grid computation @@ -763,13 +637,14 @@ def __call__( shared memory when use_2cta_instrs is True - TMA warp: Load B and SFB with multicast - MMA warp: Perform matrix multiply-accumulate - - Epilogue warps: Apply SwiGLU activation, optional quantization, and store results + - Epilogue warps: Apply fused activation, optional quantization, and store results :param a: Input tensor A (MxKx1), will be gathered using token_id_mapping :type a: cute.Tensor - :param b: Input tensor B (NxKxL), L is the number of experts/groups, weights are interleaved for SwiGLU + :param b: Input tensor B (NxKxL), L is the number of experts/groups; gated mode uses + interleaved up/gate weights, non-gated uses plain weights. :type b: cute.Tensor - :param c: Output tensor C (Mx(N/2)x1), N is halved due to SwiGLU fusion + :param c: Output tensor C; last dim is N/2 for gated (SwiGLU), N for non-gated (Relu2). :type c: cute.Tensor :param sfa: Scale factor tensor A, will be gathered using token_id_mapping :type sfa: cute.Tensor @@ -1603,8 +1478,8 @@ def kernel( sInfo[(4, tile_info_producer_state.index)] = mn_limit # fence view async shared cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.sched_sync_barrier.arrive_and_wait() @@ -1633,8 +1508,8 @@ def kernel( sInfo[(4, tile_info_producer_state.index)] = mn_limit # fence view async shared cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.sched_sync_barrier.arrive_and_wait() @@ -1654,8 +1529,8 @@ def kernel( sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(0) sInfo[(4, tile_info_producer_state.index)] = -1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.sched_sync_barrier.arrive_and_wait() tile_info_pipeline.producer_commit(tile_info_producer_state) @@ -1747,8 +1622,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1902,8 +1777,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1944,8 +1819,8 @@ def kernel( valid_tile_info[0] = sInfo[(3, tile_info_consumer_state.index)] is_valid_tile = valid_tile_info[0] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1979,8 +1854,8 @@ def kernel( valid_tile_info[0] = sInfo[(3, tile_info_consumer_state.index)] is_valid_tile = valid_tile_info[0] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -2020,8 +1895,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -2302,8 +2177,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -2403,8 +2278,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -2605,8 +2480,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -2715,8 +2590,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -2815,35 +2690,45 @@ def kernel( bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) # - # Process accumulator subtiles with SwiGLU fusion and store to global memory - # Each iteration processes a pair of subtiles (up, gate) and computes - # up * silu(gate) + # Process accumulator subtiles with activation fusion and store to global memory. + # - Gated (SwiGLU): processes pairs of subtiles (up, gate), computes up * silu(gate) + # - Non-gated (e.g. Relu2): processes individual subtiles, computes relu(alpha*x)^2 # subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) - for subtile_idx in cutlass.range(0, subtile_cnt, 2): - real_subtile_idx = subtile_idx // 2 + for subtile_idx in cutlass.range(0, subtile_cnt, 2 if self.is_gated else 1): + if cutlass.const_expr(self.is_gated): + real_subtile_idx = subtile_idx // 2 + else: + real_subtile_idx = subtile_idx if cutlass.const_expr(self.overlapping_accum): if reverse_subtile: real_subtile_idx = ( self.cta_tile_shape_mnk[1] // self.epi_tile_n_required - 1 - - subtile_idx // 2 + - real_subtile_idx ) - # - # Load accumulator from tensor memory buffer to register - # - tTR_tAcc_mn_up = tTR_tAcc[(None, None, None, real_subtile_idx * 2)] - tTR_tAcc_mn_gate = tTR_tAcc[(None, None, None, real_subtile_idx * 2 + 1)] - cute.copy(tiled_copy_t2r, tTR_tAcc_mn_up, tTR_rAcc_up) - cute.copy(tiled_copy_t2r, tTR_tAcc_mn_gate, tTR_rAcc_gate) + if cutlass.const_expr(self.is_gated): + # + # Gated: Load pair of accumulator subtiles (up, gate) + # + tTR_tAcc_mn_up = tTR_tAcc[(None, None, None, real_subtile_idx * 2)] + tTR_tAcc_mn_gate = tTR_tAcc[(None, None, None, real_subtile_idx * 2 + 1)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn_up, tTR_rAcc_up) + cute.copy(tiled_copy_t2r, tTR_tAcc_mn_gate, tTR_rAcc_gate) + else: + # + # Non-gated: Load single accumulator subtile + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, real_subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc_up) # # Async arrive accumulator buffer empty earlier when overlapping_accum is enabled # if cutlass.const_expr(self.overlapping_accum): - if subtile_idx // 2 == self.iter_acc_early_release_in_epilogue: + if real_subtile_idx == self.iter_acc_early_release_in_epilogue: # Fence for TMEM load cute.arch.fence_view_async_tmem_load() with cute.arch.elect_one(): @@ -2851,71 +2736,94 @@ def kernel( acc_consumer_state.advance() acc_vec_up = tTR_rAcc_up.load() - acc_vec_gate = tTR_rAcc_gate.load() - # - # SwiGLU activation: output = up * silu(gate) - # where silu(x) = x * sigmoid(x) - # up and gate are extracted from interleaved accumulator subtiles - # - tCompute = cute.make_rmem_tensor(acc_vec_gate.shape, self.acc_dtype) - if cutlass.const_expr(self.vectorized_f32): - # SwiGLU Packed Version: uses f32x2 packed operations for better performance - # Computes: output = (alpha * up) * silu(alpha * gate) - # where silu(x) = x * sigmoid(x) = x / (1 + exp(-x)) - LOG2_E = cutlass.Float32(1.4426950408889634) - for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc_up), 2): - acc_vec_up_alpha = cute.arch.mul_packed_f32x2( - (acc_vec_up[i], acc_vec_up[i + 1]), - (cutlass.Float32(alpha_val), cutlass.Float32(alpha_val)), - ) - acc_vec_gate_alpha = cute.arch.mul_packed_f32x2( - (acc_vec_gate[i], acc_vec_gate[i + 1]), - (cutlass.Float32(alpha_val), cutlass.Float32(alpha_val)), - ) - tCompute_log2e = cute.arch.mul_packed_f32x2( - (acc_vec_gate_alpha[0], acc_vec_gate_alpha[1]), (-LOG2_E, -LOG2_E) - ) - ( - tCompute[i], - tCompute[i + 1], - ) = cute.arch.add_packed_f32x2( + tCompute = cute.make_rmem_tensor(acc_vec_up.shape, self.acc_dtype) + if cutlass.const_expr(self.is_gated): + acc_vec_gate = tTR_rAcc_gate.load() + # + # SwiGLU activation: output = up * silu(gate) + # where silu(x) = x * sigmoid(x) + # up and gate are extracted from interleaved accumulator subtiles + # + if cutlass.const_expr(self.vectorized_f32): + # SwiGLU Packed Version: uses f32x2 packed operations for better performance + LOG2_E = cutlass.Float32(1.4426950408889634) + for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc_up), 2): + acc_vec_up_alpha = cute.arch.mul_packed_f32x2( + (acc_vec_up[i], acc_vec_up[i + 1]), + (cutlass.Float32(alpha_val), cutlass.Float32(alpha_val)), + ) + acc_vec_gate_alpha = cute.arch.mul_packed_f32x2( + (acc_vec_gate[i], acc_vec_gate[i + 1]), + (cutlass.Float32(alpha_val), cutlass.Float32(alpha_val)), + ) + tCompute_log2e = cute.arch.mul_packed_f32x2( + (acc_vec_gate_alpha[0], acc_vec_gate_alpha[1]), + (-LOG2_E, -LOG2_E), + ) ( - cute.math.exp2(tCompute_log2e[0], fastmath=True), - cute.math.exp2(tCompute_log2e[1], fastmath=True), - ), - (1.0, 1.0), - ) - tCompute[i] = cute.arch.rcp_approx(tCompute[i]) - tCompute[i + 1] = cute.arch.rcp_approx(tCompute[i + 1]) - ( - tCompute[i], - tCompute[i + 1], - ) = cute.arch.mul_packed_f32x2( - (tCompute[i], tCompute[i + 1]), - (acc_vec_gate_alpha[0], acc_vec_gate_alpha[1]), - ) - ( - tCompute[i], - tCompute[i + 1], - ) = cute.arch.mul_packed_f32x2( - (tCompute[i], tCompute[i + 1]), - (acc_vec_up_alpha[0], acc_vec_up_alpha[1]), - ) + tCompute[i], + tCompute[i + 1], + ) = cute.arch.add_packed_f32x2( + ( + cute.math.exp2(tCompute_log2e[0], fastmath=True), + cute.math.exp2(tCompute_log2e[1], fastmath=True), + ), + (1.0, 1.0), + ) + tCompute[i] = cute.arch.rcp_approx(tCompute[i]) + tCompute[i + 1] = cute.arch.rcp_approx(tCompute[i + 1]) + ( + tCompute[i], + tCompute[i + 1], + ) = cute.arch.mul_packed_f32x2( + (tCompute[i], tCompute[i + 1]), + (acc_vec_gate_alpha[0], acc_vec_gate_alpha[1]), + ) + ( + tCompute[i], + tCompute[i + 1], + ) = cute.arch.mul_packed_f32x2( + (tCompute[i], tCompute[i + 1]), + (acc_vec_up_alpha[0], acc_vec_up_alpha[1]), + ) + else: + # SwiGLU Unpacked Version: scalar operations + for i in cutlass.range_constexpr(cute.size(tTR_rAcc_up)): + acc_vec_up_alpha = acc_vec_up[i] * cutlass.Float32(alpha_val) + acc_vec_gate_alpha = acc_vec_gate[i] * cutlass.Float32(alpha_val) + tCompute[i] = acc_vec_up_alpha * silu_f32( + acc_vec_gate_alpha, fastmath=True + ) else: - # SwiGLU Unpacked Version: scalar operations - # Computes: output = (alpha * up) * silu(alpha * gate) - for i in cutlass.range_constexpr(cute.size(tTR_rAcc_up)): - acc_vec_up_alpha = acc_vec_up[i] * cutlass.Float32(alpha_val) - acc_vec_gate_alpha = acc_vec_gate[i] * cutlass.Float32(alpha_val) - tCompute[i] = acc_vec_up_alpha * silu_f32( - acc_vec_gate_alpha, fastmath=True - ) + # + # Non-gated activation (relu2): output = relu(alpha * x)^2 + # + if cutlass.const_expr(self.vectorized_f32): + for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc_up), 2): + scaled = cute.arch.mul_packed_f32x2( + (acc_vec_up[i], acc_vec_up[i + 1]), + (cutlass.Float32(alpha_val), cutlass.Float32(alpha_val)), + ) + relu0 = cute.arch.fmax(scaled[0], 0.0) + relu1 = cute.arch.fmax(scaled[1], 0.0) + ( + tCompute[i], + tCompute[i + 1], + ) = cute.arch.mul_packed_f32x2( + (relu0, relu1), + (relu0, relu1), + ) + else: + for i in cutlass.range_constexpr(cute.size(tTR_rAcc_up)): + scaled = acc_vec_up[i] * cutlass.Float32(alpha_val) + relu_val = cute.arch.fmax(scaled, 0.0) + tCompute[i] = relu_val * relu_val if cutlass.const_expr(self.generate_sfc): # # Quantization path for Float4E2M1FN output: - # 1. Compute per-vector absolute max from SwiGLU result + # 1. Compute per-vector absolute max from activation result # 2. Generate scale factor C (SFC) based on max values # 3. Store SFC to global memory # 4. Quantize output by scaling with reciprocal of SFC @@ -3052,8 +2960,8 @@ def kernel( ) # Fence and barrier to make sure shared memory store is visible to TMA store cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.epilog_sync_barrier.arrive_and_wait() # @@ -3086,8 +2994,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -3750,6 +3658,7 @@ def wrapper( max_active_clusters: cutlass.Constexpr, stream: cuda.CUstream, epilogue_op: cutlass.Constexpr = lambda x: x, + is_gated: cutlass.Constexpr = True, ): """Unified wrapper supporting both single-B and multi-B tensors. @@ -3757,7 +3666,7 @@ def wrapper( L sizes are configured via b_tensor_l_sizes in __init__. """ scale_k = k // scaling_vector_size - interm_size = n // 2 + interm_size = n // 2 if is_gated else n num_tiles = m // tile_size total_l = self.b_tensor_l_offsets[self.num_b_tensors] diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm.py index a5571f616dda..4704a55d3117 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm.py @@ -60,143 +60,6 @@ ) -def hooked_PersistentTileSchedulerParams_init( - self, - problem_shape_ntile_mnl: cute.Shape, - cluster_shape_mnk: cute.Shape, - swizzle_size: int = 1, - raster_along_m: bool = True, - *, - loc=None, - ip=None, -): - if cluster_shape_mnk[2] != 1: - raise ValueError(f"unsupported cluster_shape_k {cluster_shape_mnk[2]}") - if swizzle_size < 1: - raise ValueError(f"expect swizzle_size >= 1, but get {swizzle_size}") - - self.problem_shape_ntile_mnl = problem_shape_ntile_mnl - # cluster_shape_mnk is kept for reconstruction - self._cluster_shape_mnk = cluster_shape_mnk - self.cluster_shape_mn = cluster_shape_mnk[:2] - self.swizzle_size = swizzle_size - self._raster_along_m = raster_along_m - self._loc = loc - - # Apply swizzle if swizzle_size > 1 - if swizzle_size > 1: - problem_shape_ncluster_mnl = cute.round_up( - self.problem_layout_ncluster_mnl.shape, - (1, swizzle_size, 1) if raster_along_m else (swizzle_size, 1, 1), - ) - - if raster_along_m: - self.problem_layout_ncluster_mnl = cute.make_layout( - ( - problem_shape_ncluster_mnl[0], - (swizzle_size, problem_shape_ncluster_mnl[1] // swizzle_size), - problem_shape_ncluster_mnl[2], - ), - stride=( - swizzle_size, - (1, swizzle_size * problem_shape_ncluster_mnl[0]), - problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1], - ), - loc=loc, - ip=ip, - ) - else: - self.problem_layout_ncluster_mnl = cute.make_layout( - ( - (swizzle_size, problem_shape_ncluster_mnl[0] // swizzle_size), - problem_shape_ncluster_mnl[1], - problem_shape_ncluster_mnl[2], - ), - stride=( - (1, swizzle_size * problem_shape_ncluster_mnl[1]), - swizzle_size, - problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1], - ), - loc=loc, - ip=ip, - ) - - # Create FastDivmod divisors (only when swizzle_size == 1 for correctness) - # FastDivmod assumes simple col-major/row-major layout, incompatible with swizzled layouts - if swizzle_size == 1: - problem_shape_ncluster_mnl = cute.ceil_div( - self.problem_shape_ntile_mnl, cluster_shape_mnk[:2], loc=loc, ip=ip - ) - if raster_along_m: - self.problem_layout_ncluster_mnl = cute.make_layout( - problem_shape_ncluster_mnl, - stride=( - 1, - problem_shape_ncluster_mnl[0], - problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1], - ), - loc=loc, - ip=ip, - ) - else: - self.problem_layout_ncluster_mnl = cute.make_layout( - problem_shape_ncluster_mnl, - stride=( - problem_shape_ncluster_mnl[1], - 1, - problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1], - ), - loc=loc, - ip=ip, - ) - problem_layout_size = cute.size(self.problem_layout_ncluster_mnl, loc=loc, ip=ip) - cluster_count_m = self.problem_layout_ncluster_mnl.shape[0] - cluster_count_n = self.problem_layout_ncluster_mnl.shape[1] - - # batch_fdd: Used to map linear_idx to work_unit_id (handles persistent scheduling) - self.batch_fdd = cute.fast_divmod_create_divisor(problem_layout_size, loc=loc, ip=ip) - - # cluster_shape_m_fdd: Used to decode work_unit_id to cluster coordinates - self.cluster_shape_m_fdd = cute.fast_divmod_create_divisor(cluster_count_m, loc=loc, ip=ip) - - # cluster_shape_n_fdd: Used for the second level decomposition - self.cluster_shape_n_fdd = cute.fast_divmod_create_divisor(cluster_count_n, loc=loc, ip=ip) - else: - # FastDivmod not applicable with swizzling, set to None - self.batch_fdd = None - self.cluster_shape_m_fdd = None - self.cluster_shape_n_fdd = None - - -def hooked_get_cluster_work_idx_with_fastdivmod( - self, current_work_linear_idx: cutlass.Int32, *, loc=None, ip=None -) -> Tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32]: - work_iteration, work_unit_id = divmod(current_work_linear_idx, self.params.batch_fdd) - - if self.params._raster_along_m: - # raster_along_m=True means column major (m is fastest) - # First, get cluster_m using cluster_shape_m_fdd - cluster_n_batch, cluster_m = divmod(work_unit_id, self.params.cluster_shape_m_fdd) - - # Then decode cluster_n_batch to get cluster_n and batch_l using FastDivmod - batch_l, cluster_n = divmod(cluster_n_batch, self.params.cluster_shape_n_fdd) - else: - # raster_along_m=False means row major (n is fastest) - # First, get cluster_n using cluster_shape_n_fdd - cluster_m_batch, cluster_n = divmod(work_unit_id, self.params.cluster_shape_n_fdd) - - # Then decode cluster_m_batch to get cluster_m and batch_l using FastDivmod - batch_l, cluster_m = divmod(cluster_m_batch, self.params.cluster_shape_m_fdd) - - return (cluster_m, cluster_n, batch_l) - - -cutlass.utils.PersistentTileSchedulerParams.__init__ = hooked_PersistentTileSchedulerParams_init -cutlass.utils.StaticPersistentTileScheduler._get_cluster_work_idx_with_fastdivmod = ( - hooked_get_cluster_work_idx_with_fastdivmod -) - - class Sm100BlockScaledContiguousGroupedGemmKernel: """This class implements batched matrix multiplication (C = A x SFA x B x SFB) with support for various data types and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization. @@ -1162,8 +1025,8 @@ def kernel( ) # fence view async shared cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.sched_sync_barrier.arrive_and_wait() @@ -1192,8 +1055,8 @@ def kernel( ) # fence view async shared cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.sched_sync_barrier.arrive_and_wait() @@ -1213,8 +1076,8 @@ def kernel( sInfo[(2, tile_info_producer_state.index)] = -1 sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(0) cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.sched_sync_barrier.arrive_and_wait() tile_info_pipeline.producer_commit(tile_info_producer_state) @@ -1250,8 +1113,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1348,8 +1211,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1443,8 +1306,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1609,8 +1472,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1695,8 +1558,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1801,8 +1664,8 @@ def kernel( ) # Fence and barrier to make sure shared memory store is visible to TMA store cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.epilog_sync_barrier.arrive_and_wait() # @@ -1835,8 +1698,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py index babf3dbcb261..3a8f1b16d3a6 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py @@ -158,144 +158,6 @@ """ -# TODO(zhichenj): Remove this hook helper function after nvidia-cutlass-dsl 4.4 is released. -def hooked_PersistentTileSchedulerParams_init( - self, - problem_shape_ntile_mnl: cute.Shape, - cluster_shape_mnk: cute.Shape, - swizzle_size: int = 1, - raster_along_m: bool = True, - *, - loc=None, - ip=None, -): - if cluster_shape_mnk[2] != 1: - raise ValueError(f"unsupported cluster_shape_k {cluster_shape_mnk[2]}") - if swizzle_size < 1: - raise ValueError(f"expect swizzle_size >= 1, but get {swizzle_size}") - - self.problem_shape_ntile_mnl = problem_shape_ntile_mnl - # cluster_shape_mnk is kept for reconstruction - self._cluster_shape_mnk = cluster_shape_mnk - self.cluster_shape_mn = cluster_shape_mnk[:2] - self.swizzle_size = swizzle_size - self._raster_along_m = raster_along_m - self._loc = loc - - # Apply swizzle if swizzle_size > 1 - if swizzle_size > 1: - problem_shape_ncluster_mnl = cute.round_up( - self.problem_layout_ncluster_mnl.shape, - (1, swizzle_size, 1) if raster_along_m else (swizzle_size, 1, 1), - ) - - if raster_along_m: - self.problem_layout_ncluster_mnl = cute.make_layout( - ( - problem_shape_ncluster_mnl[0], - (swizzle_size, problem_shape_ncluster_mnl[1] // swizzle_size), - problem_shape_ncluster_mnl[2], - ), - stride=( - swizzle_size, - (1, swizzle_size * problem_shape_ncluster_mnl[0]), - problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1], - ), - loc=loc, - ip=ip, - ) - else: - self.problem_layout_ncluster_mnl = cute.make_layout( - ( - (swizzle_size, problem_shape_ncluster_mnl[0] // swizzle_size), - problem_shape_ncluster_mnl[1], - problem_shape_ncluster_mnl[2], - ), - stride=( - (1, swizzle_size * problem_shape_ncluster_mnl[1]), - swizzle_size, - problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1], - ), - loc=loc, - ip=ip, - ) - - # Create FastDivmod divisors (only when swizzle_size == 1 for correctness) - # FastDivmod assumes simple col-major/row-major layout, incompatible with swizzled layouts - if swizzle_size == 1: - problem_shape_ncluster_mnl = cute.ceil_div( - self.problem_shape_ntile_mnl, cluster_shape_mnk[:2], loc=loc, ip=ip - ) - if raster_along_m: - self.problem_layout_ncluster_mnl = cute.make_layout( - problem_shape_ncluster_mnl, - stride=( - 1, - problem_shape_ncluster_mnl[0], - problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1], - ), - loc=loc, - ip=ip, - ) - else: - self.problem_layout_ncluster_mnl = cute.make_layout( - problem_shape_ncluster_mnl, - stride=( - problem_shape_ncluster_mnl[1], - 1, - problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1], - ), - loc=loc, - ip=ip, - ) - problem_layout_size = cute.size(self.problem_layout_ncluster_mnl, loc=loc, ip=ip) - cluster_count_m = self.problem_layout_ncluster_mnl.shape[0] - cluster_count_n = self.problem_layout_ncluster_mnl.shape[1] - - # batch_fdd: Used to map linear_idx to work_unit_id (handles persistent scheduling) - self.batch_fdd = cute.fast_divmod_create_divisor(problem_layout_size, loc=loc, ip=ip) - - # cluster_shape_m_fdd: Used to decode work_unit_id to cluster coordinates - self.cluster_shape_m_fdd = cute.fast_divmod_create_divisor(cluster_count_m, loc=loc, ip=ip) - - # cluster_shape_n_fdd: Used for the second level decomposition - self.cluster_shape_n_fdd = cute.fast_divmod_create_divisor(cluster_count_n, loc=loc, ip=ip) - else: - # FastDivmod not applicable with swizzling, set to None - self.batch_fdd = None - self.cluster_shape_m_fdd = None - self.cluster_shape_n_fdd = None - - -def hooked_get_cluster_work_idx_with_fastdivmod( - self, current_work_linear_idx: cutlass.Int32, *, loc=None, ip=None -) -> Tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32]: - work_iteration, work_unit_id = divmod(current_work_linear_idx, self.params.batch_fdd) - - if self.params._raster_along_m: - # raster_along_m=True means column major (m is fastest) - # First, get cluster_m using cluster_shape_m_fdd - cluster_n_batch, cluster_m = divmod(work_unit_id, self.params.cluster_shape_m_fdd) - - # Then decode cluster_n_batch to get cluster_n and batch_l using FastDivmod - batch_l, cluster_n = divmod(cluster_n_batch, self.params.cluster_shape_n_fdd) - else: - # raster_along_m=False means row major (n is fastest) - # First, get cluster_n using cluster_shape_n_fdd - cluster_m_batch, cluster_n = divmod(work_unit_id, self.params.cluster_shape_n_fdd) - - # Then decode cluster_m_batch to get cluster_m and batch_l using FastDivmod - batch_l, cluster_m = divmod(cluster_m_batch, self.params.cluster_shape_m_fdd) - - return (cluster_m, cluster_n, batch_l) - - -cutlass.utils.PersistentTileSchedulerParams.__init__ = hooked_PersistentTileSchedulerParams_init -cutlass.utils.StaticPersistentTileScheduler._get_cluster_work_idx_with_fastdivmod = ( - hooked_get_cluster_work_idx_with_fastdivmod -) - - class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: """This class implements batched matrix multiplication (C = A x SFA x B x SFB) with support for various data types and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization. @@ -1490,8 +1352,8 @@ def kernel( sInfo[(4, tile_info_producer_state.index)] = mn_limit # fence view async shared cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.sched_sync_barrier.arrive_and_wait() @@ -1520,8 +1382,8 @@ def kernel( sInfo[(4, tile_info_producer_state.index)] = mn_limit # fence view async shared cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.sched_sync_barrier.arrive_and_wait() @@ -1542,8 +1404,8 @@ def kernel( sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(0) sInfo[(4, tile_info_producer_state.index)] = cutlass.Int32(0) cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.sched_sync_barrier.arrive_and_wait() tile_info_pipeline.producer_commit(tile_info_producer_state) @@ -1571,8 +1433,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1877,8 +1739,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1963,8 +1825,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -2114,8 +1976,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -2182,8 +2044,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -2333,8 +2195,8 @@ def kernel( if cutlass.const_expr(self.use_blkred): cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) # # Async arrive accumulator buffer empty @@ -2347,8 +2209,8 @@ def kernel( if cutlass.const_expr(self.use_blkred): cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) if is_valid_row: coord_n = mma_tile_coord_mnl[1] * self.cta_tile_shape_mnk[1] @@ -2380,8 +2242,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_swiglu_fusion.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_swiglu_fusion.py index 4dcd157e993d..2aeb4e1df530 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_swiglu_fusion.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_swiglu_fusion.py @@ -1129,8 +1129,8 @@ def kernel( ) # fence view async shared cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.sched_sync_barrier.arrive_and_wait() @@ -1147,8 +1147,8 @@ def kernel( sInfo[(2, tile_info_producer_state.index)] = -1 sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(0) cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.sched_sync_barrier.arrive_and_wait() tile_info_pipeline.producer_commit(tile_info_producer_state) @@ -1183,8 +1183,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1282,8 +1282,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1375,8 +1375,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1533,8 +1533,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1643,8 +1643,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1951,8 +1951,8 @@ def kernel( ) # Fence and barrier to make sure shared memory store is visible to TMA store cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.epilog_sync_barrier.arrive_and_wait() # @@ -1985,8 +1985,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockwise_gemm/blockwise_gemm.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockwise_gemm/blockwise_gemm.py index 43b64cbeb21b..b9338d4464c9 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockwise_gemm/blockwise_gemm.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockwise_gemm/blockwise_gemm.py @@ -982,8 +982,8 @@ def kernel( # fence view async shared cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.sched_sync_barrier.arrive_and_wait() # commit tile info pipeline @@ -1090,8 +1090,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1259,8 +1259,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1407,8 +1407,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1632,8 +1632,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1796,8 +1796,8 @@ def kernel( ) # Fence and barrier to make sure shared memory store is visible to TMA store cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.epilog_sync_barrier.arrive_and_wait() @@ -1829,8 +1829,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/custom_pipeline.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/custom_pipeline.py index 42791b3d994a..760a26dfe774 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/custom_pipeline.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/custom_pipeline.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # Redistribution and use in source and binary forms, with or without @@ -48,8 +48,17 @@ import cutlass.cute as cute from cutlass.cutlass_dsl import Boolean, if_generate +# nvidia-cutlass-dsl 4.4.2 split the sync-object factory: sm90's +# PipelineAsync._make_sync_object no longer accepts Blackwell ops like +# TCGen05Mma/ClcLoad. The sm100 PipelineTmaUmma provides the expanded variant +# that handles every op used by the custom pipelines below. Alias to avoid +# colliding with the local PipelineTmaUmma defined in this module. from cutlass.pipeline import (Agent, CooperativeGroup, PipelineAsync, - PipelineOp, PipelineState, agent_sync) + PipelineOp, PipelineState) +from cutlass.pipeline import PipelineTmaUmma as _Sm100PipelineFactory +from cutlass.pipeline import agent_sync + +_make_sync_object = _Sm100PipelineFactory._make_sync_object def pipeline_init_wait(cta_layout_vmnk: Optional[cute.Layout] = None): @@ -179,9 +188,9 @@ def create( producer = (producer_type, producer_group) consumer = (consumer_type, consumer_group) - sync_object_full = PipelineAsync._make_sync_object( - barrier_storage.align(min_align=8), num_stages, producer, tx_count) - sync_object_empty = PipelineAsync._make_sync_object( + sync_object_full = _make_sync_object(barrier_storage.align(min_align=8), + num_stages, producer, tx_count) + sync_object_empty = _make_sync_object( barrier_storage.align(min_align=8) + num_stages, num_stages, consumer) @@ -214,7 +223,7 @@ def create( cta_group, ) - def consumer_release(self, state: PipelineState): + def consumer_release(self, state: PipelineState, *, loc=None, ip=None): """ UMMA consumer release buffer empty, cta_group needs to be provided. @@ -225,12 +234,18 @@ def consumer_release(self, state: PipelineState): Returns: None """ - self.sync_object_empty.arrive(state.index, self.consumer_mask, - self.cta_group) + self.sync_object_empty.arrive(state.index, + self.consumer_mask, + self.cta_group, + loc=loc, + ip=ip) def producer_acquire(self, state: PipelineState, - try_acquire_token: Optional[Boolean] = None): + try_acquire_token: Optional[Boolean] = None, + *, + loc=None, + ip=None): """ Conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. @@ -246,15 +261,20 @@ def producer_acquire(self, """ if_generate( try_acquire_token is None or try_acquire_token == 0, - lambda: self.sync_object_empty.wait(state.index, state.phase), + lambda: self.sync_object_empty.wait( + state.index, state.phase, loc=loc, ip=ip), + loc=loc, + ip=ip, ) if_generate( self.is_leader_cta, - lambda: self.sync_object_full.arrive(state.index, self.producer_mask - ), + lambda: self.sync_object_full.arrive( + state.index, self.producer_mask, loc=loc, ip=ip), + loc=loc, + ip=ip, ) - def producer_commit(self, state: PipelineState): + def producer_commit(self, state: PipelineState, *, loc=None, ip=None): """ TMA producer commit is a noop since TMA instruction itself updates the transaction count. @@ -323,9 +343,9 @@ def create( producer = (producer_type, producer_group) consumer = (consumer_type, consumer_group) - sync_object_full = PipelineAsync._make_sync_object( - barrier_storage.align(min_align=8), num_stages, producer) - sync_object_empty = PipelineAsync._make_sync_object( + sync_object_full = _make_sync_object(barrier_storage.align(min_align=8), + num_stages, producer) + sync_object_empty = _make_sync_object( barrier_storage.align(min_align=8) + num_stages, num_stages, consumer) @@ -357,23 +377,26 @@ def create( cta_group, ) - def producer_commit(self, state: PipelineState): - self.sync_object_full.arrive(state.index, self.producer_mask, - self.cta_group) + def producer_commit(self, state: PipelineState, *, loc=None, ip=None): + self.sync_object_full.arrive(state.index, + self.producer_mask, + self.cta_group, + loc=loc, + ip=ip) - def producer_tail(self, state: PipelineState): + def producer_tail(self, state: PipelineState, *, loc=None, ip=None): cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster()) + cute.arch.block_idx_in_cluster(loc=loc, ip=ip), loc=loc, ip=ip) is_leader_cta = cta_rank_in_cluster % 2 == 0 def then_body(): # Assume state contains that next useful buffer # So we only need to advance to num_stages - 1 times to last used buffer for i in range(self.num_stages - 1): - state.advance() - self.producer_acquire(state) + state.advance(loc=loc, ip=ip) + self.producer_acquire(state, loc=loc, ip=ip) - if_generate(is_leader_cta, then_body) + if_generate(is_leader_cta, then_body, loc=loc, ip=ip) @dataclass(frozen=True) @@ -473,12 +496,12 @@ def create( producer = (producer_type, producer_group) consumer = (consumer_type, consumer_group) - sync_object_full = PipelineAsync._make_sync_object( + sync_object_full = _make_sync_object( barrier_storage.align(min_align=8), num_stages, producer, ) - sync_object_empty = PipelineAsync._make_sync_object( + sync_object_empty = _make_sync_object( barrier_storage.align(min_align=8) + num_stages, num_stages, consumer) @@ -515,9 +538,12 @@ def create( cta_group, ) - def consumer_release(self, state: PipelineState): + def consumer_release(self, state: PipelineState, *, loc=None, ip=None): """ UMMA consumer release buffer empty, cta_group needs to be provided. """ - self.sync_object_empty.arrive(state.index, self.consumer_mask, - self.cta_group) + self.sync_object_empty.arrive(state.index, + self.consumer_mask, + self.cta_group, + loc=loc, + ip=ip) diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py index ccfc548373b3..414984446547 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py @@ -1415,8 +1415,8 @@ def kernel( ) # Fence and barrier to make sure shared memory store is visible to TMA store cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) epilog_threads = 32 * len(self.epilog_warp_id) cute.arch.barrier( diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc1.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc1.py index b67a3100be4b..de35f01e35cb 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc1.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc1.py @@ -1765,8 +1765,8 @@ def kernel( ) # Fence and barrier to make sure shared memory store is visible to TMA store cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.epilog_sync_barrier.arrive_and_wait() diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc2.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc2.py index 8a26748c4e80..f5dcec19d49b 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc2.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc2.py @@ -1611,8 +1611,8 @@ def kernel( ) # Fence and barrier to make sure shared memory store is visible to TMA store cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.epilog_sync_barrier.arrive_and_wait() diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py index 98f9294d1dcb..fa98d0057844 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py @@ -206,15 +206,22 @@ def fmin(a: Union[float, cutlass.Float32], nan=False, loc=None, ip=None) -> cutlass.Float32: - return cutlass.Float32( - nvvm.fmin( - T.f32(), - cutlass.Float32(a).ir_value(loc=loc, ip=ip), - cutlass.Float32(b).ir_value(loc=loc, ip=ip), - nan=nan, - loc=loc, - ip=ip, - )) + a_ir = cutlass.Float32(a).ir_value(loc=loc, ip=ip) + b_ir = cutlass.Float32(b).ir_value(loc=loc, ip=ip) + if nan: + # CUTLASS DSL 4.4+ dropped the `nan` attribute from nvvm.FminOp; emit the + # NaN-propagating PTX instruction directly to preserve the original semantics. + return cutlass.Float32( + llvm.inline_asm( + T.f32(), + [a_ir, b_ir], + "min.NaN.f32 $0, $1, $2;", + "=f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + )) + return cutlass.Float32(nvvm.fmin(a_ir, b_ir, loc=loc, ip=ip)) def sigmoid_f32(a: Union[float, cutlass.Float32], diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py b/tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py index 5be9a3d59e29..852d8cb4b5d3 100644 --- a/tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py +++ b/tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py @@ -7,6 +7,7 @@ from tensorrt_llm._torch.utils import split +@register_mapper("HF", "NemotronHPuzzleForCausalLM") @register_mapper("HF", "NemotronHForCausalLM") class NemotronHHfWeightMapper(HfWeightMapper): diff --git a/tensorrt_llm/_torch/models/modeling_nemotron_h.py b/tensorrt_llm/_torch/models/modeling_nemotron_h.py index 6fbb0f22b9c4..833d07ab2919 100644 --- a/tensorrt_llm/_torch/models/modeling_nemotron_h.py +++ b/tensorrt_llm/_torch/models/modeling_nemotron_h.py @@ -55,6 +55,27 @@ class NemotronHConfig(PretrainedConfig): model_type = "nemotron_h" +class NemotronHPuzzleConfig(PretrainedConfig): + model_type = "nemotron_h_puzzle" + + +def _bc_getattr(bc, key, default=None): + """Get attribute from a block_config entry (dict or dataclass).""" + if isinstance(bc, dict): + return bc.get(key, default) + return getattr(bc, key, default) + + +def _get_layer_moe_param(config, layer_idx: int, param_name: str): + """Get per-layer MoE parameter, falling back to global config.""" + block_configs = getattr(config, 'block_configs', None) + if block_configs and layer_idx < len(block_configs): + val = _bc_getattr(block_configs[layer_idx], param_name) + if val is not None: + return val + return getattr(config, param_name, None) + + class MLPLayer(MLP): def __init__( @@ -152,22 +173,26 @@ def __init__( self.hidden_dim = config.hidden_size self.ffn_dim = config.intermediate_size self.layer_idx = layer_idx - self.moe_intermediate_size = (config.moe_intermediate_size[0] - if isinstance( - config.moe_intermediate_size, list) - else config.moe_intermediate_size) - self.use_latent_moe: bool = getattr(config, "moe_latent_size", - None) is not None - self.moe_hidden_size: int = (config.moe_latent_size - if self.use_latent_moe else + + # Per-layer MoE params (models with block_configs have varying params). + def _moe(name): + return _get_layer_moe_param(config, layer_idx, name) + + moe_intermediate = _moe('moe_intermediate_size') + self.moe_intermediate_size = (moe_intermediate[0] if isinstance( + moe_intermediate, list) else moe_intermediate) + + moe_latent = _moe('moe_latent_size') + self.use_latent_moe: bool = moe_latent is not None + self.moe_hidden_size: int = (moe_latent if self.use_latent_moe else config.hidden_size) self.mlp_bias = config.mlp_bias if hasattr(config, "mlp_bias") else False self.moe_n_group = config.n_group - self.num_experts = config.n_routed_experts + self.num_experts = _moe('n_routed_experts') self.hidden_size = config.hidden_size self.num_shared_experts = config.n_shared_experts - self.top_k = config.num_experts_per_tok + self.top_k = _moe('num_experts_per_tok') self.enable_attention_dp = model_config.mapping.enable_attention_dp self.routed_scaling_factor = config.routed_scaling_factor self.mapping = model_config.mapping @@ -177,7 +202,7 @@ def __init__( self.shared_experts = None else: shared_expert_intermediate_size = ( - config.moe_shared_expert_intermediate_size * + _moe('moe_shared_expert_intermediate_size') * config.n_shared_experts) self.shared_experts = MLP( @@ -705,6 +730,7 @@ def forward( return hidden_states +@register_auto_model("NemotronHPuzzleForCausalLM") @register_auto_model("NemotronHForCausalLM") class NemotronHForCausalLM(SpecDecOneEngineForCausalLM[NemotronHModel, NemotronHConfig]): @@ -722,6 +748,9 @@ def __init__( raise ValueError("layer_norm_epsilon or rms_norm_eps is not set") model_config.pretrained_config.rms_norm_eps = rms_epsilon + # Normalize per-layer block_configs into global config attributes. + self._normalize_puzzle_config(model_config.pretrained_config) + if (not model_config.mapping.enable_attention_dp and model_config.mapping.tp_size not in [1, 2, 4, 8]): raise ValueError("TP has to be either 1, 2, 4 or 8") @@ -778,6 +807,31 @@ def __init__( self.epilogue.extend(self.draft_model.mtp_layers) self.epilogue.append(self.spec_worker) + @staticmethod + def _normalize_puzzle_config(config): + """Set global MoE defaults from block_configs for models with per-layer MoE params.""" + block_configs = getattr(config, 'block_configs', None) + if not block_configs: + return + + def _is_moe(bc): + return _bc_getattr(bc, 'block_type') == 'moe' + + first_moe = next((bc for bc in block_configs if _is_moe(bc)), None) + if first_moe is None: + return + + # Prefer MTP MoE block as fallback (used for MTP layers beyond + # block_configs range), otherwise use first main-model MoE block. + mtp_configs = getattr(config, 'mtp_block_configs', None) or [] + fallback = next((bc for bc in mtp_configs if _is_moe(bc)), first_moe) + + for attr in ('n_routed_experts', 'moe_intermediate_size', + 'num_experts_per_tok', 'moe_latent_size', + 'moe_shared_expert_intermediate_size'): + if not hasattr(config, attr) or getattr(config, attr) is None: + setattr(config, attr, _bc_getattr(fallback, attr)) + def load_weights(self, weights: dict, weight_mapper: BaseWeightMapper): new_weights = weight_mapper.preprocess_weights(weights) super().load_weights(weights=new_weights, weight_mapper=weight_mapper) @@ -1076,3 +1130,4 @@ def forward( AutoConfig.register(NemotronHConfig.model_type, NemotronHConfig) +AutoConfig.register(NemotronHPuzzleConfig.model_type, NemotronHPuzzleConfig) diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py index 7c103d9b25ec..041fa02c0f3b 100755 --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -804,7 +804,7 @@ def __init__( case "exaone_moe": from .modeling_exaone_moe import ExaoneMoeMTP mtp_layer = ExaoneMoeMTP - case "nemotron_h": + case "nemotron_h" | "nemotron_h_puzzle": from .modeling_nemotron_h import NemotronHMTP mtp_layer = NemotronHMTP case "qwen3_next": diff --git a/tensorrt_llm/_torch/modules/fused_moe/__init__.py b/tensorrt_llm/_torch/modules/fused_moe/__init__.py index 51d6ba5f9475..37c3a3797539 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/__init__.py +++ b/tensorrt_llm/_torch/modules/fused_moe/__init__.py @@ -9,15 +9,19 @@ from .moe_load_balancer import (MoeLoadBalancer, moe_load_balancer_set_repeated_for_next_layer) from .quantization import FusedMoEQuantScalesFP8 +# yapf: disable from .routing import (BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod, DefaultMoeRoutingMethod, Llama4RenormalizeMoeRoutingMethod, LoadBalancedMoeRoutingMethod, MiniMaxM2MoeRoutingMethod, RenormalizeMoeRoutingMethod, RenormalizeNaiveMoeRoutingMethod, RoutingMethodType, + SigmoidRenormMoeRoutingMethod, SparseMixerMoeRoutingMethod, StaticMoeRoutingMethod, create_renormalize_expert_load_balanced_logits) +# yapf: enable + __all__ = [ "BaseMoeRoutingMethod", "create_renormalize_expert_load_balanced_logits", @@ -36,6 +40,7 @@ "MoEWeightLoadingMode", "MiniMaxM2MoeRoutingMethod", "RenormalizeMoeRoutingMethod", + "SigmoidRenormMoeRoutingMethod", "RenormalizeNaiveMoeRoutingMethod", "RoutingMethodType", "SparseMixerMoeRoutingMethod", diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py b/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py index e37d5db10819..53d18135c11b 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py @@ -224,7 +224,32 @@ def __init__( # Initialize or reuse workspace MnnvlMemory.initialize() - if self._WORKSPACE is None: + need_alloc = self._WORKSPACE is None + if not need_alloc: + assert self._WORKSPACE["max_num_tokens_per_rank"] == self.max_num_tokens_per_rank, ( + "reuse workspace with different max_num_tokens_per_rank" + ) + assert self._WORKSPACE["ep_rank"] == self.ep_rank, ( + "reuse workspace with different ep_rank" + ) + assert self._WORKSPACE["ep_size"] == self.ep_size, ( + "reuse workspace with different ep_size" + ) + assert self._WORKSPACE["eplb_stats_num_experts"] == self.eplb_stats_num_experts, ( + "reuse workspace with different eplb_stats_num_experts" + ) + + # Models with per-layer MoE params may request different workspace sizes across layers. + # Reallocate when a larger workspace is needed; reuse otherwise. + if self._WORKSPACE["workspace_size_per_rank"] < self.workspace_size_per_rank: + tllm_logger.info( + f"NVLinkOneSided: Reallocating workspace " + f"{self._WORKSPACE['workspace_size_per_rank']} -> " + f"{self.workspace_size_per_rank} bytes." + ) + need_alloc = True + + if need_alloc: tllm_logger.info( f"NVLinkOneSided: Allocating workspace with size {self.workspace_size_per_rank} bytes." f"ep_rank: {self.ep_rank}, ep_size: {self.ep_size}, top_k: {self.top_k}, max_num_tokens_per_rank: {self.max_num_tokens_per_rank}" @@ -248,26 +273,8 @@ def __init__( "workspace": workspace, "metainfo": metainfo, } - else: - assert self._WORKSPACE["workspace_size_per_rank"] == self.workspace_size_per_rank, ( - "reuse workspace with different workspace_size_per_rank" - ) - assert self._WORKSPACE["max_num_tokens_per_rank"] == self.max_num_tokens_per_rank, ( - "reuse workspace with different max_num_tokens_per_rank" - ) - assert self._WORKSPACE["ep_rank"] == self.ep_rank, ( - "reuse workspace with different ep_rank" - ) - assert self._WORKSPACE["ep_size"] == self.ep_size, ( - "reuse workspace with different ep_size" - ) - assert self._WORKSPACE["eplb_stats_num_experts"] == self.eplb_stats_num_experts, ( - "reuse workspace with different eplb_stats_num_experts" - ) - self.mnnvl_mem = self._WORKSPACE["mnnvl_mem"] - self.workspace = self._WORKSPACE["workspace"] - self.moe_a2a_metainfo = self._WORKSPACE["metainfo"] + # Read max_num_tokens_per_rank from the (possibly grown) workspace. self.max_num_tokens_per_rank = self._WORKSPACE["max_num_tokens_per_rank"] # Initialize dispatch state @@ -276,6 +283,21 @@ def __init__( # Invalid token expert ID (default to -1), the kernels in TRTLLM-gen is hard-code to support -1 only. self.invalid_token_expert_id: int = -1 + # Properties delegate to _WORKSPACE so all instances see the latest + # allocation (workspace may be reallocated when layers need more space). + + @property + def mnnvl_mem(self): + return self._WORKSPACE["mnnvl_mem"] + + @property + def workspace(self): + return self._WORKSPACE["workspace"] + + @property + def moe_a2a_metainfo(self): + return self._WORKSPACE["metainfo"] + @staticmethod def is_platform_supported() -> bool: """ diff --git a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py index 0638904ab8e7..c08661ef73a4 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py @@ -261,6 +261,7 @@ def create_moe_backend( layer_idx=layer_idx, init_load_balancer=init_load_balancer, without_comm=without_comm, + activation_type=activation_type, ) elif moe_cls == DeepGemmFusedMoE: return moe_cls( diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py index 97c3c9072ca8..df32d0f0e62f 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py @@ -27,13 +27,14 @@ OptimizationProfile, TunableRunner, TuningConfig) from ...custom_ops.cute_dsl_custom_ops import ( GroupedGemmInputsHelper, - Sm100BlockScaledContiguousGatherGroupedGemmSwigluFusionRunner, + Sm100BlockScaledContiguousGatherGroupedGemmActFusionRunner, Sm100BlockScaledContiguousGroupedGemmFinalizeFusionRunner, Sm100BlockScaledContiguousGroupedGemmRunner, Sm100BlockScaledContiguousGroupedGemmSwigluFusionRunner) from ...distributed import allgather from ...model_config import ModelConfig -from ...utils import (AuxStreamType, EventType, Fp4QuantizedTensor, +from ...utils import (ActivationType, AuxStreamType, EventType, + Fp4QuantizedTensor, get_last_power_of_2_num_tokens_buckets, last_positive_power_of_2) from .fused_moe_cutlass import CutlassFusedMoE @@ -327,8 +328,7 @@ def runner_tactic_comb_checker( (Sm100BlockScaledContiguousGroupedGemmRunner, Sm100BlockScaledContiguousGroupedGemmFinalizeFusionRunner, Sm100BlockScaledContiguousGroupedGemmSwigluFusionRunner, - Sm100BlockScaledContiguousGatherGroupedGemmSwigluFusionRunner - )): + Sm100BlockScaledContiguousGatherGroupedGemmActFusionRunner)): mma_tiler_mn, *_ = tactic if mma_tiler_mn[0] != tile_size: return False @@ -430,6 +430,7 @@ def __init__( layer_idx: Optional[int] = None, init_load_balancer: bool = True, without_comm: bool = False, + activation_type: ActivationType = ActivationType.Swiglu, ): super().__init__( routing_method=routing_method, @@ -445,6 +446,7 @@ def __init__( layer_idx=layer_idx, init_load_balancer=init_load_balancer, without_comm=without_comm, + activation_type=activation_type, ) if self.aux_stream_dict is None: @@ -625,7 +627,10 @@ def run_moe_nvfp4_impl( moe_output.record_stream( self.aux_stream_dict[AuxStreamType.MoeOutputMemset]) - x, x_sf = torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell( + # Fused gather + GEMM + activation + quantize for FC1. + # For gated (SwiGLU): weights are interleaved [up, gate], output is N/2. + # For non-gated (Relu2): weights are plain, output is N. + x, x_sf = torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_act_fusion_blackwell( input=x.view(torch.float4_e2m1fn_x2), weight=weight_view.w3_w1_weight[0].view(torch.float4_e2m1fn_x2), input_scale=x_sf.view(torch.uint8), @@ -641,6 +646,7 @@ def run_moe_nvfp4_impl( num_local_experts=esp, local_expert_offset=slot_start, tile_size=tile_size, + is_gated=self.is_gated_activation, ) if self.use_fused_finalize: @@ -744,7 +750,7 @@ def run_moe_nvfp4_impl_dwdp( moe_output.record_stream( self.aux_stream_dict[AuxStreamType.MoeOutputMemset]) - x, x_sf = torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b( + x, x_sf = torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_act_fusion_blackwell_multi_b( input=x.view(torch.float4_e2m1fn_x2), weight=[ w.view(torch.float4_e2m1fn_x2) for w in weight_view.w3_w1_weight diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index 23354f5a5b34..4d5a57f5a1b6 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -15,6 +15,7 @@ import inspect import os +from dataclasses import dataclass from functools import cached_property from typing import Dict, List, Optional, Tuple, Union @@ -44,7 +45,16 @@ W4A16MXFP4TRTLLMGenFusedMoEMethod) # isort: on from .routing import (BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod, - DefaultMoeRoutingMethod) + DefaultMoeRoutingMethod, MiniMaxM2MoeRoutingMethod) + + +@dataclass +class RoutingParams: + top_k: int + routing_bias: Optional[torch.Tensor] + n_group: Optional[int] + topk_group: Optional[int] + routed_scaling_factor: Optional[float] class TRTLLMGenFusedMoE(MoE): @@ -528,6 +538,33 @@ def quantize_input(self, x, post_quant_comm: bool = True): def supports_moe_output_in_alltoall_workspace(self): return True + def _extract_routing_params(self) -> RoutingParams: + if isinstance(self.routing_method, DeepSeekV3MoeRoutingMethod): + return RoutingParams( + top_k=self.routing_method.routing_impl.top_k, + routing_bias=self.routing_method.e_score_correction_bias, + n_group=self.routing_method.routing_impl.n_group, + topk_group=self.routing_method.routing_impl.topk_group, + routed_scaling_factor=self.routing_method.routing_impl. + routed_scaling_factor, + ) + elif isinstance(self.routing_method, MiniMaxM2MoeRoutingMethod): + return RoutingParams( + top_k=self.routing_method.top_k, + routing_bias=self.routing_method.e_score_correction_bias, + n_group=None, + topk_group=None, + routed_scaling_factor=None, + ) + else: + return RoutingParams( + top_k=self.routing_method.top_k, + routing_bias=None, + n_group=None, + topk_group=None, + routed_scaling_factor=None, + ) + def run_moe( self, x: torch.Tensor, @@ -565,21 +602,12 @@ def run_moe( If do_finalize=True: final_hidden_states tensor If do_finalize=False: tuple of intermediate outputs (for nvfp4 and w4a8_nvfp4_fp8) """ - # Extract routing parameters from routing_method - if isinstance(self.routing_method, DeepSeekV3MoeRoutingMethod): - top_k = self.routing_method.routing_impl.top_k - routing_bias = self.routing_method.e_score_correction_bias - n_group = self.routing_method.routing_impl.n_group - topk_group = self.routing_method.routing_impl.topk_group - routed_scaling_factor = self.routing_method.routing_impl.routed_scaling_factor - else: - top_k = self.routing_method.top_k - routing_bias = None - n_group = None - topk_group = None - routed_scaling_factor = None - - routing_bias = routing_bias if router_logits is not None else None + routing_params = self._extract_routing_params() + top_k = routing_params.top_k + routing_bias = routing_params.routing_bias if router_logits is not None else None + n_group = routing_params.n_group + topk_group = routing_params.topk_group + routed_scaling_factor = routing_params.routed_scaling_factor if token_selected_experts is not None: # for cases like deepep low latency where fake top_k=1 might be used @@ -594,7 +622,7 @@ def run_moe( if self.has_deepseek_fp8_block_scales: assert do_finalize, "fp8_block_scale_moe_runner does not support do_finalize=False" - # fp8_block_scale_moe_runner needs 2D shape for x_sf and only support SM100+ + # fp8_quantize_1x128 returns 2D x_sf on SM100+, 1D on SM90 if x_sf is None: x, x_sf = torch.ops.trtllm.fp8_quantize_1x128(x) result = self.op_backend.run_fp8_block_scale_moe( @@ -780,11 +808,7 @@ def forward_impl( ) -> torch.Tensor: assert x.dtype == torch.bfloat16 - # Get top_k for routing (other routing parameters are extracted inside run_moe) - if isinstance(self.routing_method, DeepSeekV3MoeRoutingMethod): - top_k = self.routing_method.routing_impl.top_k - else: - top_k = self.routing_method.top_k + top_k = self._extract_routing_params().top_k run_post_quant_allgather = (self.use_dp and self.parallel_size > 1 and not self.enable_alltoall) @@ -1069,8 +1093,11 @@ def forward_fake( else: is_deepseek_v3_routing = isinstance(self.routing_method, DeepSeekV3MoeRoutingMethod) + is_minimax_routing = isinstance(self.routing_method, + MiniMaxM2MoeRoutingMethod) top_k = self.routing_method.routing_impl.top_k if is_deepseek_v3_routing else self.routing_method.top_k - routing_bias = self.routing_method.e_score_correction_bias if is_deepseek_v3_routing else None + routing_bias = self.routing_method.e_score_correction_bias if ( + is_deepseek_v3_routing or is_minimax_routing) else None return fp4_block_scale_fake_output_without_finalize( x, self.num_experts, diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index c30b7d771aa9..45ee2026901d 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -2523,6 +2523,11 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module, super().load_expert_w3_w1_weight(module, w1_weight, w3_weight, dst_w3_w1_weight) + # Only interleave for gated activations (SwiGLU) where the fused + # gather+GEMM+SwiGLU kernel expects interleaved gate/up weights. + if not module.is_gated_activation: + return + # Interleave FC1 weight for GEMM1 + SwiGLU fusion. w3_w1_weight = dst_w3_w1_weight.cuda().view(float4_e2m1x2) w3_w1_weight_interleaved = interleave_linear_and_gate(w3_w1_weight, @@ -2540,6 +2545,12 @@ def load_expert_w3_w1_weight_scale_nvfp4( w3_weight_scale, dst_w3_w1_weight_scale) + # Only interleave for gated activations (SwiGLU). + # For non-gated, the parent's block_scale_interleave format is already + # the swizzled layout expected by the CuTe DSL grouped GEMM kernels. + if not module.is_gated_activation: + return + # Interleave FC1 scales for GEMM1 + SwiGLU fusion. n = module.intermediate_size_per_partition * 2 k = module.hidden_size @@ -2867,8 +2878,19 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): local_shared_fc2_alpha_tensors, ignore_weight_scale=True) - local_shared_fc31_scale_c = module.fc2_input_scale.data.cpu( - ) * local_shared_fc31_alpha_tensors + # The shared host copy of fc31_scale_c is consumed by online EPLB + # when an expert is migrated into a local slot, so it must match + # the main-slot formula exactly (see load_quant_scales above). + # For Relu2/Silu: fc31_scale_c = fc2_input_scale (broadcast). + # For gated (SwiGlu): fc31_scale_c = fc2_input_scale * fc31_alpha. + if hasattr(module, 'activation_type') and module.activation_type in [ + ActivationType.Relu2, ActivationType.Silu + ]: + local_shared_fc31_scale_c = module.fc2_input_scale.data.cpu( + ).expand(len(local_shared_load_expert_ids)).contiguous() + else: + local_shared_fc31_scale_c = module.fc2_input_scale.data.cpu( + ) * local_shared_fc31_alpha_tensors module.register_all_parameter_slot_and_to_fix_weight_fns({ 'fc31_scale_c': diff --git a/tensorrt_llm/_torch/modules/fused_moe/routing.py b/tensorrt_llm/_torch/modules/fused_moe/routing.py index db3de6ea7ad6..f8a93b868ace 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/routing.py +++ b/tensorrt_llm/_torch/modules/fused_moe/routing.py @@ -157,8 +157,10 @@ class RoutingMethodType(IntEnum): RenormalizeNaive = 4, # MiniMaxM2: Sigmoid -> RoutingBiasAdd -> TopK -> Renormalize(without bias) MiniMax2 = 5, + # SigmoidRenorm: Sigmoid -> TopK -> Renormalize + SigmoidRenorm = 6, # Unspecified - Unspecified = 6, + Unspecified = 7, class BaseMoeRoutingMethod(nn.Module): @@ -257,16 +259,15 @@ def noaux_tc(self, logits, e_score_correction_bias): _, num_experts = logits.shape if self.n_group > 1: - if self.top_k > 8 or (num_experts / n_group) > 32 or ( - num_experts / n_group) * self.topk_group > 128: + experts_per_group = num_experts // n_group + if (self.top_k > 8 or num_experts > 256 or experts_per_group > 32 + or experts_per_group * self.topk_group > 256): if self.is_fused: warnings.warn( "The configuration is not supported by the fused routing kernel. We have to use the original pytorch implementation." ) self.is_fused = False - elif num_experts > 512 or (self.top_k > 8 and self.top_k != 22): - # The fused noaux_tc_op kernel supports n_group==1 with top_k<=8 - # or top_k==22, and num_experts<=512. + elif num_experts > 1024 or self.top_k > 32: if self.is_fused: warnings.warn( "The configuration is not supported by the fused routing kernel. We have to use the original pytorch implementation." @@ -437,6 +438,38 @@ def routing_method_type(self): return RoutingMethodType.MiniMax2 +class SigmoidRenormMoeRoutingMethod(BaseMoeRoutingMethod): + + def __init__( + self, + top_k: int, + num_experts: int, + renormalize: bool = True, + output_dtype: torch.dtype = torch.float32, + ): + super().__init__() + self.top_k = top_k + self.num_experts = num_experts + self.renormalize = renormalize + self.output_dtype = output_dtype + + def apply(self, + router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + scores = torch.sigmoid(router_logits) + topk_weights, topk_idx = torch.topk(scores, + k=self.top_k, + dim=-1, + sorted=False) + if self.renormalize: + topk_weights = topk_weights / ( + topk_weights.sum(dim=-1, keepdim=True) + 1e-20) + return topk_idx.to(torch.int32), topk_weights.to(self.output_dtype) + + @property + def routing_method_type(self): + return RoutingMethodType.SigmoidRenorm + + class RenormalizeMoeRoutingMethod(BaseMoeRoutingMethod): def __init__( @@ -647,6 +680,8 @@ def routing_method_type(self) -> RoutingMethodType: BaseMoeRoutingMethod, RoutingMethodType.MiniMax2: MiniMaxM2MoeRoutingMethod, + RoutingMethodType.SigmoidRenorm: + SigmoidRenormMoeRoutingMethod, } diff --git a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py index 23557ee2011c..13924e96bb5a 100644 --- a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py +++ b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py @@ -153,13 +153,10 @@ def __init__( # Choose between flashinfer and native implementation. (default to flashinfer) self._mamba_ssm_cache_dtype = config.quant_config.mamba_ssm_cache_dtype - # TODO: Update head_dims and head_group_ratios once flashinfer is updated. + # TODO: Update head_dims once flashinfer is updated. + # Nemotron-v2-Nano (mamba_head_dim=80) is not supported by flashinfer yet. supported_head_dims = [64, 128] - supported_head_group_ratios = [1, 8, 16] - head_group_ratio = (self.tp_nheads // - self.tp_ngroups if self.tp_ngroups > 0 else 0) - self._use_flashinfer = (head_dim in supported_head_dims and - head_group_ratio in supported_head_group_ratios) + self._use_flashinfer = head_dim in supported_head_dims # Stochastic rounding requires FlashInfer and fp16 cache self._use_stochastic_rounding = ( config.quant_config.mamba_ssm_stochastic_rounding diff --git a/tensorrt_llm/_torch/modules/mamba/ssd_combined.py b/tensorrt_llm/_torch/modules/mamba/ssd_combined.py index a5916657c31f..63ea5b91ee2c 100644 --- a/tensorrt_llm/_torch/modules/mamba/ssd_combined.py +++ b/tensorrt_llm/_torch/modules/mamba/ssd_combined.py @@ -17,14 +17,138 @@ # limitations under the License. import torch +import torch.nn.functional as F from einops import rearrange +from tensorrt_llm._utils import is_sm_100f +from tensorrt_llm.logger import logger +from tensorrt_llm.math_utils import pad_up + from .ssd_bmm import _bmm_chunk_fwd from .ssd_chunk_scan import _chunk_scan_fwd from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd, chunk_state_varlen) from .ssd_state_passing import _state_passing_fwd +# FlashInfer fused SSD kernel cache (Blackwell SM100+ only). +_flashinfer_ssd_cache: dict = {} + + +def _get_flashinfer_ssd(chunk_size, nheads, headdim, dstate, ngroups): + """Get or compile a cached FlashInfer SSDCombined kernel instance.""" + key = (chunk_size, nheads, headdim, dstate, ngroups) + if key not in _flashinfer_ssd_cache: + from flashinfer.mamba import SSDCombined + _flashinfer_ssd_cache[key] = SSDCombined( + chunk_size=chunk_size, + nheads=nheads, + headdim=headdim, + dstate=dstate, + ngroups=ngroups, + io_dtype=torch.bfloat16, + state_dtype=torch.bfloat16, + has_d=True, + d_has_hdim=False, + has_initial_states=True, + has_varlen=True, + has_z=False, + seq_idx_dtype=torch.int32, + ) + logger.info_once("Using FlashInfer fused SSD kernel for Mamba2 prefill", + key="flashinfer_ssd_prefill") + return _flashinfer_ssd_cache[key] + + +def _mamba_chunk_scan_flashinfer_fwd( + x, + dt, + A, + B, + C, + chunk_size, + D=None, + dt_bias=None, + initial_states=None, + seq_idx=None, + chunk_indices=None, + chunk_offsets=None, + cu_seqlens=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + out=None, + return_final_states=False, + state_dtype=None, +): + """FlashInfer fused SSD forward using a single CUTLASS persistent kernel.""" + _, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + io_dtype = torch.bfloat16 + + ssd = _get_flashinfer_ssd(chunk_size, nheads, headdim, dstate, ngroups) + num_seqs = cu_seqlens.shape[0] - 1 + + # Pad seqlen to chunk_size boundary — padded tokens use dt=-100 + # so softplus ≈ 0, contributing nothing to state or output. + pad_len = pad_up(seqlen, chunk_size) - seqlen + if pad_len > 0: + x = F.pad(x, (0, 0, 0, 0, 0, pad_len)) + B = F.pad(B, (0, 0, 0, 0, 0, pad_len)) + C = F.pad(C, (0, 0, 0, 0, 0, pad_len)) + dt = F.pad(dt, (0, 0, 0, pad_len), value=-100.0) + if seq_idx is not None: + seq_idx = F.pad(seq_idx, (0, pad_len), value=int(num_seqs - 1)) + + if x.dtype != io_dtype: + x = x.to(io_dtype) + B = B.to(io_dtype) + C = C.to(io_dtype) + dt = dt.to(io_dtype) + + D_bf16 = D.to(io_dtype) if D is not None and D.dtype != io_dtype else D + + if initial_states is not None: + fi_initial_states = (initial_states if initial_states.dtype == io_dtype + else initial_states.to(io_dtype)) + else: + fi_initial_states = x.new_zeros(num_seqs, + nheads, + headdim, + dstate, + dtype=io_dtype) + + if chunk_indices is None or chunk_offsets is None: + from .mamba2_metadata import cu_seqlens_to_chunk_indices_offsets_triton + chunk_indices, chunk_offsets = ( + cu_seqlens_to_chunk_indices_offsets_triton(cu_seqlens, + chunk_size, + total_seqlens=seqlen)) + + out_view, fstate = ssd.run( + x, + dt, + A, + B, + C, + D=D_bf16, + dt_bias=dt_bias, + dt_softplus=dt_softplus, + dt_limit=dt_limit, + initial_states=fi_initial_states, + seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + return_final_states=True, + ) + + if out is not None: + out.copy_(out_view[:, :seqlen]) + + if state_dtype is not None and fstate.dtype != state_dtype: + fstate = fstate.to(state_dtype) + + # Both final_states and varlen_states are per-sequence in FlashInfer. + return (fstate, fstate) if return_final_states else fstate + def is_int_pow_2(n): return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0 @@ -235,6 +359,30 @@ def mamba_chunk_scan_combined( else: assert (cu_seqlens is not None ), "cu_seqlens must be provided if return_varlen_states is True" + + # Dispatch to FlashInfer fused CUTLASS kernel on Blackwell (SM100+). + if (return_varlen_states and z is None and is_sm_100f()): + return _mamba_chunk_scan_flashinfer_fwd( + x, + dt, + A, + B, + C, + chunk_size, + D=D, + dt_bias=dt_bias, + initial_states=initial_states, + seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + cu_seqlens=cu_seqlens, + dt_softplus=dt_softplus, + dt_limit=dt_limit, + out=out, + return_final_states=return_final_states, + state_dtype=state_dtype, + ) + out_x, dt_out, dA_cumsum, states, final_states, *rest = ( _mamba_chunk_scan_combined_fwd( x, diff --git a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py index 375a3d350dbe..4b22a466c641 100644 --- a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py @@ -138,12 +138,10 @@ def free_resources(self, request: LlmRequest): self.mamba_impl.free_cache_block(request.py_request_id) def add_dummy_requests(self, request_ids: List[int], **kwargs): - # For CUDA graph dummy requests, the blocks will be allocated - # when get_state_indices is called. - from .cuda_graph_runner import CUDA_GRAPH_DUMMY_REQUEST_ID - request_ids = [ - rid for rid in request_ids if rid != CUDA_GRAPH_DUMMY_REQUEST_ID - ] + # Allocate a permanent slot for every id, including CUDA-graph + # padding sentinels (matches PythonMambaCacheManager). Padding + # entries in get_state_indices then resolve via mCacheIndex to + # the sentinel's reserved slot and never alias a live request. if request_ids: self.mamba_impl.allocate_cache_blocks(request_ids) @@ -370,10 +368,16 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): self._prepare_mamba_cache_blocks(request_ids) def add_dummy_requests(self, request_ids: List[int], **kwargs): - from .cuda_graph_runner import CUDA_GRAPH_DUMMY_REQUEST_ID - request_ids = [ - rid for rid in request_ids if rid != CUDA_GRAPH_DUMMY_REQUEST_ID - ] + # Allocate a permanent slot for every dummy request ID, including + # the CUDA-graph padding sentinel. Padding entries in a batch all + # reference the same dummy request ID, so they share one slot via + # mamba_cache_index lookup in get_state_indices. This mirrors how + # MTP's per-draft-len padding dummies already behave (they use + # CUDA_GRAPH_DUMMY_REQUEST_ID - draft_len, which was never + # filtered here) and keeps padding writes off every live + # request's slot, even under the overlap scheduler where a prior + # batch's completed requests linger in mamba_cache_index until + # _process_previous_batch runs. if request_ids: for r in request_ids: if r not in self.mamba_cache_index: @@ -390,29 +394,10 @@ def free_resources(self, request: LlmRequest): def get_state_indices(self, request_ids: List[int], is_padding: List[bool]) -> List[int]: - assert len(request_ids) == len(is_padding), ( - "request_ids and is_padding must have the same size") - - used_slots = { - self.mamba_cache_index[req_id] - for req_id, pad in zip(request_ids, is_padding) if not pad - } - available_slots = iter( - sorted(set(range(self.state_indices.numel())) - used_slots)) - - def slot_for(req_id: int, pad: bool): - if pad: - try: - return next(available_slots) - except StopIteration: - raise RuntimeError( - "Run out of available slots for padding") from None - return self.mamba_cache_index[req_id] - - result = [ - slot_for(rid, pad) for rid, pad in zip(request_ids, is_padding) - ] - return result + # Padding entries reuse the slot pre-allocated by their dummy + # request in add_dummy_requests; see that method for the + # overlap-scheduler rationale. + return [self.mamba_cache_index[rid] for rid in request_ids] def get_conv_states(self, layer_idx: int) -> torch.Tensor: layer_offset = self.mamba_layer_offsets[layer_idx] @@ -472,14 +457,14 @@ def shutdown(self): @torch.compile(options={"max-autotune": True}) def update_mamba_states(self, attn_metadata: "AttentionMetadata", - num_accepted_tokens: torch.Tensor): + num_accepted_tokens: torch.Tensor, + state_indices: torch.Tensor): batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts num_gens = batch_size - num_contexts num_accepted_draft_tokens = num_accepted_tokens[ num_contexts:num_contexts + num_gens] - 1 - state_indices_d = self.state_indices[num_contexts:num_contexts + - num_gens] + state_indices_d = state_indices[num_contexts:num_contexts + num_gens] conv_states = self.mamba_cache.conv ssm_states = self.mamba_cache.temporal @@ -621,10 +606,17 @@ def mamba_layer_cache( def shutdown(self): self._impl.shutdown() - def update_mamba_states(self, attn_metadata: "AttentionMetadata", - num_accepted_tokens: torch.Tensor): + def update_mamba_states(self, + attn_metadata: "AttentionMetadata", + num_accepted_tokens: torch.Tensor, + state_indices: Optional[torch.Tensor] = None): assert not self._use_cpp, "update_mamba_states is not supported in CppMambaCacheManager" - self._impl.update_mamba_states(attn_metadata, num_accepted_tokens) + # Resolve the forward-path fallback outside the @torch.compile body + # so Dynamo only specializes on a concrete Tensor. + if state_indices is None: + state_indices = self._impl.state_indices + self._impl.update_mamba_states(attn_metadata, num_accepted_tokens, + state_indices) class MambaHybridCacheManager(KVCacheManager, MambaCacheManager): @@ -665,7 +657,13 @@ def __init__( # mamba hybrid cache requires block reuse to be disabled in KV cache config assert not kv_cache_config.enable_block_reuse, "mamba hybrid cache requires block reuse to be disabled in KV cache config" - # initialize mamba cache manager + # Reserve one Mamba slot per possible CUDA-graph padding dummy + # (one per runtime_draft_len in 0..max_draft_len) so a full + # max_batch_size of real requests still leaves room for padding. + max_draft_len = (spec_config.max_draft_len + if spec_config is not None else 0) + pool_size = max_batch_size + max_draft_len + 1 + MambaCacheManager.__init__( self, mamba_d_state, @@ -674,7 +672,7 @@ def __init__( mamba_n_groups, mamba_head_dim, mamba_num_layers, - max_batch_size, + pool_size, max_batch_size, mapping, mamba_cache_dtype, @@ -728,7 +726,10 @@ def update_resources(self, KVCacheManager.update_resources(self, scheduled_batch, attn_metadata, kv_cache_dtype_byte_size) - def update_mamba_states(self, attn_metadata: "AttentionMetadata", - num_accepted_tokens: torch.Tensor): + def update_mamba_states(self, + attn_metadata: "AttentionMetadata", + num_accepted_tokens: torch.Tensor, + state_indices: Optional[torch.Tensor] = None): MambaCacheManager.update_mamba_states(self, attn_metadata, - num_accepted_tokens) + num_accepted_tokens, + state_indices) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 4d6e51551d09..f5d87562940d 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -846,6 +846,16 @@ def drafting_loop_wrapper(model): py_executor.kv_cache_transceiver.shutdown() finally: kv_cache_creator.teardown_managers(resources) + + # Release Phase-1 CUDA graph pools before final KV allocation to avoid overshoot. + for eng in [model_engine, draft_model_engine]: + if eng is None: + continue + if eng.attn_metadata is not None: + if llm_args.cuda_graph_config is not None: + eng._release_cuda_graphs() + eng.attn_metadata = None + del py_executor # free before constructing new gc.collect() @@ -860,13 +870,6 @@ def drafting_loop_wrapper(model): max_seq_len = kv_cache_creator._max_seq_len update_sampler_max_seq_len(max_seq_len, sampler) - for eng in [model_engine, draft_model_engine]: - if eng is None: - continue - if eng.attn_metadata is not None: - if llm_args.cuda_graph_config is not None: - eng._release_cuda_graphs() - eng.attn_metadata = None with allocation_scope(ExecutorMemoryType.EXTRA_RESOURCES): # run gc.collect() to free memory of the previous py_executor, avoid cudaFree overlap with cuda graph capture diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 8baf2da76615..2cf0b1fac4ba 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -1167,9 +1167,15 @@ def forward( self._is_mamba_hybrid_cache = isinstance( attn_metadata.kv_cache_manager, MambaHybridCacheManager) if num_gens > 0 and self._is_mamba_hybrid_cache: + # Use the forward-path state_indices so the scatter lines up + # with the mixer's reads (covers the full padded batch). + mamba_metadata = getattr(attn_metadata, 'mamba_metadata', None) + mamba_state_indices = (mamba_metadata.state_indices + if mamba_metadata is not None else None) attn_metadata.kv_cache_manager.update_mamba_states( attn_metadata=attn_metadata, - num_accepted_tokens=num_accepted_tokens) + num_accepted_tokens=num_accepted_tokens, + state_indices=mamba_state_indices) # Save the old attn_metadata and spec_metadata self._prepare_attn_metadata_for_spec_dec(attn_metadata) diff --git a/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/cute_dsl_utils.py b/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/cute_dsl_utils.py index 9d6ee345d002..edb8f692a03d 100644 --- a/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/cute_dsl_utils.py +++ b/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/cute_dsl_utils.py @@ -108,20 +108,35 @@ def load_cubin_module_data_patched(cubin_data, filepath): return load_cubin_module_data_og(cubin_data) -def cute_compile_patched(*args, **kwargs): - """A patched version of cute.compile that dump the SASS to a file if CUTE_CUBIN_PATH is set.""" - cubin_path = os.getenv("CUTE_CUBIN_PATH", None) - if cubin_path is not None: - cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial( - load_cubin_module_data_patched, filepath=cubin_path - ) - output = cute_compile_og(*args, **kwargs) - if cubin_path is not None: - cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og - if extract is not None: - sass = extract(cubin_path, None) - pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass) - return output +class _CuteCompilePatched: + """Wrapper around cute.compile that optionally dumps SASS via CUTE_CUBIN_PATH. + + Preserves the CompileCallable subscript interface (cute.compile[opts](...)) + so that third-party CuTe DSL kernels (e.g. FlashInfer) keep working. + """ + + def __init__(self, original=None): + self._original = original or cute_compile_og + + def __getitem__(self, item): + return _CuteCompilePatched(self._original[item]) + + def __call__(self, *args, **kwargs): + cubin_path = os.getenv("CUTE_CUBIN_PATH", None) + if cubin_path is not None: + cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial( + load_cubin_module_data_patched, filepath=cubin_path + ) + output = self._original(*args, **kwargs) + if cubin_path is not None: + cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og + if extract is not None: + sass = extract(cubin_path, None) + pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass) + return output + + +cute_compile_patched = _CuteCompilePatched() def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, enable_tvm_ffi=True): diff --git a/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/flash_bwd_sm100.py b/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/flash_bwd_sm100.py index 18a559d7dba9..b1b82458ce09 100644 --- a/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/flash_bwd_sm100.py +++ b/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/flash_bwd_sm100.py @@ -2295,7 +2295,7 @@ def compute_loop( if const_expr(not self.use_smem_dS_for_mma_dK): cute.arch.fence_view_async_tmem_store() cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + "async.shared", space="cta" ) self.compute_sync_barrier.arrive_and_wait() @@ -2546,7 +2546,7 @@ def dQacc_reduce( cute.copy(thr_copy_dQaccum_r2s, tdQrdQ_r2s, tdQsdQ_r2s) # Fence and barrier to make sure shared memory store is visible to TMA store cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + "async.shared", space="cta" ) # semaphore acquire if const_expr(self.deterministic and stage == 0): @@ -2904,7 +2904,7 @@ def epilogue_dK_or_dV_tma( tdKVrdKV_r2s = cute.make_tensor(tdKVrdKV.iterator, tdKVsdKV_r2s.shape) cute.copy(thr_copy_r2s_dKV, tdKVrdKV_r2s, tdKVsdKV_r2s) cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + "async.shared", space="cta" ) cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128) @@ -2928,7 +2928,7 @@ def epilogue_dK_or_dV_tma( # Barrier since all warps need to wait for SMEM to be freed cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + "async.shared", space="cta" ) cute.arch.barrier( barrier_id=barrier_id + wg_idx, number_of_threads=128 + cute.arch.WARP_SIZE diff --git a/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/flash_bwd_sm90.py b/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/flash_bwd_sm90.py index 6d1ead4a2acb..b97d751ca639 100644 --- a/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/flash_bwd_sm90.py +++ b/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/flash_bwd_sm90.py @@ -8,7 +8,6 @@ import cutlass.cute as cute import cutlass.utils.hopper_helpers as sm90_utils_basic from cutlass.cute.nvgpu import cpasync, warpgroup -from cutlass.cute.arch import ProxyKind, SharedSpace from cutlass.cute import FastDivmodDivisor from cutlass import Float32, Int32, Boolean, const_expr from cutlass.utils import LayoutEnum @@ -1416,7 +1415,7 @@ def mma_one_m_block( # This sync is to ensure (1) P is written in case of !mma_dkv_is_rs and # (2) dS is already read by the Mma in the previous iteration in case of mma_dkv_is_rs. if const_expr(not self.mma_dkv_is_rs or (self.PdS_stage == 1 and self.mma_dkv_is_rs)): - cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.fence_proxy("async.shared", space="cta") cute.arch.barrier( barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads ) @@ -1434,7 +1433,7 @@ def mma_one_m_block( mma_pdo_fn(tCrA=tdVrP, B_idx=smem_idx_dO, zero_init=not dKV_accumulate, wg_wait=-1) # smem fence to make sure sdS is written before it's read by WGMMA - cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.fence_proxy("async.shared", space="cta") cute.arch.barrier( barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads ) @@ -1458,7 +1457,7 @@ def mma_one_m_block( ) tdQrdQaccum_flat = cute.make_tensor(acc_dQ.iterator, cute.make_layout(tdQsdQaccum.shape)) cute.autovec_copy(tdQrdQaccum_flat, tdQsdQaccum) - cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.fence_proxy("async.shared", space="cta") cute.arch.barrier_arrive( barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx, number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, @@ -1531,7 +1530,7 @@ def epilogue_dKV( sdV = sV if const_expr(not self.dKV_swapAB) else utils.transpose_view(sV) taccdVsdV = smem_thr_copy_dV.partition_D(sdV) cute.copy(smem_copy_atom_dKV, taccdVrdV, taccdVsdV) - cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.fence_proxy("async.shared", space="cta") cute.arch.barrier( barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads ) @@ -1541,7 +1540,7 @@ def epilogue_dKV( sdK = sK if const_expr(not self.dKV_swapAB) else utils.transpose_view(sK) taccdKsdK = smem_thr_copy_dK.partition_D(sdK) cute.copy(smem_copy_atom_dKV, taccdKrdK, taccdKsdK) - cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.fence_proxy("async.shared", space="cta") cute.arch.barrier( barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads ) @@ -1580,7 +1579,7 @@ def epilogue_dKV( acc_dK.iterator, cute.make_layout(tdKsdKVaccum.shape) ) cute.autovec_copy(tdKrdKaccum_flat, tdKsdKVaccum) - cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.fence_proxy("async.shared", space="cta") cute.arch.barrier( barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads ) @@ -1604,7 +1603,7 @@ def epilogue_dKV( acc_dV.iterator, cute.make_layout(tdKsdKVaccum.shape) ) cute.autovec_copy(tdVrdVaccum_flat, tdKsdKVaccum) - cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.fence_proxy("async.shared", space="cta") cute.arch.barrier( barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads ) diff --git a/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/flash_fwd.py b/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/flash_fwd.py index 336da9d8737b..9d6f23b27309 100644 --- a/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/flash_fwd.py +++ b/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/flash_fwd.py @@ -16,7 +16,6 @@ import cutlass.cute as cute from cutlass import Float32, Int32, const_expr from cutlass.cute.nvgpu import cpasync, warp, warpgroup -from cutlass.cute.arch import ProxyKind, SharedSpace import cutlass.utils as utils_basic from cutlass.utils import LayoutEnum import cutlass.utils.hopper_helpers as sm90_utils_basic @@ -402,7 +401,7 @@ def epilogue( # sync to make sure all smem stores are done if const_expr(self.use_tma_O): # ensure smem writes are visible to TMA - cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.fence_proxy("async.shared", space="cta") cute.arch.barrier_arrive( barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE, @@ -2259,7 +2258,7 @@ def first_half_block_overlap( cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) # Fence and barrier to make smem store visible to WGMMA cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + "async.shared", space="cta" ) cute.arch.sync_warp() @@ -2331,7 +2330,7 @@ def mma_one_n_block( softmax.rescale_O(acc_O, row_scale) if const_expr(not self.mma_pv_is_rs): # Fence and barrier to make sure smem store is visible to WGMMA - cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.fence_proxy("async.shared", space="cta") cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read)) self.warp_scheduler_barrier_sync() @@ -2398,7 +2397,7 @@ def mma_one_n_block_intrawg_overlap( softmax.rescale_O(acc_O, row_scale) if const_expr(not self.mma_pv_is_rs): # Fence and barrier to make sure smem store is visible to WGMMA - cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.fence_proxy("async.shared", space="cta") cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV return smem_pipe_read diff --git a/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/flash_fwd_sm100.py b/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/flash_fwd_sm100.py index acdc0be71a8d..d6cb154c2154 100644 --- a/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/flash_fwd_sm100.py +++ b/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/flash_fwd_sm100.py @@ -2472,8 +2472,8 @@ def correction_epilogue( cute.copy(tiled_smem_store, tOrO_frg_cvt, tOsO_r2s_i) # fence view async shared cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) if const_expr(self.use_correction_warps_for_epi): diff --git a/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/utils.py b/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/utils.py index 4688323c8300..86d298fb3bc0 100644 --- a/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/utils.py +++ b/tensorrt_llm/_torch/visual_gen/jit_kernels/flash_attention/cute/utils.py @@ -324,16 +324,16 @@ def logf(a: float | Float32, *, loc=None, ip=None) -> Float32: def fmax( a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None ) -> Float32: - return Float32( - nvvm.fmax( - T.f32(), - Float32(a).ir_value(loc=loc, ip=ip), - Float32(b).ir_value(loc=loc, ip=ip), - c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, - loc=loc, - ip=ip, - ) - ) + a_ir = Float32(a).ir_value(loc=loc, ip=ip) + b_ir = Float32(b).ir_value(loc=loc, ip=ip) + # CUTLASS DSL 4.4+ dropped the result-type positional arg and the ternary `c` fused + # operand from nvvm.FmaxOp. Decompose the optional ternary form into two binary calls + # to preserve existing callers' semantics (max(max(a, b), c)). + ab = nvvm.fmax(a_ir, b_ir, loc=loc, ip=ip) + if c is None: + return Float32(ab) + c_ir = Float32(c).ir_value(loc=loc, ip=ip) + return Float32(nvvm.fmax(ab, c_ir, loc=loc, ip=ip)) @cute.jit diff --git a/tensorrt_llm/llmapi/reasoning_parser.py b/tensorrt_llm/llmapi/reasoning_parser.py index 049c1fbe4583..32e0b8a4d4ba 100644 --- a/tensorrt_llm/llmapi/reasoning_parser.py +++ b/tensorrt_llm/llmapi/reasoning_parser.py @@ -182,6 +182,7 @@ def parse_delta(self, delta_text: str) -> ReasoningParserResult: "deepseek_v3": "deepseek-r1", "deepseek_v32": "deepseek-r1", "nemotron_h": "nano-v3", + "nemotron_h_puzzle": "nano-v3", } _QWEN3_MODEL_TYPES = frozenset({ diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 0a367f96cb91..dd96e91e50c7 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -6246,18 +6246,12 @@ def test_auto_dtype_4gpus(self, tp_size, ep_size, attention_dp, task.evaluate(llm, extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS) - def _run_nvfp4_4gpus_eplb(self, moe_backend, eplb_config): - if moe_backend == "TRTLLM": - pytest.skip( - "TRTLLM + EPLB is not supported yet, see https://nvbugs/5997893." - ) - + def _run_nvfp4_4gpus_eplb(self, moe_backend, eplb_config, model_path): kv_cache_config = KvCacheConfig( enable_block_reuse=False, mamba_ssm_cache_dtype="float16", free_gpu_memory_fraction=0.5, ) - model_path = f"{llm_models_root()}/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4" max_batch_size = 32 cuda_graph_config = CudaGraphConfig(max_batch_size=max_batch_size, enable_padding=True) @@ -6305,12 +6299,17 @@ def test_nvfp4_4gpus_static_eplb(self, moe_backend): num_slots=num_slots, initial_global_assignments=initial_global_assignments, layer_updates_per_iter=0) - self._run_nvfp4_4gpus_eplb(moe_backend, eplb_config) + self._run_nvfp4_4gpus_eplb(moe_backend, eplb_config, model_path) @pytest.mark.skip_less_device(4) @pytest.mark.skip_device_not_contain(["GB200"]) @parametrize_with_ids("moe_backend", ["TRTLLM", "CUTLASS"]) def test_nvfp4_4gpus_online_eplb(self, moe_backend): + if moe_backend == "TRTLLM": + pytest.skip( + "TRTLLM + online EPLB is not supported yet, see https://nvbugs/5997893." + ) + model_path = f"{llm_models_root()}/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4" num_experts = 512 # 512 experts per token for Nemotron V3 Super. # num_slots should be larger than or equal to num_experts and should be divisible by parallel_size. # Assign extra 16 expert slots per rank. @@ -6318,7 +6317,7 @@ def test_nvfp4_4gpus_online_eplb(self, moe_backend): num_slots = num_experts + extra_num_slots eplb_config = MoeLoadBalancerConfig(num_slots=num_slots, layer_updates_per_iter=2) - self._run_nvfp4_4gpus_eplb(moe_backend, eplb_config) + self._run_nvfp4_4gpus_eplb(moe_backend, eplb_config, model_path) @skip_pre_hopper @pytest.mark.skip_less_mpi_world_size(4) @@ -6374,8 +6373,8 @@ def test_fp8_4gpus(self, attention_dp, use_cpp_mamba, monkeypatch): @skip_pre_blackwell @pytest.mark.skip_less_mpi_world_size(8) - @pytest.mark.parametrize("moe_backend", ["TRTLLM", "CUTLASS"], - ids=["trtllm", "cutlass"]) + @pytest.mark.parametrize("moe_backend", ["TRTLLM", "CUTLASS", "CUTEDSL"], + ids=["trtllm", "cutlass", "cutedsl"]) @pytest.mark.parametrize( "attention_dp", [ diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index d5d9602a7815..c0352a3f7e0d 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -144,6 +144,7 @@ l0_dgx_b200: - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus_mtp_custom_op TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus[attention_dp_on-trtllm] TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus[attention_dp_on-cutlass] TIMEOUT (60) + - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus[attention_dp_on-cutedsl] TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_parallelism[TP4_PP2] TIMEOUT (60) - accuracy/test_disaggregated_serving.py::TestNemotron3Super120B::test_auto_dtype[use_py_transceiver=True] TIMEOUT (60) - accuracy/test_disaggregated_serving.py::TestNemotron3Super120B::test_auto_dtype[use_py_transceiver=False] TIMEOUT (60) diff --git a/tests/scripts/cute_dsl_kernels/moe_as_dense_gemm/run_moe_as_dense_gemm_fc1.py b/tests/scripts/cute_dsl_kernels/moe_as_dense_gemm/run_moe_as_dense_gemm_fc1.py index be578889672c..9ed3c2918270 100644 --- a/tests/scripts/cute_dsl_kernels/moe_as_dense_gemm/run_moe_as_dense_gemm_fc1.py +++ b/tests/scripts/cute_dsl_kernels/moe_as_dense_gemm/run_moe_as_dense_gemm_fc1.py @@ -736,7 +736,7 @@ def create_alpha_scale_post_swiglu_tensor(l, m, expert_count, weight_per_expert) torch.testing.assert_close(c_ref, ref, atol=tolerance, rtol=1e-02) elif c_dtype is cutlass.Float4E2M1FN: # FP4 quantization with SFC (Scale Factor C) verification - # Reference: run_blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py + # Reference: run_blockscaled_contiguous_gather_grouped_gemm_act_fusion.py # ============================================================ # Step 1: Compute reference scale factor (SFC) from SwiGLU output diff --git a/tests/scripts/cute_dsl_kernels/run_blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py b/tests/scripts/cute_dsl_kernels/run_blockscaled_contiguous_gather_grouped_gemm_act_fusion.py similarity index 99% rename from tests/scripts/cute_dsl_kernels/run_blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py rename to tests/scripts/cute_dsl_kernels/run_blockscaled_contiguous_gather_grouped_gemm_act_fusion.py index 3d7c46e5d0aa..87d2d5b33ee3 100644 --- a/tests/scripts/cute_dsl_kernels/run_blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py +++ b/tests/scripts/cute_dsl_kernels/run_blockscaled_contiguous_gather_grouped_gemm_act_fusion.py @@ -29,19 +29,19 @@ """Example usage of the kernel. Functional testing: -python run_blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py \ +python run_blockscaled_contiguous_gather_grouped_gemm_act_fusion.py \ --ab_dtype Float4E2M1FN --c_dtype Float4E2M1FN \ --sf_dtype Float8E4M3FN --sf_vec_size 16 \ --mma_tiler_mn 128,128 --cluster_shape_mn 1,1 \ --nkl 4096,7168,8 --fixed_m 128 or use a benchmark file: -python run_blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py \ +python run_blockscaled_contiguous_gather_grouped_gemm_act_fusion.py \ --ab_dtype Float4E2M1FN --c_dtype Float4E2M1FN \ --sf_dtype Float8E4M3FN --sf_vec_size 16 \ --mma_tiler_mn 128,128 --cluster_shape_mn 1,1 \ --benchmark benchmark.txt Perf testing: -python run_blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py \ +python run_blockscaled_contiguous_gather_grouped_gemm_act_fusion.py \ --ab_dtype Float4E2M1FN --c_dtype Float4E2M1FN \ --sf_dtype Float8E4M3FN --sf_vec_size 16 \ --mma_tiler_mn 128,128 --cluster_shape_mn 1,1 \ @@ -72,11 +72,11 @@ try: from tensorrt_llm._torch.cute_dsl_kernels.blackwell import ( - blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion as kernel_module, + blockscaled_contiguous_gather_grouped_gemm_act_fusion as kernel_module, ) except (ModuleNotFoundError, ImportError): sys.path.insert(0, str(Path(__file__).parents[3] / "tensorrt_llm/_torch/cute_dsl_kernels")) - from blackwell import blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion as kernel_module + from blackwell import blockscaled_contiguous_gather_grouped_gemm_act_fusion as kernel_module BlockScaledContiguousGatherGroupedGemmKernel = ( kernel_module.BlockScaledContiguousGatherGroupedGemmKernel diff --git a/tests/unittest/_torch/executor/test_mamba_cache_manager.py b/tests/unittest/_torch/executor/test_mamba_cache_manager.py new file mode 100644 index 000000000000..d8ff936bb693 --- /dev/null +++ b/tests/unittest/_torch/executor/test_mamba_cache_manager.py @@ -0,0 +1,163 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Regression tests for MambaCacheManager padding-slot behavior.""" + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +import torch + +from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDA_GRAPH_DUMMY_REQUEST_ID +from tensorrt_llm._torch.pyexecutor.mamba_cache_manager import ( + CppMambaCacheManager, + PythonMambaCacheManager, +) +from tensorrt_llm.mapping import Mapping + +skip_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") + + +def _make_mgr(max_batch_size=4, max_draft_len=2): + # Pool size mirrors MambaHybridCacheManager's +max_draft_len+1 headroom. + pool = max_batch_size + max_draft_len + 1 + return PythonMambaCacheManager( + d_state=8, + d_conv=4, + num_heads=4, + n_groups=1, + head_dim=8, + num_layers=2, + max_batch_size=pool, + spec_state_size=max_batch_size, + mapping=Mapping(world_size=1, tp_size=1, pp_size=1), + dtype=torch.float16, + ssm_cache_dtype=torch.float16, + speculative_num_draft_tokens=max_draft_len, + ) + + +@skip_no_cuda +def test_padding_slot_not_held_by_parked_real(): + """get_state_indices must not hand the padding position a slot + owned by a live request outside the current batch. Padding entries + all reuse the pre-allocated slot of their dummy request (added via + add_dummy_requests), which is distinct from every real request's + slot.""" + mgr = _make_mgr(max_batch_size=4, max_draft_len=2) + # Four real requests claim slots; pool has max_batch_size+max_draft_len+1 = 7 slots. + mgr._prepare_mamba_cache_blocks([100, 101, 102, 103]) + # Pre-allocate the padding dummy's slot (what _get_padded_batch does + # via kv_cache_manager.add_dummy_requests before get_state_indices). + mgr.add_dummy_requests([CUDA_GRAPH_DUMMY_REQUEST_ID]) + # Current batch has only two reals; 102 and 103 are "parked". + request_ids = [100, 101, CUDA_GRAPH_DUMMY_REQUEST_ID] + indices = mgr.get_state_indices(request_ids, [False, False, True]) + real_slots = {mgr.mamba_cache_index[r] for r in [100, 101, 102, 103]} + assert indices[2] not in real_slots, ( + f"padding slot {indices[2]} overlaps a real request's slot (real slots: {real_slots})" + ) + # Padding should reuse the dummy's reserved slot, not allocate a new one. + assert indices[2] == mgr.mamba_cache_index[CUDA_GRAPH_DUMMY_REQUEST_ID] + + +@skip_no_cuda +def test_padding_survives_overlap_scheduler_pressure(): + """Regression for the overlap-scheduler + attention-dp + CUDA-graph + padding combo: prior-iter completions linger in mamba_cache_index + until _process_previous_batch runs, so get_state_indices must not + require N unused pool slots to serve N padding entries.""" + mgr = _make_mgr(max_batch_size=4, max_draft_len=0) + # Fill the pool with "live" real requests (simulates completed + # requests from prior iter that haven't been freed yet). + mgr._prepare_mamba_cache_blocks([100, 101, 102, 103]) + # Pre-allocate the padding dummy's slot. + mgr.add_dummy_requests([CUDA_GRAPH_DUMMY_REQUEST_ID]) + # Current batch: 1 real request + 3 padding entries (attention-dp + # pushed padded_batch_size to 4 on this rank even though only 1 real + # gen request is scheduled here). + request_ids = [100] + [CUDA_GRAPH_DUMMY_REQUEST_ID] * 3 + is_padding = [False] + [True] * 3 + indices = mgr.get_state_indices(request_ids, is_padding) + # All padding entries share the dummy's slot. + dummy_slot = mgr.mamba_cache_index[CUDA_GRAPH_DUMMY_REQUEST_ID] + assert indices[0] == mgr.mamba_cache_index[100] + assert indices[1:] == [dummy_slot] * 3 + + +@skip_no_cuda +def test_update_mamba_states_uses_passed_state_indices(): + """update_mamba_states must scatter using the caller-supplied + state_indices tensor (e.g. mamba_metadata.state_indices).""" + mgr = _make_mgr() + mgr._prepare_mamba_cache_blocks([100, 101, 102]) + + ssm, conv = mgr.mamba_cache.temporal, mgr.mamba_cache.conv + ssm.zero_() + conv.zero_() + mgr.mamba_cache.intermediate_ssm.fill_(7.0) + mgr.mamba_cache.intermediate_conv_window.fill_(7.0) + + # Caller passes slots [slot_R1, slot_R2, slot_R3, 0] for a padded + # batch. Slot 0 belongs to the padding dummy. + state_indices = torch.tensor( + [mgr.mamba_cache_index[100], mgr.mamba_cache_index[101], mgr.mamba_cache_index[102], 0], + dtype=torch.int32, + device="cuda", + ) + attn = SimpleNamespace(num_seqs=4, num_contexts=0) + mgr.update_mamba_states( + attn, + torch.tensor([1, 1, 1, 1], dtype=torch.int32, device="cuda"), + state_indices=state_indices, + ) + + for rid in [100, 101, 102]: + slot = mgr.mamba_cache_index[rid] + assert torch.all(ssm[:, slot] == 7.0) + assert torch.all(conv[:, slot] == 7.0) + + +@skip_no_cuda +def test_update_mamba_states_uses_self_state_indices_when_passed(): + """Regression for the pre-patch behavior where update_mamba_states + implicitly read self.state_indices: passing it explicitly as the + caller-supplied tensor must produce the same writes.""" + mgr = _make_mgr() + mgr._prepare_mamba_cache_blocks([100, 101, 102]) + + ssm, conv = mgr.mamba_cache.temporal, mgr.mamba_cache.conv + ssm.zero_() + conv.zero_() + mgr.mamba_cache.intermediate_ssm.fill_(5.0) + mgr.mamba_cache.intermediate_conv_window.fill_(5.0) + + attn = SimpleNamespace(num_seqs=3, num_contexts=0) + mgr.update_mamba_states( + attn, + torch.tensor([1, 1, 1], dtype=torch.int32, device="cuda"), + state_indices=mgr.state_indices, + ) + + for rid in [100, 101, 102]: + slot = mgr.mamba_cache_index[rid] + assert torch.all(ssm[:, slot] == 5.0) + assert torch.all(conv[:, slot] == 5.0) + + +def test_cpp_add_dummy_requests_allocates_sentinel(): + """CppMambaCacheManager.add_dummy_requests must allocate a permanent + slot for every id — including the raw CUDA-graph sentinel — so + get_state_indices can return it without racing parked live requests + for a free slot.""" + stub = SimpleNamespace(mamba_impl=MagicMock()) + CppMambaCacheManager.add_dummy_requests(stub, [100, CUDA_GRAPH_DUMMY_REQUEST_ID, 101]) + stub.mamba_impl.allocate_cache_blocks.assert_called_once_with( + [100, CUDA_GRAPH_DUMMY_REQUEST_ID, 101] + ) + + +def test_cpp_add_dummy_requests_noop_on_empty_list(): + stub = SimpleNamespace(mamba_impl=MagicMock()) + CppMambaCacheManager.add_dummy_requests(stub, []) + stub.mamba_impl.allocate_cache_blocks.assert_not_called() diff --git a/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py b/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py index 81b9d4834812..e753ee63679f 100644 --- a/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py +++ b/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py @@ -329,7 +329,8 @@ def test_nemotron_h_cuda_graph_overlap_scheduler(): "The chemical symbol for water is", ] - sampling_config = SamplingParams(max_tokens=10, + # max_tokens=2 keeps the smoke check tight. + sampling_config = SamplingParams(max_tokens=2, temperature=0.0, return_generation_logits=True) diff --git a/tests/unittest/_torch/models/test_nemotron_h_puzzle.py b/tests/unittest/_torch/models/test_nemotron_h_puzzle.py new file mode 100644 index 000000000000..4ecbc6bdd9b9 --- /dev/null +++ b/tests/unittest/_torch/models/test_nemotron_h_puzzle.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for NemotronHPuzzle model support.""" + +from dataclasses import dataclass +from types import SimpleNamespace + +import pytest + +from tensorrt_llm._torch.models.modeling_nemotron_h import ( + NemotronHForCausalLM, + _get_layer_moe_param, +) + + +@dataclass +class _MambaBlock: + block_type: str = "mamba" + + +@dataclass +class _MoeBlock: + block_type: str = "moe" + moe_intermediate_size: int = 1280 + n_routed_experts: int = 512 + num_experts_per_tok: int = 4 + moe_latent_size: int = 1024 + moe_shared_expert_intermediate_size: int = 5376 + + +def _make_puzzle_config(use_dataclass=False): + """Minimal config mimicking the real puzzle model.""" + if use_dataclass: + bcs = [ + _MambaBlock(), + _MoeBlock(num_experts_per_tok=4), + _MambaBlock(), + _MoeBlock(moe_intermediate_size=2048, num_experts_per_tok=12), + ] + else: + bcs = [ + {"block_type": "mamba"}, + { + "block_type": "moe", + "moe_intermediate_size": 1280, + "n_routed_experts": 512, + "num_experts_per_tok": 4, + "moe_latent_size": 1024, + "moe_shared_expert_intermediate_size": 5376, + }, + {"block_type": "mamba"}, + { + "block_type": "moe", + "moe_intermediate_size": 2048, + "n_routed_experts": 512, + "num_experts_per_tok": 12, + "moe_latent_size": 1024, + "moe_shared_expert_intermediate_size": 5376, + }, + ] + return SimpleNamespace( + block_configs=bcs, + mtp_block_configs=[ + {"block_type": "attention"}, + { + "block_type": "moe", + "moe_intermediate_size": 2688, + "n_routed_experts": 512, + "num_experts_per_tok": 22, + "moe_latent_size": 1024, + "moe_shared_expert_intermediate_size": 5376, + }, + ], + ) + + +class TestPerLayerMoeParams: + """The key change: block_configs can be dicts or HF dataclass objects, + and per-layer values must differ while MTP falls back to globals.""" + + @pytest.mark.parametrize("use_dc", [False, True], ids=["dict", "dataclass"]) + def test_varying_params_per_layer(self, use_dc): + config = _make_puzzle_config(use_dataclass=use_dc) + NemotronHForCausalLM._normalize_puzzle_config(config) + + # MoE layer 1: top_k=4, intermediate=1280 + assert _get_layer_moe_param(config, 1, "num_experts_per_tok") == 4 + assert _get_layer_moe_param(config, 1, "moe_intermediate_size") == 1280 + # MoE layer 3: top_k=12, intermediate=2048 + assert _get_layer_moe_param(config, 3, "num_experts_per_tok") == 12 + assert _get_layer_moe_param(config, 3, "moe_intermediate_size") == 2048 + + @pytest.mark.parametrize("use_dc", [False, True], ids=["dict", "dataclass"]) + def test_mtp_layer_gets_global_defaults(self, use_dc): + """MTP layer_idx beyond block_configs range uses globals from mtp_block_configs.""" + config = _make_puzzle_config(use_dataclass=use_dc) + NemotronHForCausalLM._normalize_puzzle_config(config) + + mtp_idx = len(config.block_configs) # beyond range + assert _get_layer_moe_param(config, mtp_idx, "num_experts_per_tok") == 22 + assert _get_layer_moe_param(config, mtp_idx, "moe_intermediate_size") == 2688 + + @pytest.mark.parametrize("use_dc", [False, True], ids=["dict", "dataclass"]) + def test_normalize_sets_all_global_attrs(self, use_dc): + config = _make_puzzle_config(use_dataclass=use_dc) + NemotronHForCausalLM._normalize_puzzle_config(config) + + for attr in ( + "n_routed_experts", + "moe_intermediate_size", + "num_experts_per_tok", + "moe_latent_size", + "moe_shared_expert_intermediate_size", + ): + assert getattr(config, attr) is not None, f"{attr} not set" + + def test_normalize_preserves_existing_attrs(self): + config = _make_puzzle_config() + config.n_routed_experts = 999 + NemotronHForCausalLM._normalize_puzzle_config(config) + assert config.n_routed_experts == 999 + + def test_normalize_noop_without_block_configs(self): + config = SimpleNamespace() + NemotronHForCausalLM._normalize_puzzle_config(config) + assert not hasattr(config, "n_routed_experts") + + def test_standard_config_passthrough(self): + """Non-puzzle model: no block_configs, returns global directly.""" + config = SimpleNamespace(n_routed_experts=512) + assert _get_layer_moe_param(config, 0, "n_routed_experts") == 512 diff --git a/tests/unittest/_torch/modules/moe/moe_test_utils.py b/tests/unittest/_torch/modules/moe/moe_test_utils.py index a26d0466ee87..8d123fb6c75a 100644 --- a/tests/unittest/_torch/modules/moe/moe_test_utils.py +++ b/tests/unittest/_torch/modules/moe/moe_test_utils.py @@ -167,31 +167,19 @@ def should_skip_trtllm( return None # Routing method compatibility check (used by test_moe_module.py) - # TRTLLMGen C++ routing kernel (runner.cu) only implements: - # - DeepSeekV3 (requires float32 routing_logits) + # TRTLLMGen C++ routing kernel (runner.cu) implements: + # - DeepSeekV3 (nGroup<=1: SigmoidBias+ScaledSumNormalize; nGroup>1: full DeepSeek kernel) + # - SigmoidRenorm (sigmoid activation, sum-normalize) + # - MiniMax2 (sigmoid activation, bias-added selection, scaled sum-normalize) # - Llama4 (requires top_k=1) - # - Renormalize - # - RenormalizeNaive - # See: cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu:77-212 + # - Renormalize / RenormalizeNaive / Default (softmax-based) + # See: cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu if routing_method_cls is not None: from tensorrt_llm._torch.modules.fused_moe import ( DeepSeekV3MoeRoutingMethod, - DefaultMoeRoutingMethod, Llama4RenormalizeMoeRoutingMethod, - MiniMaxM2MoeRoutingMethod, ) - # Routing methods NOT implemented in C++ kernel - trtllm_unimplemented_routing = ( - DefaultMoeRoutingMethod, # runner.cu:210 - "Unimplemented routing method" - MiniMaxM2MoeRoutingMethod, # runner.cu:210 - "Unimplemented routing method" - ) - if routing_method_cls in trtllm_unimplemented_routing: - routing_name = routing_method_cls.__name__ - return ( - f"TRTLLMGen C++ routing kernel does not implement {routing_name}. See runner.cu:210" - ) - # Llama4 routing only supports top_k=1 # See: runner.cu:113 - TLLM_CHECK_WITH_INFO(topK == 1, ...) if routing_method_cls == Llama4RenormalizeMoeRoutingMethod: diff --git a/tests/unittest/_torch/modules/moe/test_moe_module.py b/tests/unittest/_torch/modules/moe/test_moe_module.py index 527cbf1a3437..ebb4cb2ecbac 100644 --- a/tests/unittest/_torch/modules/moe/test_moe_module.py +++ b/tests/unittest/_torch/modules/moe/test_moe_module.py @@ -74,6 +74,7 @@ MiniMaxM2MoeRoutingMethod, RenormalizeMoeRoutingMethod, RenormalizeNaiveMoeRoutingMethod, + SigmoidRenormMoeRoutingMethod, create_moe, ) from tensorrt_llm._torch.modules.fused_moe.communication.deep_ep_low_latency import DeepEPLowLatency @@ -258,9 +259,7 @@ def _run_autotune_test( _ = run_forward_fn() # Check if we should run full tactic replay - if not run_all_tactics or not supports_autotuner_capture( - backend_type, quant_algo, use_flashinfer - ): + if not run_all_tactics or not supports_autotuner_capture(backend_type, quant_algo): # Simple accuracy check for unsupported backends or when run_all_tactics is False with torch.inference_mode(): output = run_forward_fn() @@ -370,6 +369,13 @@ def _create_routing_method(routing_method_cls, top_k, num_experts, dtype): callable_e_score_correction_bias=lambda: e_score_correction_bias, ) + # SigmoidRenorm routing method requires num_experts + if routing_method_cls == SigmoidRenormMoeRoutingMethod: + return routing_method_cls( + top_k=top_k, + num_experts=num_experts, + ) + # Fallback: try with just top_k return routing_method_cls(top_k=top_k) @@ -787,6 +793,7 @@ def init_worker(custom_paths, comm_method_type): Llama4RenormalizeMoeRoutingMethod, # Top1 -> Sigmoid (Llama4) DeepSeekV3MoeRoutingMethod, # Sigmoid -> BiasAdd -> Group TopK (DeepSeek-V3) MiniMaxM2MoeRoutingMethod, # Sigmoid -> BiasAdd -> TopK -> Renormalize (MiniMax-M2) + SigmoidRenormMoeRoutingMethod, # Sigmoid -> TopK -> Renormalize ] @@ -1161,7 +1168,7 @@ def test_configurable_moe_single_gpu( comm_methods=COMM_METHODS, swiglu_combos=SWIGLU_COMBOS, model_configs=MOE_MODEL_CONFIGS, - seq_lens=[8] if IS_CI_MODE else SEQ_LENS, + seq_lens=[1, 8] if IS_CI_MODE else SEQ_LENS, dtypes=DTYPES, backend_types=BACKEND_TYPES, quant_algos=QUANT_ALGOS, diff --git a/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py b/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py index 3f35261ddc61..2702d02f1130 100644 --- a/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py +++ b/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py @@ -869,7 +869,7 @@ def test_nvfp4_gather_grouped_gemm_swiglu_blackwell( c_ref, c_sf_ref = torch.ops.trtllm.fp4_quantize(c_ref, 1 / global_sf, sf_vec_size, False) # Call gather kernel (single-B via multi_b op with single-element lists) - c, c_sf = torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b( + c, c_sf = torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_act_fusion_blackwell_multi_b( a, [b_interleaved], a_sf_unswizzled, @@ -930,7 +930,7 @@ def test_nvfp4_gather_grouped_gemm_swiglu_blackwell( b_sf_interleaved_list = list(torch.split(b_sf_interleaved, split_sizes, dim=0)) alpha_list = list(torch.split(alpha, split_sizes, dim=0)) c_multi, c_sf_multi = ( - torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell_multi_b( + torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_act_fusion_blackwell_multi_b( a, b_interleaved_list, a_sf_unswizzled, diff --git a/tests/unittest/others/test_kv_cache_transceiver.py b/tests/unittest/others/test_kv_cache_transceiver.py index 6960b1e28d07..3db562baad13 100644 --- a/tests/unittest/others/test_kv_cache_transceiver.py +++ b/tests/unittest/others/test_kv_cache_transceiver.py @@ -440,13 +440,23 @@ def test_hybrid_cache_transceiver_single_process(backend, hybrid_dtypes, hybrid_cache_manager_gen.get_buffers(0), hybrid_cache_manager_ctx.get_buffers(0)), "different kv-cache values" - assert torch.equal(hybrid_cache_manager_gen.get_conv_states(1), - hybrid_cache_manager_ctx.get_conv_states( - 1)), "different mamba conv states" + # The transceiver copies a single request's state between + # independently-allocated slots on each side, so we check the + # request's own slot instead of the full state buffer (which has + # extra padding-dummy slots that only the ctx side touched). + slot_ctx = hybrid_cache_manager_ctx._impl.mamba_impl.get_cache_index( + ctx_request.py_request_id) + slot_gen = hybrid_cache_manager_gen._impl.mamba_impl.get_cache_index( + gen_request.py_request_id) + assert torch.equal( + hybrid_cache_manager_gen.get_conv_states(1)[slot_gen], + hybrid_cache_manager_ctx.get_conv_states(1)[slot_ctx]), ( + "different mamba conv states") - assert torch.equal(hybrid_cache_manager_gen.get_ssm_states(1), - hybrid_cache_manager_ctx.get_ssm_states( - 1)), "different mamba ssm states" + assert torch.equal( + hybrid_cache_manager_gen.get_ssm_states(1)[slot_gen], + hybrid_cache_manager_ctx.get_ssm_states(1)[slot_ctx]), ( + "different mamba ssm states") @pytest.mark.timeout(120)