Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions docs/en/advanced/lora-grpo.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Megatron-Bridge LoRA for GRPO

slime supports a first Megatron-Bridge LoRA path for dense GRPO actor training. This path keeps training in Megatron, merges LoRA adapters only during SGLang weight export, and restores the unmerged actor weights immediately after export.

## Example

Start from a dense Megatron GRPO script such as `scripts/run-qwen3-4B.sh`, then add the LoRA and bridge flags:

```bash
--enable-lora \
--megatron-to-hf-mode bridge \
--colocate \
--lora-rank 16 \
--lora-alpha 32 \
--lora-dropout 0.0 \
--lora-target-modules linear_qkv linear_proj linear_fc1 linear_fc2
```

`--lora-target-modules` is optional. If it is omitted, slime uses the Megatron-Bridge LoRA defaults.

## Supported Scope

The initial LoRA path intentionally supports a narrow, validated configuration:

- Megatron training backend.
- Megatron-Bridge HF weight export mode.
- GRPO actor training.
- Colocated SGLang rollout and weight updates.
- Dense models.
- Default weight backuper enabled.

The following combinations are rejected at startup until they have dedicated parity coverage:

- MoE models.
- PPO or critic-based training.
- Decoupled rollout mode outside `--debug-train-only`.
- Custom model providers.
- `--only-train-params-name-list` or `--freeze-params-name-list`.
- On-policy distillation.
- Reference model update intervals.
- `--disable-weights-backuper`.

1 change: 1 addition & 0 deletions docs/en/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ slime is the RL-framework behind GLM-4.7, GLM-4.6 and GLM-4.5. Apart from models
:caption: Advanced Features

