Skip to content

[https://nvbugs/6084764][fix] Cache NVLinkOneSided init failure to prevent OOM from repeated MnnvlMemory alloc#13235

Closed
tensorrt-cicd wants to merge 1 commit intoNVIDIA:mainfrom
tensorrt-cicd:repair-bot-bug6084764
Closed

[https://nvbugs/6084764][fix] Cache NVLinkOneSided init failure to prevent OOM from repeated MnnvlMemory alloc#13235
tensorrt-cicd wants to merge 1 commit intoNVIDIA:mainfrom
tensorrt-cicd:repair-bot-bug6084764

Conversation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

@tensorrt-cicd tensorrt-cicd commented Apr 20, 2026

Summary

  • Fix for NVBugs 6084764: [TensorRT-LLM][L0][Pre-Merge][main] torch.OutOfMemoryError: CUDA out of memory
  • Root cause: Cache NVLinkOneSided init failure to prevent OOM from repeated MnnvlMemory allocations
  • Fix: (auto-detected from git commit)
  • Automated fix generated by repair-bot

Test plan

  • Verify fix on the same GPU type as the original failure
  • Check for regressions in related tests

Links

Summary by CodeRabbit

  • Bug Fixes
    • Improved initialization error handling in the communication module to prevent repeated allocation attempts after a prior failure.

…M from repeated MnnvlMemory allocations

When NVLinkOneSided workspace initialization fails (e.g. moe_a2a_initialize
raises RuntimeError), the _WORKSPACE class variable stays None. On subsequent
MoE layers (58 for DeepSeek-R1), the factory retries NVLinkOneSided each time,
allocating ~1.76GB of CUDA physical memory via cuMemCreate per attempt. These
MnnvlMemory objects are held alive by exception traceback references and not
freed until GC runs, accumulating ~102GB of leaked GPU memory. Combined with
~97GB model weights on B200 (192GB), this leaves almost no room for KV cache
(only 5.39 GiB instead of ~80GB), causing OOM.

Fix: Add _WORKSPACE_INIT_FAILED class flag to skip repeated initialization
attempts after the first failure. Also explicitly release MnnvlMemory in the
except block to prevent memory leaks from traceback references.

Signed-off-by: tensorrt-cicd <90828364+tensorrt-cicd@users.noreply.github.com>
@tensorrt-cicd tensorrt-cicd requested a review from a team as a code owner April 20, 2026 17:52
@tensorrt-cicd tensorrt-cicd requested a review from yuxianq April 20, 2026 17:52
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 20, 2026

📝 Walkthrough

Walkthrough

This change adds defensive error handling to the NVLinkOneSided class by introducing a class-level flag to track workspace initialization failures. When initialization fails, the flag prevents subsequent instances from attempting allocation, instead raising an error immediately without resource waste.

Changes

Cohort / File(s) Summary
Workspace Initialization Error Handling
tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py
Added _WORKSPACE_INIT_FAILED class flag to gate initialization attempts. Wrapped workspace allocation (MnnvlMemory, tensor conversion, and torch.ops.trtllm.moe_a2a_initialize) in try/except to set flag and prevent re-allocation after prior failure. Successful initialization path unchanged.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: caching NVLinkOneSided initialization failure to prevent OOM from repeated MnnvlMemory allocations, which aligns perfectly with the changeset.
Description check ✅ Passed The description provides a clear summary of the issue, root cause, fix approach, and test plan. While it omits the detailed checklist items from the template, the essential information is present and well-documented.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py (1)

1-1: ⚠️ Potential issue | 🟡 Minor

Update the NVIDIA copyright year for this modified file.

This file was meaningfully modified in 2026, but the header still shows 2025.

As per coding guidelines: "Add NVIDIA copyright header on ALL new files and update year on modified files" and "All TensorRT-LLM source files must contain an NVIDIA copyright header with the year of latest meaningful modification".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py` at
line 1, The file header year is outdated (shows 2025) and must be updated to
reflect the 2026 modification; locate the file-level SPDX/copyright header at
the top of nvlink_one_sided.py and change the year to 2026 so the header reads
the correct latest modification year (e.g., update the SPDX-FileCopyrightText
line in nvlink_one_sided.py).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py`:
- Around line 253-260: The except block in NVLinkOneSided that currently uses a
broad "except Exception" should be narrowed to "except RuntimeError" so only
expected MNNVL initialization failures (from MnnvlMemory,
as_torch_strided_tensor, torch.ops.trtllm.moe_a2a_initialize) trigger setting
NVLinkOneSided._WORKSPACE_INIT_FAILED and cleanup of workspace and mnnvl_mem;
change the handler to "except RuntimeError:" keep the workspace = None and
mnnvl_mem = None cleanup and re-raise the error, ensuring other unexpected
exceptions are not swallowed and do not permanently disable NVLinkOneSided.

---

Outside diff comments:
In `@tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py`:
- Line 1: The file header year is outdated (shows 2025) and must be updated to
reflect the 2026 modification; locate the file-level SPDX/copyright header at
the top of nvlink_one_sided.py and change the year to 2026 so the header reads
the correct latest modification year (e.g., update the SPDX-FileCopyrightText
line in nvlink_one_sided.py).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 96f64544-581e-4f60-9a6a-0d885d11a983

📥 Commits

Reviewing files that changed from the base of the PR and between 7b84136 and bfd18ff.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py

Comment on lines +253 to +260
except Exception:
# Release CUDA physical memory immediately to prevent leak.
# Without explicit cleanup, MnnvlMemory objects stay alive
# (held by exception traceback references) until GC runs.
workspace = None
mnnvl_mem = None
NVLinkOneSided._WORKSPACE_INIT_FAILED = True
raise
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat -n tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py | sed -n '240,275p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 2009


🏁 Script executed:

find . -name "MOE_DEVELOPER_GUIDE.md" -type f

Repository: NVIDIA/TensorRT-LLM

Length of output: 125


🏁 Script executed:

cat tensorrt_llm/_torch/modules/fused_moe/MOE_DEVELOPER_GUIDE.md

Repository: NVIDIA/TensorRT-LLM

Length of output: 8060


🏁 Script executed:

rg "class MnnvlMemory" -A 30 tensorrt_llm/

Repository: NVIDIA/TensorRT-LLM

Length of output: 2183


🏁 Script executed:

rg "class MnnvlMemory" -A 100 tensorrt_llm/_mnnvl_utils.py | head -150

Repository: NVIDIA/TensorRT-LLM

Length of output: 4017


🏁 Script executed:

rg "def open_mnnvl_memory" -A 50 tensorrt_llm/

Repository: NVIDIA/TensorRT-LLM

Length of output: 3845


🏁 Script executed:

rg "def open_mnnvl_memory" -A 100 tensorrt_llm/_mnnvl_utils.py | head -150

Repository: NVIDIA/TensorRT-LLM

Length of output: 4627


🏁 Script executed:

rg "_WORKSPACE_INIT_FAILED" -B 3 -A 3 tensorrt_llm/_torch/modules/fused_moe/communication/

Repository: NVIDIA/TensorRT-LLM

Length of output: 2562


🏁 Script executed:

rg "pack_strided_memory" -B 5 -A 20 tensorrt_llm/

Repository: NVIDIA/TensorRT-LLM

Length of output: 5061


🏁 Script executed:

cat -n tensorrt_llm/_dlpack_utils.py | sed -n '1,100p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 3978


🏁 Script executed:

cat -n tensorrt_llm/_dlpack_utils.py | sed -n '58,120p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 2927


🏁 Script executed:

rg "def pack_strided_memory" -A 30 tensorrt_llm/_dlpack_utils.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1143


🏁 Script executed:

rg "moe_a2a_initialize" -B 3 -A 3 tensorrt_llm/

Repository: NVIDIA/TensorRT-LLM

Length of output: 2184


🏁 Script executed:

rg "except.*:" -B 3 -A 10 tensorrt_llm/_torch/distributed/moe_alltoall.py | head -80

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

cat -n tensorrt_llm/_torch/distributed/moe_alltoall.py | sed -n '1,150p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 7549


🏁 Script executed:

cat -n tensorrt_llm/_torch/distributed/moe_alltoall.py | sed -n '150,230p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 4528


Narrow exception handling to only catch expected initialization failures.

except Exception is too broad and can permanently disable NVLinkOneSided for unrelated errors. Since MNNVL initialization operations (MnnvlMemory, as_torch_strided_tensor, torch.ops.trtllm.moe_a2a_initialize) are expected to raise RuntimeError on initialization failures, catch RuntimeError specifically. Setting the class-level _WORKSPACE_INIT_FAILED flag prevents any future initialization attempts, so it should only be set for actual initialization errors, not incidental exceptions.

Proposed change
-            except Exception:
+            except RuntimeError:
                 # Release CUDA physical memory immediately to prevent leak.
                 # Without explicit cleanup, MnnvlMemory objects stay alive
                 # (held by exception traceback references) until GC runs.
                 workspace = None
                 mnnvl_mem = None
                 NVLinkOneSided._WORKSPACE_INIT_FAILED = True
                 raise
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
except Exception:
# Release CUDA physical memory immediately to prevent leak.
# Without explicit cleanup, MnnvlMemory objects stay alive
# (held by exception traceback references) until GC runs.
workspace = None
mnnvl_mem = None
NVLinkOneSided._WORKSPACE_INIT_FAILED = True
raise
except RuntimeError:
# Release CUDA physical memory immediately to prevent leak.
# Without explicit cleanup, MnnvlMemory objects stay alive
# (held by exception traceback references) until GC runs.
workspace = None
mnnvl_mem = None
NVLinkOneSided._WORKSPACE_INIT_FAILED = True
raise
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py`
around lines 253 - 260, The except block in NVLinkOneSided that currently uses a
broad "except Exception" should be narrowed to "except RuntimeError" so only
expected MNNVL initialization failures (from MnnvlMemory,
as_torch_strided_tensor, torch.ops.trtllm.moe_a2a_initialize) trigger setting
NVLinkOneSided._WORKSPACE_INIT_FAILED and cleanup of workspace and mnnvl_mem;
change the handler to "except RuntimeError:" keep the workspace = None and
mnnvl_mem = None cleanup and re-raise the error, ensuring other unexpected
exceptions are not swallowed and do not permanently disable NVLinkOneSided.


# Single shared workspace/memory across the process
_WORKSPACE: dict | None = None
_WORKSPACE_INIT_FAILED: bool = False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This solution is duplicate of #13172, we may close this PR instead. We can unwaive 6084764 in #13172.

ziyixiong-nv added a commit to ziyixiong-nv/TensorRT-LLM that referenced this pull request Apr 22, 2026
- Update copyright year to 2025-2026
- Mark _WORKSPACE and _WORKSPACE_INIT_FAILED as ClassVar
- Move _WORKSPACE_INIT_FAILED check before MnnvlMemory.initialize()
- Release workspace/mnnvl_mem on failure to prevent CUDA memory leak
- Broaden exception handling to match PR NVIDIA#13235 pattern

Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
@bobboli bobboli closed this Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants