Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,16 @@ def create_strategy(
logger.debug(f"NVLinkTwoSided not available: {e}")

# Try DeepEP (if enabled and weight dtype is bfloat16)
if os.environ.get("TRTLLM_CAN_USE_DEEP_EP", "1") == "1" and act_dtype == torch.bfloat16:
# Skip DeepEP/DeepEPLowLatency if NVLink symmetric memory init is known to
# be broken (detected by NVLinkOneSided workspace init failure). DeepEP also
# relies on NVSHMEM/symmetric memory internally, so it would hang during
# forward pass if the NVLink memory infrastructure is unavailable.
if NVLinkOneSided._WORKSPACE_INIT_FAILED:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The role of detecting whether NVLink Symmetric memory is supported shouldn't be dedicated to a specific communication backend, which is not reliable.

logger.info(
"Skipping DeepEP/DeepEPLowLatency: NVLink symmetric memory "
"initialization previously failed (detected via NVLinkOneSided)."
)
elif os.environ.get("TRTLLM_CAN_USE_DEEP_EP", "1") == "1" and act_dtype == torch.bfloat16:
try:
strategy = DeepEP(
mapping,
Expand Down
Comment thread
ziyixiong-nv marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ class NVLinkOneSided(Communication):

# Single shared workspace/memory across the process
_WORKSPACE: dict | None = None
# Track if workspace initialization (MNNVL + NVSHMEM) has failed, to avoid
# repeated attempts and to signal other NVSHMEM-dependent strategies (e.g.
# DeepEP) to skip initialization — they share the same NVLink/symmetric
# memory infrastructure and will also fail or hang.
_WORKSPACE_INIT_FAILED: bool = False
Comment thread
ziyixiong-nv marked this conversation as resolved.
Outdated

# MetaInfo indices - initialized from C++ constants
FLAG_VAL_OFFSET_INDEX = None
Expand Down Expand Up @@ -224,20 +229,30 @@ def __init__(
# Initialize or reuse workspace
MnnvlMemory.initialize()

if self._WORKSPACE_INIT_FAILED:
raise RuntimeError(
"NVLinkOneSided: workspace initialization (MNNVL/NVSHMEM) previously "
"failed on this node, skipping repeated initialization attempt."
)

if self._WORKSPACE is None:
tllm_logger.info(
f"NVLinkOneSided: Allocating workspace with size {self.workspace_size_per_rank} bytes."
f"ep_rank: {self.ep_rank}, ep_size: {self.ep_size}, top_k: {self.top_k}, max_num_tokens_per_rank: {self.max_num_tokens_per_rank}"
)
mnnvl_mem = MnnvlMemory(mapping, self.workspace_size_per_rank)
workspace = mnnvl_mem.as_torch_strided_tensor(torch.uint8)
metainfo = torch.ops.trtllm.moe_a2a_initialize(
workspace,
self.ep_rank,
self.ep_size,
self.max_num_tokens_per_rank,
self.eplb_stats_num_experts,
)
try:
mnnvl_mem = MnnvlMemory(mapping, self.workspace_size_per_rank)
workspace = mnnvl_mem.as_torch_strided_tensor(torch.uint8)
metainfo = torch.ops.trtllm.moe_a2a_initialize(
workspace,
self.ep_rank,
self.ep_size,
self.max_num_tokens_per_rank,
self.eplb_stats_num_experts,
)
except (RuntimeError, AssertionError) as e:
NVLinkOneSided._WORKSPACE_INIT_FAILED = True
Comment thread
ziyixiong-nv marked this conversation as resolved.
raise RuntimeError(f"NVLinkOneSided workspace initialization failed: {e}") from e
NVLinkOneSided._WORKSPACE = {
"workspace_size_per_rank": self.workspace_size_per_rank,
"max_num_tokens_per_rank": self.max_num_tokens_per_rank,
Expand Down
1 change: 0 additions & 1 deletion tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,6 @@ accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-cutl
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-cutlass-auto] SKIP (https://nvbugs/5838211)
full:A10/unittest/kv_cache_manager_v2_tests/ SKIP (https://nvbugs/5841954)
examples/test_mistral.py::test_mistral_with_bf16_lora_torch[mistral-7b-v0.1] SKIP (https://nvbugs/5846178)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[disable_skip_indexer] SKIP (https://nvbugs/5859886)
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-tp4-cutlass-fp8] SKIP (https://nvbugs/5651865)
test_e2e.py::test_trtllm_multimodal_benchmark_serving SKIP (https://nvbugs/5864769)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=vanilla-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5879577)
Expand Down
Loading