diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 8b2da5404489..669fe9073b09 100755 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -991,7 +991,8 @@ def __init__(self, num_experts=num_experts, experts_per_token=top_k, moe_ep_size=model_config.mapping.moe_ep_size, - dtype=torch.float32) + dtype=torch.float32, + ep_rank=model_config.mapping.moe_ep_rank) def _compute_shared_expert_tp_size( self, intermediate_size: int, @@ -1052,6 +1053,7 @@ def _create_ideal_expert_load_balanced_logits( num_experts=num_experts, experts_per_token=self.top_k, moe_ep_size=self.model_config.mapping.moe_ep_size, + ep_rank=self.model_config.mapping.moe_ep_rank, device=device, dtype=torch.float32) diff --git a/tensorrt_llm/_torch/models/modeling_gpt_oss.py b/tensorrt_llm/_torch/models/modeling_gpt_oss.py index 4d46611a7fd8..8dc7d48cf65b 100644 --- a/tensorrt_llm/_torch/models/modeling_gpt_oss.py +++ b/tensorrt_llm/_torch/models/modeling_gpt_oss.py @@ -199,7 +199,8 @@ def __init__( num_experts=pretrained_config.num_local_experts, experts_per_token=pretrained_config.num_experts_per_tok, moe_ep_size=config.mapping.moe_ep_size, - dtype=pretrained_config.torch_dtype) + dtype=pretrained_config.torch_dtype, + ep_rank=config.mapping.moe_ep_rank) @staticmethod def swiglu(x, alpha: float = 1.702): @@ -227,6 +228,7 @@ def _create_ideal_expert_load_balanced_logits( num_experts=num_experts, experts_per_token=pretrained_config.experts_per_token, moe_ep_size=self.config.mapping.moe_ep_size, + ep_rank=self.config.mapping.moe_ep_rank, device=device, dtype=pretrained_config.torch_dtype) diff --git a/tensorrt_llm/_torch/modules/fused_moe/routing.py b/tensorrt_llm/_torch/modules/fused_moe/routing.py index 69498c96cfc3..7ede9231ef40 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/routing.py +++ b/tensorrt_llm/_torch/modules/fused_moe/routing.py @@ -27,8 +27,8 @@ def get_perfect_router_cache_stats(): total_memory = 0 cached_batch_sizes = [] - for (num_tokens, num_experts, experts_per_token, - moe_ep_size), logits in _PERFECT_ROUTER_LOGITS_CACHE.items(): + for cache_key, logits in _PERFECT_ROUTER_LOGITS_CACHE.items(): + num_tokens = cache_key[0] total_memory += logits.numel() * logits.element_size() cached_batch_sizes.append(num_tokens) @@ -42,7 +42,8 @@ def get_perfect_router_cache_stats(): def precompute_common_perfect_router_logits(num_experts: int, experts_per_token: int, moe_ep_size: int, - dtype: torch.dtype): + dtype: torch.dtype, + ep_rank: int = 0): """ Pre-compute logits for common batch sizes to avoid first-time computation overhead. Only precomputes if cache is empty (avoids redundant work across multiple MLPBlock instances). @@ -91,6 +92,7 @@ def precompute_common_perfect_router_logits(num_experts: int, num_experts=num_experts, experts_per_token=experts_per_token, moe_ep_size=moe_ep_size, + ep_rank=ep_rank, device=torch.device('cpu'), # Precompute on CPU dtype=dtype) @@ -110,6 +112,7 @@ def precompute_common_perfect_router_logits(num_experts: int, def get_cached_perfect_router_logits(num_tokens: int, num_experts: int, experts_per_token: int, moe_ep_size: int, + ep_rank: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: """ @@ -118,7 +121,8 @@ def get_cached_perfect_router_logits(num_tokens: int, num_experts: int, """ global _PERFECT_ROUTER_LOGITS_CACHE - cache_key = (num_tokens, num_experts, experts_per_token, moe_ep_size) + cache_key = (num_tokens, num_experts, experts_per_token, moe_ep_size, + ep_rank) if cache_key in _PERFECT_ROUTER_LOGITS_CACHE: # Return cached logits moved to the correct device @@ -135,6 +139,7 @@ def get_cached_perfect_router_logits(num_tokens: int, num_experts: int, num_experts=num_experts, experts_per_token=experts_per_token, moe_ep_size=moe_ep_size, + ep_rank=ep_rank, device=device, dtype=dtype) @@ -655,6 +660,7 @@ def create_renormalize_expert_load_balanced_logits( num_experts: int, experts_per_token: int, moe_ep_size: int, + ep_rank: int, device: torch.device, dtype: torch.dtype = torch.float32) -> torch.Tensor: """ @@ -670,6 +676,8 @@ def create_renormalize_expert_load_balanced_logits( The function creates routing logits that ensure perfect load balancing across GPUs by cycling through experts in a GPU-aware pattern. Each token is assigned to exactly k=experts_per_token experts, distributed evenly across all GPUs. + The schedule is offset by ``ep_rank`` so different sender ranks use the same + rank-aware pattern as ``tests/microbenchmarks/bench_moe_comm.py --perfect_router``. Strategy: 1. First cycle through one expert from each GPU (GPU representatives) @@ -732,6 +740,7 @@ def create_renormalize_expert_load_balanced_logits( num_experts: Total number of experts experts_per_token: Number of experts each token should be routed to (top-k) moe_ep_size: Number of GPUs for MoE expert parallelism + ep_rank: Sender EP rank used to stagger the routing schedule device: Device to create tensors on dtype: Data type for the logits tensor @@ -755,6 +764,10 @@ def create_renormalize_expert_load_balanced_logits( if moe_ep_size == 0: raise ValueError("moe_ep_size cannot be zero") + if not 0 <= ep_rank < moe_ep_size: + raise ValueError( + f"ep_rank ({ep_rank}) must be in [0, {moe_ep_size})") + # Create logits tensor on the same device and dtype as input # Shape: [num_tokens, num_experts] - will hold routing probabilities logits = torch.zeros(num_tokens, num_experts, device=device, dtype=dtype) @@ -773,12 +786,16 @@ def create_renormalize_expert_load_balanced_logits( # i_tensor: sequential indices from 0 to final_size-1 i_tensor = torch.arange(final_size, device=device) + # Match the bench_moe_comm perfect-router schedule by offsetting each sender + # rank before cycling over target EP ranks and local experts. + schedule = i_tensor + ep_rank + # gpu_idx: which GPU this assignment should go to (cycles through 0,1,2,3,0,1,2,3,...) - gpu_idx = i_tensor % num_gpus + gpu_idx = schedule % num_gpus # expert_offset: which expert within the GPU (0,0,0,0,1,1,1,1,2,2,2,2,...) # This ensures we use all experts from each GPU before moving to next expert - expert_offset = (i_tensor // num_gpus) % experts_per_gpu + expert_offset = (schedule // num_gpus) % experts_per_gpu # indices: actual expert indices by combining GPU base + offset indices = gpu_representatives[gpu_idx] + expert_offset diff --git a/tests/unittest/_torch/modules/test_moe_routing.py b/tests/unittest/_torch/modules/test_moe_routing.py index 405ef0299f93..8a8673fcf827 100644 --- a/tests/unittest/_torch/modules/test_moe_routing.py +++ b/tests/unittest/_torch/modules/test_moe_routing.py @@ -223,6 +223,7 @@ def test_renormalize_expert_load_balanced_logits(num_tokens, num_experts=num_experts, experts_per_token=experts_per_token, moe_ep_size=moe_ep_size, + ep_rank=0, device=device, dtype=torch.float32) @@ -260,5 +261,41 @@ def test_renormalize_expert_load_balanced_logits(num_tokens, expected_assignments), f"Load balance failed for {description}" +@pytest.mark.parametrize("num_tokens", [1, 3, 8]) +@pytest.mark.parametrize("ep_rank", [0, 1, 3]) +def test_rank_aware_perfect_router_matches_bench_moe_comm_schedule( + num_tokens: int, ep_rank: int) -> None: + """Verify the rank-aware helper matches bench_moe_comm's perfect-router schedule.""" + num_experts = 8 + experts_per_token = 2 + moe_ep_size = 4 + + logits = create_renormalize_expert_load_balanced_logits( + num_tokens=num_tokens, + num_experts=num_experts, + experts_per_token=experts_per_token, + moe_ep_size=moe_ep_size, + ep_rank=ep_rank, + device=torch.device("cpu"), + dtype=torch.float32) + + routing = RenormalizeMoeRoutingMethod( + top_k=experts_per_token, force_enable_pytorch_op=True) + indices, _ = routing.apply(logits) + + experts_per_rank = num_experts // moe_ep_size + flat_slots = torch.arange(num_tokens * experts_per_token, dtype=torch.int64) + schedule = flat_slots + ep_rank + expected = ( + (schedule % moe_ep_size) * experts_per_rank + + (schedule // moe_ep_size) % experts_per_rank).view( + num_tokens, experts_per_token).to(torch.int32) + + assert torch.equal( + torch.sort(indices.cpu(), dim=1).values, + torch.sort(expected, dim=1).values, + ) + + if __name__ == '__main__': pytest.main()