Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions slime/backends/megatron_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,33 +431,39 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p
forward_only=False,
)

# CI check: verify only MTP parameters have non-zero gradients when truncation happens
# This check must happen before optimizer.step() as gradients may be modified during step
if args.ci_test and args.enable_mtp_training:
from slime.backends.megatron_utils.ci_utils import check_mtp_only_grad

check_mtp_only_grad(model, step_id)

# Update parameters.
valid_step = True
grad_norm = float("nan")
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()

if not getattr(args, "check_for_nan_in_loss_and_grad", True):
found_inf_flag = optimizer.prepare_grads()
found_inf_flag = not update_successful and grad_norm is None and num_zeros_in_grad is None
if found_inf_flag:
valid_step = False
current_scale = optimizer.get_loss_scale().item()
logger.warning(
"Inf found in gradients (step_id=%d, loss_scale=%s), skipping parameter update (dynamic loss scaling will reduce scale)",
step_id,
current_scale,
)
else:
grad_norm = optimizer.get_grad_norm()
if isinstance(grad_norm, torch.Tensor):
valid_step = not (torch.isnan(grad_norm) or torch.isinf(grad_norm))
else:
valid_step = not (math.isnan(grad_norm) or math.isinf(grad_norm))

# CI check: verify only MTP parameters have non-zero gradients when truncation happens
# This check must happen before optimizer.step() as gradients may be modified during step
if args.ci_test and args.enable_mtp_training:
from slime.backends.megatron_utils.ci_utils import check_mtp_only_grad

check_mtp_only_grad(model, step_id)

if valid_step:
# Update parameters.
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()

# Update learning rate.
assert update_successful
opt_param_scheduler.step(increment=args.global_batch_size)
else:
grad_norm = float("nan")

# release grad
for model_chunk in model:
Expand Down