Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 78 additions & 77 deletions cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2025-2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -123,8 +123,75 @@ struct TopKIdx<K_, true>
topK[J].compValIdx = pairMin; \
}

template <int N>
struct IsPowerOf2
{
static constexpr bool value = (N > 0) && ((N & (N - 1)) == 0);
};

template <int N, typename RedType>
struct Sort;
struct Sort
{
static_assert(N > 0 && N <= 32, "Sort supports N in [1, 32]");

static __device__ void run(RedType* topK)
{
if constexpr (IsPowerOf2<N>::value)
{
#pragma unroll
for (int k = 2; k <= N; k *= 2)
{
#pragma unroll
for (int j = k / 2; j > 0; j /= 2)
{
#pragma unroll
for (int i = 0; i < N; ++i)
{
int ixj = i ^ j;
if (ixj > i)
{
if ((i & k) == 0)
{
if (topK[i].compValIdx < topK[ixj].compValIdx)
{
auto tmp = topK[i].compValIdx;
topK[i].compValIdx = topK[ixj].compValIdx;
topK[ixj].compValIdx = tmp;
}
}
else
{
if (topK[i].compValIdx > topK[ixj].compValIdx)
{
auto tmp = topK[i].compValIdx;
topK[i].compValIdx = topK[ixj].compValIdx;
topK[ixj].compValIdx = tmp;
}
}
}
}
}
}
}
else
{
#pragma unroll
for (int pass = 0; pass < N; ++pass)
{
#pragma unroll
for (int i = 0; i < N - 1; i += 2)
{
TOPK_SWAP(i, i + 1);
}
#pragma unroll
for (int i = 1; i < N - 1; i += 2)
{
TOPK_SWAP(i, i + 1);
}
}
}
}
};

template <typename RedType>
struct Sort<1, RedType>
Expand Down Expand Up @@ -170,28 +237,27 @@ __forceinline__ __device__ void reduceTopK(cg::thread_block_tile<kWARP_SIZE> con
int32_t (&outIdx)[K], Type value, int32_t idx, Type const minValue, int actualK = K)
{
static_assert(K > 0, "Top K must have K > 0");
static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE");
static_assert(K <= kWARP_SIZE, "Top K must have K <= kWARP_SIZE");
using RedType = TopKRedType<Type>;
RedType topK{value, idx};
typename RedType::TypeCmp packedMax{};
#pragma unroll
for (int kk = 0; kk < actualK; ++kk) //@todo: check if actualK is correct
for (int kk = 0; kk < actualK; ++kk)
{
topK = kk > 0 && packedMax == topK.compValIdx ? RedType{minValue, idx} : topK;
// get the next largest value
packedMax = topK.reduce(warp);
RedType::unpack(out[kk], outIdx[kk], packedMax);
}
};

template <int K, typename Type, int N, bool IsSorted = false>
__device__ void reduceTopKFunc(cg::thread_block_tile<kWARP_SIZE> const& warp, Type (&out)[K], int32_t (&outIdx)[K],
Type (&value)[N], int32_t (&idx)[N], Type minValue, int actualK = K)
template <int K, typename Type, int N>
__forceinline__ __device__ void reduceTopK(cg::thread_block_tile<kWARP_SIZE> const& warp, Type (&out)[K],
int32_t (&outIdx)[K], Type (&value)[N], int32_t (&idx)[N], Type const minValue, int actualK = K)
{
static_assert(K > 0, "Top K must have K > 0");
static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE");
static_assert(K <= kWARP_SIZE, "Top K must have K <= kWARP_SIZE");
static_assert(N > 0, "Top K must have N > 0");
static_assert(N < 5, "Only support candidates number less than or equal to 128");
static_assert(N <= 32, "Only support candidates number less than or equal to 32*32=1024");
using RedType = TopKRedType<Type>;
RedType topK[N];
#pragma unroll
Expand All @@ -200,12 +266,9 @@ __device__ void reduceTopKFunc(cg::thread_block_tile<kWARP_SIZE> const& warp, Ty
topK[nn] = RedType{value[nn], idx[nn]};
}

if constexpr (!IsSorted)
{
Sort<N, RedType>::run(topK);
}
Sort<N, RedType>::run(topK);

typename RedType::TypeCmp packedMax{};
#pragma unroll
for (int kk = 0; kk < actualK; ++kk)
{
bool update = kk > 0 && packedMax == topK[0].compValIdx;
Expand All @@ -214,73 +277,11 @@ __device__ void reduceTopKFunc(cg::thread_block_tile<kWARP_SIZE> const& warp, Ty
{
topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} : update ? topK[nn + 1] : topK[nn];
}
// get the next largest value
packedMax = topK[0].reduce(warp);
RedType::unpack(out[kk], outIdx[kk], packedMax);
}
};

template <int K, typename Type, int N>
__forceinline__ __device__ void reduceTopK(cg::thread_block_tile<kWARP_SIZE> const& warp, Type (&out)[K],
int32_t (&outIdx)[K], Type (&value)[N], int32_t (&idx)[N], Type const minValue, int actualK = K)
{
static_assert(K > 0, "Top K must have K > 0");
static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE");
static_assert(N > 0, "Top K must have N > 0");
static_assert(N <= 16, "Only support candidates number less than or equal to 16*32=512");
static_assert(
N <= 4 || N % 4 == 0, "Only support candidates number is a multiple of 4*32=128 or less than or equal to 4");
using RedType = TopKRedType<Type>;

if constexpr (N <= 4)
{
reduceTopKFunc<K, Type, N>(warp, out, outIdx, value, idx, minValue, actualK);
}
else
{

constexpr int numLoops = N / 4;
constexpr int numResults = (numLoops * K - 1) / kWARP_SIZE + 1;

Type topKBufferValue[numResults];
int32_t topKBufferIdx[numResults];
int32_t laneIdx = threadIdx.x % kWARP_SIZE;

for (int ii = 0; ii < numResults; ++ii)
{
topKBufferValue[ii] = minValue;
topKBufferIdx[ii] = ii * kWARP_SIZE - 1; //@todo: check if this is correct
}
for (int loop = 0; loop < numLoops; ++loop)
{
int start = loop * 4;
Type topKValue[K];
int32_t topKIdx[K];
Type inValue[4];
int32_t inIdx[4];
for (int i = 0; i < 4; ++i)
{
inValue[i] = value[start + i];
inIdx[i] = idx[start + i];
}
reduceTopKFunc<K, Type, 4>(warp, topKValue, topKIdx, inValue, inIdx, minValue, actualK);
int inOffset = laneIdx % K;
if (laneIdx >= loop * K && laneIdx < (loop + 1) * K)
{
topKBufferValue[0] = topKValue[inOffset];
topKBufferIdx[0] = topKIdx[inOffset];
}
if (loop == numLoops - 1 && (laneIdx < (numLoops * K - kWARP_SIZE)))
{
topKBufferValue[1] = topKValue[inOffset];
topKBufferIdx[1] = topKIdx[inOffset];
}
}

reduceTopKFunc<K, Type, numResults>(warp, out, outIdx, topKBufferValue, topKBufferIdx, minValue, actualK);
}
};

#undef TOPK_SWAP

} // namespace reduce_topk
Expand Down
Loading
Loading