Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion slime/backends/megatron_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ def compute_log_prob(
data_iterator,
num_microbatches,
store_prefix=store_prefix,
step_callback=self.prof.step_train_log_probs,
)

def train(self, rollout_id: int, rollout_data_ref: Box, external_data=None):
Expand Down Expand Up @@ -385,7 +386,16 @@ def train_critic(self, rollout_id: int, rollout_data: RolloutBatch):
data_iterator, num_microbatches = get_data_iterator(self.args, self.model, rollout_data)

# Compute current critic values (used as old_values for value loss and for actor advantages).
rollout_data.update(forward_only(get_values, self.args, self.model, data_iterator, num_microbatches))
rollout_data.update(
forward_only(
get_values,
self.args,
self.model,
data_iterator,
num_microbatches,
step_callback=self.prof.step_train_log_probs,
)
)

compute_advantages_and_returns(self.args, rollout_data)

Expand Down Expand Up @@ -504,6 +514,7 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch, external_data
self.opt_param_scheduler,
data_iterator,
num_microbatches,
step_callback=self.prof.step_train_actor,
)

self.prof.step(rollout_id=rollout_id)
Expand Down
13 changes: 13 additions & 0 deletions slime/backends/megatron_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def forward_only(
data_iterator: Sequence[DataIterator],
num_microbatches: Sequence[int],
store_prefix: str = "",
step_callback: Callable[[], None] | None = None,
) -> dict[str, list[torch.Tensor]]:
"""Run forward passes only and collect non-loss outputs (e.g., logprobs).

Expand Down Expand Up @@ -377,6 +378,10 @@ def forward_step(
)
microbatch_pbar.close()

# Advance the train_log_probs profiler (if active) one tick per step.
if step_callback is not None:
step_callback()

# Move model back to the train mode.
for model_module in model:
model_module.train()
Expand Down Expand Up @@ -603,6 +608,7 @@ def train(
opt_param_scheduler: OptimizerParamScheduler,
data_iterator: Sequence[DataIterator],
num_microbatches: Sequence[int],
step_callback: Callable[[], None] | None = None,
) -> None:
"""Run training over a rollout consisting of multiple steps.

Expand Down Expand Up @@ -716,6 +722,13 @@ def train(
microbatch_pbar=microbatch_pbar,
)

# Advance the train_actor profiler (if active) one tick per grad-accum
# step. The torch.profiler schedule (wait/warmup/active/repeat) decides
# which step actually captures; most ticks are no-ops. Kept out of the
# hot path by the callback being None when profiling is disabled.
if step_callback is not None:
step_callback()

if step_id == 0:
# Enable forward pre-hook after training step has successfully run. All subsequent
# forward passes will use the forward pre-hook / `param_sync_func` in
Expand Down
107 changes: 95 additions & 12 deletions slime/utils/profile_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
import time
import traceback
from pathlib import Path
Expand All @@ -10,24 +11,75 @@
logger = logging.getLogger(__name__)


def _env_flag(name: str) -> bool:
"""Read a boolean env var. Accepts 1/true/yes (case-insensitive) as truthy."""
return os.environ.get(name, "0").lower() not in ("0", "", "false", "no")


def _should_profile_this_rank() -> bool:
"""Only rank 0 profiles by default. Each rank holds a full torch.profiler
buffer during the active window (~60 GB on a 26B MoE), so running on all
16 ranks adds ~1 TB of host RAM pressure and has caused host-OOM kills.
Set SLIME_PROFILE_ALL_RANKS=1 to opt into per-rank traces when diagnosing
cross-rank sync or PP-stage imbalance.
"""
if _env_flag("SLIME_PROFILE_ALL_RANKS"):
return True
if not torch.distributed.is_initialized():
return True
return torch.distributed.get_rank() == 0


class TrainProfiler:
"""Manages torch.profiler and memory profilers across training phases.

Three profile targets, typically used one at a time via ``--profile-target``:

- ``train_overall`` — one active window covers a full rollout
(ref_log_probs + log_probs + actor_train + update_weights). Useful for
cumulative stats but produces huge traces (8+ GB gzipped on a 26B MoE).
- ``train_actor`` — one active window covers a single grad-accum step
inside ``actor_train``. Trace is ~500× smaller (~15 MB) and actually
openable in Perfetto/Chrome. Hooked via ``step_train_actor`` at the
boundary of each step in ``megatron_utils/model.py::train``.
- ``train_log_probs`` — one active window per log-probs forward. Hooked
via ``step_train_log_probs``.

``--profile-target`` is a list, so multiple targets can be passed at once;
the code paths check membership independently rather than enforcing
exclusivity.
"""

