diff --git a/tensorrt_llm/_torch/modules/fused_moe/routing.py b/tensorrt_llm/_torch/modules/fused_moe/routing.py index 69498c96cfc3..db3de6ea7ad6 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/routing.py +++ b/tensorrt_llm/_torch/modules/fused_moe/routing.py @@ -264,9 +264,9 @@ def noaux_tc(self, logits, e_score_correction_bias): "The configuration is not supported by the fused routing kernel. We have to use the original pytorch implementation." ) self.is_fused = False - elif (num_experts > 512 or (self.top_k > 8 and self.top_k != 22) - or (self.topk_group == 1 and self.top_k != 22)): - # We have special implementation for n_group == 1, top_k == 22 and num_experts == 512 for Nemotron Super v3. + elif num_experts > 512 or (self.top_k > 8 and self.top_k != 22): + # The fused noaux_tc_op kernel supports n_group==1 with top_k<=8 + # or top_k==22, and num_experts<=512. if self.is_fused: warnings.warn( "The configuration is not supported by the fused routing kernel. We have to use the original pytorch implementation." diff --git a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py index fdb5004495e3..23557ee2011c 100644 --- a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py +++ b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py @@ -245,6 +245,16 @@ def post_load_weights(self): and self.norm.nvfp4_scale is None): self._try_attach_nvfp4_scale() + # Pre-expand A, D, dt_bias for the decode path. + self._A_expanded = repeat(self.A, + "h -> h p n", + p=self.head_dim, + n=self.d_state).to(dtype=torch.float32) + self._dt_bias_expanded = repeat(self.dt_bias, + "h -> h p", + p=self.head_dim) + self._D_expanded = repeat(self.D, "h -> h p", p=self.head_dim) + def _try_attach_nvfp4_scale(self): """Attach input_scale from out_proj to norm for fused RMSNorm+Quant.""" @@ -471,10 +481,9 @@ def convert_dt(): g=self.tp_ngroups).contiguous() z_d = rearrange(z_d, "b (h p) -> b h p", p=self.head_dim) - A = repeat(self.A, "h -> h p n", p=self.head_dim, - n=self.d_state).to(dtype=torch.float32) - dt_bias = repeat(self.dt_bias, "h -> h p", p=self.head_dim) - D = repeat(self.D, "h -> h p", p=self.head_dim) + A = self._A_expanded + dt_bias = self._dt_bias_expanded + D = self._D_expanded if is_target_verify: intermediate_ssm_states = layer_cache.intermediate_ssm x_d_mtp = x_d.view(