Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 54 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from ..compilation.backend import Backend
from ..compilation.utils import capture_piecewise_cuda_graph
from ..distributed import Distributed
from ..distributed.communicator import init_pp_comm
from ..distributed.communicator import ReduceOp, init_pp_comm
from ..expert_statistic import ExpertStatistic
from ..memory_buffer_utils import with_shared_pool
from ..metadata import KVCacheParams
Expand Down Expand Up @@ -795,8 +795,12 @@ def _general_warmup(self, resource_manager: ResourceManager,
self._create_warmup_request(resource_manager,
num_tokens, num_gen_tokens),
resource_manager) as batch:
if batch is None:
continue # Not enough KV cache space
# Synchronize skip decisions across all ranks to prevent
# hangs when one rank skips but others proceed into
# collective ops inside the forward pass.
if not self._all_ranks_warmup_ready(batch is not None,
num_tokens):
continue
logger.info(
f"Run warmup with {num_tokens} tokens, include {num_gen_tokens} generation tokens"
)
Expand All @@ -810,6 +814,53 @@ def _general_warmup(self, resource_manager: ResourceManager,
f"{num_gen_tokens} generation tokens. Skipping.")
torch.cuda.empty_cache()

def _all_ranks_warmup_ready(self, has_batch: bool,
num_tokens: int) -> bool:
"""Ensure all ranks agree on whether to proceed with warmup.

In multi-rank setups, warmup forward passes involve collective
operations (allreduce, allgather). If one rank skips warmup (e.g.,
due to insufficient KV cache or GPU memory) while others proceed,
the proceeding ranks will hang indefinitely at the next collective
op, eventually timing out and crashing.

This method synchronizes the decision across all ranks: warmup
proceeds only if every rank is ready.
"""
if self.dist is None:
return has_batch

is_ready = has_batch

# Pre-check GPU memory to avoid OOM that would leave other ranks
# stuck in collective ops mid-forward. The OOM catch in
# _general_warmup only helps the failing rank; other ranks remain
# blocked in the forward pass.
if is_ready:
free_mem, _ = torch.cuda.mem_get_info()
# Estimate memory needed for the warmup forward pass:
# attention workspace + activation tensors + temporary buffers.
# Use ~500 KiB per token as a conservative estimate, with a
# floor of 4 GiB.
min_free = max(num_tokens * 512 * 1024, 4 * 1024**3)
if free_mem < min_free:
is_ready = False
logger.info(
f"Insufficient free GPU memory "
f"({free_mem / 1024**3:.1f} GiB free, "
f"need {min_free / 1024**3:.1f} GiB) for warmup with "
f"{num_tokens} tokens on rank {self.dist.rank}")

all_ready = self.dist.allreduce(
1 if is_ready else 0, op=ReduceOp.MIN) > 0

if not all_ready and has_batch:
logger.warning(
f"Skipping warmup with {num_tokens} tokens: "
f"not all ranks have sufficient resources")

return all_ready

def _run_autotuner_warmup(self, resource_manager: ResourceManager):
"""Runs a forward pass to populate the autotuner cache."""
if not self.llm_args.enable_autotuner:
Expand Down
Loading