From 899044cd2e63e1347f4bdde6f85a5e5d5829d120 Mon Sep 17 00:00:00 2001 From: taivu1998 <46636857+taivu1998@users.noreply.github.com> Date: Sun, 26 Apr 2026 04:30:08 -0700 Subject: [PATCH] Add Megatron-Bridge LoRA GRPO support --- docs/en/advanced/lora-grpo.md | 42 ++++ docs/en/index.rst | 1 + slime/backends/megatron_utils/model.py | 3 + .../backends/megatron_utils/model_provider.py | 8 +- slime/backends/megatron_utils/peft.py | 225 ++++++++++++++++++ .../hf_weight_iterator_bridge.py | 47 +++- slime/utils/arguments.py | 41 ++++ tests/test_lora_support.py | 184 ++++++++++++++ 8 files changed, 549 insertions(+), 2 deletions(-) create mode 100644 docs/en/advanced/lora-grpo.md create mode 100644 slime/backends/megatron_utils/peft.py create mode 100644 tests/test_lora_support.py diff --git a/docs/en/advanced/lora-grpo.md b/docs/en/advanced/lora-grpo.md new file mode 100644 index 0000000000..c7e14c2d61 --- /dev/null +++ b/docs/en/advanced/lora-grpo.md @@ -0,0 +1,42 @@ +# Megatron-Bridge LoRA for GRPO + +slime supports a first Megatron-Bridge LoRA path for dense GRPO actor training. This path keeps training in Megatron, merges LoRA adapters only during SGLang weight export, and restores the unmerged actor weights immediately after export. + +## Example + +Start from a dense Megatron GRPO script such as `scripts/run-qwen3-4B.sh`, then add the LoRA and bridge flags: + +```bash +--enable-lora \ +--megatron-to-hf-mode bridge \ +--colocate \ +--lora-rank 16 \ +--lora-alpha 32 \ +--lora-dropout 0.0 \ +--lora-target-modules linear_qkv linear_proj linear_fc1 linear_fc2 +``` + +`--lora-target-modules` is optional. If it is omitted, slime uses the Megatron-Bridge LoRA defaults. + +## Supported Scope + +The initial LoRA path intentionally supports a narrow, validated configuration: + +- Megatron training backend. +- Megatron-Bridge HF weight export mode. +- GRPO actor training. +- Colocated SGLang rollout and weight updates. +- Dense models. +- Default weight backuper enabled. + +The following combinations are rejected at startup until they have dedicated parity coverage: + +- MoE models. +- PPO or critic-based training. +- Decoupled rollout mode outside `--debug-train-only`. +- Custom model providers. +- `--only-train-params-name-list` or `--freeze-params-name-list`. +- On-policy distillation. +- Reference model update intervals. +- `--disable-weights-backuper`. + diff --git a/docs/en/index.rst b/docs/en/index.rst index 0c0b7521f9..08ef03c899 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -42,6 +42,7 @@ slime is the RL-framework behind GLM-4.7, GLM-4.6 and GLM-4.5. Apart from models :caption: Advanced Features advanced/on-policy-distillation.md + advanced/lora-grpo.md advanced/speculative-decoding.md advanced/low-precision.md advanced/reproducibility.md diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index 0bbe5bf49b..40b6bca02b 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -29,6 +29,7 @@ from .data import DataIterator, get_batch from .loss import loss_function from .model_provider import get_model_provider_func, wrap_model_provider_with_freeze +from .peft import log_lora_parameter_summary, lora_enabled logger = logging.getLogger(__name__) @@ -108,6 +109,8 @@ def setup_model_and_optimizer( model = get_model( wrap_model_provider_with_freeze(get_model_provider_func(args, role), args), ModelType.encoder_or_decoder ) + if role == "actor" and lora_enabled(args): + log_lora_parameter_summary(model) # Optimizer kwargs = {} diff --git a/slime/backends/megatron_utils/model_provider.py b/slime/backends/megatron_utils/model_provider.py index 2ab8d6534b..649d1e39f0 100644 --- a/slime/backends/megatron_utils/model_provider.py +++ b/slime/backends/megatron_utils/model_provider.py @@ -19,6 +19,8 @@ from slime.utils.misc import load_function +from .peft import maybe_apply_lora + # Adapt from https://github.com/volcengine/verl/blob/c3b20575d2bc815fcccd84bddb4c0401fc4b632b/verl/models/llama/megatron/layers/parallel_linear.py#L82 class LinearForLastLayer(torch.nn.Linear): @@ -116,7 +118,11 @@ def _critic_provide(pre_process=True, post_process=True, vp_stage=None): return _critic_provide - return provider.provide + def _actor_provide(pre_process=True, post_process=True, vp_stage=None): + model = provider.provide(pre_process=pre_process, post_process=post_process, vp_stage=vp_stage) + return maybe_apply_lora(model, args, role) + + return _actor_provide def model_provider(pre_process: bool = True, post_process: bool = True, vp_stage: int | None = None) -> GPTModel: """Builds the model. diff --git a/slime/backends/megatron_utils/peft.py b/slime/backends/megatron_utils/peft.py new file mode 100644 index 0000000000..aea6bd3e84 --- /dev/null +++ b/slime/backends/megatron_utils/peft.py @@ -0,0 +1,225 @@ +from __future__ import annotations + +import inspect +import logging +from argparse import Namespace +from collections.abc import Iterable + +import torch + +logger = logging.getLogger(__name__) + + +class LoRAConfigurationError(RuntimeError): + """Raised when the runtime cannot support the requested LoRA configuration.""" + + +def lora_enabled(args: Namespace) -> bool: + return bool(getattr(args, "enable_lora", False)) + + +def validate_lora_args(args: Namespace) -> None: + """Validate the first supported LoRA GRPO path. + + The initial implementation intentionally supports only the narrow path that is + safe to reason about: Megatron GRPO actor LoRA through Megatron-Bridge with + colocated SGLang weight updates. Broader combinations should be enabled only + after they have parity tests. + """ + + if not lora_enabled(args): + return + + errors = [] + if getattr(args, "train_backend", "megatron") != "megatron": + errors.append("--enable-lora requires --train-backend megatron") + if getattr(args, "megatron_to_hf_mode", None) != "bridge": + errors.append("--enable-lora requires --megatron-to-hf-mode bridge") + if getattr(args, "advantage_estimator", None) != "grpo": + errors.append("--enable-lora currently supports only --advantage-estimator grpo") + if not getattr(args, "colocate", False) and not getattr(args, "debug_train_only", False): + errors.append("--enable-lora currently requires --colocate outside debug-train-only runs") + if getattr(args, "custom_model_provider_path", None) is not None: + errors.append("--enable-lora does not yet support --custom-model-provider-path") + if getattr(args, "only_train_params_name_list", None): + errors.append("--enable-lora cannot be combined with --only-train-params-name-list") + if getattr(args, "freeze_params_name_list", None): + errors.append("--enable-lora cannot be combined with --freeze-params-name-list") + if not getattr(args, "enable_weights_backuper", True): + errors.append("--enable-lora cannot be combined with --disable-weights-backuper") + if getattr(args, "use_opd", False): + errors.append("--enable-lora does not yet support on-policy distillation") + if getattr(args, "num_experts", None): + errors.append("--enable-lora does not yet support MoE models") + if getattr(args, "ref_update_interval", None) is not None: + errors.append("--enable-lora does not yet support --ref-update-interval") + lora_rank = getattr(args, "lora_rank", None) + if lora_rank is None or lora_rank <= 0: + errors.append("--lora-rank must be positive") + lora_alpha = getattr(args, "lora_alpha", None) + if lora_alpha is None or lora_alpha <= 0: + errors.append("--lora-alpha must be positive") + lora_dropout = getattr(args, "lora_dropout", None) + if lora_dropout is None or not 0.0 <= lora_dropout < 1.0: + errors.append("--lora-dropout must be in [0.0, 1.0)") + + if errors: + raise ValueError("; ".join(errors)) + + ensure_lora_runtime_available() + + +def ensure_lora_runtime_available() -> None: + """Fail early if the installed Megatron-Bridge does not expose LoRA PEFT.""" + + _get_lora_cls() + + +def build_lora_config(args: Namespace): + LoRA = _get_lora_cls() + signature = inspect.signature(LoRA) + parameters = signature.parameters + + for required in ("dim", "alpha", "dropout"): + if required not in parameters: + raise LoRAConfigurationError( + f"Installed megatron.bridge.peft.lora.LoRA does not accept the required '{required}' argument." + ) + + kwargs = { + "dim": getattr(args, "lora_rank"), + "alpha": getattr(args, "lora_alpha"), + "dropout": getattr(args, "lora_dropout"), + } + target_modules = getattr(args, "lora_target_modules", None) + if target_modules: + if "target_modules" not in parameters: + raise LoRAConfigurationError( + "Installed megatron.bridge.peft.lora.LoRA does not accept target_modules, " + "but --lora-target-modules was set." + ) + kwargs["target_modules"] = list(target_modules) + + return LoRA(**kwargs) + + +def maybe_apply_lora(model, args: Namespace, role: str): + if not lora_enabled(args) or role != "actor": + return model + + lora_config = build_lora_config(args) + model = lora_config(model, training=True) + setattr(model, "_slime_lora_config", lora_config) + setattr(model, "_slime_lora_enabled", True) + return model + + +def count_parameters(model) -> tuple[int, int]: + total = 0 + trainable = 0 + for param in _iter_parameters(model): + numel = param.numel() + total += numel + if getattr(param, "requires_grad", False): + trainable += numel + return total, trainable + + +def log_lora_parameter_summary(model) -> None: + total, trainable = count_parameters(model) + pct = 0.0 if total == 0 else trainable / total * 100 + logger.info("LoRA local parameter summary: trainable=%s total=%s trainable_pct=%.6f", trainable, total, pct) + + +def merge_lora_weights_for_export(model) -> None: + """Merge LoRA adapter weights into base weights for a temporary export. + + Megatron-Bridge's public PEFT entrypoint freezes models when called directly, + which is not appropriate during weight sync. Calling LoRAMerge.transform on + each module performs only the merge operation and leaves the training wrapper + structure in place; callers must restore the unmerged weights afterwards. + """ + + LoRAMerge = _get_lora_merge_cls() + merger = LoRAMerge() + for module in _iter_modules(model): + merger.transform(module) + + +def _get_lora_cls(): + try: + from megatron.bridge.peft.lora import LoRA + except Exception as exc: + raise LoRAConfigurationError( + "--enable-lora requires Megatron-Bridge PEFT LoRA support " + "(megatron.bridge.peft.lora.LoRA). The installed Megatron-Bridge runtime does not expose it." + ) from exc + return LoRA + + +def _get_lora_merge_cls(): + try: + from megatron.bridge.peft.lora import LoRAMerge + except Exception as exc: + raise LoRAConfigurationError( + "LoRA weight sync requires megatron.bridge.peft.lora.LoRAMerge, " + "but the installed Megatron-Bridge runtime does not expose it." + ) from exc + return LoRAMerge + + +def _iter_parameters(model) -> Iterable: + if isinstance(model, (list, tuple)): + for model_chunk in model: + yield from model_chunk.parameters() + else: + yield from model.parameters() + + +def _iter_modules(model) -> Iterable: + if isinstance(model, (list, tuple)): + for model_chunk in model: + yield from model_chunk.modules() + else: + yield from model.modules() + + +@torch.no_grad() +def restore_model_from_named_tensors(model, named_tensors): + """Restore TensorBackuper-style tensors before or after temporary LoRA merge export.""" + + missing_names = [] + for name, tensor in _iter_model_named_tensors(model): + if name not in named_tensors: + missing_names.append(name) + continue + tensor.copy_(named_tensors[name].to(device=tensor.device, non_blocking=True), non_blocking=True) + if missing_names: + preview = ", ".join(missing_names[:5]) + suffix = "" if len(missing_names) <= 5 else f", ... ({len(missing_names)} total)" + raise KeyError(f"LoRA weight export backup is missing model tensors: {preview}{suffix}") + if torch.cuda.is_available(): + torch.cuda.synchronize() + + +def _iter_model_named_tensors(model): + model_chunks = model if isinstance(model, (list, tuple)) else [model] + for vp_stage, model_module in enumerate(model_chunks): + + def _compute_fqn(name, vp_stage=vp_stage): + return f"vp_stages.{vp_stage}.{_strip_param_name_prefix(name)}" + + for name, param in model_module.named_parameters(): + yield _compute_fqn(name), param + + for name, buffer in model_module.named_buffers(): + if "expert_bias" not in name: + continue + yield _compute_fqn(name), buffer + + +def _strip_param_name_prefix(name: str): + prefix = "module." + while name.startswith(prefix): + name = name[len(prefix) :] + return name diff --git a/slime/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py b/slime/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py index 638d8fd1a9..153d8532a2 100644 --- a/slime/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py +++ b/slime/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py @@ -1,6 +1,10 @@ import dataclasses - +from slime.backends.megatron_utils.peft import ( + lora_enabled, + merge_lora_weights_for_export, + restore_model_from_named_tensors, +) from slime.utils import megatron_bridge_utils from slime.utils.misc import chunk_named_params_by_size @@ -48,6 +52,10 @@ def __init__(self, *args, **kwargs): _patch_bridge_expert_cache_to_cpu() def get_hf_weight_chunks(self, megatron_local_weights): + if lora_enabled(self.args): + yield from self._get_lora_hf_weight_chunks(megatron_local_weights) + return + # TODO support quantization (e.g. modify megatron-bridge to provide megatron param name) renamed_megatron_local_weights = {strip_param_name_prefix(k): v for k, v in megatron_local_weights.items()} with megatron_bridge_utils.patch_megatron_model(self.model): @@ -77,6 +85,43 @@ def _streaming_quantized(): _streaming_quantized(), chunk_size=self.args.update_weight_buffer_size ) + def _get_lora_hf_weight_chunks(self, megatron_local_weights): + # The normal bridge path substitutes per-task weights from TensorBackuper. + # For LoRA we need the effective actor (base + adapter deltas), so restore + # the actor backup into the live model, temporarily merge adapters, export + # through Megatron-Bridge, and always restore the unmerged actor backup. + restore_model_from_named_tensors(self.model, megatron_local_weights) + try: + merge_lora_weights_for_export(self.model) + with megatron_bridge_utils.patch_megatron_model(self.model): + conversion_tasks = self._bridge.get_conversion_tasks(self.model) + named_weights = self._bridge.export_hf_weights( + self.model, cpu=False, conversion_tasks=conversion_tasks + ) + + def _streaming_quantized(): + for hf_param_name, weight, megatron_param_name in named_weights: + processed_weight = postprocess_hf_param( + args=self.args, + megatron_param_name=megatron_param_name, + hf_param_name=hf_param_name, + param=weight, + ) + converted_named_params = [(hf_param_name, processed_weight)] + quantized_batch = quantize_params( + args=self.args, + megatron_name=megatron_param_name, + converted_named_params=converted_named_params, + quantization_config=self.quantization_config, + ) + yield from quantized_batch + + yield from chunk_named_params_by_size( + _streaming_quantized(), chunk_size=self.args.update_weight_buffer_size + ) + finally: + restore_model_from_named_tensors(self.model, megatron_local_weights) + def _process_conversion_tasks(vanilla_conversion_tasks, new_weight_dict): def _handle_one(task): diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index c1c25eff89..94501fc881 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -148,6 +148,43 @@ def add_train_arguments(parser): "Example: 'my_module.my_model_provider'." ), ) + parser.add_argument( + "--enable-lora", + action="store_true", + default=False, + help=( + "Enable Megatron-Bridge LoRA for GRPO actor training. " + "The initial supported path requires bridge mode and colocated rollout." + ), + ) + parser.add_argument( + "--lora-target-modules", + type=str, + nargs="*", + default=None, + help=( + "Megatron-Bridge LoRA target modules. If unset, use Megatron-Bridge defaults " + "(typically linear_qkv, linear_proj, linear_fc1, linear_fc2)." + ), + ) + parser.add_argument( + "--lora-rank", + type=int, + default=16, + help="LoRA adapter rank passed to Megatron-Bridge LoRA(dim=...).", + ) + parser.add_argument( + "--lora-alpha", + type=int, + default=32, + help="LoRA alpha passed to Megatron-Bridge LoRA(alpha=...).", + ) + parser.add_argument( + "--lora-dropout", + type=float, + default=0.0, + help="LoRA dropout passed to Megatron-Bridge LoRA(dropout=...).", + ) parser.add_argument( "--recompute-loss-function", action="store_true", @@ -1705,6 +1742,10 @@ def slime_validate_args(args): args.critic_num_gpus_per_node = args.actor_num_gpus_per_node args.critic_num_nodes = args.actor_num_nodes + from slime.backends.megatron_utils.peft import validate_lora_args + + validate_lora_args(args) + if args.offload: args.offload_train = True args.offload_rollout = True diff --git a/tests/test_lora_support.py b/tests/test_lora_support.py new file mode 100644 index 0000000000..b5b3a2c70e --- /dev/null +++ b/tests/test_lora_support.py @@ -0,0 +1,184 @@ +import sys +import types +from argparse import Namespace + +import pytest +import torch + +from slime.backends.megatron_utils import peft + + +def _args(**overrides): + values = { + "enable_lora": True, + "train_backend": "megatron", + "megatron_to_hf_mode": "bridge", + "advantage_estimator": "grpo", + "colocate": True, + "debug_train_only": False, + "custom_model_provider_path": None, + "only_train_params_name_list": None, + "freeze_params_name_list": None, + "enable_weights_backuper": True, + "use_opd": False, + "num_experts": None, + "ref_update_interval": None, + "lora_target_modules": None, + "lora_rank": 16, + "lora_alpha": 32, + "lora_dropout": 0.0, + } + values.update(overrides) + return Namespace(**values) + + +@pytest.fixture +def fake_lora_runtime(monkeypatch): + megatron_mod = types.ModuleType("megatron") + bridge_mod = types.ModuleType("megatron.bridge") + peft_mod = types.ModuleType("megatron.bridge.peft") + lora_mod = types.ModuleType("megatron.bridge.peft.lora") + + class FakeLoRA: + instances = [] + + def __init__(self, target_modules=None, dim=32, alpha=32, dropout=0.0): + self.target_modules = target_modules + self.dim = dim + self.alpha = alpha + self.dropout = dropout + FakeLoRA.instances.append(self) + + def __call__(self, model, training=True): + model.lora_applied = True + model.lora_training = training + model.lora_config = self + return model + + class FakeLoRAMerge: + def transform(self, module): + module.merge_count = getattr(module, "merge_count", 0) + 1 + return module + + lora_mod.LoRA = FakeLoRA + lora_mod.LoRAMerge = FakeLoRAMerge + + monkeypatch.setitem(sys.modules, "megatron", megatron_mod) + monkeypatch.setitem(sys.modules, "megatron.bridge", bridge_mod) + monkeypatch.setitem(sys.modules, "megatron.bridge.peft", peft_mod) + monkeypatch.setitem(sys.modules, "megatron.bridge.peft.lora", lora_mod) + + return FakeLoRA, FakeLoRAMerge + + +def test_validate_lora_args_disabled_does_not_require_runtime(): + peft.validate_lora_args(_args(enable_lora=False)) + + +def test_validate_lora_args_accepts_supported_first_slice(fake_lora_runtime): + peft.validate_lora_args(_args()) + + +@pytest.mark.parametrize( + ("override", "expected"), + [ + ({"megatron_to_hf_mode": "raw"}, "--megatron-to-hf-mode bridge"), + ({"advantage_estimator": "ppo"}, "--advantage-estimator grpo"), + ({"colocate": False}, "--colocate"), + ({"custom_model_provider_path": "custom.provider"}, "--custom-model-provider-path"), + ({"only_train_params_name_list": ["adapter"]}, "--only-train-params-name-list"), + ({"freeze_params_name_list": ["linear"]}, "--freeze-params-name-list"), + ({"enable_weights_backuper": False}, "--disable-weights-backuper"), + ({"use_opd": True}, "on-policy distillation"), + ({"num_experts": 8}, "MoE models"), + ({"ref_update_interval": 10}, "--ref-update-interval"), + ({"lora_rank": 0}, "--lora-rank"), + ({"lora_alpha": 0}, "--lora-alpha"), + ({"lora_dropout": 1.0}, "--lora-dropout"), + ({"lora_dropout": -0.1}, "--lora-dropout"), + ], +) +def test_validate_lora_args_rejects_unsupported_combinations(fake_lora_runtime, override, expected): + with pytest.raises(ValueError, match=expected): + peft.validate_lora_args(_args(**override)) + + +def test_validate_lora_args_allows_non_colocated_debug_train_only(fake_lora_runtime): + peft.validate_lora_args(_args(colocate=False, debug_train_only=True)) + + +def test_build_lora_config_maps_cli_args(fake_lora_runtime): + FakeLoRA, _ = fake_lora_runtime + + config = peft.build_lora_config( + _args( + lora_target_modules=["linear_qkv", "linear_proj"], + lora_rank=8, + lora_alpha=16, + lora_dropout=0.1, + ) + ) + + assert isinstance(config, FakeLoRA) + assert config.target_modules == ["linear_qkv", "linear_proj"] + assert config.dim == 8 + assert config.alpha == 16 + assert config.dropout == 0.1 + + +def test_maybe_apply_lora_only_applies_to_actor(fake_lora_runtime): + actor_model = types.SimpleNamespace() + critic_model = types.SimpleNamespace() + + applied = peft.maybe_apply_lora(actor_model, _args(), role="actor") + untouched = peft.maybe_apply_lora(critic_model, _args(), role="critic") + + assert applied is actor_model + assert actor_model.lora_applied is True + assert actor_model._slime_lora_enabled is True + assert untouched is critic_model + assert not hasattr(critic_model, "lora_applied") + + +def test_merge_lora_weights_for_export_visits_modules(fake_lora_runtime): + class FakeModule: + def __init__(self, children=None): + self.children = children or [] + + def modules(self): + yield self + for child in self.children: + yield child + + child = FakeModule() + root = FakeModule(children=[child]) + + peft.merge_lora_weights_for_export(root) + + assert root.merge_count == 1 + assert child.merge_count == 1 + + +class _FakeVpStage(torch.nn.Module): + def __init__(self): + super().__init__() + self.module = torch.nn.Module() + self.module.module = torch.nn.Module() + self.module.module.linear = torch.nn.Linear(2, 2, bias=False) + + +def test_restore_model_from_named_tensors_restores_vanilla_backup_names(): + stage = _FakeVpStage() + stage.module.module.linear.weight.data.zero_() + backup = {"vp_stages.0.linear.weight": torch.ones_like(stage.module.module.linear.weight)} + + peft.restore_model_from_named_tensors([stage], backup) + + assert torch.equal(stage.module.module.linear.weight, backup["vp_stages.0.linear.weight"]) + + +def test_restore_model_from_named_tensors_rejects_missing_backup_tensor(): + stage = _FakeVpStage() + + with pytest.raises(KeyError, match="missing model tensors"): + peft.restore_model_from_named_tensors([stage], {})