diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/trtllmGen_fmha_export/KernelTraits.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/trtllmGen_fmha_export/KernelTraits.h index 46ddcdc34a59..052845834ea8 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/trtllmGen_fmha_export/KernelTraits.h +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/trtllmGen_fmha_export/KernelTraits.h @@ -127,9 +127,13 @@ struct KernelConfig : public KernelConfigBase { } // Set numStagesQ for headDim > 128 kernels. - if (mNumInstsQ * mNumInstsKv == 1) { + bool const isGenerationSkipSoftmax = + options.mSkipsSoftmaxWhenPossible && !isContextKernel(options.mFmhaKernelType); + if (mNumInstsQ * mNumInstsKv == 1 && !isGenerationSkipSoftmax) { TLLM_CHECK_INFO(mTileSizeQ == 64 || (mHeadDimQk > 128 && mHeadDimV > 128), "Consider using numInstsQ = 2 for better performance."); + } + if (mNumInstsQ * mNumInstsKv == 1) { // There is no enough shared memory for 2 stages when the headDim is not split into multiple // stages. if (mHeadDimPerStageKv == 0 && keepsMmaAbForDsMlaGen) {