Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions docs/en/get_started/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ Additionally, we provide a `metadata_key`, which defaults to `"metadata"`. When

Note: On-policy distillation (OPD) is now orthogonal to the advantage estimator. Use `--use-opd` and `--opd-kl-coef` to enable OPD on top of any estimator.
- `--calculate-per-token-loss`: By default, slime calculates loss on a per-sample basis, i.e., `mean(sum(sample_i) / len(sample_i))`. Enable this flag to calculate loss on a per-token basis, i.e., `sum(sum(sample_i)) / sum(len(sample_i))`.
- `--policy-loss-type`: Selects the policy-gradient surrogate objective used by `--loss-type policy_loss`. The default `clip` keeps the existing PPO-style hard clipping objective. Set `sapo` to use Soft Adaptive Policy Optimization with a smooth temperature-controlled surrogate.
- `--use-tis`: Enable this setting to use TIS (Truncated Importance Sampling) (https://fengyao.notion.site/off-policy-rl).

#### GRPO Algorithm
Expand All @@ -212,6 +213,27 @@ Related parameters:
- `--normalize-advantages`: Whether to normalize advantages.
- `--eps-clip`: PPO-style clip range.

#### SAPO Policy Objective

SAPO (Soft Adaptive Policy Optimization) replaces the hard clipping in the policy objective with a smooth temperature-controlled surrogate. It keeps the same advantage estimator and rollout flow as GRPO/GSPO, so it is enabled through `--policy-loss-type` rather than `--advantage-estimator`.

To use SAPO with GRPO, set:

```bash
--advantage-estimator grpo \
--policy-loss-type sapo \
--sapo-tau-pos 1.0 \
--sapo-tau-neg 1.05
```

Related parameters:

- `--sapo-tau-pos`: Temperature for tokens with positive advantages. Default is `1.0`.
- `--sapo-tau-neg`: Temperature for tokens with zero or negative advantages. Default is `1.05`.
- `--eps-clip` and `--eps-clip-high`: Still used to report the legacy `pg_clipfrac` diagnostic when SAPO is enabled.

With `--advantage-estimator gspo`, SAPO softens slime's existing sequence-level GSPO ratio rather than using token-level ratios.

#### PPO Algorithm

PPO (Proximal Policy Optimization) is a classic RL algorithm that uses a critic model to estimate the value function for computing advantages.
Expand Down
22 changes: 22 additions & 0 deletions docs/zh/get_started/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ sglang 的加载非常简单,只需要:

注意:在策略蒸馏 (OPD) 现在与 advantage estimator 正交,使用 `--use-opd` 和 `--opd-kl-coef` 可以在任意 estimator 之上启用 OPD。
- `--calculate-per-token-loss`:slime 中默认的方案是 per sample loss,即 `mean(sum(sample_i) / len(sample_i))`,如果需要计算 per token loss,即 `sum(sum(sample_i)) / sum(len(sample_i))`,可以开启 `--calculate-per-token-loss`;
- `--policy-loss-type`:选择 `--loss-type policy_loss` 下使用的 policy-gradient surrogate objective。默认值 `clip` 保持现有 PPO 风格的 hard clipping;设置为 `sapo` 时使用 Soft Adaptive Policy Optimization 的平滑温度控制目标;
- `--use-tis`:如果需要开启 tis(https://fengyao.notion.site/off-policy-rl),可以开启这一设置;

#### GRPO 算法
Expand All @@ -216,6 +217,27 @@ GRPO 的主要特点:
- `--normalize-advantages`:是否对 advantage 进行归一化;
- `--eps-clip`:PPO 风格的 clip 范围。

#### SAPO Policy Objective

SAPO(Soft Adaptive Policy Optimization)使用平滑的温度控制 surrogate objective 替代 policy objective 中的 hard clipping。它复用 GRPO/GSPO 的 advantage estimator 和 rollout 流程,因此通过 `--policy-loss-type` 启用,而不是作为新的 `--advantage-estimator`。

如果想在 GRPO 中使用 SAPO,可以设置:

```bash
--advantage-estimator grpo \
--policy-loss-type sapo \
--sapo-tau-pos 1.0 \
--sapo-tau-neg 1.05
```

相关参数:

- `--sapo-tau-pos`:正 advantage token 使用的温度参数,默认值为 `1.0`;
- `--sapo-tau-neg`:零或负 advantage token 使用的温度参数,默认值为 `1.05`;
- `--eps-clip` 和 `--eps-clip-high`:启用 SAPO 时仍用于报告旧 hard-clipping 语义下的 `pg_clipfrac` 诊断指标。

与 `--advantage-estimator gspo` 一起使用时,SAPO 会平滑 slime 当前的 sequence-level GSPO ratio,而不是使用 token-level ratio。

#### PPO 算法

PPO(Proximal Policy Optimization)是经典的 RL 算法,使用 critic 模型来估计 value function,从而计算 advantage。
Expand Down
14 changes: 13 additions & 1 deletion slime/backends/megatron_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,15 @@ def policy_loss_function(
log_probs = torch.cat(log_probs, dim=0)
ppo_kl = old_log_probs - log_probs

pg_loss, pg_clipfrac = compute_policy_loss(ppo_kl, advantages, args.eps_clip, args.eps_clip_high)
pg_loss, pg_clipfrac, policy_loss_aux = compute_policy_loss(
ppo_kl=ppo_kl,
advantages=advantages,
eps_clip=args.eps_clip,
eps_clip_high=args.eps_clip_high,
policy_loss_type=args.policy_loss_type,
sapo_tau_pos=args.sapo_tau_pos,
sapo_tau_neg=args.sapo_tau_neg,
)

if args.use_opsm:
pg_loss = pg_loss * opsm_mask
Expand Down Expand Up @@ -944,6 +952,7 @@ def policy_loss_function(

pg_loss = pg_loss_reducer(pg_loss)
pg_clipfrac = sum_of_sample_mean(pg_clipfrac)
policy_loss_aux = {key: sum_of_sample_mean(value) for key, value in policy_loss_aux.items()}
ppo_kl = sum_of_sample_mean(ppo_kl)

# entropy loss
Expand Down Expand Up @@ -989,6 +998,9 @@ def policy_loss_function(
if train_rollout_logprob_abs_diff is not None:
reported_loss["train_rollout_logprob_abs_diff"] = train_rollout_logprob_abs_diff.clone().detach()

for key, value in policy_loss_aux.items():
reported_loss[key] = value.clone().detach()

if args.use_kl_loss:
reported_loss["kl_loss"] = kl_loss.clone().detach()

Expand Down
26 changes: 26 additions & 0 deletions slime/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,28 @@ def add_algo_arguments(parser):

parser.add_argument("--eps-clip", type=float, default=0.2, help="PPO clip range")
parser.add_argument("--eps-clip-high", type=float, default=None, help="PPO clip upper range")
parser.add_argument(
"--policy-loss-type",
type=str,
choices=["clip", "sapo"],
default="clip",
help=(
"Policy-gradient surrogate objective. 'clip' uses PPO hard clipping; "
"'sapo' uses Soft Adaptive Policy Optimization."
),
)
parser.add_argument(
"--sapo-tau-pos",
type=float,
default=1.0,
help="SAPO temperature for positive advantages.",
)
parser.add_argument(
"--sapo-tau-neg",
type=float,
default=1.05,
help="SAPO temperature for zero or negative advantages.",
)
parser.add_argument(
"--eps-clip-c",
type=float,
Expand Down Expand Up @@ -1685,6 +1707,10 @@ def slime_validate_args(args):

if args.eps_clip_high is None:
args.eps_clip_high = args.eps_clip
assert args.sapo_tau_pos > 0, "sapo_tau_pos must be positive."
assert args.sapo_tau_neg > 0, "sapo_tau_neg must be positive."
if args.policy_loss_type == "sapo":
assert args.advantage_estimator in ["grpo", "gspo"], "SAPO policy loss is currently supported for GRPO/GSPO."

if args.eval_reward_key is None:
args.eval_reward_key = args.reward_key
Expand Down
62 changes: 60 additions & 2 deletions slime/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,13 @@ def compute_gspo_kl(


@torch.compile(dynamic=True)
def compute_policy_loss(
def _compute_clipped_policy_loss(
ppo_kl: torch.Tensor,
advantages: torch.Tensor,
eps_clip: float,
eps_clip_high: float,
eps_clip_c: float | None = None,
):
) -> tuple[torch.Tensor, torch.Tensor]:
ratio = (-ppo_kl).exp()
pg_losses1 = -ratio * advantages
pg_losses2 = -ratio.clamp(1 - eps_clip, 1 + eps_clip_high) * advantages
Expand All @@ -148,6 +148,64 @@ def compute_policy_loss(
return pg_losses, clipfrac


@torch.compile(dynamic=True)
def _compute_sapo_policy_loss(
ppo_kl: torch.Tensor,
advantages: torch.Tensor,
eps_clip: float,
eps_clip_high: float,
sapo_tau_pos: float,
sapo_tau_neg: float,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# SAPO saturates for extreme ratios; clamp the log-ratio used by the soft gate.
ratio = torch.exp(torch.clamp(-ppo_kl, min=-60.0, max=60.0))
tau_pos = torch.full_like(advantages, sapo_tau_pos)
tau_neg = torch.full_like(advantages, sapo_tau_neg)
tau = torch.where(advantages > 0, tau_pos, tau_neg)
soft_ratio = torch.sigmoid(torch.clamp(tau * (ratio - 1.0), min=-60.0, max=60.0)) * (4.0 / tau)
pg_losses = -soft_ratio * advantages

would_clip_low = (ratio < 1 - eps_clip) & (advantages < 0)
would_clip_high = (ratio > 1 + eps_clip_high) & (advantages > 0)
clipfrac = (would_clip_low | would_clip_high).float()
return pg_losses, clipfrac, soft_ratio


def compute_policy_loss(
ppo_kl: torch.Tensor,
advantages: torch.Tensor,
eps_clip: float,
eps_clip_high: float,
eps_clip_c: float | None = None,
policy_loss_type: str = "clip",
sapo_tau_pos: float = 1.0,
sapo_tau_neg: float = 1.05,
) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor]]:
"""Compute the token-level policy gradient surrogate and auxiliary metrics."""
if policy_loss_type == "clip":
pg_losses, clipfrac = _compute_clipped_policy_loss(
ppo_kl,
advantages,
eps_clip,
eps_clip_high,
eps_clip_c,
)
return pg_losses, clipfrac, {}

if policy_loss_type == "sapo":
pg_losses, clipfrac, soft_ratio = _compute_sapo_policy_loss(
ppo_kl,
advantages,
eps_clip,
eps_clip_high,
sapo_tau_pos,
sapo_tau_neg,
)
return pg_losses, clipfrac, {"sapo_soft_ratio": soft_ratio}

raise ValueError(f"Unknown policy_loss_type: {policy_loss_type}")


def compute_log_probs(logits: torch.Tensor, tokens: torch.Tensor, process_group: dist.ProcessGroup | None):
# TODO: when megatron is not installed, fall back to naive implementation
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
Expand Down
123 changes: 123 additions & 0 deletions tests/test_policy_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import pytest
import torch

from slime.utils.ppo_utils import compute_policy_loss


def _ppo_kl_from_ratio(ratio: torch.Tensor) -> torch.Tensor:
return -ratio.log()


@pytest.mark.unit
def test_clipped_policy_loss_matches_existing_formula():
ratio = torch.tensor([1.5, 0.7, 1.1, 0.9], dtype=torch.float32)
advantages = torch.tensor([2.0, -3.0, -1.0, 0.5], dtype=torch.float32)
eps_clip = 0.2
eps_clip_high = 0.3

pg_loss, clipfrac, aux = compute_policy_loss(
_ppo_kl_from_ratio(ratio),
advantages,
eps_clip,
eps_clip_high,
)

expected_loss = torch.maximum(
-ratio * advantages,
-ratio.clamp(1 - eps_clip, 1 + eps_clip_high) * advantages,
)
expected_clipfrac = torch.gt(
-ratio.clamp(1 - eps_clip, 1 + eps_clip_high) * advantages,
-ratio * advantages,
).float()

torch.testing.assert_close(pg_loss, expected_loss)
torch.testing.assert_close(clipfrac, expected_clipfrac)
assert aux == {}


@pytest.mark.unit
def test_sapo_policy_loss_matches_soft_gate_formula():
ratio = torch.tensor([1.5, 0.7, 1.0, 1.4], dtype=torch.float32)
advantages = torch.tensor([2.0, -3.0, 0.0, -1.0], dtype=torch.float32)
eps_clip = 0.2
eps_clip_high = 0.3
tau_pos = 1.0
tau_neg = 1.05

pg_loss, clipfrac, aux = compute_policy_loss(
_ppo_kl_from_ratio(ratio),
advantages,
eps_clip,
eps_clip_high,
policy_loss_type="sapo",
sapo_tau_pos=tau_pos,
sapo_tau_neg=tau_neg,
)

tau = torch.where(
advantages > 0,
torch.full_like(advantages, tau_pos),
torch.full_like(advantages, tau_neg),
)
expected_soft_ratio = torch.sigmoid(tau * (ratio - 1.0)) * (4.0 / tau)
expected_loss = -expected_soft_ratio * advantages
expected_clipfrac = torch.tensor([1.0, 1.0, 0.0, 0.0], dtype=torch.float32)

torch.testing.assert_close(pg_loss, expected_loss)
torch.testing.assert_close(clipfrac, expected_clipfrac)
torch.testing.assert_close(aux["sapo_soft_ratio"], expected_soft_ratio)
assert set(aux) == {"sapo_soft_ratio"}


@pytest.mark.unit
def test_sapo_gradient_matches_unclipped_policy_gradient_at_unit_ratio():
ppo_kl = torch.zeros(4, dtype=torch.float32, requires_grad=True)
advantages = torch.tensor([2.0, -3.0, 0.0, 0.5], dtype=torch.float32)

pg_loss, _, _ = compute_policy_loss(
ppo_kl,
advantages,
eps_clip=0.2,
eps_clip_high=0.2,
policy_loss_type="sapo",
sapo_tau_pos=1.0,
sapo_tau_neg=1.05,
)
pg_loss.sum().backward()

torch.testing.assert_close(ppo_kl.grad, advantages)


@pytest.mark.unit
def test_sapo_policy_loss_is_finite_for_large_log_ratios():
ppo_kl = torch.tensor([-100.0, 100.0], dtype=torch.float32, requires_grad=True)
advantages = torch.tensor([1.0, -1.0], dtype=torch.float32)

pg_loss, clipfrac, aux = compute_policy_loss(
ppo_kl,
advantages,
eps_clip=0.2,
eps_clip_high=0.2,
policy_loss_type="sapo",
sapo_tau_pos=1.0,
sapo_tau_neg=1.05,
)
pg_loss.sum().backward()

assert torch.isfinite(pg_loss).all()
assert torch.isfinite(clipfrac).all()
assert torch.isfinite(aux["sapo_soft_ratio"]).all()
assert torch.isfinite(ppo_kl.grad).all()


@pytest.mark.unit
def test_unknown_policy_loss_type_raises():
with pytest.raises(ValueError, match="Unknown policy_loss_type"):
compute_policy_loss(
torch.zeros(1),
torch.ones(1),
eps_clip=0.2,
eps_clip_high=0.2,
policy_loss_type="missing",
)