diff --git a/cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh b/cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh index c94ff267e5cf..efe0f396ae05 100644 --- a/cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh +++ b/cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh @@ -1,6 +1,6 @@ /* - * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2025-2026, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -123,8 +123,75 @@ struct TopKIdx topK[J].compValIdx = pairMin; \ } +template +struct IsPowerOf2 +{ + static constexpr bool value = (N > 0) && ((N & (N - 1)) == 0); +}; + template -struct Sort; +struct Sort +{ + static_assert(N > 0 && N <= 32, "Sort supports N in [1, 32]"); + + static __device__ void run(RedType* topK) + { + if constexpr (IsPowerOf2::value) + { +#pragma unroll + for (int k = 2; k <= N; k *= 2) + { +#pragma unroll + for (int j = k / 2; j > 0; j /= 2) + { +#pragma unroll + for (int i = 0; i < N; ++i) + { + int ixj = i ^ j; + if (ixj > i) + { + if ((i & k) == 0) + { + if (topK[i].compValIdx < topK[ixj].compValIdx) + { + auto tmp = topK[i].compValIdx; + topK[i].compValIdx = topK[ixj].compValIdx; + topK[ixj].compValIdx = tmp; + } + } + else + { + if (topK[i].compValIdx > topK[ixj].compValIdx) + { + auto tmp = topK[i].compValIdx; + topK[i].compValIdx = topK[ixj].compValIdx; + topK[ixj].compValIdx = tmp; + } + } + } + } + } + } + } + else + { +#pragma unroll + for (int pass = 0; pass < N; ++pass) + { +#pragma unroll + for (int i = 0; i < N - 1; i += 2) + { + TOPK_SWAP(i, i + 1); + } +#pragma unroll + for (int i = 1; i < N - 1; i += 2) + { + TOPK_SWAP(i, i + 1); + } + } + } + } +}; template struct Sort<1, RedType> @@ -170,28 +237,27 @@ __forceinline__ __device__ void reduceTopK(cg::thread_block_tile con int32_t (&outIdx)[K], Type value, int32_t idx, Type const minValue, int actualK = K) { static_assert(K > 0, "Top K must have K > 0"); - static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE"); + static_assert(K <= kWARP_SIZE, "Top K must have K <= kWARP_SIZE"); using RedType = TopKRedType; RedType topK{value, idx}; typename RedType::TypeCmp packedMax{}; #pragma unroll - for (int kk = 0; kk < actualK; ++kk) //@todo: check if actualK is correct + for (int kk = 0; kk < actualK; ++kk) { topK = kk > 0 && packedMax == topK.compValIdx ? RedType{minValue, idx} : topK; - // get the next largest value packedMax = topK.reduce(warp); RedType::unpack(out[kk], outIdx[kk], packedMax); } }; -template -__device__ void reduceTopKFunc(cg::thread_block_tile const& warp, Type (&out)[K], int32_t (&outIdx)[K], - Type (&value)[N], int32_t (&idx)[N], Type minValue, int actualK = K) +template +__forceinline__ __device__ void reduceTopK(cg::thread_block_tile const& warp, Type (&out)[K], + int32_t (&outIdx)[K], Type (&value)[N], int32_t (&idx)[N], Type const minValue, int actualK = K) { static_assert(K > 0, "Top K must have K > 0"); - static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE"); + static_assert(K <= kWARP_SIZE, "Top K must have K <= kWARP_SIZE"); static_assert(N > 0, "Top K must have N > 0"); - static_assert(N < 5, "Only support candidates number less than or equal to 128"); + static_assert(N <= 32, "Only support candidates number less than or equal to 32*32=1024"); using RedType = TopKRedType; RedType topK[N]; #pragma unroll @@ -200,12 +266,9 @@ __device__ void reduceTopKFunc(cg::thread_block_tile const& warp, Ty topK[nn] = RedType{value[nn], idx[nn]}; } - if constexpr (!IsSorted) - { - Sort::run(topK); - } + Sort::run(topK); + typename RedType::TypeCmp packedMax{}; -#pragma unroll for (int kk = 0; kk < actualK; ++kk) { bool update = kk > 0 && packedMax == topK[0].compValIdx; @@ -214,73 +277,11 @@ __device__ void reduceTopKFunc(cg::thread_block_tile const& warp, Ty { topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} : update ? topK[nn + 1] : topK[nn]; } - // get the next largest value packedMax = topK[0].reduce(warp); RedType::unpack(out[kk], outIdx[kk], packedMax); } }; -template -__forceinline__ __device__ void reduceTopK(cg::thread_block_tile const& warp, Type (&out)[K], - int32_t (&outIdx)[K], Type (&value)[N], int32_t (&idx)[N], Type const minValue, int actualK = K) -{ - static_assert(K > 0, "Top K must have K > 0"); - static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE"); - static_assert(N > 0, "Top K must have N > 0"); - static_assert(N <= 16, "Only support candidates number less than or equal to 16*32=512"); - static_assert( - N <= 4 || N % 4 == 0, "Only support candidates number is a multiple of 4*32=128 or less than or equal to 4"); - using RedType = TopKRedType; - - if constexpr (N <= 4) - { - reduceTopKFunc(warp, out, outIdx, value, idx, minValue, actualK); - } - else - { - - constexpr int numLoops = N / 4; - constexpr int numResults = (numLoops * K - 1) / kWARP_SIZE + 1; - - Type topKBufferValue[numResults]; - int32_t topKBufferIdx[numResults]; - int32_t laneIdx = threadIdx.x % kWARP_SIZE; - - for (int ii = 0; ii < numResults; ++ii) - { - topKBufferValue[ii] = minValue; - topKBufferIdx[ii] = ii * kWARP_SIZE - 1; //@todo: check if this is correct - } - for (int loop = 0; loop < numLoops; ++loop) - { - int start = loop * 4; - Type topKValue[K]; - int32_t topKIdx[K]; - Type inValue[4]; - int32_t inIdx[4]; - for (int i = 0; i < 4; ++i) - { - inValue[i] = value[start + i]; - inIdx[i] = idx[start + i]; - } - reduceTopKFunc(warp, topKValue, topKIdx, inValue, inIdx, minValue, actualK); - int inOffset = laneIdx % K; - if (laneIdx >= loop * K && laneIdx < (loop + 1) * K) - { - topKBufferValue[0] = topKValue[inOffset]; - topKBufferIdx[0] = topKIdx[inOffset]; - } - if (loop == numLoops - 1 && (laneIdx < (numLoops * K - kWARP_SIZE))) - { - topKBufferValue[1] = topKValue[inOffset]; - topKBufferIdx[1] = topKIdx[inOffset]; - } - } - - reduceTopKFunc(warp, out, outIdx, topKBufferValue, topKBufferIdx, minValue, actualK); - } -}; - #undef TOPK_SWAP } // namespace reduce_topk diff --git a/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu b/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu index 21f68c71824d..8256c4e6ca73 100644 --- a/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu +++ b/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2019-2026, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -32,15 +32,13 @@ TRTLLM_NAMESPACE_BEGIN namespace kernels { static constexpr int WARP_SIZE = 32; -static constexpr int NumNemotronExperts = 512; -static constexpr int NumKimiK2Experts = 384; static constexpr int NumDeepseekExperts = 256; -static constexpr int MaxSupportedExpertCount = std::max({NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts}); -static constexpr int MaxNumExpertsUnit = 128; +static constexpr int MaxSupportedExpertCount = 1024; static constexpr int NumTopGroupScores = 2; static constexpr int DefaultMaxNumTopExperts = 8; -static constexpr int MaxSupportedTopExperts = 22; -static constexpr int MaxNumTopGroups = 4; +static constexpr int MaxSupportedTopExperts = 32; +static constexpr int DefaultMaxNumTopGroups = 4; +static constexpr int LargeMaxNumTopGroups = 8; static __device__ inline float sigmoid_accurate(float x) { @@ -48,7 +46,7 @@ static __device__ inline float sigmoid_accurate(float x) } template + int MaxNumTopExperts = DefaultMaxNumTopExperts, int MaxNumTopGroups = DefaultMaxNumTopGroups> __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, IdxT* topkIndices, BiasT* routingBias, int64_t const numTokens, int64_t const numGroup, int64_t const topkGroup, int64_t const topk, int64_t const numExperts, int64_t const numExpertsPerGroup, double const routedScalingFactor) @@ -57,208 +55,137 @@ __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, Idx cudaGridDependencySynchronize(); #endif - // declare shared memory structure - // number of experts is bounded by number of threads __shared__ float __attribute((aligned(128))) smemScoreSigmoid[MaxNumExperts]; __shared__ float __attribute((aligned(128))) smemScoreBias[MaxNumExperts]; - // number of expert groups is bounded by number of warps - int constexpr NumWarps = MaxNumExperts / WARP_SIZE; - __shared__ float __attribute((aligned(128))) smemGroupScores[NumWarps]; - // needed for warp reduce auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); - // for the final reduction of weight norm, only some lanes need to participate int32_t laneIdx = threadIdx.x % WARP_SIZE; int32_t warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WARP_SIZE, 0); - if constexpr (UseGroups) - { - if (warpIdx >= numGroup) - { - return; - } - } - - // note that for invalid scores, we simply use a negative value: - // they work well even with the compacted format used in topK, and - // sigmoid / bias activated scores cannot be negative static constexpr float invalidScoreFloat = float{-INFINITY}; - const OutputT invalidScore = OutputT{invalidScoreFloat}; - // load bias already; each warp represents one expert group - auto threadExpert = threadIdx.x; - bool expertSelected = threadExpert < numExperts; - if constexpr (UseGroups) - { - threadExpert = warpIdx * numExpertsPerGroup + laneIdx; - expertSelected = laneIdx < numExpertsPerGroup; - } - - auto scoreIdx = int64_t{blockIdx.x} * int64_t{numExperts} + threadExpert; - auto biasVal = expertSelected ? static_cast(routingBias[threadExpert]) : invalidScoreFloat; topkValues += blockIdx.x * topk; topkIndices += blockIdx.x * topk; - // get our assigned thread score; each warp represents one expert group - float score = expertSelected ? static_cast(scores[scoreIdx]) : invalidScoreFloat; - auto scoreSigmoid = sigmoid_accurate(score); - // write the sigmoid score to shared for later use - if (expertSelected) + if constexpr (UseGroups) { - smemScoreSigmoid[threadExpert] = scoreSigmoid; - } + int constexpr NumWarps = MaxNumExperts / WARP_SIZE; + __shared__ float __attribute((aligned(128))) smemGroupScores[NumWarps]; - // get the score with bias - // note that with invalid values, because sigmoid is < 1 and bias is -1, - // we must get a negative value, which is smaller than any valid value - auto scoreBias = float{scoreSigmoid + float{biasVal}}; + if (warpIdx >= numGroup) + { + return; + } - if (expertSelected) - { - smemScoreBias[threadExpert] = scoreBias; - } + auto threadExpert = warpIdx * numExpertsPerGroup + laneIdx; + bool expertSelected = laneIdx < numExpertsPerGroup; - // registers for top group score reduction - float topExpGroupScores[NumTopGroupScores]; - [[maybe_unused]] int32_t topExpGroupIdx[NumTopGroupScores]; - float topGroups[MaxNumTopGroups]; // bound of numGroup - int32_t topGroupIdx[MaxNumTopGroups]; - float expertScoreGroup[MaxNumTopGroups]; - int32_t expertIdxGroup[MaxNumTopGroups]; - float topScores[MaxNumTopExperts]; // bound of topk - int32_t topExperts[MaxNumTopExperts]; + auto scoreIdx = int64_t{blockIdx.x} * int64_t{numExperts} + threadExpert; + auto biasVal = expertSelected ? static_cast(routingBias[threadExpert]) : invalidScoreFloat; + float score = expertSelected ? static_cast(scores[scoreIdx]) : invalidScoreFloat; + auto scoreSigmoid = sigmoid_accurate(score); + if (expertSelected) + { + smemScoreSigmoid[threadExpert] = scoreSigmoid; + } + auto scoreBias = float{scoreSigmoid + float{biasVal}}; + if (expertSelected) + { + smemScoreBias[threadExpert] = scoreBias; + } - if constexpr (UseGroups) - { + float topExpGroupScores[NumTopGroupScores]; + [[maybe_unused]] int32_t topExpGroupIdx[NumTopGroupScores]; reduce_topk::reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias, threadExpert, /* minValue */ invalidScoreFloat); - // get the final group score and write it to shared if (warp.thread_rank() == 0) { auto groupScore = topExpGroupScores[0] + topExpGroupScores[1]; smemGroupScores[warpIdx] = groupScore; } - } - // make group scores available to all warps - __syncthreads(); + __syncthreads(); + + float topScores[MaxNumTopExperts]; + int32_t topExperts[MaxNumTopExperts]; - if constexpr (UseGroups) - { if (warpIdx == 0) { - // a single warp performs the selection of top groups, and goes on to select the final experts + float topGroups[MaxNumTopGroups]; + int32_t topGroupIdx[MaxNumTopGroups]; float groupScore = laneIdx < numGroup ? smemGroupScores[laneIdx] : invalidScoreFloat; - reduce_topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx, /* minValue */ invalidScoreFloat); - // final expert selection: get relevant indexes and scores from shared + + float expertScoreGroup[MaxNumTopGroups]; + int32_t expertIdxGroup[MaxNumTopGroups]; #pragma unroll for (int ii = 0; ii < MaxNumTopGroups; ++ii) - { // bound of numGroup + { auto groupIdx = topGroupIdx[ii]; expertIdxGroup[ii] = groupIdx * numExpertsPerGroup + laneIdx; - expertScoreGroup[ii] = (ii < topkGroup) && expertSelected ? smemScoreBias[expertIdxGroup[ii]] : invalidScoreFloat; } - tensorrt_llm::kernels::reduce_topk::reduceTopK( + reduce_topk::reduceTopK( warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, /* minValue */ invalidScoreFloat, topk); - } - } - else if constexpr (MaxNumExperts > MaxNumExpertsUnit) - { - // without groups, and the expert number is larger than MaxNumExpertsUnit, - // we need to use multiple warps to calculate the intermediate topk results - - int constexpr NumExpertWarps = (MaxNumExperts - 1) / MaxNumExpertsUnit + 1; - int constexpr NumInterTopK = NumExpertWarps * MaxNumTopExperts; - __shared__ float __attribute((aligned(128))) smemInterTopScores[NumInterTopK]; - __shared__ int32_t __attribute((aligned(128))) smemInterTopExperts[NumInterTopK]; - if (warpIdx < NumExpertWarps) - { - int offset = warpIdx * WARP_SIZE * MaxNumTopGroups; -#pragma unroll - for (int ii = 0; ii < MaxNumTopGroups; ++ii) - { - auto expertIdx = ii * WARP_SIZE + laneIdx; - expertIdxGroup[ii] = offset + expertIdx; - expertScoreGroup[ii] - = offset + expertIdx < numExperts ? smemScoreBias[offset + expertIdx] : invalidScoreFloat; - } - reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, - /* minValue */ invalidScoreFloat, topk); + int32_t expertIdx = laneIdx < topk ? topExperts[laneIdx] : MaxNumExperts - 1; + float scoreNorm = laneIdx < topk ? smemScoreSigmoid[expertIdx] : 0.F; + auto redNorm = cg::reduce(warp, scoreNorm, cg::plus{}); + auto finalScore = static_cast(scoreNorm * routedScalingFactor / (redNorm + 1e-20)); if (laneIdx < topk) { - smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = topScores[laneIdx]; - smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = topExperts[laneIdx]; - } - else if (laneIdx >= topk && laneIdx < MaxNumTopExperts) - { - smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = invalidScoreFloat; - smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = MaxNumExperts - 1; - } - } - __syncthreads(); - if (warpIdx == 0) - { - int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WARP_SIZE + 1; - float intermediateScore[NumInterTopKPerThread]; - int32_t intermediateExpert[NumInterTopKPerThread]; - for (int i = laneIdx; i < NumInterTopKPerThread * WARP_SIZE; i += WARP_SIZE) - { - int ii = i / WARP_SIZE; - if (i < NumInterTopK) - { - intermediateScore[ii] = smemInterTopScores[i]; - intermediateExpert[ii] = smemInterTopExperts[i]; - } - else - { - intermediateScore[ii] = invalidScoreFloat; - intermediateExpert[ii] = MaxNumExperts - 1; - } + topkValues[laneIdx] = static_cast(finalScore); + topkIndices[laneIdx] = expertIdx; } - reduce_topk::reduceTopK(warp, topScores, topExperts, intermediateScore, intermediateExpert, - /* minValue */ invalidScoreFloat, topk); } } else { - // without groups, and the expert number is smaller than MaxNumExpertsUnit - // each thread just takes `MaxNumTopGroups` experts + for (int e = threadIdx.x; e < numExperts; e += blockDim.x) + { + auto scoreIdx = int64_t{blockIdx.x} * int64_t{numExperts} + e; + auto biasVal = static_cast(routingBias[e]); + float score = static_cast(scores[scoreIdx]); + auto scoreSigmoid = sigmoid_accurate(score); + smemScoreSigmoid[e] = scoreSigmoid; + smemScoreBias[e] = scoreSigmoid + biasVal; + } + + __syncthreads(); + + float topScores[MaxNumTopExperts]; + int32_t topExperts[MaxNumTopExperts]; + if (warpIdx == 0) { + constexpr int NumChunks = (MaxNumExperts + WARP_SIZE - 1) / WARP_SIZE; + float localScores[NumChunks]; + int32_t localIdx[NumChunks]; #pragma unroll - for (int ii = 0; ii < MaxNumTopGroups; ++ii) + for (int ii = 0; ii < NumChunks; ++ii) { auto expertIdx = ii * WARP_SIZE + laneIdx; - expertIdxGroup[ii] = expertIdx; - expertScoreGroup[ii] = expertIdx < numExperts ? smemScoreBias[expertIdx] : invalidScoreFloat; + localIdx[ii] = expertIdx; + localScores[ii] = expertIdx < numExperts ? smemScoreBias[expertIdx] : invalidScoreFloat; } - reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, + reduce_topk::reduceTopK(warp, topScores, topExperts, localScores, localIdx, /* minValue */ invalidScoreFloat, topk); - } - } - if (warpIdx == 0) - { - // determine our lane's expert index and write to output - int32_t expertIdx = laneIdx < topk ? topExperts[laneIdx] : MaxNumExperts - 1; - // norm the value - float scoreNorm = laneIdx < topk ? smemScoreSigmoid[expertIdx] : 0.F; - auto redNorm = cg::reduce(warp, scoreNorm, cg::plus{}); - auto finalScore = static_cast(scoreNorm * routedScalingFactor / (redNorm + 1e-20)); - // store the topk scores and experts to output - if (laneIdx < topk) - { - topkValues[laneIdx] = static_cast(finalScore); - topkIndices[laneIdx] = expertIdx; + int32_t expertIdx = laneIdx < topk ? topExperts[laneIdx] : MaxNumExperts - 1; + float scoreNorm = laneIdx < topk ? smemScoreSigmoid[expertIdx] : 0.F; + auto redNorm = cg::reduce(warp, scoreNorm, cg::plus{}); + auto finalScore = static_cast(scoreNorm * routedScalingFactor / (redNorm + 1e-20)); + if (laneIdx < topk) + { + topkValues[laneIdx] = static_cast(finalScore); + topkIndices[laneIdx] = expertIdx; + } } } @@ -272,43 +199,57 @@ void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk int64_t const num_experts, int64_t const n_group, int64_t const topk_group, int64_t const topk, double const routed_scaling_factor, cudaStream_t const stream) { - - // Check if we can use the optimized deepseek_v3_topk_kernel - bool const is_single_group = (n_group <= 1) && (num_experts <= MaxSupportedExpertCount); + bool const is_single_group + = (n_group <= 1) && (num_experts <= MaxSupportedExpertCount) && (topk <= MaxSupportedTopExperts); int64_t const experts_per_group = num_experts / n_group; bool const is_multi_group = (n_group > 1) && (num_experts <= NumDeepseekExperts) && (experts_per_group <= WARP_SIZE) - && (experts_per_group * topk_group <= MaxNumExpertsUnit); + && (topk <= DefaultMaxNumTopExperts) && (experts_per_group * topk_group <= LargeMaxNumTopGroups * WARP_SIZE); if (is_single_group || is_multi_group) { cudaLaunchConfig_t config; auto* kernel_instance = &deepseek_v3_topk_kernel; int num_threads = NumDeepseekExperts; - if (is_single_group) + + if (is_multi_group) { - // Special case for Nemotron, which selects top 22 from 512 experts, and 1 group only. - if (num_experts == NumNemotronExperts && n_group == 1 && topk == MaxSupportedTopExperts) + if (experts_per_group * topk_group <= DefaultMaxNumTopGroups * WARP_SIZE) { - kernel_instance = &deepseek_v3_topk_kernel; - num_threads = NumNemotronExperts; + kernel_instance = &deepseek_v3_topk_kernel; } - else if (num_experts > NumKimiK2Experts && num_experts <= MaxSupportedExpertCount) + else + { + kernel_instance = &deepseek_v3_topk_kernel; + } + num_threads = NumDeepseekExperts; + } + else if (is_single_group) + { + if (num_experts <= 128) { kernel_instance - = &deepseek_v3_topk_kernel; - num_threads = MaxSupportedExpertCount; + = &deepseek_v3_topk_kernel; + num_threads = 128; } - else if (num_experts > MaxNumExpertsUnit && num_experts <= NumKimiK2Experts) + else if (num_experts <= 256) { - kernel_instance = &deepseek_v3_topk_kernel; - num_threads = NumKimiK2Experts; + kernel_instance + = &deepseek_v3_topk_kernel; + num_threads = 256; + } + else if (num_experts <= 512) + { + kernel_instance + = &deepseek_v3_topk_kernel; + num_threads = 256; } else { - kernel_instance = &deepseek_v3_topk_kernel; - num_threads = MaxNumExpertsUnit; + kernel_instance + = &deepseek_v3_topk_kernel; + num_threads = 256; } } @@ -328,11 +269,10 @@ void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk } else { - // TODO: call the generic path (previous implementation) or signal unsupported config. TLLM_CHECK_WITH_INFO(false, - "invokeNoAuxTc: unsupported configuration (n_group=%ld, num_experts=%ld, topk_group=%ld). Please use " - "original pytorch implementation.", - n_group, num_experts, topk_group); + "invokeNoAuxTc: unsupported configuration (n_group=%ld, num_experts=%ld, topk_group=%ld, topk=%ld). " + "Please use original pytorch implementation.", + n_group, num_experts, topk_group, topk); } } diff --git a/tensorrt_llm/_torch/modules/fused_moe/routing.py b/tensorrt_llm/_torch/modules/fused_moe/routing.py index 69498c96cfc3..e9424d863d6d 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/routing.py +++ b/tensorrt_llm/_torch/modules/fused_moe/routing.py @@ -257,16 +257,15 @@ def noaux_tc(self, logits, e_score_correction_bias): _, num_experts = logits.shape if self.n_group > 1: - if self.top_k > 8 or (num_experts / n_group) > 32 or ( - num_experts / n_group) * self.topk_group > 128: + experts_per_group = num_experts // n_group + if (self.top_k > 8 or num_experts > 256 or experts_per_group > 32 + or experts_per_group * self.topk_group > 256): if self.is_fused: warnings.warn( "The configuration is not supported by the fused routing kernel. We have to use the original pytorch implementation." ) self.is_fused = False - elif (num_experts > 512 or (self.top_k > 8 and self.top_k != 22) - or (self.topk_group == 1 and self.top_k != 22)): - # We have special implementation for n_group == 1, top_k == 22 and num_experts == 512 for Nemotron Super v3. + elif num_experts > 1024 or self.top_k > 32: if self.is_fused: warnings.warn( "The configuration is not supported by the fused routing kernel. We have to use the original pytorch implementation."