diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index 586f88ded..d573cef5b 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -357,6 +357,7 @@ def compute_log_prob( data_iterator, num_microbatches, store_prefix=store_prefix, + step_callback=self.prof.step_train_log_probs, ) def train(self, rollout_id: int, rollout_data_ref: Box, external_data=None): @@ -385,7 +386,16 @@ def train_critic(self, rollout_id: int, rollout_data: RolloutBatch): data_iterator, num_microbatches = get_data_iterator(self.args, self.model, rollout_data) # Compute current critic values (used as old_values for value loss and for actor advantages). - rollout_data.update(forward_only(get_values, self.args, self.model, data_iterator, num_microbatches)) + rollout_data.update( + forward_only( + get_values, + self.args, + self.model, + data_iterator, + num_microbatches, + step_callback=self.prof.step_train_log_probs, + ) + ) compute_advantages_and_returns(self.args, rollout_data) @@ -504,6 +514,7 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch, external_data self.opt_param_scheduler, data_iterator, num_microbatches, + step_callback=self.prof.step_train_actor, ) self.prof.step(rollout_id=rollout_id) diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index f326b1d0d..152566498 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -252,6 +252,7 @@ def forward_only( data_iterator: Sequence[DataIterator], num_microbatches: Sequence[int], store_prefix: str = "", + step_callback: Callable[[], None] | None = None, ) -> dict[str, list[torch.Tensor]]: """Run forward passes only and collect non-loss outputs (e.g., logprobs). @@ -377,6 +378,10 @@ def forward_step( ) microbatch_pbar.close() + # Advance the train_log_probs profiler (if active) one tick per step. + if step_callback is not None: + step_callback() + # Move model back to the train mode. for model_module in model: model_module.train() @@ -603,6 +608,7 @@ def train( opt_param_scheduler: OptimizerParamScheduler, data_iterator: Sequence[DataIterator], num_microbatches: Sequence[int], + step_callback: Callable[[], None] | None = None, ) -> None: """Run training over a rollout consisting of multiple steps. @@ -716,6 +722,13 @@ def train( microbatch_pbar=microbatch_pbar, ) + # Advance the train_actor profiler (if active) one tick per grad-accum + # step. The torch.profiler schedule (wait/warmup/active/repeat) decides + # which step actually captures; most ticks are no-ops. Kept out of the + # hot path by the callback being None when profiling is disabled. + if step_callback is not None: + step_callback() + if step_id == 0: # Enable forward pre-hook after training step has successfully run. All subsequent # forward passes will use the forward pre-hook / `param_sync_func` in diff --git a/slime/utils/profile_utils.py b/slime/utils/profile_utils.py index 504d1ce86..13fe1ab39 100644 --- a/slime/utils/profile_utils.py +++ b/slime/utils/profile_utils.py @@ -1,4 +1,5 @@ import logging +import os import time import traceback from pathlib import Path @@ -10,24 +11,75 @@ logger = logging.getLogger(__name__) +def _env_flag(name: str) -> bool: + """Read a boolean env var. Accepts 1/true/yes (case-insensitive) as truthy.""" + return os.environ.get(name, "0").lower() not in ("0", "", "false", "no") + + +def _should_profile_this_rank() -> bool: + """Only rank 0 profiles by default. Each rank holds a full torch.profiler + buffer during the active window (~60 GB on a 26B MoE), so running on all + 16 ranks adds ~1 TB of host RAM pressure and has caused host-OOM kills. + Set SLIME_PROFILE_ALL_RANKS=1 to opt into per-rank traces when diagnosing + cross-rank sync or PP-stage imbalance. + """ + if _env_flag("SLIME_PROFILE_ALL_RANKS"): + return True + if not torch.distributed.is_initialized(): + return True + return torch.distributed.get_rank() == 0 + + class TrainProfiler: + """Manages torch.profiler and memory profilers across training phases. + + Three profile targets, typically used one at a time via ``--profile-target``: + + - ``train_overall`` — one active window covers a full rollout + (ref_log_probs + log_probs + actor_train + update_weights). Useful for + cumulative stats but produces huge traces (8+ GB gzipped on a 26B MoE). + - ``train_actor`` — one active window covers a single grad-accum step + inside ``actor_train``. Trace is ~500× smaller (~15 MB) and actually + openable in Perfetto/Chrome. Hooked via ``step_train_actor`` at the + boundary of each step in ``megatron_utils/model.py::train``. + - ``train_log_probs`` — one active window per log-probs forward. Hooked + via ``step_train_log_probs``. + + ``--profile-target`` is a list, so multiple targets can be passed at once; + the code paths check membership independently rather than enforcing + exclusivity. + """ + def __init__(self, args): self.args = args self._torch_profiler_overall = None self._memory_profiler_overall = None - - if args.use_pytorch_profiler and ("train_overall" in args.profile_target): - self._torch_profiler_overall = _create_torch_profiler(args, name="train_overall") - - if args.record_memory_history and ("train_overall" in args.profile_target): - self._memory_profiler_overall = _BaseMemoryProfiler.create(args) - self._memory_profiler_overall.start() + self._torch_profiler_train_actor = None + self._torch_profiler_train_actor_started = False + self._torch_profiler_train_log_probs = None + self._torch_profiler_train_log_probs_started = False + + if _should_profile_this_rank(): + if args.use_pytorch_profiler and ("train_overall" in args.profile_target): + self._torch_profiler_overall = _create_torch_profiler(args, name="train_overall") + if args.use_pytorch_profiler and ("train_actor" in args.profile_target): + self._torch_profiler_train_actor = _create_torch_profiler(args, name="train_actor") + if args.use_pytorch_profiler and ("train_log_probs" in args.profile_target): + self._torch_profiler_train_log_probs = _create_torch_profiler(args, name="train_log_probs") + if args.record_memory_history and ("train_overall" in args.profile_target): + self._memory_profiler_overall = _BaseMemoryProfiler.create(args) + self._memory_profiler_overall.start() def on_init_end(self): + # Only the train_overall profiler starts at init; the per-step ones + # start lazily on their first tick so they don't waste a warmup slot + # on code that runs before training begins. if self._torch_profiler_overall is not None: self._torch_profiler_overall.start() def step(self, rollout_id: int): + """Called once per rollout from the actor loop. Advances the + train_overall profiler's state machine.""" if self._torch_profiler_overall is not None: self._torch_profiler_overall.step() @@ -38,6 +90,25 @@ def step(self, rollout_id: int): ): self._memory_profiler_overall.stop() + def step_train_actor(self): + """Called at each grad-accum step boundary inside ``actor_train``. + Each call advances the train_actor profiler by one tick.""" + if self._torch_profiler_train_actor is None: + return + if not self._torch_profiler_train_actor_started: + self._torch_profiler_train_actor.start() + self._torch_profiler_train_actor_started = True + self._torch_profiler_train_actor.step() + + def step_train_log_probs(self): + """Called at each log-probs forward-step boundary.""" + if self._torch_profiler_train_log_probs is None: + return + if not self._torch_profiler_train_log_probs_started: + self._torch_profiler_train_log_probs.start() + self._torch_profiler_train_log_probs_started = True + self._torch_profiler_train_log_probs.step() + def iterate_train_actor(self, iterator): return _profile_simple_loop(iterator, self.args, name="train_actor") @@ -46,7 +117,11 @@ def iterate_train_log_probs(self, iterator): def _profile_simple_loop(iterator, args, name): - if not (args.use_pytorch_profiler and (name in args.profile_target)): + if not ( + args.use_pytorch_profiler + and (name in args.profile_target) + and _should_profile_this_rank() + ): yield from iterator return @@ -58,6 +133,14 @@ def _profile_simple_loop(iterator, args, name): def _create_torch_profiler(args, name): + # On large MoE models (26B+), record_shapes / with_flops / with_stack / + # profile_memory are memory amplifiers that can produce 10+ GB traces and + # OOM the host. All off by default; flip via env var for tighter capture + # windows where the extra metadata is worth the cost. + record_shapes = _env_flag("SLIME_PROFILE_RECORD_SHAPES") + with_flops = _env_flag("SLIME_PROFILE_WITH_FLOPS") + with_stack = _env_flag("SLIME_PROFILE_WITH_STACK") + profile_memory = _env_flag("SLIME_PROFILE_MEMORY") return torch.profiler.profile( schedule=torch.profiler.schedule( # TODO the train_actor and train_log_probs ones may need to have different args to control step @@ -71,10 +154,10 @@ def _create_torch_profiler(args, name): worker_name=f"{name}_rank_{torch.distributed.get_rank()}", use_gzip=True, ), - record_shapes=True, - with_stack=True, - profile_memory=True, - with_flops=True, + record_shapes=record_shapes, + with_flops=with_flops, + with_stack=with_stack, + profile_memory=profile_memory, )