advanced/on-policy-distillation.md
advanced/lora-grpo.md
advanced/speculative-decoding.md
advanced/low-precision.md
advanced/reproducibility.md
Expand Down
3 changes: 3 additions & 0 deletions slime/backends/megatron_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .data import DataIterator, get_batch
from .loss import loss_function
from .model_provider import get_model_provider_func, wrap_model_provider_with_freeze
from .peft import log_lora_parameter_summary, lora_enabled

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -108,6 +109,8 @@ def setup_model_and_optimizer(
model = get_model(
wrap_model_provider_with_freeze(get_model_provider_func(args, role), args), ModelType.encoder_or_decoder
)
if role == "actor" and lora_enabled(args):
log_lora_parameter_summary(model)

# Optimizer
kwargs = {}
Expand Down
8 changes: 7 additions & 1 deletion slime/backends/megatron_utils/model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from slime.utils.misc import load_function

from .peft import maybe_apply_lora


# Adapt from https://github.com/volcengine/verl/blob/c3b20575d2bc815fcccd84bddb4c0401fc4b632b/verl/models/llama/megatron/layers/parallel_linear.py#L82
class LinearForLastLayer(torch.nn.Linear):
Expand Down Expand Up @@ -116,7 +118,11 @@ def _critic_provide(pre_process=True, post_process=True, vp_stage=None):

return _critic_provide

return provider.provide
def _actor_provide(pre_process=True, post_process=True, vp_stage=None):
model = provider.provide(pre_process=pre_process, post_process=post_process, vp_stage=vp_stage)
return maybe_apply_lora(model, args, role)

return _actor_provide

def model_provider(pre_process: bool = True, post_process: bool = True, vp_stage: int | None = None) -> GPTModel:
"""Builds the model.
Expand Down
225 changes: 225 additions & 0 deletions slime/backends/megatron_utils/peft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
from __future__ import annotations

import inspect
import logging
from argparse import Namespace
from collections.abc import Iterable

import torch

logger = logging.getLogger(__name__)


class LoRAConfigurationError(RuntimeError):
"""Raised when the runtime cannot support the requested LoRA configuration."""


def lora_enabled(args: Namespace) -> bool:
return bool(getattr(args, "enable_lora", False))


def validate_lora_args(args: Namespace) -> None:
"""Validate the first supported LoRA GRPO path.

The initial implementation intentionally supports only the narrow path that is
safe to reason about: Megatron GRPO actor LoRA through Megatron-Bridge with
colocated SGLang weight updates. Broader combinations should be enabled only
after they have parity tests.
"""

if not lora_enabled(args):
return

errors = []
if getattr(args, "train_backend", "megatron") != "megatron":
errors.append("--enable-lora requires --train-backend megatron")
if getattr(args, "megatron_to_hf_mode", None) != "bridge":
errors.append("--enable-lora requires --megatron-to-hf-mode bridge")
if getattr(args, "advantage_estimator", None) != "grpo":
errors.append("--enable-lora currently supports only --advantage-estimator grpo")
if not getattr(args, "colocate", False) and not getattr(args, "debug_train_only", False):
errors.append("--enable-lora currently requires --colocate outside debug-train-only runs")
if getattr(args, "custom_model_provider_path", None) is not None:
errors.append("--enable-lora does not yet support --custom-model-provider-path")
if getattr(args, "only_train_params_name_list", None):
errors.append("--enable-lora cannot be combined with --only-train-params-name-list")
if getattr(args, "freeze_params_name_list", None):
errors.append("--enable-lora cannot be combined with --freeze-params-name-list")
if not getattr(args, "enable_weights_backuper", True):
errors.append("--enable-lora cannot be combined with --disable-weights-backuper")
if getattr(args, "use_opd", False):
errors.append("--enable-lora does not yet support on-policy distillation")
if getattr(args, "num_experts", None):
errors.append("--enable-lora does not yet support MoE models")
if getattr(args, "ref_update_interval", None) is not None:
errors.append("--enable-lora does not yet support --ref-update-interval")
lora_rank = getattr(args, "lora_rank", None)
if lora_rank is None or lora_rank <= 0:
errors.append("--lora-rank must be positive")
lora_alpha = getattr(args, "lora_alpha", None)
if lora_alpha is None or lora_alpha <= 0:
errors.append("--lora-alpha must be positive")
lora_dropout = getattr(args, "lora_dropout", None)
if lora_dropout is None or not 0.0 <= lora_dropout < 1.0:
errors.append("--lora-dropout must be in [0.0, 1.0)")

if errors:
raise ValueError("; ".join(errors))

ensure_lora_runtime_available()


def ensure_lora_runtime_available() -> None:
"""Fail early if the installed Megatron-Bridge does not expose LoRA PEFT."""

_get_lora_cls()


def build_lora_config(args: Namespace):
LoRA = _get_lora_cls()
signature = inspect.signature(LoRA)
parameters = signature.parameters

for required in ("dim", "alpha", "dropout"):
if required not in parameters:
raise LoRAConfigurationError(
f"Installed megatron.bridge.peft.lora.LoRA does not accept the required '{required}' argument."
)

kwargs = {
"dim": getattr(args, "lora_rank"),
"alpha": getattr(args, "lora_alpha"),
"dropout": getattr(args, "lora_dropout"),
}
target_modules = getattr(args, "lora_target_modules", None)
if target_modules:
if "target_modules" not in parameters:
raise LoRAConfigurationError(
"Installed megatron.bridge.peft.lora.LoRA does not accept target_modules, "
"but --lora-target-modules was set."
)
kwargs["target_modules"] = list(target_modules)

return LoRA(**kwargs)


def maybe_apply_lora(model, args: Namespace, role: str):
if not lora_enabled(args) or role != "actor":
return model

lora_config = build_lora_config(args)
model = lora_config(model, training=True)
setattr(model, "_slime_lora_config", lora_config)
setattr(model, "_slime_lora_enabled", True)
return model


def count_parameters(model) -> tuple[int, int]:
total = 0
trainable = 0
for param in _iter_parameters(model):
numel = param.numel()
total += numel
if getattr(param, "requires_grad", False):
trainable += numel
return total, trainable


def log_lora_parameter_summary(model) -> None:
total, trainable = count_parameters(model)
pct = 0.0 if total == 0 else trainable / total * 100
logger.info("LoRA local parameter summary: trainable=%s total=%s trainable_pct=%.6f", trainable, total, pct)


def merge_lora_weights_for_export(model) -> None:
"""Merge LoRA adapter weights into base weights for a temporary export.

Megatron-Bridge's public PEFT entrypoint freezes models when called directly,
which is not appropriate during weight sync. Calling LoRAMerge.transform on
each module performs only the merge operation and leaves the training wrapper
structure in place; callers must restore the unmerged weights afterwards.
"""

LoRAMerge = _get_lora_merge_cls()
merger = LoRAMerge()
for module in _iter_modules(model):
merger.transform(module)


def _get_lora_cls():
try:
from megatron.bridge.peft.lora import LoRA
except Exception as exc:
raise LoRAConfigurationError(
"--enable-lora requires Megatron-Bridge PEFT LoRA support "
"(megatron.bridge.peft.lora.LoRA). The installed Megatron-Bridge runtime does not expose it."
) from exc
return LoRA


def _get_lora_merge_cls():
try:
from megatron.bridge.peft.lora import LoRAMerge
except Exception as exc:
raise LoRAConfigurationError(
"LoRA weight sync requires megatron.bridge.peft.lora.LoRAMerge, "
"but the installed Megatron-Bridge runtime does not expose it."
) from exc
return LoRAMerge


def _iter_parameters(model) -> Iterable:
if isinstance(model, (list, tuple)):
for model_chunk in model:
yield from model_chunk.parameters()
else:
yield from model.parameters()


def _iter_modules(model) -> Iterable:
if isinstance(model, (list, tuple)):
for model_chunk in model:
yield from model_chunk.modules()
else:
yield from model.modules()


@torch.no_grad()
def restore_model_from_named_tensors(model, named_tensors):
"""Restore TensorBackuper-style tensors before or after temporary LoRA merge export."""

missing_names = []
for name, tensor in _iter_model_named_tensors(model):
if name not in named_tensors:
missing_names.append(name)
continue
tensor.copy_(named_tensors[name].to(device=tensor.device, non_blocking=True), non_blocking=True)
if missing_names:
preview = ", ".join(missing_names[:5])
suffix = "" if len(missing_names) <= 5 else f", ... ({len(missing_names)} total)"
raise KeyError(f"LoRA weight export backup is missing model tensors: {preview}{suffix}")
if torch.cuda.is_available():
torch.cuda.synchronize()


def _iter_model_named_tensors(model):
model_chunks = model if isinstance(model, (list, tuple)) else [model]
for vp_stage, model_module in enumerate(model_chunks):

def _compute_fqn(name, vp_stage=vp_stage):
return f"vp_stages.{vp_stage}.{_strip_param_name_prefix(name)}"

for name, param in model_module.named_parameters():
yield _compute_fqn(name), param

for name, buffer in model_module.named_buffers():
if "expert_bias" not in name:
continue
yield _compute_fqn(name), buffer


def _strip_param_name_prefix(name: str):
prefix = "module."
while name.startswith(prefix):
name = name[len(prefix) :]
return name
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import dataclasses


from slime.backends.megatron_utils.peft import (
lora_enabled,
merge_lora_weights_for_export,
restore_model_from_named_tensors,
)
from slime.utils import megatron_bridge_utils
from slime.utils.misc import chunk_named_params_by_size

Expand Down Expand Up @@ -48,6 +52,10 @@ def __init__(self, *args, **kwargs):
_patch_bridge_expert_cache_to_cpu()

def get_hf_weight_chunks(self, megatron_local_weights):
if lora_enabled(self.args):
yield from self._get_lora_hf_weight_chunks(megatron_local_weights)
return

# TODO support quantization (e.g. modify megatron-bridge to provide megatron param name)
renamed_megatron_local_weights = {strip_param_name_prefix(k): v for k, v in megatron_local_weights.items()}
with megatron_bridge_utils.patch_megatron_model(self.model):
Expand Down Expand Up @@ -77,6 +85,43 @@ def _streaming_quantized():
_streaming_quantized(), chunk_size=self.args.update_weight_buffer_size
)

def _get_lora_hf_weight_chunks(self, megatron_local_weights):
# The normal bridge path substitutes per-task weights from TensorBackuper.
# For LoRA we need the effective actor (base + adapter deltas), so restore
# the actor backup into the live model, temporarily merge adapters, export
# through Megatron-Bridge, and always restore the unmerged actor backup.
restore_model_from_named_tensors(self.model, megatron_local_weights)
try:
merge_lora_weights_for_export(self.model)
with megatron_bridge_utils.patch_megatron_model(self.model):
conversion_tasks = self._bridge.get_conversion_tasks(self.model)
named_weights = self._bridge.export_hf_weights(
self.model, cpu=False, conversion_tasks=conversion_tasks
)

def _streaming_quantized():
for hf_param_name, weight, megatron_param_name in named_weights:
processed_weight = postprocess_hf_param(
args=self.args,
megatron_param_name=megatron_param_name,
hf_param_name=hf_param_name,
param=weight,
)
converted_named_params = [(hf_param_name, processed_weight)]
quantized_batch = quantize_params(
args=self.args,
megatron_name=megatron_param_name,
converted_named_params=converted_named_params,
quantization_config=self.quantization_config,
)
yield from quantized_batch

yield from chunk_named_params_by_size(
_streaming_quantized(), chunk_size=self.args.update_weight_buffer_size
)
finally:
restore_model_from_named_tensors(self.model, megatron_local_weights)


def _process_conversion_tasks(vanilla_conversion_tasks, new_weight_dict):
def _handle_one(task):
Expand Down
Loading