Skip to content

[None][perf] make perfect router rank-aware for EP MoE#13175

Open
bobboli wants to merge 1 commit intoNVIDIA:mainfrom
bobboli:user/lbo/trtllm-perfect-router-pr
Open

[None][perf] make perfect router rank-aware for EP MoE#13175
bobboli wants to merge 1 commit intoNVIDIA:mainfrom
bobboli:user/lbo/trtllm-perfect-router-pr

Conversation

@bobboli
Copy link
Copy Markdown
Collaborator

@bobboli bobboli commented Apr 18, 2026

Summary

  • make Renormalize perfect-router logits rank-aware by offsetting the EP routing schedule with ep_rank
  • thread mapping.moe_ep_rank through the cached perfect-router helpers used by GPT-OSS and DeepSeek-V3 MoE paths
  • add a routing unit test that checks the generated expert assignments match the bench_moe_comm.py --perfect_router schedule

Testing

  • python -m py_compile tensorrt_llm/_torch/modules/fused_moe/routing.py tensorrt_llm/_torch/models/modeling_gpt_oss.py tensorrt_llm/_torch/models/modeling_deepseekv3.py tests/unittest/_torch/modules/test_moe_routing.py
  • PYTHONPATH=. pytest -q tests/unittest/_torch/modules/test_moe_routing.py -k "rank_aware_perfect_router or renormalize_expert_load_balanced_logits"

Notes

  • this PR intentionally does not include benchmark scripts or result artifacts

Summary by CodeRabbit

Release Notes

  • New Features

    • Enhanced Mixture of Experts (MoE) routing with rank-aware expert assignment and improved load balancing calculations across expert partitions.
    • Extended logits caching to incorporate partition-specific routing context for optimized performance.
  • Tests

    • Added comprehensive test coverage for rank-aware expert routing scenarios.

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
@bobboli bobboli requested review from a team as code owners April 18, 2026 09:40
@bobboli bobboli requested review from HuiGao-NV, hlu1 and syuoni April 18, 2026 09:40
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 18, 2026

📝 Walkthrough

Walkthrough

The changes introduce an ep_rank (expert-partition rank) parameter throughout MoE routing components, affecting cache key structure for perfect-router logits, GPU/expert assignment scheduling, and related initialization across model files and tests.

Changes

Cohort / File(s) Summary
Model Configuration Updates
tensorrt_llm/_torch/models/modeling_deepseekv3.py, tensorrt_llm/_torch/models/modeling_gpt_oss.py
Added ep_rank=model_config.mapping.moe_ep_rank parameter to calls for precompute_common_perfect_router_logits(...), get_cached_perfect_router_logits(...), and create_moe(...) functions, enabling rank-aware MoE configuration.
Core Routing Module
tensorrt_llm/_torch/modules/fused_moe/routing.py
Extended cache key structure to include ep_rank dimension; added ep_rank parameter (default 0) to precompute_common_perfect_router_logits(...), get_cached_perfect_router_logits(...), and create_renormalize_expert_load_balanced_logits(...); modified GPU/expert assignment schedule by offsetting iteration with ep_rank value; added validation that ep_rank is within valid range.
Routing Tests
tests/unittest/_torch/modules/test_moe_routing.py
Updated existing test call to include ep_rank=0 parameter; added new parametrized test test_rank_aware_perfect_router_matches_bench_moe_comm_schedule that verifies rank-aware expert routing across multiple ep_rank values against expected assignment schedules.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 inconclusive)

Check name Status Explanation Resolution
Description check ❓ Inconclusive The PR description covers the main changes, testing approach, and notes. However, it lacks a clear 'Description' section explaining the issue and solution as required by the template. Expand the Description section to explain the problem being solved and the solution approach. Include context on why making the perfect router rank-aware is necessary.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically identifies the main change: making the perfect router rank-aware for expert parallelism (EP) in MoE, with performance focus indicated.
Docstring Coverage ✅ Passed Docstring coverage is 83.33% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/modules/fused_moe/routing.py`:
- Around line 767-770: The ValueError raise in routing.py (the check using
ep_rank and moe_ep_size) is misformatted for yapf; update the raise statement so
the f-string message is wrapped across lines according to the project's yapf
style (e.g., put the f-string on its own line inside the parentheses and close
the parenthesis on the following line) for the ep_rank/ moe_ep_size check, then
run yapf to confirm formatting; target the raise ValueError(...) that references
ep_rank and moe_ep_size.
- Around line 45-46: The function signature with parameters "dtype:
torch.dtype," and "ep_rank: int = 0" is misformatted for yapf; open the function
that contains these parameters (the routing function in fused_moe/routing.py)
and reformat the signature so parameters are each on their own properly indented
lines and the closing parenthesis aligns per yapf style (e.g., one parameter per
line with consistent indentation), then run yapf to confirm the file now passes
CI.

In `@tests/unittest/_torch/modules/test_moe_routing.py`:
- Around line 264-298: The test function
test_rank_aware_perfect_router_matches_bench_moe_comm_schedule has a YAPF
formatting failure around the RenormalizeMoeRoutingMethod(...) call; run yapf
(or your repository formatter) on this file and reformat the call to a single
properly indented expression (e.g., keep
RenormalizeMoeRoutingMethod(top_k=experts_per_token,
force_enable_pytorch_op=True) on one line or align the named args vertically) so
the file passes the style pipeline while preserving the existing arguments and
behavior of RenormalizeMoeRoutingMethod and the surrounding assertions.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: c6630ee2-1842-4be4-8ee9-f377df869e07

📥 Commits

Reviewing files that changed from the base of the PR and between 6428102 and f535da9.

📒 Files selected for processing (4)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
  • tensorrt_llm/_torch/models/modeling_gpt_oss.py
  • tensorrt_llm/_torch/modules/fused_moe/routing.py
  • tests/unittest/_torch/modules/test_moe_routing.py

Comment on lines +45 to +46
dtype: torch.dtype,
ep_rank: int = 0):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix yapf formatting to pass CI.

The pipeline failure indicates this function signature needs formatting adjustment. The dtype and ep_rank parameters should be on properly formatted lines per yapf rules.

🔧 Suggested fix
 def precompute_common_perfect_router_logits(num_experts: int,
                                             experts_per_token: int,
                                             moe_ep_size: int,
-                                            dtype: torch.dtype,
-                                            ep_rank: int = 0):
+                                            dtype: torch.dtype, ep_rank: int = 0):

Note: Run yapf locally to confirm the exact formatting required.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
dtype: torch.dtype,
ep_rank: int = 0):
def precompute_common_perfect_router_logits(num_experts: int,
experts_per_token: int,
moe_ep_size: int,
dtype: torch.dtype, ep_rank: int = 0):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/fused_moe/routing.py` around lines 45 - 46, The
function signature with parameters "dtype: torch.dtype," and "ep_rank: int = 0"
is misformatted for yapf; open the function that contains these parameters (the
routing function in fused_moe/routing.py) and reformat the signature so
parameters are each on their own properly indented lines and the closing
parenthesis aligns per yapf style (e.g., one parameter per line with consistent
indentation), then run yapf to confirm the file now passes CI.

Comment on lines +767 to +770
if not 0 <= ep_rank < moe_ep_size:
raise ValueError(
f"ep_rank ({ep_rank}) must be in [0, {moe_ep_size})")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix yapf formatting for CI compliance.

The pipeline failure indicates the ValueError f-string needs different wrapping. Run yapf to fix the formatting.

🔧 Suggested fix
     if not 0 <= ep_rank < moe_ep_size:
-        raise ValueError(
-            f"ep_rank ({ep_rank}) must be in [0, {moe_ep_size})")
+        raise ValueError(f"ep_rank ({ep_rank}) must be in [0, {moe_ep_size})")

Note: Run yapf locally to confirm the exact formatting required.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/fused_moe/routing.py` around lines 767 - 770, The
ValueError raise in routing.py (the check using ep_rank and moe_ep_size) is
misformatted for yapf; update the raise statement so the f-string message is
wrapped across lines according to the project's yapf style (e.g., put the
f-string on its own line inside the parentheses and close the parenthesis on the
following line) for the ep_rank/ moe_ep_size check, then run yapf to confirm
formatting; target the raise ValueError(...) that references ep_rank and
moe_ep_size.

Comment on lines +264 to +298
@pytest.mark.parametrize("num_tokens", [1, 3, 8])
@pytest.mark.parametrize("ep_rank", [0, 1, 3])
def test_rank_aware_perfect_router_matches_bench_moe_comm_schedule(
num_tokens: int, ep_rank: int) -> None:
"""Verify the rank-aware helper matches bench_moe_comm's perfect-router schedule."""
num_experts = 8
experts_per_token = 2
moe_ep_size = 4

logits = create_renormalize_expert_load_balanced_logits(
num_tokens=num_tokens,
num_experts=num_experts,
experts_per_token=experts_per_token,
moe_ep_size=moe_ep_size,
ep_rank=ep_rank,
device=torch.device("cpu"),
dtype=torch.float32)

routing = RenormalizeMoeRoutingMethod(
top_k=experts_per_token, force_enable_pytorch_op=True)
indices, _ = routing.apply(logits)

experts_per_rank = num_experts // moe_ep_size
flat_slots = torch.arange(num_tokens * experts_per_token, dtype=torch.int64)
schedule = flat_slots + ep_rank
expected = (
(schedule % moe_ep_size) * experts_per_rank +
(schedule // moe_ep_size) % experts_per_rank).view(
num_tokens, experts_per_token).to(torch.int32)

assert torch.equal(
torch.sort(indices.cpu(), dim=1).values,
torch.sort(expected, dim=1).values,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Good test coverage for rank-aware routing.

The new parametrized test correctly validates that:

  1. Logits generated with a specific ep_rank produce the expected expert assignments
  2. The schedule formula (schedule % moe_ep_size) * experts_per_rank + (schedule // moe_ep_size) % experts_per_rank matches the implementation in routing.py
  3. Order-insensitive comparison via torch.sort handles potential index ordering differences

Minor: Fix yapf formatting on line 279 per pipeline failure.

🔧 Run yapf to fix formatting

The pipeline failure indicates the RenormalizeMoeRoutingMethod call on lines 282-283 needs formatting adjustment. Run yapf locally to resolve.

🧰 Tools
🪛 GitHub Actions: Release Checks

[error] 279-279: Formatting check failed by pre-commit 'yapf' (file was modified): RenormalizeMoeRoutingMethod call formatting adjusted.


[error] 279-279: Formatting check failed by pre-commit 'yapf' (file was modified): expected expression wrapping/indentation adjusted.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/modules/test_moe_routing.py` around lines 264 - 298,
The test function test_rank_aware_perfect_router_matches_bench_moe_comm_schedule
has a YAPF formatting failure around the RenormalizeMoeRoutingMethod(...) call;
run yapf (or your repository formatter) on this file and reformat the call to a
single properly indented expression (e.g., keep
RenormalizeMoeRoutingMethod(top_k=experts_per_token,
force_enable_pytorch_op=True) on one line or align the named args vertically) so
the file passes the style pipeline while preserving the existing arguments and
behavior of RenormalizeMoeRoutingMethod and the surrounding assertions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant