Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tensorrt_llm/_torch/modules/fused_moe/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,9 @@ def noaux_tc(self, logits, e_score_correction_bias):
"The configuration is not supported by the fused routing kernel. We have to use the original pytorch implementation."
)
self.is_fused = False
elif (num_experts > 512 or (self.top_k > 8 and self.top_k != 22)
or (self.topk_group == 1 and self.top_k != 22)):
Comment thread
Wanli-Jiang marked this conversation as resolved.
# We have special implementation for n_group == 1, top_k == 22 and num_experts == 512 for Nemotron Super v3.
elif num_experts > 512 or (self.top_k > 8 and self.top_k != 22):
# The fused noaux_tc_op kernel supports n_group==1 with top_k<=8
# or top_k==22, and num_experts<=512.
if self.is_fused:
warnings.warn(
"The configuration is not supported by the fused routing kernel. We have to use the original pytorch implementation."
Expand Down
17 changes: 13 additions & 4 deletions tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,16 @@ def post_load_weights(self):
and self.norm.nvfp4_scale is None):
self._try_attach_nvfp4_scale()

# Pre-expand A, D, dt_bias for the decode path.
self._A_expanded = repeat(self.A,
"h -> h p n",
p=self.head_dim,
n=self.d_state).to(dtype=torch.float32)
self._dt_bias_expanded = repeat(self.dt_bias,
"h -> h p",
p=self.head_dim)
self._D_expanded = repeat(self.D, "h -> h p", p=self.head_dim)
Comment thread
Wanli-Jiang marked this conversation as resolved.

def _try_attach_nvfp4_scale(self):
"""Attach input_scale from out_proj to norm for fused RMSNorm+Quant."""

Expand Down Expand Up @@ -471,10 +481,9 @@ def convert_dt():
g=self.tp_ngroups).contiguous()
z_d = rearrange(z_d, "b (h p) -> b h p", p=self.head_dim)

A = repeat(self.A, "h -> h p n", p=self.head_dim,
n=self.d_state).to(dtype=torch.float32)
dt_bias = repeat(self.dt_bias, "h -> h p", p=self.head_dim)
D = repeat(self.D, "h -> h p", p=self.head_dim)
A = self._A_expanded
dt_bias = self._dt_bias_expanded
D = self._D_expanded
if is_target_verify:
intermediate_ssm_states = layer_cache.intermediate_ssm
x_d_mtp = x_d.view(
Expand Down
Loading