diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index 0bbe5bf49b..6a4f207c41 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -431,33 +431,39 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p forward_only=False, ) + # CI check: verify only MTP parameters have non-zero gradients when truncation happens + # This check must happen before optimizer.step() as gradients may be modified during step + if args.ci_test and args.enable_mtp_training: + from slime.backends.megatron_utils.ci_utils import check_mtp_only_grad + + check_mtp_only_grad(model, step_id) + + # Update parameters. valid_step = True - grad_norm = float("nan") + update_successful, grad_norm, num_zeros_in_grad = optimizer.step() + if not getattr(args, "check_for_nan_in_loss_and_grad", True): - found_inf_flag = optimizer.prepare_grads() + found_inf_flag = not update_successful and grad_norm is None and num_zeros_in_grad is None if found_inf_flag: valid_step = False + current_scale = optimizer.get_loss_scale().item() + logger.warning( + "Inf found in gradients (step_id=%d, loss_scale=%s), skipping parameter update (dynamic loss scaling will reduce scale)", + step_id, + current_scale, + ) else: - grad_norm = optimizer.get_grad_norm() if isinstance(grad_norm, torch.Tensor): valid_step = not (torch.isnan(grad_norm) or torch.isinf(grad_norm)) else: valid_step = not (math.isnan(grad_norm) or math.isinf(grad_norm)) - # CI check: verify only MTP parameters have non-zero gradients when truncation happens - # This check must happen before optimizer.step() as gradients may be modified during step - if args.ci_test and args.enable_mtp_training: - from slime.backends.megatron_utils.ci_utils import check_mtp_only_grad - - check_mtp_only_grad(model, step_id) - if valid_step: - # Update parameters. - update_successful, grad_norm, num_zeros_in_grad = optimizer.step() - # Update learning rate. assert update_successful opt_param_scheduler.step(increment=args.global_batch_size) + else: + grad_norm = float("nan") # release grad for model_chunk in model: