diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index a32734bd599e..63bf0d7ffdc5 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -38,7 +38,7 @@ from ..compilation.backend import Backend from ..compilation.utils import capture_piecewise_cuda_graph from ..distributed import Distributed -from ..distributed.communicator import init_pp_comm +from ..distributed.communicator import ReduceOp, init_pp_comm from ..expert_statistic import ExpertStatistic from ..memory_buffer_utils import with_shared_pool from ..metadata import KVCacheParams @@ -795,8 +795,12 @@ def _general_warmup(self, resource_manager: ResourceManager, self._create_warmup_request(resource_manager, num_tokens, num_gen_tokens), resource_manager) as batch: - if batch is None: - continue # Not enough KV cache space + # Synchronize skip decisions across all ranks to prevent + # hangs when one rank skips but others proceed into + # collective ops inside the forward pass. + if not self._all_ranks_warmup_ready(batch is not None, + num_tokens): + continue logger.info( f"Run warmup with {num_tokens} tokens, include {num_gen_tokens} generation tokens" ) @@ -810,6 +814,53 @@ def _general_warmup(self, resource_manager: ResourceManager, f"{num_gen_tokens} generation tokens. Skipping.") torch.cuda.empty_cache() + def _all_ranks_warmup_ready(self, has_batch: bool, + num_tokens: int) -> bool: + """Ensure all ranks agree on whether to proceed with warmup. + + In multi-rank setups, warmup forward passes involve collective + operations (allreduce, allgather). If one rank skips warmup (e.g., + due to insufficient KV cache or GPU memory) while others proceed, + the proceeding ranks will hang indefinitely at the next collective + op, eventually timing out and crashing. + + This method synchronizes the decision across all ranks: warmup + proceeds only if every rank is ready. + """ + if self.dist is None: + return has_batch + + is_ready = has_batch + + # Pre-check GPU memory to avoid OOM that would leave other ranks + # stuck in collective ops mid-forward. The OOM catch in + # _general_warmup only helps the failing rank; other ranks remain + # blocked in the forward pass. + if is_ready: + free_mem, _ = torch.cuda.mem_get_info() + # Estimate memory needed for the warmup forward pass: + # attention workspace + activation tensors + temporary buffers. + # Use ~500 KiB per token as a conservative estimate, with a + # floor of 4 GiB. + min_free = max(num_tokens * 512 * 1024, 4 * 1024**3) + if free_mem < min_free: + is_ready = False + logger.info( + f"Insufficient free GPU memory " + f"({free_mem / 1024**3:.1f} GiB free, " + f"need {min_free / 1024**3:.1f} GiB) for warmup with " + f"{num_tokens} tokens on rank {self.dist.rank}") + + all_ready = self.dist.allreduce( + 1 if is_ready else 0, op=ReduceOp.MIN) > 0 + + if not all_ready and has_batch: + logger.warning( + f"Skipping warmup with {num_tokens} tokens: " + f"not all ranks have sufficient resources") + + return all_ready + def _run_autotuner_warmup(self, resource_manager: ResourceManager): """Runs a forward pass to populate the autotuner cache.""" if not self.llm_args.enable_autotuner: