diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py b/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py index cbcf0502ae93..195622f77fb8 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py @@ -174,7 +174,16 @@ def create_strategy( logger.debug(f"NVLinkTwoSided not available: {e}") # Try DeepEP (if enabled and weight dtype is bfloat16) - if os.environ.get("TRTLLM_CAN_USE_DEEP_EP", "1") == "1" and act_dtype == torch.bfloat16: + # Skip DeepEP/DeepEPLowLatency if NVLink symmetric memory init is known to + # be broken (detected by NVLinkOneSided workspace init failure). DeepEP also + # relies on NVSHMEM/symmetric memory internally, so it would hang during + # forward pass if the NVLink memory infrastructure is unavailable. + if NVLinkOneSided._WORKSPACE_INIT_FAILED: + logger.info( + "Skipping DeepEP/DeepEPLowLatency: NVLink symmetric memory " + "initialization previously failed (detected via NVLinkOneSided)." + ) + elif os.environ.get("TRTLLM_CAN_USE_DEEP_EP", "1") == "1" and act_dtype == torch.bfloat16: try: strategy = DeepEP( mapping, diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py b/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py index e37d5db10819..4488f10af759 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py @@ -57,6 +57,11 @@ class NVLinkOneSided(Communication): # Single shared workspace/memory across the process _WORKSPACE: dict | None = None + # Track if workspace initialization (MNNVL + NVSHMEM) has failed, to avoid + # repeated attempts and to signal other NVSHMEM-dependent strategies (e.g. + # DeepEP) to skip initialization — they share the same NVLink/symmetric + # memory infrastructure and will also fail or hang. + _WORKSPACE_INIT_FAILED: bool = False # MetaInfo indices - initialized from C++ constants FLAG_VAL_OFFSET_INDEX = None @@ -224,20 +229,30 @@ def __init__( # Initialize or reuse workspace MnnvlMemory.initialize() + if self._WORKSPACE_INIT_FAILED: + raise RuntimeError( + "NVLinkOneSided: workspace initialization (MNNVL/NVSHMEM) previously " + "failed on this node, skipping repeated initialization attempt." + ) + if self._WORKSPACE is None: tllm_logger.info( f"NVLinkOneSided: Allocating workspace with size {self.workspace_size_per_rank} bytes." f"ep_rank: {self.ep_rank}, ep_size: {self.ep_size}, top_k: {self.top_k}, max_num_tokens_per_rank: {self.max_num_tokens_per_rank}" ) - mnnvl_mem = MnnvlMemory(mapping, self.workspace_size_per_rank) - workspace = mnnvl_mem.as_torch_strided_tensor(torch.uint8) - metainfo = torch.ops.trtllm.moe_a2a_initialize( - workspace, - self.ep_rank, - self.ep_size, - self.max_num_tokens_per_rank, - self.eplb_stats_num_experts, - ) + try: + mnnvl_mem = MnnvlMemory(mapping, self.workspace_size_per_rank) + workspace = mnnvl_mem.as_torch_strided_tensor(torch.uint8) + metainfo = torch.ops.trtllm.moe_a2a_initialize( + workspace, + self.ep_rank, + self.ep_size, + self.max_num_tokens_per_rank, + self.eplb_stats_num_experts, + ) + except (RuntimeError, AssertionError) as e: + NVLinkOneSided._WORKSPACE_INIT_FAILED = True + raise RuntimeError(f"NVLinkOneSided workspace initialization failed: {e}") from e NVLinkOneSided._WORKSPACE = { "workspace_size_per_rank": self.workspace_size_per_rank, "max_num_tokens_per_rank": self.max_num_tokens_per_rank, diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 9c19eb7d302d..e4d1251da901 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -222,7 +222,6 @@ accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-cutl accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-cutlass-auto] SKIP (https://nvbugs/5838211) full:A10/unittest/kv_cache_manager_v2_tests/ SKIP (https://nvbugs/5841954) examples/test_mistral.py::test_mistral_with_bf16_lora_torch[mistral-7b-v0.1] SKIP (https://nvbugs/5846178) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[disable_skip_indexer] SKIP (https://nvbugs/5859886) accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-tp4-cutlass-fp8] SKIP (https://nvbugs/5651865) test_e2e.py::test_trtllm_multimodal_benchmark_serving SKIP (https://nvbugs/5864769) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=vanilla-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5879577)