From ef935a124e357c553279e5268c1bdbbfc4b955e4 Mon Sep 17 00:00:00 2001 From: Ziyi Xiong <219238287+ziyixiong-nv@users.noreply.github.com> Date: Fri, 17 Apr 2026 17:13:00 -0700 Subject: [PATCH 1/2] [nvbugs/5859886][fix] Skip DeepEP when NVLink symmetric memory init fails When NVLinkOneSided workspace initialization fails (MNNVL allocation or NVSHMEM moe_a2a_initialize), cache the failure and propagate it to skip DeepEP/DeepEPLowLatency strategies in CommunicationFactory. DeepEP also relies on NVSHMEM internally and would hang during forward pass if the NVLink symmetric memory infrastructure is unavailable. Changes: - NVLinkOneSided: wrap workspace init (MnnvlMemory + moe_a2a_initialize) in try-except, set _WORKSPACE_INIT_FAILED on failure to avoid repeated attempts across MoE layers and signal the factory. - CommunicationFactory: check _WORKSPACE_INIT_FAILED before trying DeepEP/DeepEPLowLatency; fall through to AllGatherReduceScatter (NCCL). - Remove test waiver for test_fp8_blockscale[disable_skip_indexer]. Signed-off-by: Ziyi Xiong <219238287+ziyixiong-nv@users.noreply.github.com> --- .../communication/communication_factory.py | 11 ++++++- .../communication/nvlink_one_sided.py | 33 ++++++++++++++----- tests/integration/test_lists/waives.txt | 1 - 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py b/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py index cbcf0502ae93..195622f77fb8 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py @@ -174,7 +174,16 @@ def create_strategy( logger.debug(f"NVLinkTwoSided not available: {e}") # Try DeepEP (if enabled and weight dtype is bfloat16) - if os.environ.get("TRTLLM_CAN_USE_DEEP_EP", "1") == "1" and act_dtype == torch.bfloat16: + # Skip DeepEP/DeepEPLowLatency if NVLink symmetric memory init is known to + # be broken (detected by NVLinkOneSided workspace init failure). DeepEP also + # relies on NVSHMEM/symmetric memory internally, so it would hang during + # forward pass if the NVLink memory infrastructure is unavailable. + if NVLinkOneSided._WORKSPACE_INIT_FAILED: + logger.info( + "Skipping DeepEP/DeepEPLowLatency: NVLink symmetric memory " + "initialization previously failed (detected via NVLinkOneSided)." + ) + elif os.environ.get("TRTLLM_CAN_USE_DEEP_EP", "1") == "1" and act_dtype == torch.bfloat16: try: strategy = DeepEP( mapping, diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py b/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py index e37d5db10819..4488f10af759 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py @@ -57,6 +57,11 @@ class NVLinkOneSided(Communication): # Single shared workspace/memory across the process _WORKSPACE: dict | None = None + # Track if workspace initialization (MNNVL + NVSHMEM) has failed, to avoid + # repeated attempts and to signal other NVSHMEM-dependent strategies (e.g. + # DeepEP) to skip initialization — they share the same NVLink/symmetric + # memory infrastructure and will also fail or hang. + _WORKSPACE_INIT_FAILED: bool = False # MetaInfo indices - initialized from C++ constants FLAG_VAL_OFFSET_INDEX = None @@ -224,20 +229,30 @@ def __init__( # Initialize or reuse workspace MnnvlMemory.initialize() + if self._WORKSPACE_INIT_FAILED: + raise RuntimeError( + "NVLinkOneSided: workspace initialization (MNNVL/NVSHMEM) previously " + "failed on this node, skipping repeated initialization attempt." + ) + if self._WORKSPACE is None: tllm_logger.info( f"NVLinkOneSided: Allocating workspace with size {self.workspace_size_per_rank} bytes." f"ep_rank: {self.ep_rank}, ep_size: {self.ep_size}, top_k: {self.top_k}, max_num_tokens_per_rank: {self.max_num_tokens_per_rank}" ) - mnnvl_mem = MnnvlMemory(mapping, self.workspace_size_per_rank) - workspace = mnnvl_mem.as_torch_strided_tensor(torch.uint8) - metainfo = torch.ops.trtllm.moe_a2a_initialize( - workspace, - self.ep_rank, - self.ep_size, - self.max_num_tokens_per_rank, - self.eplb_stats_num_experts, - ) + try: + mnnvl_mem = MnnvlMemory(mapping, self.workspace_size_per_rank) + workspace = mnnvl_mem.as_torch_strided_tensor(torch.uint8) + metainfo = torch.ops.trtllm.moe_a2a_initialize( + workspace, + self.ep_rank, + self.ep_size, + self.max_num_tokens_per_rank, + self.eplb_stats_num_experts, + ) + except (RuntimeError, AssertionError) as e: + NVLinkOneSided._WORKSPACE_INIT_FAILED = True + raise RuntimeError(f"NVLinkOneSided workspace initialization failed: {e}") from e NVLinkOneSided._WORKSPACE = { "workspace_size_per_rank": self.workspace_size_per_rank, "max_num_tokens_per_rank": self.max_num_tokens_per_rank, diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 9c19eb7d302d..e4d1251da901 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -222,7 +222,6 @@ accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-cutl accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-cutlass-auto] SKIP (https://nvbugs/5838211) full:A10/unittest/kv_cache_manager_v2_tests/ SKIP (https://nvbugs/5841954) examples/test_mistral.py::test_mistral_with_bf16_lora_torch[mistral-7b-v0.1] SKIP (https://nvbugs/5846178) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[disable_skip_indexer] SKIP (https://nvbugs/5859886) accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-tp4-cutlass-fp8] SKIP (https://nvbugs/5651865) test_e2e.py::test_trtllm_multimodal_benchmark_serving SKIP (https://nvbugs/5864769) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=vanilla-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5879577) From dd0769e158bf9ee7dbe45aa954e2b2e6ca635b11 Mon Sep 17 00:00:00 2001 From: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com> Date: Tue, 21 Apr 2026 20:05:10 -0700 Subject: [PATCH 2/2] [nvbugs/5859886][fix] Address PR review comments - Update copyright year to 2025-2026 - Mark _WORKSPACE and _WORKSPACE_INIT_FAILED as ClassVar - Move _WORKSPACE_INIT_FAILED check before MnnvlMemory.initialize() - Release workspace/mnnvl_mem on failure to prevent CUDA memory leak - Broaden exception handling to match PR #13235 pattern Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com> --- .../communication/nvlink_one_sided.py | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py b/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py index 4488f10af759..8e5531ca6018 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,7 +25,7 @@ """ import os -from typing import List, Optional, Tuple +from typing import ClassVar, List, Optional, Tuple import torch @@ -56,12 +56,12 @@ class NVLinkOneSided(Communication): MAX_PAYLOADS = 8 # Single shared workspace/memory across the process - _WORKSPACE: dict | None = None + _WORKSPACE: ClassVar[dict | None] = None # Track if workspace initialization (MNNVL + NVSHMEM) has failed, to avoid # repeated attempts and to signal other NVSHMEM-dependent strategies (e.g. # DeepEP) to skip initialization — they share the same NVLink/symmetric # memory infrastructure and will also fail or hang. - _WORKSPACE_INIT_FAILED: bool = False + _WORKSPACE_INIT_FAILED: ClassVar[bool] = False # MetaInfo indices - initialized from C++ constants FLAG_VAL_OFFSET_INDEX = None @@ -226,20 +226,22 @@ def __init__( ) self.workspace_size_per_rank = 2048 * 1024 * 1024 - # Initialize or reuse workspace - MnnvlMemory.initialize() - if self._WORKSPACE_INIT_FAILED: raise RuntimeError( "NVLinkOneSided: workspace initialization (MNNVL/NVSHMEM) previously " "failed on this node, skipping repeated initialization attempt." ) + # Initialize or reuse workspace + MnnvlMemory.initialize() + if self._WORKSPACE is None: tllm_logger.info( f"NVLinkOneSided: Allocating workspace with size {self.workspace_size_per_rank} bytes." f"ep_rank: {self.ep_rank}, ep_size: {self.ep_size}, top_k: {self.top_k}, max_num_tokens_per_rank: {self.max_num_tokens_per_rank}" ) + mnnvl_mem = None + workspace = None try: mnnvl_mem = MnnvlMemory(mapping, self.workspace_size_per_rank) workspace = mnnvl_mem.as_torch_strided_tensor(torch.uint8) @@ -250,9 +252,14 @@ def __init__( self.max_num_tokens_per_rank, self.eplb_stats_num_experts, ) - except (RuntimeError, AssertionError) as e: + except Exception: + # Release CUDA physical memory immediately to prevent leak. + # Without explicit cleanup, MnnvlMemory objects stay alive + # (held by exception traceback references) until GC runs. + workspace = None + mnnvl_mem = None NVLinkOneSided._WORKSPACE_INIT_FAILED = True - raise RuntimeError(f"NVLinkOneSided workspace initialization failed: {e}") from e + raise NVLinkOneSided._WORKSPACE = { "workspace_size_per_rank": self.workspace_size_per_rank, "max_num_tokens_per_rank": self.max_num_tokens_per_rank,