Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,8 @@ def __init__(self,
num_experts=num_experts,
experts_per_token=top_k,
moe_ep_size=model_config.mapping.moe_ep_size,
dtype=torch.float32)
dtype=torch.float32,
ep_rank=model_config.mapping.moe_ep_rank)

def _compute_shared_expert_tp_size(
self, intermediate_size: int,
Expand Down Expand Up @@ -1052,6 +1053,7 @@ def _create_ideal_expert_load_balanced_logits(
num_experts=num_experts,
experts_per_token=self.top_k,
moe_ep_size=self.model_config.mapping.moe_ep_size,
ep_rank=self.model_config.mapping.moe_ep_rank,
device=device,
dtype=torch.float32)

Expand Down
4 changes: 3 additions & 1 deletion tensorrt_llm/_torch/models/modeling_gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ def __init__(
num_experts=pretrained_config.num_local_experts,
experts_per_token=pretrained_config.num_experts_per_tok,
moe_ep_size=config.mapping.moe_ep_size,
dtype=pretrained_config.torch_dtype)
dtype=pretrained_config.torch_dtype,
ep_rank=config.mapping.moe_ep_rank)

@staticmethod
def swiglu(x, alpha: float = 1.702):
Expand Down Expand Up @@ -227,6 +228,7 @@ def _create_ideal_expert_load_balanced_logits(
num_experts=num_experts,
experts_per_token=pretrained_config.experts_per_token,
moe_ep_size=self.config.mapping.moe_ep_size,
ep_rank=self.config.mapping.moe_ep_rank,
device=device,
dtype=pretrained_config.torch_dtype)

Expand Down
29 changes: 23 additions & 6 deletions tensorrt_llm/_torch/modules/fused_moe/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def get_perfect_router_cache_stats():
total_memory = 0
cached_batch_sizes = []

for (num_tokens, num_experts, experts_per_token,
moe_ep_size), logits in _PERFECT_ROUTER_LOGITS_CACHE.items():
for cache_key, logits in _PERFECT_ROUTER_LOGITS_CACHE.items():
num_tokens = cache_key[0]
total_memory += logits.numel() * logits.element_size()
cached_batch_sizes.append(num_tokens)

Expand All @@ -42,7 +42,8 @@ def get_perfect_router_cache_stats():
def precompute_common_perfect_router_logits(num_experts: int,
experts_per_token: int,
moe_ep_size: int,
dtype: torch.dtype):
dtype: torch.dtype,
ep_rank: int = 0):
Comment on lines +45 to +46
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix yapf formatting to pass CI.

The pipeline failure indicates this function signature needs formatting adjustment. The dtype and ep_rank parameters should be on properly formatted lines per yapf rules.

🔧 Suggested fix
 def precompute_common_perfect_router_logits(num_experts: int,
                                             experts_per_token: int,
                                             moe_ep_size: int,
-                                            dtype: torch.dtype,
-                                            ep_rank: int = 0):
+                                            dtype: torch.dtype, ep_rank: int = 0):