def __init__(self, args):
self.args = args
self._torch_profiler_overall = None
self._memory_profiler_overall = None

if args.use_pytorch_profiler and ("train_overall" in args.profile_target):
self._torch_profiler_overall = _create_torch_profiler(args, name="train_overall")

if args.record_memory_history and ("train_overall" in args.profile_target):
self._memory_profiler_overall = _BaseMemoryProfiler.create(args)
self._memory_profiler_overall.start()
self._torch_profiler_train_actor = None
self._torch_profiler_train_actor_started = False
self._torch_profiler_train_log_probs = None
self._torch_profiler_train_log_probs_started = False

if _should_profile_this_rank():
if args.use_pytorch_profiler and ("train_overall" in args.profile_target):
self._torch_profiler_overall = _create_torch_profiler(args, name="train_overall")
if args.use_pytorch_profiler and ("train_actor" in args.profile_target):
self._torch_profiler_train_actor = _create_torch_profiler(args, name="train_actor")
if args.use_pytorch_profiler and ("train_log_probs" in args.profile_target):
self._torch_profiler_train_log_probs = _create_torch_profiler(args, name="train_log_probs")
if args.record_memory_history and ("train_overall" in args.profile_target):
self._memory_profiler_overall = _BaseMemoryProfiler.create(args)
self._memory_profiler_overall.start()

def on_init_end(self):
# Only the train_overall profiler starts at init; the per-step ones
# start lazily on their first tick so they don't waste a warmup slot
# on code that runs before training begins.
if self._torch_profiler_overall is not None:
self._torch_profiler_overall.start()

def step(self, rollout_id: int):
"""Called once per rollout from the actor loop. Advances the
train_overall profiler's state machine."""
if self._torch_profiler_overall is not None:
self._torch_profiler_overall.step()

Expand All @@ -38,6 +90,25 @@ def step(self, rollout_id: int):
):
self._memory_profiler_overall.stop()

def step_train_actor(self):
"""Called at each grad-accum step boundary inside ``actor_train``.
Each call advances the train_actor profiler by one tick."""
if self._torch_profiler_train_actor is None:
return
if not self._torch_profiler_train_actor_started:
self._torch_profiler_train_actor.start()
self._torch_profiler_train_actor_started = True
self._torch_profiler_train_actor.step()

def step_train_log_probs(self):
"""Called at each log-probs forward-step boundary."""
if self._torch_profiler_train_log_probs is None:
return
if not self._torch_profiler_train_log_probs_started:
self._torch_profiler_train_log_probs.start()
self._torch_profiler_train_log_probs_started = True
self._torch_profiler_train_log_probs.step()

def iterate_train_actor(self, iterator):
return _profile_simple_loop(iterator, self.args, name="train_actor")

Expand All @@ -46,7 +117,11 @@ def iterate_train_log_probs(self, iterator):


def _profile_simple_loop(iterator, args, name):
if not (args.use_pytorch_profiler and (name in args.profile_target)):
if not (
args.use_pytorch_profiler
and (name in args.profile_target)
and _should_profile_this_rank()
):
yield from iterator
return

Expand All @@ -58,6 +133,14 @@ def _profile_simple_loop(iterator, args, name):


def _create_torch_profiler(args, name):
# On large MoE models (26B+), record_shapes / with_flops / with_stack /
# profile_memory are memory amplifiers that can produce 10+ GB traces and
# OOM the host. All off by default; flip via env var for tighter capture
# windows where the extra metadata is worth the cost.
record_shapes = _env_flag("SLIME_PROFILE_RECORD_SHAPES")
with_flops = _env_flag("SLIME_PROFILE_WITH_FLOPS")
with_stack = _env_flag("SLIME_PROFILE_WITH_STACK")
profile_memory = _env_flag("SLIME_PROFILE_MEMORY")
return torch.profiler.profile(
schedule=torch.profiler.schedule(
# TODO the train_actor and train_log_probs ones may need to have different args to control step
Expand All @@ -71,10 +154,10 @@ def _create_torch_profiler(args, name):
worker_name=f"{name}_rank_{torch.distributed.get_rank()}",
use_gzip=True,
),
record_shapes=True,
with_stack=True,
profile_memory=True,
with_flops=True,
record_shapes=record_shapes,
with_flops=with_flops,
with_stack=with_stack,
profile_memory=profile_memory,
)


Expand Down
Loading