-
Notifications
You must be signed in to change notification settings - Fork 2.3k
[None][perf] make perfect router rank-aware for EP MoE #13175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,8 +27,8 @@ def get_perfect_router_cache_stats(): | |
| total_memory = 0 | ||
| cached_batch_sizes = [] | ||
|
|
||
| for (num_tokens, num_experts, experts_per_token, | ||
| moe_ep_size), logits in _PERFECT_ROUTER_LOGITS_CACHE.items(): | ||
| for cache_key, logits in _PERFECT_ROUTER_LOGITS_CACHE.items(): | ||
| num_tokens = cache_key[0] | ||
| total_memory += logits.numel() * logits.element_size() | ||
| cached_batch_sizes.append(num_tokens) | ||
|
|
||
|
|
@@ -42,7 +42,8 @@ def get_perfect_router_cache_stats(): | |
| def precompute_common_perfect_router_logits(num_experts: int, | ||
| experts_per_token: int, | ||
| moe_ep_size: int, | ||
| dtype: torch.dtype): | ||
| dtype: torch.dtype, | ||
| ep_rank: int = 0): | ||
| """ | ||
| Pre-compute logits for common batch sizes to avoid first-time computation overhead. | ||
| Only precomputes if cache is empty (avoids redundant work across multiple MLPBlock instances). | ||
|
|
@@ -91,6 +92,7 @@ def precompute_common_perfect_router_logits(num_experts: int, | |
| num_experts=num_experts, | ||
| experts_per_token=experts_per_token, | ||
| moe_ep_size=moe_ep_size, | ||
| ep_rank=ep_rank, | ||
| device=torch.device('cpu'), # Precompute on CPU | ||
| dtype=dtype) | ||
|
|
||
|
|
@@ -110,6 +112,7 @@ def precompute_common_perfect_router_logits(num_experts: int, | |
|
|
||
| def get_cached_perfect_router_logits(num_tokens: int, num_experts: int, | ||
| experts_per_token: int, moe_ep_size: int, | ||
| ep_rank: int, | ||
| device: torch.device, | ||
| dtype: torch.dtype) -> torch.Tensor: | ||
| """ | ||
|
|
@@ -118,7 +121,8 @@ def get_cached_perfect_router_logits(num_tokens: int, num_experts: int, | |
| """ | ||
| global _PERFECT_ROUTER_LOGITS_CACHE | ||
|
|
||
| cache_key = (num_tokens, num_experts, experts_per_token, moe_ep_size) | ||
| cache_key = (num_tokens, num_experts, experts_per_token, moe_ep_size, | ||
| ep_rank) | ||
|
|
||
| if cache_key in _PERFECT_ROUTER_LOGITS_CACHE: | ||
| # Return cached logits moved to the correct device | ||
|
|
@@ -135,6 +139,7 @@ def get_cached_perfect_router_logits(num_tokens: int, num_experts: int, | |
| num_experts=num_experts, | ||
| experts_per_token=experts_per_token, | ||
| moe_ep_size=moe_ep_size, | ||
| ep_rank=ep_rank, | ||
| device=device, | ||
| dtype=dtype) | ||
|
|
||
|
|
@@ -655,6 +660,7 @@ def create_renormalize_expert_load_balanced_logits( | |
| num_experts: int, | ||
| experts_per_token: int, | ||
| moe_ep_size: int, | ||
| ep_rank: int, | ||
| device: torch.device, | ||
| dtype: torch.dtype = torch.float32) -> torch.Tensor: | ||
| """ | ||
|
|
@@ -670,6 +676,8 @@ def create_renormalize_expert_load_balanced_logits( | |
| The function creates routing logits that ensure perfect load balancing across GPUs | ||
| by cycling through experts in a GPU-aware pattern. Each token is assigned to | ||
| exactly k=experts_per_token experts, distributed evenly across all GPUs. | ||
| The schedule is offset by ``ep_rank`` so different sender ranks use the same | ||
| rank-aware pattern as ``tests/microbenchmarks/bench_moe_comm.py --perfect_router``. | ||
|
|
||
| Strategy: | ||
| 1. First cycle through one expert from each GPU (GPU representatives) | ||
|
|
@@ -732,6 +740,7 @@ def create_renormalize_expert_load_balanced_logits( | |
| num_experts: Total number of experts | ||
| experts_per_token: Number of experts each token should be routed to (top-k) | ||
| moe_ep_size: Number of GPUs for MoE expert parallelism | ||
| ep_rank: Sender EP rank used to stagger the routing schedule | ||
| device: Device to create tensors on | ||
| dtype: Data type for the logits tensor | ||
|
|
||
|
|
@@ -755,6 +764,10 @@ def create_renormalize_expert_load_balanced_logits( | |
| if moe_ep_size == 0: | ||
| raise ValueError("moe_ep_size cannot be zero") | ||
|
|
||
| if not 0 <= ep_rank < moe_ep_size: | ||
| raise ValueError( | ||
| f"ep_rank ({ep_rank}) must be in [0, {moe_ep_size})") | ||
|
|
||
|
Comment on lines
+767
to
+770
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix yapf formatting for CI compliance. The pipeline failure indicates the 🔧 Suggested fix if not 0 <= ep_rank < moe_ep_size:
- raise ValueError(
- f"ep_rank ({ep_rank}) must be in [0, {moe_ep_size})")
+ raise ValueError(f"ep_rank ({ep_rank}) must be in [0, {moe_ep_size})")Note: Run 🤖 Prompt for AI Agents |
||
| # Create logits tensor on the same device and dtype as input | ||
| # Shape: [num_tokens, num_experts] - will hold routing probabilities | ||
| logits = torch.zeros(num_tokens, num_experts, device=device, dtype=dtype) | ||
|
|
@@ -773,12 +786,16 @@ def create_renormalize_expert_load_balanced_logits( | |
| # i_tensor: sequential indices from 0 to final_size-1 | ||
| i_tensor = torch.arange(final_size, device=device) | ||
|
|
||
| # Match the bench_moe_comm perfect-router schedule by offsetting each sender | ||
| # rank before cycling over target EP ranks and local experts. | ||
| schedule = i_tensor + ep_rank | ||
|
|
||
| # gpu_idx: which GPU this assignment should go to (cycles through 0,1,2,3,0,1,2,3,...) | ||
| gpu_idx = i_tensor % num_gpus | ||
| gpu_idx = schedule % num_gpus | ||
|
|
||
| # expert_offset: which expert within the GPU (0,0,0,0,1,1,1,1,2,2,2,2,...) | ||
| # This ensures we use all experts from each GPU before moving to next expert | ||
| expert_offset = (i_tensor // num_gpus) % experts_per_gpu | ||
| expert_offset = (schedule // num_gpus) % experts_per_gpu | ||
|
|
||
| # indices: actual expert indices by combining GPU base + offset | ||
| indices = gpu_representatives[gpu_idx] + expert_offset | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -223,6 +223,7 @@ def test_renormalize_expert_load_balanced_logits(num_tokens, | |
| num_experts=num_experts, | ||
| experts_per_token=experts_per_token, | ||
| moe_ep_size=moe_ep_size, | ||
| ep_rank=0, | ||
| device=device, | ||
| dtype=torch.float32) | ||
|
|
||
|
|
@@ -260,5 +261,41 @@ def test_renormalize_expert_load_balanced_logits(num_tokens, | |
| expected_assignments), f"Load balance failed for {description}" | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("num_tokens", [1, 3, 8]) | ||
| @pytest.mark.parametrize("ep_rank", [0, 1, 3]) | ||
| def test_rank_aware_perfect_router_matches_bench_moe_comm_schedule( | ||
| num_tokens: int, ep_rank: int) -> None: | ||
| """Verify the rank-aware helper matches bench_moe_comm's perfect-router schedule.""" | ||
| num_experts = 8 | ||
| experts_per_token = 2 | ||
| moe_ep_size = 4 | ||
|
|
||
| logits = create_renormalize_expert_load_balanced_logits( | ||
| num_tokens=num_tokens, | ||
| num_experts=num_experts, | ||
| experts_per_token=experts_per_token, | ||
| moe_ep_size=moe_ep_size, | ||
| ep_rank=ep_rank, | ||
| device=torch.device("cpu"), | ||
| dtype=torch.float32) | ||
|
|
||
| routing = RenormalizeMoeRoutingMethod( | ||
| top_k=experts_per_token, force_enable_pytorch_op=True) | ||
| indices, _ = routing.apply(logits) | ||
|
|
||
| experts_per_rank = num_experts // moe_ep_size | ||
| flat_slots = torch.arange(num_tokens * experts_per_token, dtype=torch.int64) | ||
| schedule = flat_slots + ep_rank | ||
| expected = ( | ||
| (schedule % moe_ep_size) * experts_per_rank + | ||
| (schedule // moe_ep_size) % experts_per_rank).view( | ||
| num_tokens, experts_per_token).to(torch.int32) | ||
|
|
||
| assert torch.equal( | ||
| torch.sort(indices.cpu(), dim=1).values, | ||
| torch.sort(expected, dim=1).values, | ||
| ) | ||
|
|
||
|
Comment on lines
+264
to
+298
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good test coverage for rank-aware routing. The new parametrized test correctly validates that:
Minor: Fix yapf formatting on line 279 per pipeline failure. 🔧 Run yapf to fix formattingThe pipeline failure indicates the 🧰 Tools🪛 GitHub Actions: Release Checks[error] 279-279: Formatting check failed by pre-commit 'yapf' (file was modified): RenormalizeMoeRoutingMethod call formatting adjusted. [error] 279-279: Formatting check failed by pre-commit 'yapf' (file was modified): expected expression wrapping/indentation adjusted. 🤖 Prompt for AI Agents |
||
|
|
||
| if __name__ == '__main__': | ||
| pytest.main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix yapf formatting to pass CI.
The pipeline failure indicates this function signature needs formatting adjustment. The
dtypeandep_rankparameters should be on properly formatted lines per yapf rules.🔧 Suggested fix
def precompute_common_perfect_router_logits(num_experts: int, experts_per_token: int, moe_ep_size: int, - dtype: torch.dtype, - ep_rank: int = 0): + dtype: torch.dtype, ep_rank: int = 0):Note: Run
yapflocally to confirm the exact formatting required.📝 Committable suggestion
🤖 Prompt for AI Agents