Note: Run yapf locally to confirm the exact formatting required.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
dtype: torch.dtype,
ep_rank: int = 0):
def precompute_common_perfect_router_logits(num_experts: int,
experts_per_token: int,
moe_ep_size: int,
dtype: torch.dtype, ep_rank: int = 0):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/fused_moe/routing.py` around lines 45 - 46, The
function signature with parameters "dtype: torch.dtype," and "ep_rank: int = 0"
is misformatted for yapf; open the function that contains these parameters (the
routing function in fused_moe/routing.py) and reformat the signature so
parameters are each on their own properly indented lines and the closing
parenthesis aligns per yapf style (e.g., one parameter per line with consistent
indentation), then run yapf to confirm the file now passes CI.

"""
Pre-compute logits for common batch sizes to avoid first-time computation overhead.
Only precomputes if cache is empty (avoids redundant work across multiple MLPBlock instances).
Expand Down Expand Up @@ -91,6 +92,7 @@ def precompute_common_perfect_router_logits(num_experts: int,
num_experts=num_experts,
experts_per_token=experts_per_token,
moe_ep_size=moe_ep_size,
ep_rank=ep_rank,
device=torch.device('cpu'), # Precompute on CPU
dtype=dtype)

Expand All @@ -110,6 +112,7 @@ def precompute_common_perfect_router_logits(num_experts: int,

def get_cached_perfect_router_logits(num_tokens: int, num_experts: int,
experts_per_token: int, moe_ep_size: int,
ep_rank: int,
device: torch.device,
dtype: torch.dtype) -> torch.Tensor:
"""
Expand All @@ -118,7 +121,8 @@ def get_cached_perfect_router_logits(num_tokens: int, num_experts: int,
"""
global _PERFECT_ROUTER_LOGITS_CACHE

cache_key = (num_tokens, num_experts, experts_per_token, moe_ep_size)
cache_key = (num_tokens, num_experts, experts_per_token, moe_ep_size,
ep_rank)

if cache_key in _PERFECT_ROUTER_LOGITS_CACHE:
# Return cached logits moved to the correct device
Expand All @@ -135,6 +139,7 @@ def get_cached_perfect_router_logits(num_tokens: int, num_experts: int,
num_experts=num_experts,
experts_per_token=experts_per_token,
moe_ep_size=moe_ep_size,
ep_rank=ep_rank,
device=device,
dtype=dtype)

Expand Down Expand Up @@ -655,6 +660,7 @@ def create_renormalize_expert_load_balanced_logits(
num_experts: int,
experts_per_token: int,
moe_ep_size: int,
ep_rank: int,
device: torch.device,
dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""
Expand All @@ -670,6 +676,8 @@ def create_renormalize_expert_load_balanced_logits(
The function creates routing logits that ensure perfect load balancing across GPUs
by cycling through experts in a GPU-aware pattern. Each token is assigned to
exactly k=experts_per_token experts, distributed evenly across all GPUs.
The schedule is offset by ``ep_rank`` so different sender ranks use the same
rank-aware pattern as ``tests/microbenchmarks/bench_moe_comm.py --perfect_router``.

Strategy:
1. First cycle through one expert from each GPU (GPU representatives)
Expand Down Expand Up @@ -732,6 +740,7 @@ def create_renormalize_expert_load_balanced_logits(
num_experts: Total number of experts
experts_per_token: Number of experts each token should be routed to (top-k)
moe_ep_size: Number of GPUs for MoE expert parallelism
ep_rank: Sender EP rank used to stagger the routing schedule
device: Device to create tensors on
dtype: Data type for the logits tensor

Expand All @@ -755,6 +764,10 @@ def create_renormalize_expert_load_balanced_logits(
if moe_ep_size == 0:
raise ValueError("moe_ep_size cannot be zero")

if not 0 <= ep_rank < moe_ep_size:
raise ValueError(
f"ep_rank ({ep_rank}) must be in [0, {moe_ep_size})")

Comment on lines +767 to +770
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix yapf formatting for CI compliance.

The pipeline failure indicates the ValueError f-string needs different wrapping. Run yapf to fix the formatting.

🔧 Suggested fix
     if not 0 <= ep_rank < moe_ep_size:
-        raise ValueError(
-            f"ep_rank ({ep_rank}) must be in [0, {moe_ep_size})")
+        raise ValueError(f"ep_rank ({ep_rank}) must be in [0, {moe_ep_size})")

Note: Run yapf locally to confirm the exact formatting required.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/fused_moe/routing.py` around lines 767 - 770, The
ValueError raise in routing.py (the check using ep_rank and moe_ep_size) is
misformatted for yapf; update the raise statement so the f-string message is
wrapped across lines according to the project's yapf style (e.g., put the
f-string on its own line inside the parentheses and close the parenthesis on the
following line) for the ep_rank/ moe_ep_size check, then run yapf to confirm
formatting; target the raise ValueError(...) that references ep_rank and
moe_ep_size.

# Create logits tensor on the same device and dtype as input
# Shape: [num_tokens, num_experts] - will hold routing probabilities
logits = torch.zeros(num_tokens, num_experts, device=device, dtype=dtype)
Expand All @@ -773,12 +786,16 @@ def create_renormalize_expert_load_balanced_logits(
# i_tensor: sequential indices from 0 to final_size-1
i_tensor = torch.arange(final_size, device=device)

# Match the bench_moe_comm perfect-router schedule by offsetting each sender
# rank before cycling over target EP ranks and local experts.
schedule = i_tensor + ep_rank

# gpu_idx: which GPU this assignment should go to (cycles through 0,1,2,3,0,1,2,3,...)
gpu_idx = i_tensor % num_gpus
gpu_idx = schedule % num_gpus

# expert_offset: which expert within the GPU (0,0,0,0,1,1,1,1,2,2,2,2,...)
# This ensures we use all experts from each GPU before moving to next expert
expert_offset = (i_tensor // num_gpus) % experts_per_gpu
expert_offset = (schedule // num_gpus) % experts_per_gpu

# indices: actual expert indices by combining GPU base + offset
indices = gpu_representatives[gpu_idx] + expert_offset
Expand Down
37 changes: 37 additions & 0 deletions tests/unittest/_torch/modules/test_moe_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def test_renormalize_expert_load_balanced_logits(num_tokens,
num_experts=num_experts,
experts_per_token=experts_per_token,
moe_ep_size=moe_ep_size,
ep_rank=0,
device=device,
dtype=torch.float32)

Expand Down Expand Up @@ -260,5 +261,41 @@ def test_renormalize_expert_load_balanced_logits(num_tokens,
expected_assignments), f"Load balance failed for {description}"


@pytest.mark.parametrize("num_tokens", [1, 3, 8])
@pytest.mark.parametrize("ep_rank", [0, 1, 3])
def test_rank_aware_perfect_router_matches_bench_moe_comm_schedule(
num_tokens: int, ep_rank: int) -> None:
"""Verify the rank-aware helper matches bench_moe_comm's perfect-router schedule."""
num_experts = 8
experts_per_token = 2
moe_ep_size = 4

logits = create_renormalize_expert_load_balanced_logits(
num_tokens=num_tokens,
num_experts=num_experts,
experts_per_token=experts_per_token,
moe_ep_size=moe_ep_size,
ep_rank=ep_rank,
device=torch.device("cpu"),
dtype=torch.float32)

routing = RenormalizeMoeRoutingMethod(
top_k=experts_per_token, force_enable_pytorch_op=True)
indices, _ = routing.apply(logits)

experts_per_rank = num_experts // moe_ep_size
flat_slots = torch.arange(num_tokens * experts_per_token, dtype=torch.int64)
schedule = flat_slots + ep_rank
expected = (
(schedule % moe_ep_size) * experts_per_rank +
(schedule // moe_ep_size) % experts_per_rank).view(
num_tokens, experts_per_token).to(torch.int32)

assert torch.equal(
torch.sort(indices.cpu(), dim=1).values,
torch.sort(expected, dim=1).values,
)

Comment on lines +264 to +298
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Good test coverage for rank-aware routing.

The new parametrized test correctly validates that:

  1. Logits generated with a specific ep_rank produce the expected expert assignments
  2. The schedule formula (schedule % moe_ep_size) * experts_per_rank + (schedule // moe_ep_size) % experts_per_rank matches the implementation in routing.py
  3. Order-insensitive comparison via torch.sort handles potential index ordering differences

Minor: Fix yapf formatting on line 279 per pipeline failure.

🔧 Run yapf to fix formatting

The pipeline failure indicates the RenormalizeMoeRoutingMethod call on lines 282-283 needs formatting adjustment. Run yapf locally to resolve.

🧰 Tools
🪛 GitHub Actions: Release Checks

[error] 279-279: Formatting check failed by pre-commit 'yapf' (file was modified): RenormalizeMoeRoutingMethod call formatting adjusted.


[error] 279-279: Formatting check failed by pre-commit 'yapf' (file was modified): expected expression wrapping/indentation adjusted.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/modules/test_moe_routing.py` around lines 264 - 298,
The test function test_rank_aware_perfect_router_matches_bench_moe_comm_schedule
has a YAPF formatting failure around the RenormalizeMoeRoutingMethod(...) call;
run yapf (or your repository formatter) on this file and reformat the call to a
single properly indented expression (e.g., keep
RenormalizeMoeRoutingMethod(top_k=experts_per_token,
force_enable_pytorch_op=True) on one line or align the named args vertically) so
the file passes the style pipeline while preserving the existing arguments and
behavior of RenormalizeMoeRoutingMethod and the surrounding assertions.


if __name__ == '__main__':
pytest.main()
Loading