diff --git a/docs/en/get_started/usage.md b/docs/en/get_started/usage.md index 5a59812e8f..2ef62ddac7 100644 --- a/docs/en/get_started/usage.md +++ b/docs/en/get_started/usage.md @@ -188,6 +188,7 @@ Additionally, we provide a `metadata_key`, which defaults to `"metadata"`. When Note: On-policy distillation (OPD) is now orthogonal to the advantage estimator. Use `--use-opd` and `--opd-kl-coef` to enable OPD on top of any estimator. - `--calculate-per-token-loss`: By default, slime calculates loss on a per-sample basis, i.e., `mean(sum(sample_i) / len(sample_i))`. Enable this flag to calculate loss on a per-token basis, i.e., `sum(sum(sample_i)) / sum(len(sample_i))`. +- `--policy-loss-type`: Selects the policy-gradient surrogate objective used by `--loss-type policy_loss`. The default `clip` keeps the existing PPO-style hard clipping objective. Set `sapo` to use Soft Adaptive Policy Optimization with a smooth temperature-controlled surrogate. - `--use-tis`: Enable this setting to use TIS (Truncated Importance Sampling) (https://fengyao.notion.site/off-policy-rl). #### GRPO Algorithm @@ -212,6 +213,27 @@ Related parameters: - `--normalize-advantages`: Whether to normalize advantages. - `--eps-clip`: PPO-style clip range. +#### SAPO Policy Objective + +SAPO (Soft Adaptive Policy Optimization) replaces the hard clipping in the policy objective with a smooth temperature-controlled surrogate. It keeps the same advantage estimator and rollout flow as GRPO/GSPO, so it is enabled through `--policy-loss-type` rather than `--advantage-estimator`. + +To use SAPO with GRPO, set: + +```bash +--advantage-estimator grpo \ +--policy-loss-type sapo \ +--sapo-tau-pos 1.0 \ +--sapo-tau-neg 1.05 +``` + +Related parameters: + +- `--sapo-tau-pos`: Temperature for tokens with positive advantages. Default is `1.0`. +- `--sapo-tau-neg`: Temperature for tokens with zero or negative advantages. Default is `1.05`. +- `--eps-clip` and `--eps-clip-high`: Still used to report the legacy `pg_clipfrac` diagnostic when SAPO is enabled. + +With `--advantage-estimator gspo`, SAPO softens slime's existing sequence-level GSPO ratio rather than using token-level ratios. + #### PPO Algorithm PPO (Proximal Policy Optimization) is a classic RL algorithm that uses a critic model to estimate the value function for computing advantages. diff --git a/docs/zh/get_started/usage.md b/docs/zh/get_started/usage.md index 01746514bc..de8a8555f7 100644 --- a/docs/zh/get_started/usage.md +++ b/docs/zh/get_started/usage.md @@ -192,6 +192,7 @@ sglang 的加载非常简单,只需要: 注意:在策略蒸馏 (OPD) 现在与 advantage estimator 正交,使用 `--use-opd` 和 `--opd-kl-coef` 可以在任意 estimator 之上启用 OPD。 - `--calculate-per-token-loss`:slime 中默认的方案是 per sample loss,即 `mean(sum(sample_i) / len(sample_i))`,如果需要计算 per token loss,即 `sum(sum(sample_i)) / sum(len(sample_i))`,可以开启 `--calculate-per-token-loss`; +- `--policy-loss-type`:选择 `--loss-type policy_loss` 下使用的 policy-gradient surrogate objective。默认值 `clip` 保持现有 PPO 风格的 hard clipping;设置为 `sapo` 时使用 Soft Adaptive Policy Optimization 的平滑温度控制目标; - `--use-tis`:如果需要开启 tis(https://fengyao.notion.site/off-policy-rl),可以开启这一设置; #### GRPO 算法 @@ -216,6 +217,27 @@ GRPO 的主要特点: - `--normalize-advantages`:是否对 advantage 进行归一化; - `--eps-clip`:PPO 风格的 clip 范围。 +#### SAPO Policy Objective + +SAPO(Soft Adaptive Policy Optimization)使用平滑的温度控制 surrogate objective 替代 policy objective 中的 hard clipping。它复用 GRPO/GSPO 的 advantage estimator 和 rollout 流程,因此通过 `--policy-loss-type` 启用,而不是作为新的 `--advantage-estimator`。 + +如果想在 GRPO 中使用 SAPO,可以设置: + +```bash +--advantage-estimator grpo \ +--policy-loss-type sapo \ +--sapo-tau-pos 1.0 \ +--sapo-tau-neg 1.05 +``` + +相关参数: + +- `--sapo-tau-pos`:正 advantage token 使用的温度参数,默认值为 `1.0`; +- `--sapo-tau-neg`:零或负 advantage token 使用的温度参数,默认值为 `1.05`; +- `--eps-clip` 和 `--eps-clip-high`:启用 SAPO 时仍用于报告旧 hard-clipping 语义下的 `pg_clipfrac` 诊断指标。 + +与 `--advantage-estimator gspo` 一起使用时,SAPO 会平滑 slime 当前的 sequence-level GSPO ratio,而不是使用 token-level ratio。 + #### PPO 算法 PPO(Proximal Policy Optimization)是经典的 RL 算法,使用 critic 模型来估计 value function,从而计算 advantage。 diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index c7ad36839a..f5789b6c71 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -884,7 +884,15 @@ def policy_loss_function( log_probs = torch.cat(log_probs, dim=0) ppo_kl = old_log_probs - log_probs - pg_loss, pg_clipfrac = compute_policy_loss(ppo_kl, advantages, args.eps_clip, args.eps_clip_high) + pg_loss, pg_clipfrac, policy_loss_aux = compute_policy_loss( + ppo_kl=ppo_kl, + advantages=advantages, + eps_clip=args.eps_clip, + eps_clip_high=args.eps_clip_high, + policy_loss_type=args.policy_loss_type, + sapo_tau_pos=args.sapo_tau_pos, + sapo_tau_neg=args.sapo_tau_neg, + ) if args.use_opsm: pg_loss = pg_loss * opsm_mask @@ -944,6 +952,7 @@ def policy_loss_function( pg_loss = pg_loss_reducer(pg_loss) pg_clipfrac = sum_of_sample_mean(pg_clipfrac) + policy_loss_aux = {key: sum_of_sample_mean(value) for key, value in policy_loss_aux.items()} ppo_kl = sum_of_sample_mean(ppo_kl) # entropy loss @@ -989,6 +998,9 @@ def policy_loss_function( if train_rollout_logprob_abs_diff is not None: reported_loss["train_rollout_logprob_abs_diff"] = train_rollout_logprob_abs_diff.clone().detach() + for key, value in policy_loss_aux.items(): + reported_loss[key] = value.clone().detach() + if args.use_kl_loss: reported_loss["kl_loss"] = kl_loss.clone().detach() diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index c1c25eff89..4ad8f617d5 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -764,6 +764,28 @@ def add_algo_arguments(parser): parser.add_argument("--eps-clip", type=float, default=0.2, help="PPO clip range") parser.add_argument("--eps-clip-high", type=float, default=None, help="PPO clip upper range") + parser.add_argument( + "--policy-loss-type", + type=str, + choices=["clip", "sapo"], + default="clip", + help=( + "Policy-gradient surrogate objective. 'clip' uses PPO hard clipping; " + "'sapo' uses Soft Adaptive Policy Optimization." + ), + ) + parser.add_argument( + "--sapo-tau-pos", + type=float, + default=1.0, + help="SAPO temperature for positive advantages.", + ) + parser.add_argument( + "--sapo-tau-neg", + type=float, + default=1.05, + help="SAPO temperature for zero or negative advantages.", + ) parser.add_argument( "--eps-clip-c", type=float, @@ -1685,6 +1707,10 @@ def slime_validate_args(args): if args.eps_clip_high is None: args.eps_clip_high = args.eps_clip + assert args.sapo_tau_pos > 0, "sapo_tau_pos must be positive." + assert args.sapo_tau_neg > 0, "sapo_tau_neg must be positive." + if args.policy_loss_type == "sapo": + assert args.advantage_estimator in ["grpo", "gspo"], "SAPO policy loss is currently supported for GRPO/GSPO." if args.eval_reward_key is None: args.eval_reward_key = args.reward_key diff --git a/slime/utils/ppo_utils.py b/slime/utils/ppo_utils.py index 2a858e7a3f..25ade6e671 100644 --- a/slime/utils/ppo_utils.py +++ b/slime/utils/ppo_utils.py @@ -122,13 +122,13 @@ def compute_gspo_kl( @torch.compile(dynamic=True) -def compute_policy_loss( +def _compute_clipped_policy_loss( ppo_kl: torch.Tensor, advantages: torch.Tensor, eps_clip: float, eps_clip_high: float, eps_clip_c: float | None = None, -): +) -> tuple[torch.Tensor, torch.Tensor]: ratio = (-ppo_kl).exp() pg_losses1 = -ratio * advantages pg_losses2 = -ratio.clamp(1 - eps_clip, 1 + eps_clip_high) * advantages @@ -148,6 +148,64 @@ def compute_policy_loss( return pg_losses, clipfrac +@torch.compile(dynamic=True) +def _compute_sapo_policy_loss( + ppo_kl: torch.Tensor, + advantages: torch.Tensor, + eps_clip: float, + eps_clip_high: float, + sapo_tau_pos: float, + sapo_tau_neg: float, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # SAPO saturates for extreme ratios; clamp the log-ratio used by the soft gate. + ratio = torch.exp(torch.clamp(-ppo_kl, min=-60.0, max=60.0)) + tau_pos = torch.full_like(advantages, sapo_tau_pos) + tau_neg = torch.full_like(advantages, sapo_tau_neg) + tau = torch.where(advantages > 0, tau_pos, tau_neg) + soft_ratio = torch.sigmoid(torch.clamp(tau * (ratio - 1.0), min=-60.0, max=60.0)) * (4.0 / tau) + pg_losses = -soft_ratio * advantages + + would_clip_low = (ratio < 1 - eps_clip) & (advantages < 0) + would_clip_high = (ratio > 1 + eps_clip_high) & (advantages > 0) + clipfrac = (would_clip_low | would_clip_high).float() + return pg_losses, clipfrac, soft_ratio + + +def compute_policy_loss( + ppo_kl: torch.Tensor, + advantages: torch.Tensor, + eps_clip: float, + eps_clip_high: float, + eps_clip_c: float | None = None, + policy_loss_type: str = "clip", + sapo_tau_pos: float = 1.0, + sapo_tau_neg: float = 1.05, +) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor]]: + """Compute the token-level policy gradient surrogate and auxiliary metrics.""" + if policy_loss_type == "clip": + pg_losses, clipfrac = _compute_clipped_policy_loss( + ppo_kl, + advantages, + eps_clip, + eps_clip_high, + eps_clip_c, + ) + return pg_losses, clipfrac, {} + + if policy_loss_type == "sapo": + pg_losses, clipfrac, soft_ratio = _compute_sapo_policy_loss( + ppo_kl, + advantages, + eps_clip, + eps_clip_high, + sapo_tau_pos, + sapo_tau_neg, + ) + return pg_losses, clipfrac, {"sapo_soft_ratio": soft_ratio} + + raise ValueError(f"Unknown policy_loss_type: {policy_loss_type}") + + def compute_log_probs(logits: torch.Tensor, tokens: torch.Tensor, process_group: dist.ProcessGroup | None): # TODO: when megatron is not installed, fall back to naive implementation from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy diff --git a/tests/test_policy_loss.py b/tests/test_policy_loss.py new file mode 100644 index 0000000000..f0d1a3459f --- /dev/null +++ b/tests/test_policy_loss.py @@ -0,0 +1,123 @@ +import pytest +import torch + +from slime.utils.ppo_utils import compute_policy_loss + + +def _ppo_kl_from_ratio(ratio: torch.Tensor) -> torch.Tensor: + return -ratio.log() + + +@pytest.mark.unit +def test_clipped_policy_loss_matches_existing_formula(): + ratio = torch.tensor([1.5, 0.7, 1.1, 0.9], dtype=torch.float32) + advantages = torch.tensor([2.0, -3.0, -1.0, 0.5], dtype=torch.float32) + eps_clip = 0.2 + eps_clip_high = 0.3 + + pg_loss, clipfrac, aux = compute_policy_loss( + _ppo_kl_from_ratio(ratio), + advantages, + eps_clip, + eps_clip_high, + ) + + expected_loss = torch.maximum( + -ratio * advantages, + -ratio.clamp(1 - eps_clip, 1 + eps_clip_high) * advantages, + ) + expected_clipfrac = torch.gt( + -ratio.clamp(1 - eps_clip, 1 + eps_clip_high) * advantages, + -ratio * advantages, + ).float() + + torch.testing.assert_close(pg_loss, expected_loss) + torch.testing.assert_close(clipfrac, expected_clipfrac) + assert aux == {} + + +@pytest.mark.unit +def test_sapo_policy_loss_matches_soft_gate_formula(): + ratio = torch.tensor([1.5, 0.7, 1.0, 1.4], dtype=torch.float32) + advantages = torch.tensor([2.0, -3.0, 0.0, -1.0], dtype=torch.float32) + eps_clip = 0.2 + eps_clip_high = 0.3 + tau_pos = 1.0 + tau_neg = 1.05 + + pg_loss, clipfrac, aux = compute_policy_loss( + _ppo_kl_from_ratio(ratio), + advantages, + eps_clip, + eps_clip_high, + policy_loss_type="sapo", + sapo_tau_pos=tau_pos, + sapo_tau_neg=tau_neg, + ) + + tau = torch.where( + advantages > 0, + torch.full_like(advantages, tau_pos), + torch.full_like(advantages, tau_neg), + ) + expected_soft_ratio = torch.sigmoid(tau * (ratio - 1.0)) * (4.0 / tau) + expected_loss = -expected_soft_ratio * advantages + expected_clipfrac = torch.tensor([1.0, 1.0, 0.0, 0.0], dtype=torch.float32) + + torch.testing.assert_close(pg_loss, expected_loss) + torch.testing.assert_close(clipfrac, expected_clipfrac) + torch.testing.assert_close(aux["sapo_soft_ratio"], expected_soft_ratio) + assert set(aux) == {"sapo_soft_ratio"} + + +@pytest.mark.unit +def test_sapo_gradient_matches_unclipped_policy_gradient_at_unit_ratio(): + ppo_kl = torch.zeros(4, dtype=torch.float32, requires_grad=True) + advantages = torch.tensor([2.0, -3.0, 0.0, 0.5], dtype=torch.float32) + + pg_loss, _, _ = compute_policy_loss( + ppo_kl, + advantages, + eps_clip=0.2, + eps_clip_high=0.2, + policy_loss_type="sapo", + sapo_tau_pos=1.0, + sapo_tau_neg=1.05, + ) + pg_loss.sum().backward() + + torch.testing.assert_close(ppo_kl.grad, advantages) + + +@pytest.mark.unit +def test_sapo_policy_loss_is_finite_for_large_log_ratios(): + ppo_kl = torch.tensor([-100.0, 100.0], dtype=torch.float32, requires_grad=True) + advantages = torch.tensor([1.0, -1.0], dtype=torch.float32) + + pg_loss, clipfrac, aux = compute_policy_loss( + ppo_kl, + advantages, + eps_clip=0.2, + eps_clip_high=0.2, + policy_loss_type="sapo", + sapo_tau_pos=1.0, + sapo_tau_neg=1.05, + ) + pg_loss.sum().backward() + + assert torch.isfinite(pg_loss).all() + assert torch.isfinite(clipfrac).all() + assert torch.isfinite(aux["sapo_soft_ratio"]).all() + assert torch.isfinite(ppo_kl.grad).all() + + +@pytest.mark.unit +def test_unknown_policy_loss_type_raises(): + with pytest.raises(ValueError, match="Unknown policy_loss_type"): + compute_policy_loss( + torch.zeros(1), + torch.ones(1), + eps_clip=0.2, + eps_clip_high=0.2, + policy_loss_type="missing", + )