diff --git a/slime/backends/megatron_utils/model_provider.py b/slime/backends/megatron_utils/model_provider.py index 2ab8d6534b..7f9245aeb1 100644 --- a/slime/backends/megatron_utils/model_provider.py +++ b/slime/backends/megatron_utils/model_provider.py @@ -101,6 +101,15 @@ def wrapped_model_provider( provider.num_layers_in_first_pipeline_stage = args.decoder_first_pipeline_num_layers if getattr(args, "decoder_last_pipeline_num_layers", None) is not None: provider.num_layers_in_last_pipeline_stage = args.decoder_last_pipeline_num_layers + + recompute_fields = ( + "recompute_granularity", + "recompute_method", + "recompute_num_layers") + + for field in recompute_fields: + if hasattr(args, field) and getattr(args, field) is not None: + setattr(provider, field, getattr(args, field)) provider.finalize() if role == "critic":