diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py b/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py index e37d5db10819..95deb58956d0 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py @@ -57,6 +57,7 @@ class NVLinkOneSided(Communication): # Single shared workspace/memory across the process _WORKSPACE: dict | None = None + _WORKSPACE_INIT_FAILED: bool = False # MetaInfo indices - initialized from C++ constants FLAG_VAL_OFFSET_INDEX = None @@ -168,6 +169,14 @@ def __init__( transfer (halves NVLink bandwidth usage, output precision is preserved). Corresponds to model_config.use_low_precision_moe_combine. """ + # Skip if workspace initialization previously failed to avoid repeated + # MnnvlMemory allocations that leak CUDA physical memory (held alive by + # exception traceback references), which can exhaust GPU memory. + if NVLinkOneSided._WORKSPACE_INIT_FAILED: + raise RuntimeError( + "NVLinkOneSided workspace initialization previously failed; skipping retry." + ) + super().__init__(mapping) if self.mapping.world_size != self.ep_size: @@ -229,15 +238,26 @@ def __init__( f"NVLinkOneSided: Allocating workspace with size {self.workspace_size_per_rank} bytes." f"ep_rank: {self.ep_rank}, ep_size: {self.ep_size}, top_k: {self.top_k}, max_num_tokens_per_rank: {self.max_num_tokens_per_rank}" ) - mnnvl_mem = MnnvlMemory(mapping, self.workspace_size_per_rank) - workspace = mnnvl_mem.as_torch_strided_tensor(torch.uint8) - metainfo = torch.ops.trtllm.moe_a2a_initialize( - workspace, - self.ep_rank, - self.ep_size, - self.max_num_tokens_per_rank, - self.eplb_stats_num_experts, - ) + mnnvl_mem = None + workspace = None + try: + mnnvl_mem = MnnvlMemory(mapping, self.workspace_size_per_rank) + workspace = mnnvl_mem.as_torch_strided_tensor(torch.uint8) + metainfo = torch.ops.trtllm.moe_a2a_initialize( + workspace, + self.ep_rank, + self.ep_size, + self.max_num_tokens_per_rank, + self.eplb_stats_num_experts, + ) + except Exception: + # Release CUDA physical memory immediately to prevent leak. + # Without explicit cleanup, MnnvlMemory objects stay alive + # (held by exception traceback references) until GC runs. + workspace = None + mnnvl_mem = None + NVLinkOneSided._WORKSPACE_INIT_FAILED = True + raise NVLinkOneSided._WORKSPACE = { "workspace_size_per_rank": self.workspace_size_per_rank, "max_num_tokens_per_rank": self.max_num_tokens_per_rank,