Skip to content
Open
540 changes: 492 additions & 48 deletions docker/patch/latest/sglang.patch

Large diffs are not rendered by default.

114 changes: 114 additions & 0 deletions docs/en/advanced/partial-weight-sync.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Partial Weight Sync (Selective / Delta)

- [Overview](#overview)
- [Quick Start](#quick-start)
- [Modes: selective vs delta](#modes-selective-vs-delta)
- [How it Works](#how-it-works)
- [Choosing the Wire Encoding](#choosing-the-wire-encoding)
- [Precision Behaviour](#precision-behaviour)
- [Periodic Base Sync](#periodic-base-sync)
- [Why Not Colocated](#why-not-colocated)

## Overview

For **non-colocated** runs, slime's default weight sync broadcasts every parameter on every training step. The full broadcast scales linearly with model size and dominates the sync phase even when only a small fraction of weights actually change between steps. Partial-update modes keep a pinned-CPU snapshot of the last sync's weights and broadcast only the changed-position payload; the SGLang receiver applies it without re-touching unchanged params. During typical RL fine-tuning at conservative learning rates the per-step diff is sparse — a few percent of weights — so the wire shrinks proportionally.

**Inspiration / prior art.** `selective` is inspired by [arXiv:2509.19128](https://arxiv.org/abs/2509.19128). `delta` is informed by the additive-update approach in [Cursor Composer 2](https://cursor.com/resources/Composer2.pdf) and [Fireworks AI — Frontier RL Is Cheaper Than You Think](https://fireworks.ai/blog/frontier-rl-is-cheaper-than-you-think).

## Quick Start

Enable a partial mode on the trainer side:

```bash
--update-weight-mode selective # 'selective' / 'delta' / 'full' (default)
--update-weight-partial-encoding sparse_indices
--update-weight-delta-dtype fp32 # delta mode only
--update-weight-base-sync-interval 9999 # default. Both partial modes are lossless under
# their defaults (selective by construction, delta
# with fp32 math), so 9999 effectively disables
# periodic base syncs. Set lower (e.g. 30) to
# verify against periodic full broadcasts, or
# if your workload has a custom base-sync need.
```

And one knob on the SGLang side (auto-mirrored by slime as `--sglang-update-weight-partial-chunk-bytes`):

```bash
--sglang-update-weight-partial-chunk-bytes $((2 * 1024 * 1024 * 1024))
```

See [examples/partial_weight_sync/run-glm4.7-355B-A32B-partial.sh](../../../examples/partial_weight_sync/run-glm4.7-355B-A32B-partial.sh) for a complete non-colocated launcher.

## Modes: selective vs delta

Both modes share the same sender pipeline (snapshot, mask determination, sparse encoding, bucketed broadcast) and the same wire format. They differ only in what the values mean and how the receiver applies them:

| | `--update-weight-mode selective` | `--update-weight-mode delta` |
|---|---|---|
| Values on wire | new param values at changed positions, in the snapshot's dtype | `(current − snapshot)` cast to `--update-weight-delta-dtype` (default fp32) |
| "Unchanged" signal at receiver | NaN sentinel in the decoded dense tensor | implicit (zero delta at unchanged positions) |
| Receiver apply | `param[~isnan(src)] = src[~isnan(src)]` (selective overwrite) | `param += delta` (in-place add, auto-promotes for fp32 math, casts back to param dtype) |
| Wire bytes (values portion) | 2×nnz @ bf16 (½× delta) | 4×nnz @ fp32 |
| Lossless? | yes by construction (no arithmetic) | yes when `delta-dtype` > param dtype |

Pick `selective` when you want the smaller wire and don't need fp32 arithmetic margin; pick `delta` when you'd rather keep the arithmetic path for compatibility or want to amplify sub-bf16 deltas via the fp32 subtraction.

## How it Works

Per sync, on the trainer (PP-source rank only):

1. **Compute the payload**: for selective, take the bf16 mask `current != snapshot` and emit new values with NaN at unchanged positions; for delta, lift current weights and pinned-CPU snapshot to delta_dtype and subtract.
2. **Encode**: sparse-encode active positions into two flat packed tensors (`__packed_keys__`, `__packed_values__`) plus a per-param manifest (`PartialWeightSpec.params`).
3. **Bucket and broadcast**: pack multiple parameters per NCCL broadcast (`--update-weight-buffer-size` controls the bucket cap).
4. **Snapshot new prev**: D2H copy of the just-sent weights onto a side stream so it overlaps with downstream broadcast/encode work.

On the SGLang receiver:

1. **Broadcast**: receive the two packed tensors per bucket.
2. **Decode lazily**: yield one decoded dense tensor per parameter; unchanged positions are filled with the mode's sentinel (NaN for selective, 0 for delta). The consumer's `chunk_byte_cap` bounds peak HBM during decode (`encoded_buffers + in-flight chunk`).
3. **Apply**: route the decoded tensors through the model's normal `load_weights` path, but with `Tensor.copy_` / `fill_` rewired by a context manager:
- For `selective`: `_selective_load_context` redirects writes that target param storage to a masked overwrite (`param[~isnan(src)] = src[~isnan(src)]`), leaving NaN positions untouched.
- For `delta`: `_additive_load_context` redirects writes that target param storage to `add_` (PyTorch auto-promotes for fp32 math and casts back on store, so deltas keep fp32 precision).

Auxiliary writes (scratch buffers, dtype temporaries, `post_load_weights` for fp8-scale recompute or MoE bias materialization) keep their normal overwriting semantics in both contexts.

The wire protocol — `PartialWeightSpec` (encoding + per-param manifest), and per-param `PartialWeightParam` (name, dtype, shape, key/value slice ranges) — is defined in `sglang.srt.managers.io_struct` (added by the slime SGLang patch).

## Choosing the Wire Encoding

`--update-weight-partial-encoding` accepts three values:

| value | wire layout | when to pick |
|---|---|---|
| `sparse_indices` | int32 active offsets + values | low change density (< ~3%) |
| `sparse_bitmask` | 1 bit per element + values | moderate change density (> ~3%) |
| `dense` | identity, one tensor per param | debugging the apply path |

The break-even density between the two sparse encodings is independent of the value dtype. With `n = numel`, `k = nnz`, `v = value bytes`:

```
sparse_indices wire = k * (4 + v)
sparse_bitmask wire = ceil(n / 8) + k * v
```

Equal when `4k = n/8`, i.e. `k/n = 1/32 ≈ 3.125%`. Below that, indices is smaller; above, bitmask is smaller. For typical RL fine-tuning at moderate learning rates, `sparse_indices` wins; for early-training high-LR phases where most weights move every step, switch to `sparse_bitmask`.

## Precision Behaviour

For `delta` mode, `--update-weight-delta-dtype` is the *math* dtype, not just the wire dtype. The subtraction is performed at `delta_dtype` on both operands (after promoting from the param dtype), and the receiver's `param.data.add_(fp32_delta)` lets PyTorch do the addition at the common dtype (fp32) before casting the result back into the bf16 param. This recovers small-magnitude deltas that would otherwise round to zero through a bf16 subtraction.

For `selective` mode there is no arithmetic — the receiver overwrites changed positions with the trainer's exact bf16 values — so precision is bit-perfect regardless of `--update-weight-delta-dtype` (the flag is silently ignored).

The CPU snapshot occupies only the param dtype's bytes in both modes (no fp32 inflation of pinned memory).

## Periodic Base Sync

The first sync of every job is always a *base sync* (a full broadcast that re-establishes the snapshot). After that, slime sends partial syncs until `committed_syncs % --update-weight-base-sync-interval == 0`, at which point a base sync runs again.

In selective mode or with `--update-weight-delta-dtype fp32` (delta mode), the partial apply is **lossless**: every bf16 value is exactly representable in fp32, the subtraction `current_fp32 − snapshot_fp32` produces the exact difference between the two stored bf16 values, and the receiver's in-place `bf16_param.add_(fp32_delta)` reconstructs the trainer's bf16 state bit-for-bit when the fp32 result is rounded back to bf16. Selective is lossless by construction (direct overwrite). Because no error accumulates across partial syncs, receiver state never drifts from a base-sync reference no matter how many partial syncs elapse — periodic base sync is not needed for correctness. The default `--update-weight-base-sync-interval 9999` effectively disables it and is the recommended setting; set lower (e.g. `30`) if you want periodic full broadcasts to verify correctness or your workload has a custom base-sync requirement.

The only operational reason to keep an occasional base sync is recovery — e.g. a rollout engine that joins mid-training and needs a complete state before it can apply partial updates. If you set `--update-weight-delta-dtype bf16` (delta only, not higher than the param dtype) to save wire bytes, the delta apply is no longer lossless and a finite interval starts to matter.

## Why Not Colocated

Colocated weight sync uses CUDA IPC: the engine maps the trainer's parameter storage directly into its own process. There is no NCCL broadcast, and "wire size" is one IPC handle per param (~64 B). Partial encoding's `bytes saved on the wire` benefit is zero, while the partial-update bookkeeping (snapshot + subtract/mask + sparse encode) is pure overhead. Slime rejects `--update-weight-mode selective --colocate` and `--update-weight-mode delta --colocate` at argparse time.
1 change: 1 addition & 0 deletions docs/en/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ slime is the RL-framework behind GLM-4.7, GLM-4.6 and GLM-4.5. Apart from models
advanced/on-policy-distillation.md
advanced/speculative-decoding.md
advanced/low-precision.md
advanced/partial-weight-sync.md
advanced/reproducibility.md
advanced/fault-tolerance.md
advanced/pd-disaggregation.md
Expand Down
113 changes: 113 additions & 0 deletions docs/zh/advanced/partial-weight-sync.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# 增量权重同步(Selective / Delta)

- [概述](#概述)
- [快速开始](#快速开始)
- [两种 partial 模式:selective 与 delta](#两种-partial-模式selective-与-delta)
- [工作原理](#工作原理)
- [选择 wire 编码](#选择-wire-编码)
- [精度行为](#精度行为)
- [周期性 Base Sync](#周期性-base-sync)
- [为什么 colocate 模式不需要](#为什么-colocate-模式不需要)

## 概述

在**非 colocate**(non-colocated)模式下,slime 默认会在每一步训练时把所有参数完整地广播给 SGLang。完整广播的体积随模型规模线性增长,即使两步之间实际变化的权重比例很小,broadcast 仍然主导整个权重同步阶段。Partial-update 模式会把上一次同步时的权重在 pinned CPU 内存里保留一份 snapshot,每步只广播变化位置的数据,SGLang 接收端只更新这些位置。在 RL fine-tuning 阶段、学习率不大的常见设置里,每步 diff 都很稀疏(只有百分之几的权重发生变化),wire 体积也按比例减少。

**参考资料 / 先验工作。** `selective` 模式的灵感来自 [arXiv:2509.19128](https://arxiv.org/abs/2509.19128)。`delta` 模式的加性更新思路参考了 [Cursor Composer 2](https://cursor.com/resources/Composer2.pdf) 和 [Fireworks AI — Frontier RL Is Cheaper Than You Think](https://fireworks.ai/blog/frontier-rl-is-cheaper-than-you-think)。

## 快速开始

训练端开关与传输编码:

```bash
--update-weight-mode selective # 'selective' / 'delta' / 'full'(默认)
--update-weight-partial-encoding sparse_indices
--update-weight-delta-dtype fp32 # 仅 delta 模式生效
--update-weight-base-sync-interval 9999 # 默认值。两种 partial 模式在默认设置下都是 lossless 的
# (selective 按构造,delta 配 fp32 算术),所以 9999
# 实际上关闭了周期性 base sync。若想用周期性全量广播
# 来验证正确性、或有自定义的 base sync 需求,可调小
# (例如 30)。
```

SGLang 端唯一的旋钮(slime 通过 `--sglang-update-weight-partial-chunk-bytes` 自动转发):

```bash
--sglang-update-weight-partial-chunk-bytes $((2 * 1024 * 1024 * 1024))
```

完整非 colocate 启动脚本见 [examples/partial_weight_sync/run-glm4.7-355B-A32B-partial.sh](../../../examples/partial_weight_sync/run-glm4.7-355B-A32B-partial.sh)。

## 两种 partial 模式:selective 与 delta

两种模式共用 sender 流水线(snapshot、mask 计算、稀疏编码、桶式广播)和 wire 格式,区别只在 values 的语义以及 receiver 的 apply 方式:

| | `--update-weight-mode selective` | `--update-weight-mode delta` |
|---|---|---|
| wire 上的 values | 变化位置的新权重,dtype 同 snapshot | `(current − snapshot)`,cast 到 `--update-weight-delta-dtype`(默认 fp32) |
| 接收端"未变化"信号 | 解码后的稠密张量在未变化位置填 NaN | 隐式(delta 在未变化位置为 0) |
| 接收端 apply | `param[~isnan(src)] = src[~isnan(src)]`(selective overwrite) | `param += delta`(in-place add,自动提升到 fp32 计算后再 cast 回 param dtype) |
| values 部分 wire 字节 | 2×nnz @ bf16(½× delta) | 4×nnz @ fp32 |
| 是否 lossless | 永远 lossless(无算术) | 当 `delta-dtype` 高于 param dtype 时 lossless |

当你想要更小的 wire、不需要 fp32 算术余量时选 `selective`;当你需要 fp32 减法去保住 sub-bf16 级别的小 delta 时选 `delta`。

## 工作原理

每次同步,训练端(仅 PP-source rank):

1. **计算 payload**:selective 模式下,先在 bf16 上取 `current != snapshot` 的 mask,再生成新权重值并在 unchanged 位置填 NaN;delta 模式下,将当前权重与 pinned-CPU snapshot 同时提升到 delta_dtype 然后相减。
2. **编码**:将 active 位置稀疏编码为两条扁平张量(`__packed_keys__`、`__packed_values__`)和一份 per-param manifest(`PartialWeightSpec.params`)。
3. **分桶广播**:多个参数共享一次 NCCL 广播,桶大小由 `--update-weight-buffer-size` 控制。
4. **异步刷新 snapshot**:把当前权重通过独立 CUDA stream 拷贝到 pinned CPU,与下一轮的广播、编码计算重叠。

SGLang 接收端:

1. **接收**:每个桶接收两条 packed 张量。
2. **懒解码**:以生成器逐参数 yield 解码后的稠密张量;unchanged 位置按模式填入 sentinel(selective 模式填 NaN,delta 模式填 0)。下游 chunking 的 `chunk_byte_cap` 同时为 decode 阶段的峰值 HBM 设上限(`encoded_buffers + in-flight chunk`)。
3. **加性写入**:仍走模型 `load_weights` 主路径,但通过一个 context manager 重写 `Tensor.copy_` / `fill_`:
- `selective` 模式下 `_selective_load_context` 把落入 param storage 的 copy_ 重写为 mask-overwrite(`param[~isnan(src)] = src[~isnan(src)]`),unchanged 位置保持不动。
- `delta` 模式下 `_additive_load_context` 把落入 param storage 的 copy_ 重写为 `add_`(PyTorch 自动提升到 fp32 完成加法、再 cast 回 store,保留 fp32 精度)。

非 param 的写入(scratch buffer、dtype 转换、`post_load_weights` 中的 FP8 scale 重计算 / MoE bias 物化等)在两种 context 下都保持原始覆盖语义。

Wire protocol —— `PartialWeightSpec`(encoding + per-param manifest)和 `PartialWeightParam`(name、dtype、shape、keys/values slice)—— 定义在 `sglang.srt.managers.io_struct`(由 slime 的 SGLang patch 注入)。

## 选择 wire 编码

`--update-weight-partial-encoding` 接受三个值:

| 值 | wire 排布 | 适用场景 |
|---|---|---|
| `sparse_indices` | int32 active 下标 + 值 | 低变化率(< ~3%) |
| `sparse_bitmask` | 每元素 1 bit 的 mask + 值 | 中等变化率(> ~3%) |
| `dense` | 每参数一条张量 | 调试 apply 路径 |

两种稀疏编码的等价点和值 dtype 无关。令 `n = numel`,`k = nnz`,`v = 值字节数`:

```
sparse_indices wire = k * (4 + v)
sparse_bitmask wire = ceil(n / 8) + k * v
```

二者相等时 `4k = n/8`,即 `k/n = 1/32 ≈ 3.125%`。低于该 density 选 indices,高于则选 bitmask。常见的小学习率 RL fine-tuning 阶段 `sparse_indices` 更省,训练早期大 LR 阶段几乎所有权重都在动时换 `sparse_bitmask`。

## 精度行为

`delta` 模式下 `--update-weight-delta-dtype` 控制的是**计算 dtype**,不仅仅是 wire dtype。减法在两个操作数都被提升到 `delta_dtype` 之后进行;接收端的 `param.data.add_(fp32_delta)` 让 PyTorch 内部以共同 dtype(fp32)做加法,然后再 cast 回 bf16 写入 param。这样可以保留那些在 bf16 减法下会直接舍入为零的小幅度 delta。

`selective` 模式下没有算术,接收端直接把 trainer 的精确 bf16 值写回 param,因此精度天然 bit-perfect,与 `--update-weight-delta-dtype` 无关(该 flag 在 selective 模式下被静默忽略)。

CPU snapshot 在两种模式下都只占用 param dtype 的字节数(不会因此膨胀到 fp32 的存储)。

## 周期性 Base Sync

每次任务的第一次同步永远是 *base sync*(一次完整广播,重建 snapshot)。之后每当 `committed_syncs % --update-weight-base-sync-interval == 0` 再触发一次 base sync。

在 selective 模式或 `--update-weight-delta-dtype fp32`(delta 模式)下,partial apply 都是**无损(lossless)**的:selective 模式因为直接覆盖而天然无损;delta 模式下每个 bf16 值都可以精确表示为 fp32,`current_fp32 − snapshot_fp32` 得到两个 bf16 值的精确差,接收端的 `bf16_param.add_(fp32_delta)` 在自动提升到 fp32 完成加法、再 cast 回 bf16 之后,会逐比特地复现 trainer 的 bf16 状态。因为不会有误差累积,无论中间累积了多少次 partial 同步,接收端的状态都不会偏离对应的 base sync 结果,从正确性角度并不需要周期性 base sync。默认 `--update-weight-base-sync-interval 9999` 实际上已关闭周期性 base sync,是推荐设置;若希望用周期性全量广播来验证正确性或有自定义需求,可设成较小的值(例如 `30`)。

保留少量 base sync 的运营性理由主要是恢复点——例如一个中途加入的 rollout engine 需要先拿到完整状态才能应用后续 partial 更新。如果你为了进一步压缩 wire 体积而把 `--update-weight-delta-dtype` 设为 `bf16`(不高于 param dtype 的精度,仅对 delta 模式有意义),apply 就不再 lossless,这时 interval 才需要给一个合理的有限值。

## 为什么 colocate 模式不需要

Colocate 模式的权重同步走的是 CUDA IPC:SGLang 直接把 trainer 进程的参数 storage 映射到自己进程,wire 上只交换一个 IPC handle(~64 B),完全没有 NCCL 广播。Partial 编码的「wire 体积」优势归零,而 partial 更新的额外开销(snapshot 维护、减法/取 mask、稀疏编码)反而是纯开销。所以 slime 在 argparse 阶段就拒绝 `--update-weight-mode selective --colocate` 和 `--update-weight-mode delta --colocate` 的组合。
1 change: 1 addition & 0 deletions docs/zh/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ slime 是 GLM-4.7、GLM-4.6、GLM-4.5 背后的 RL 训练框架。除此之外
advanced/on-policy-distillation.md
advanced/speculative-decoding.md
advanced/low-precision.md
advanced/partial-weight-sync.md
advanced/reproducibility.md
advanced/fault-tolerance.md
advanced/pd-disaggregation.md
Expand Down
Loading
Loading