diff --git a/docker/patch/latest/sglang.patch b/docker/patch/latest/sglang.patch index 4a13e2f9b4..e911660e6d 100644 --- a/docker/patch/latest/sglang.patch +++ b/docker/patch/latest/sglang.patch @@ -1,5 +1,5 @@ diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py -index 691f06411d..671ac81c48 100644 +index 691f064..671ac81 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -294,6 +294,7 @@ class ModelConfig: @@ -11,7 +11,7 @@ index 691f06411d..671ac81c48 100644 ]: self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN" diff --git a/python/sglang/srt/disaggregation/base/conn.py b/python/sglang/srt/disaggregation/base/conn.py -index f7d4092d85..3aae51c849 100644 +index f7d4092..3aae51c 100644 --- a/python/sglang/srt/disaggregation/base/conn.py +++ b/python/sglang/srt/disaggregation/base/conn.py @@ -17,6 +17,7 @@ class KVArgs: @@ -23,7 +23,7 @@ index f7d4092d85..3aae51c849 100644 aux_data_lens: List[int] aux_item_lens: List[int] diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py -index f54c882cc2..03832002f0 100644 +index f54c882..0383200 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -21,6 +21,7 @@ Life cycle of a request in the decode server @@ -209,7 +209,7 @@ index f54c882cc2..03832002f0 100644 if not hasattr(self, "polling_count"): diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py -index 64d97f5c69..4ef08446aa 100644 +index 64d97f5..4ef0844 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -31,6 +31,7 @@ from sglang.srt.disaggregation.mooncake.utils import ( @@ -341,7 +341,7 @@ index 64d97f5c69..4ef08446aa 100644 # Only the last chunk we need to send the aux data ret = self.send_aux( diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py -index 8eadf81954..c180ce79f3 100644 +index 8eadf81..c180ce7 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -20,6 +20,8 @@ Life cycle of a request in the prefill server @@ -484,7 +484,7 @@ index 8eadf81954..c180ce79f3 100644 release_kv_cache(req, self.tree_cache) # unlock the tree req.finished_reason = FINISH_LENGTH(length=0) diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py -index d7956a6048..0ced278713 100644 +index d7956a6..0ced278 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -28,6 +28,17 @@ if TYPE_CHECKING: @@ -691,7 +691,7 @@ index d7956a6048..0ced278713 100644 ######################### diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py -index d864e4abaa..3a000a80f2 100644 +index d864e4a..3a000a8 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -69,6 +69,7 @@ from sglang.srt.managers.io_struct import ( @@ -728,7 +728,7 @@ index d864e4abaa..3a000a80f2 100644 """Get weights by parameter name.""" obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py -index 6978e0c062..80dc159e8f 100644 +index 6978e0c..80dc159 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -127,6 +127,7 @@ from sglang.srt.managers.io_struct import ( @@ -800,7 +800,7 @@ index 6978e0c062..80dc159e8f 100644 @auth_level(AuthLevel.ADMIN_OPTIONAL) async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request): diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py -index dfc5507de0..be9501b05a 100644 +index dfc5507..be9501b 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -242,6 +242,7 @@ class Envs: @@ -812,7 +812,7 @@ index dfc5507de0..be9501b05a 100644 # Extra slots in req_to_token_pool for decode workers (only effective when # max_num_reqs > 32). Increases pool capacity so more KV cache transfers diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py -index 02ef4e2440..fd5a43cce8 100644 +index 02ef4e2..fd5a43c 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -1,6 +1,7 @@ @@ -890,7 +890,7 @@ index 02ef4e2440..fd5a43cce8 100644 if enable_dual_stream: current_stream = torch.cuda.current_stream() diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py -index 72483f4ea6..2e1148d189 100644 +index 72483f4..2e1148d 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -702,6 +702,7 @@ class FusedMoE(torch.nn.Module): @@ -910,7 +910,7 @@ index 72483f4ea6..2e1148d189 100644 ) diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py -index 00bd687555..12d5577af2 100644 +index 00bd687..12d5577 100644 --- a/python/sglang/srt/layers/moe/routed_experts_capturer.py +++ b/python/sglang/srt/layers/moe/routed_experts_capturer.py @@ -8,10 +8,15 @@ import torch @@ -971,7 +971,7 @@ index 00bd687555..12d5577af2 100644 def get_routed_experts( diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py -index a13c53af4d..1d80d06b13 100644 +index a13c53a..1d80d06 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -500,7 +500,7 @@ class CompressedTensorsConfig(QuantizationConfig): @@ -994,7 +994,7 @@ index a13c53af4d..1d80d06b13 100644 self, layer: torch.nn.Module, diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py -index 7a8fb65421..f1c85899cd 100644 +index 7a8fb65..f1c8589 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py @@ -17,7 +17,10 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import ( @@ -1106,10 +1106,74 @@ index 7a8fb65421..f1c85899cd 100644 is_k_full=self.is_k_full, routed_scaling_factor=self.moe_runner_config.routed_scaling_factor, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py -index bd97965345..e6a147c1b4 100644 +index bd97965..ba8976a 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py -@@ -1449,6 +1449,18 @@ class ResumeMemoryOccupationReqOutput(BaseReq): +@@ -1254,6 +1254,54 @@ class UpdateWeightFromDiskReqOutput(BaseReq): + num_paused_requests: Optional[int] = 0 + + ++class PartialWeightEncoding(str, Enum): ++ # Wire encoding for a partial weight-update broadcast (used by load_format= ++ # "delta" and load_format="selective"). Choosing SPARSE_INDICES vs SPARSE_BITMASK: ++ # with n=numel, k=nnz, v=value bytes, the wire sizes are k*(4+v) and ceil(n/8)+k*v, ++ # so the break-even density k/n = 1/32 ≈ 3.125% (independent of v). Pick INDICES ++ # below ~3%, BITMASK above. ++ ++ # One tensor per parameter, broadcast identical (name, dtype, shape) to the ++ # parent request's lists. PartialWeightSpec.params is None. ++ DENSE = "dense" ++ # Two tensors broadcast: __packed_keys__ (int32 nonzero offsets) and __packed_values__. ++ # Per-param slicing is in PartialWeightSpec.params. ++ SPARSE_INDICES = "sparse_indices" ++ # Two tensors broadcast: __packed_keys__ (uint8 packed bitmask, 1 bit per element, ++ # LSB-first within each byte) and __packed_values__. Per-param slicing is in ++ # PartialWeightSpec.params. ++ SPARSE_BITMASK = "sparse_bitmask" ++ ++ ++@dataclass ++class PartialWeightParam: ++ # Decoding manifest entry for one parameter in a sparse partial-update broadcast. ++ # Unused (and absent) for DENSE — the parent request's names/dtypes/shapes ++ # already describe each param 1:1. ++ name: str ++ dtype: str ++ shape: List[int] ++ # Half-open slice [keys_start, keys_end) into __packed_keys__ ++ # (int32 indices for SPARSE_INDICES, packed bits for SPARSE_BITMASK). ++ keys_start: int ++ keys_end: int ++ # Half-open slice [values_start, values_end) into __packed_values__. ++ values_start: int ++ values_end: int ++ ++ ++@dataclass ++class PartialWeightSpec: ++ # Wire-encoding descriptor for a partial weight-update broadcast. Required iff ++ # the parent UpdateWeightsFromDistributedReqInput has load_format in ++ # ("delta", "selective"). For load_format="delta" the values are deltas applied ++ # additively; for load_format="selective" the values are new param values with ++ # NaN as the "unchanged" sentinel. ++ encoding: PartialWeightEncoding ++ # Per-param decoding manifest. Required for sparse encodings, None for DENSE. ++ params: Optional[List[PartialWeightParam]] = None ++ ++ + @dataclass + class UpdateWeightsFromDistributedReqInput(BaseReq): + names: List[str] +@@ -1269,6 +1317,8 @@ class UpdateWeightsFromDistributedReqInput(BaseReq): + weight_version: Optional[str] = None + # Optional format specification for loading + load_format: Optional[str] = None ++ # Required iff load_format in ("delta", "selective"). ++ partial: Optional[PartialWeightSpec] = None + + + @dataclass +@@ -1449,6 +1499,18 @@ class ResumeMemoryOccupationReqOutput(BaseReq): pass @@ -1128,7 +1192,7 @@ index bd97965345..e6a147c1b4 100644 @dataclass class CheckWeightsReqInput(BaseReq): action: str -@@ -1753,6 +1765,8 @@ class GetLoadReqOutput(BaseReq): +@@ -1753,6 +1815,8 @@ class GetLoadReqOutput(BaseReq): num_waiting_reqs: int num_tokens: int ts_tic: float @@ -1138,7 +1202,7 @@ index bd97965345..e6a147c1b4 100644 @dataclass diff --git a/python/sglang/srt/managers/multi_tokenizer_mixin.py b/python/sglang/srt/managers/multi_tokenizer_mixin.py -index e0a1669fb3..fbbb6bb12b 100644 +index e0a1669..fbbb6bb 100644 --- a/python/sglang/srt/managers/multi_tokenizer_mixin.py +++ b/python/sglang/srt/managers/multi_tokenizer_mixin.py @@ -496,6 +496,35 @@ def monkey_patch_uvicorn_multiprocessing(timeout: float = 10): @@ -1178,7 +1242,7 @@ index e0a1669fb3..fbbb6bb12b 100644 class SenderWrapper: def __init__(self, port_args: PortArgs, send_to_scheduler: zmq.Socket): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py -index 0b26be6c6d..2ea1042cf9 100644 +index 0b26be6..2ea1042 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1972,7 +1972,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): @@ -1194,7 +1258,7 @@ index 0b26be6c6d..2ea1042cf9 100644 break diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py -index 67af2d0de9..122ddb3874 100644 +index 67af2d0..122ddb3 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -120,6 +120,7 @@ from sglang.srt.managers.io_struct import ( @@ -1214,7 +1278,7 @@ index 67af2d0de9..122ddb3874 100644 (ReleaseMemoryOccupationReqInput, self.release_memory_occupation), (ResumeMemoryOccupationReqInput, self.resume_memory_occupation), diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py -index 496cd96656..cf2d43015a 100644 +index 496cd96..cf2d430 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -1154,7 +1154,7 @@ class SchedulerOutputProcessorMixin: @@ -1227,7 +1291,7 @@ index 496cd96656..cf2d43015a 100644 BatchTokenIDOutput( rids=rids, diff --git a/python/sglang/srt/managers/scheduler_profiler_mixin.py b/python/sglang/srt/managers/scheduler_profiler_mixin.py -index c02ed7997d..61733c4127 100644 +index c02ed79..61733c4 100644 --- a/python/sglang/srt/managers/scheduler_profiler_mixin.py +++ b/python/sglang/srt/managers/scheduler_profiler_mixin.py @@ -349,7 +349,7 @@ class SchedulerProfilerMixin: @@ -1240,7 +1304,7 @@ index c02ed7997d..61733c4127 100644 if self.profile_in_progress: # force trace flush diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py -index abcda67946..a53848b79d 100644 +index abcda67..a53848b 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -12,6 +12,7 @@ from sglang.srt.constants import ( @@ -1305,7 +1369,7 @@ index abcda67946..a53848b79d 100644 def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py -index 544c609401..841658c30e 100644 +index 544c609..841658c 100644 --- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -59,6 +59,8 @@ from sglang.srt.managers.io_struct import ( @@ -1357,7 +1421,7 @@ index 544c609401..841658c30e 100644 self: TokenizerManager, obj: InitWeightsSendGroupForRemoteInstanceReqInput, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py -index 81424329a0..2c132be63d 100644 +index 8142432..2c132be 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -1383,7 +1383,7 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerScoreMixin): @@ -1412,7 +1476,7 @@ index 81424329a0..2c132be63d 100644 if state.finished: retraction_count = ( diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py -index 7f63610da8..fb56de1583 100644 +index 7f63610..da6c2b2 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -29,6 +29,7 @@ from sglang.srt.managers.io_struct import ( @@ -1423,7 +1487,15 @@ index 7f63610da8..fb56de1583 100644 SendWeightsToRemoteInstanceReqInput, UnloadLoRAAdapterReqInput, UpdateWeightFromDiskReqInput, -@@ -170,6 +171,11 @@ class BaseTpWorker(ABC): +@@ -151,6 +152,7 @@ class BaseTpWorker(ABC): + recv_req.shapes, + recv_req.group_name, + recv_req.load_format, ++ recv_req.partial, + ) + return success, message + +@@ -170,6 +172,11 @@ class BaseTpWorker(ABC): success, message = self.model_runner.update_weights_from_ipc(recv_req) return success, message @@ -1436,7 +1508,7 @@ index 7f63610da8..fb56de1583 100644 parameter = self.model_runner.get_weights_by_name( recv_req.name, recv_req.truncate_size diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py -index 3c1e97daab..e5128e5ee2 100644 +index 3c1e97d..e5128e5 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -755,9 +755,8 @@ class HiRadixCache(RadixCache): @@ -1476,7 +1548,7 @@ index 3c1e97daab..e5128e5ee2 100644 self._inc_hit_count(new_node, chunked) total_prefix_length += prefix_len diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py -index e4c158cda9..cf7333235f 100644 +index e4c158c..cf73332 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -1854,9 +1854,12 @@ class NSATokenToKVPool(MLATokenToKVPool): @@ -1559,7 +1631,7 @@ index e4c158cda9..cf7333235f 100644 kv_size_bytes = super().get_kv_size_bytes() for index_k_cache in self.index_k_with_scale_buffer: diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py -index 7d16160372..70fbdc702f 100644 +index 7d16160..70fbdc7 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -512,7 +512,17 @@ class RadixCache(BasePrefixCache): @@ -1594,10 +1666,32 @@ index 7d16160372..70fbdc702f 100644 return DecLockRefResult(delta=delta) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py -index a59742b943..a7347c15b8 100644 +index a59742b..8b43055 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py -@@ -406,7 +406,12 @@ class ModelRunner(ModelRunnerKVCacheMixin): +@@ -15,11 +15,13 @@ + + from __future__ import annotations + ++import contextlib + import datetime + import gc + import inspect + import json + import logging ++import math + import os + import socket + import threading +@@ -118,6 +120,7 @@ from sglang.srt.layers.sampler import create_sampler + from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model + from sglang.srt.lora.lora_manager import LoRAManager + from sglang.srt.lora.lora_registry import LoRARef ++from sglang.srt.managers.io_struct import PartialWeightEncoding, PartialWeightSpec + from sglang.srt.managers.schedule_batch import sanity_check_mm_pad_shift_value + from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator + from sglang.srt.mem_cache.memory_pool import ReqToTokenPool +@@ -406,7 +409,12 @@ class ModelRunner(ModelRunnerKVCacheMixin): self.forward_stream = torch.get_device_module(self.device).Stream() # CPU offload @@ -1611,7 +1705,7 @@ index a59742b943..a7347c15b8 100644 self._weight_checker = WeightChecker(model_runner=self) -@@ -646,7 +651,8 @@ class ModelRunner(ModelRunnerKVCacheMixin): +@@ -646,7 +654,8 @@ class ModelRunner(ModelRunnerKVCacheMixin): ) # Init routed experts capturer @@ -1621,7 +1715,181 @@ index a59742b943..a7347c15b8 100644 if self.device == "cuda" or self.device == "musa": self.init_cublas() -@@ -2767,11 +2773,19 @@ class ModelRunner(ModelRunnerKVCacheMixin): +@@ -1563,6 +1572,7 @@ class ModelRunner(ModelRunnerKVCacheMixin): + shapes, + group_name, + load_format: Optional[str] = None, ++ partial: Optional[PartialWeightSpec] = None, + ): + """ + Update specific parameter in the model weights online +@@ -1583,27 +1593,13 @@ class ModelRunner(ModelRunnerKVCacheMixin): + return self._update_bucketed_weights_from_distributed( + names, dtypes, shapes, group_name + ) ++ if load_format in ("delta", "selective"): ++ return self._update_partial_weights_from_distributed( ++ names, dtypes, shapes, group_name, partial, load_format ++ ) + try: +- weights = [] +- handles = [] +- for name, dtype, shape in zip(names, dtypes, shapes): +- target_dtype = ( +- dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype) +- ) +- weight = torch.empty(shape, dtype=target_dtype, device=self.device) +- handles.append( +- torch.distributed.broadcast( +- weight, +- src=0, +- group=self._model_update_group[group_name], +- async_op=True, +- ) +- ) +- weights.append((name, weight)) +- for handle in handles: +- handle.wait() +- +- self.model.load_weights(weights) ++ received = self._broadcast_named_tensors(names, dtypes, shapes, group_name) ++ self.model.load_weights(received) + return True, "Succeeded to update parameter online." + + except Exception as e: +@@ -1615,6 +1611,28 @@ class ModelRunner(ModelRunnerKVCacheMixin): + logger.error(error_msg) + return False, error_msg + ++ def _broadcast_named_tensors(self, names, dtypes, shapes, group_name): ++ """Returns [(name, tensor)] in input order, broadcast on the weight-update group.""" ++ weights = [] ++ handles = [] ++ for name, dtype, shape in zip(names, dtypes, shapes): ++ target_dtype = ( ++ dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype) ++ ) ++ weight = torch.empty(shape, dtype=target_dtype, device=self.device) ++ handles.append( ++ torch.distributed.broadcast( ++ weight, ++ src=0, ++ group=self._model_update_group[group_name], ++ async_op=True, ++ ) ++ ) ++ weights.append((name, weight)) ++ for handle in handles: ++ handle.wait() ++ return weights ++ + def _update_bucketed_weights_from_distributed( + self, names, dtypes, shapes, group_name + ): +@@ -1646,6 +1664,102 @@ class ModelRunner(ModelRunnerKVCacheMixin): + logger.error(error_msg) + return False, error_msg + ++ def _update_partial_weights_from_distributed( ++ self, names, dtypes, shapes, group_name, partial: Optional[PartialWeightSpec], mode: str, ++ ): ++ if partial is None: ++ return False, f"load_format='{mode}' requires `partial` (PartialWeightSpec) in the request" ++ # mode 'delta' → fill unchanged with 0, apply via _additive_load_context. ++ # mode 'selective'→ fill unchanged with NaN, apply via _selective_load_context ++ # (NaN positions are left untouched on the receiver). ++ if mode == "delta": ++ fill_value = 0.0 ++ apply_ctx = _additive_load_context ++ else: # "selective" ++ fill_value = float("nan") ++ apply_ctx = _selective_load_context ++ encoding = partial.encoding ++ try: ++ encoded = self._broadcast_named_tensors(names, dtypes, shapes, group_name) ++ encoded_tensors = [t for _, t in encoded] ++ ++ # Decode must run OUTSIDE the apply context: its initial buffer fill ++ # and index_copy_ must overwrite, not accumulate / mask-overwrite. ++ # Yielding one decoded tensor at a time bounds peak HBM during apply. ++ if encoding is PartialWeightEncoding.DENSE: ++ decoded_iter = zip(names, encoded_tensors) ++ elif encoding in (PartialWeightEncoding.SPARSE_INDICES, PartialWeightEncoding.SPARSE_BITMASK): ++ if partial.params is None: ++ return False, f"encoding={encoding.value!r} requires partial.params" ++ decoded_iter = self._decode_sparse_partial( ++ encoded_tensors, partial.params, encoding, fill_value, ++ ) ++ else: ++ return False, f"unknown partial-update encoding: {encoding!r}" ++ ++ # Bigger chunk_byte_cap amortizes per-call cost (name resolution, ++ # MoE expert remap, fp8 scale repacking) but raises peak HBM. ++ # Sweep at GLM-4.7-355B H100 64-rollout: ++ # 96 MiB → 1110 calls / 37.8s ++ # 512 MiB → 128 calls / 30.3s ← default ++ # 1024 MiB → OOM on some engines ++ chunk_byte_cap = self.server_args.update_weight_partial_chunk_bytes ++ with apply_ctx(self.model): ++ chunk = [] ++ chunk_bytes = 0 ++ for name, t in decoded_iter: ++ tensor_bytes = t.numel() * t.element_size() ++ if chunk_bytes + tensor_bytes > chunk_byte_cap and chunk: ++ self.model.load_weights(chunk) ++ chunk = [] ++ chunk_bytes = 0 ++ chunk.append((name, t)) ++ chunk_bytes += tensor_bytes ++ if chunk: ++ self.model.load_weights(chunk) ++ ++ return True, f"Succeeded to apply weight {mode} update online." ++ except Exception as e: ++ error_msg = ( ++ f"Failed to apply weight {mode} update online: {e}. " ++ f"The model weights may be in an inconsistent state. " ++ f"Please discard the whole weights." ++ ) ++ logger.error(error_msg) ++ return False, error_msg ++ ++ def _decode_sparse_partial(self, encoded_tensors, params, encoding, fill_value): ++ """ ++ Generator: yield one decoded (name, tensor) per param. Decoded tensors ++ are not retained across yields, so the consumer's per-chunk byte cap ++ also bounds the receiver's peak HBM during decode. ``fill_value`` is ++ the sentinel used at unchanged positions — 0 for delta mode, NaN for ++ selective mode. ++ """ ++ # encoded_tensors[0] is __packed_keys__: either int32 nonzero offsets ++ # (SPARSE_INDICES) or a packed bitmask (SPARSE_BITMASK). index_copy_ ++ # needs int64, so we cast once up front for the indices path. ++ if encoding is PartialWeightEncoding.SPARSE_INDICES: ++ packed_keys = encoded_tensors[0].to(dtype=torch.long) ++ elif encoding is PartialWeightEncoding.SPARSE_BITMASK: ++ packed_keys = encoded_tensors[0] ++ else: ++ raise ValueError(f"unsupported sparse encoding: {encoding!r}") ++ packed_values = encoded_tensors[1] ++ ++ for p in params: ++ numel = math.prod(p.shape) ++ param_dtype = getattr(torch, p.dtype) ++ flat = torch.full((numel,), fill_value, dtype=param_dtype, device=self.device) ++ if p.keys_end > p.keys_start: ++ keys = packed_keys[p.keys_start : p.keys_end] ++ values = packed_values[p.values_start : p.values_end] ++ if encoding is PartialWeightEncoding.SPARSE_INDICES: ++ flat.index_copy_(0, keys, values) ++ elif encoding is PartialWeightEncoding.SPARSE_BITMASK: ++ flat[_unpack_bitmask(keys, numel, self.device)] = values ++ yield p.name, flat.view(tuple(p.shape)) ++ + def update_weights_from_tensor( + self, + named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]], +@@ -2767,11 +2881,19 @@ class ModelRunner(ModelRunnerKVCacheMixin): output.expert_distribution_metrics = recorder_outputs.get("metrics") # Copy cached routing experts' buffers back to CPU cache @@ -1646,7 +1914,7 @@ index a59742b943..a7347c15b8 100644 if self.eplb_manager is not None: self.eplb_manager.on_forward_pass_end() -@@ -3021,6 +3035,42 @@ class ModelRunner(ModelRunnerKVCacheMixin): +@@ -3021,6 +3143,192 @@ class ModelRunner(ModelRunnerKVCacheMixin): device=self.device, ) @@ -1685,12 +1953,162 @@ index a59742b943..a7347c15b8 100644 + quant_method.process_weights_after_loading(module) + + return True, "Success" ++ ++ ++def _param_storage_predicate(model): ++ """Return a callable that tests whether a tensor's storage falls inside one ++ of ``model``'s parameter/buffer storage ranges, captured at call time. ++ Used by the partial-update apply contexts to scope their patched copy_/ ++ fill_ rewrites to writes that hit param memory.""" ++ import bisect ++ ++ # Skip meta tensors — no real storage to track. ++ starts: List[int] = [] ++ ends: List[int] = [] ++ seen: set = set() ++ for tensors in (model.named_parameters(), model.named_buffers()): ++ for _, t in tensors: ++ if t.is_meta: ++ continue ++ try: ++ ptr = t.data_ptr() ++ except RuntimeError: ++ continue ++ if ptr == 0 or ptr in seen: ++ continue ++ seen.add(ptr) ++ sz = t.numel() * t.element_size() ++ starts.append(ptr) ++ ends.append(ptr + sz) ++ order = sorted(range(len(starts)), key=lambda i: starts[i]) ++ starts = [starts[i] for i in order] ++ ends = [ends[i] for i in order] ++ ++ def is_param_write(dst): ++ try: ++ ptr = dst.data_ptr() ++ except RuntimeError: ++ return False ++ idx = bisect.bisect_right(starts, ptr) - 1 ++ return 0 <= idx < len(starts) and starts[idx] <= ptr < ends[idx] ++ ++ return is_param_write ++ ++ ++@contextlib.contextmanager ++def _patched_in_place_writes(model, patched_copy_factory, patched_fill_factory): ++ """Shared scaffolding for the partial-update apply contexts. Patches ++ Tensor.copy_ and Tensor.fill_ with caller-provided variants, scoped to ++ writes whose destination is inside ``model``'s param storage. Also wraps ++ ``model.post_load_weights`` (if defined) to run in the unmodified op ++ environment, so derived tensors (fp8 scales, MoE biases, DeepSeek ++ w_kc/w_vc) overwrite correctly even when invoked from inside ++ ``load_weights``. ++ """ ++ is_param_write = _param_storage_predicate(model) ++ original_copy_ = torch.Tensor.copy_ ++ original_fill_ = torch.Tensor.fill_ ++ ++ patched_copy_ = patched_copy_factory(is_param_write, original_copy_) ++ patched_fill_ = patched_fill_factory(is_param_write, original_fill_) ++ ++ original_post_load = getattr(model, "post_load_weights", None) ++ if original_post_load is not None: ++ def wrapped_post_load(*args, **kwargs): ++ current_copy = torch.Tensor.copy_ ++ current_fill = torch.Tensor.fill_ ++ torch.Tensor.copy_ = original_copy_ ++ torch.Tensor.fill_ = original_fill_ ++ try: ++ return original_post_load(*args, **kwargs) ++ finally: ++ torch.Tensor.copy_ = current_copy ++ torch.Tensor.fill_ = current_fill ++ ++ model.post_load_weights = wrapped_post_load ++ ++ torch.Tensor.copy_ = patched_copy_ ++ torch.Tensor.fill_ = patched_fill_ ++ try: ++ yield ++ finally: ++ torch.Tensor.copy_ = original_copy_ ++ torch.Tensor.fill_ = original_fill_ ++ if original_post_load is not None: ++ model.post_load_weights = original_post_load ++ ++ ++def _additive_load_context(model): ++ """Make in-place writes to ``model``'s parameter memory accumulate (``add_``) ++ instead of overwriting (``copy_``/``fill_``). Writes to non-param storage ++ (scratch, dtype temps) are unaffected.""" ++ ++ def make_copy(is_param_write, original_copy_): ++ def patched_copy_(self, src, *args, **kwargs): ++ if is_param_write(self): ++ # In-place add_ promotes operands to common dtype for the math, then ++ # casts the result to self.dtype on store — so bf16 += fp32 keeps ++ # fp32 precision without an explicit src.to(self.dtype). ++ return self.add_(src.to(device=self.device)) ++ return original_copy_(self, src, *args, **kwargs) ++ return patched_copy_ ++ ++ def make_fill(is_param_write, original_fill_): ++ def patched_fill_(self, value): ++ if is_param_write(self): ++ return self.add_(value) ++ return original_fill_(self, value) ++ return patched_fill_ ++ ++ return _patched_in_place_writes(model, make_copy, make_fill) ++ ++ ++def _selective_load_context(model): ++ """Make in-place writes to ``model``'s parameter memory overwrite only at ++ positions where the source is *not* NaN; positions whose source is NaN are ++ left untouched. This is the apply path for load_format="selective", where ++ the sender uses NaN as the "unchanged" sentinel in the dense decoded tensor. ++ Non-param writes go through the unmodified ops.""" ++ ++ def make_copy(is_param_write, original_copy_): ++ def patched_copy_(self, src, *args, **kwargs): ++ if is_param_write(self): ++ src_aligned = src.to(device=self.device, dtype=self.dtype) if src.dtype != self.dtype else src ++ mask = ~torch.isnan(src_aligned) ++ self[mask] = src_aligned[mask] ++ return self ++ return original_copy_(self, src, *args, **kwargs) ++ return patched_copy_ ++ ++ def make_fill(is_param_write, original_fill_): ++ # fill_ takes a scalar. In selective mode, a NaN scalar means "don't ++ # change the param" (matches the per-element semantics of patched_copy_). ++ # Non-NaN scalars fall through to a normal fill. ++ def patched_fill_(self, value): ++ if is_param_write(self): ++ try: ++ if math.isnan(value): ++ return self ++ except TypeError: ++ pass ++ return original_fill_(self, value) ++ return patched_fill_ ++ ++ return _patched_in_place_writes(model, make_copy, make_fill) ++ ++ ++def _unpack_bitmask(packed, numel, device): ++ if numel == 0: ++ return torch.empty(0, dtype=torch.bool, device=device) ++ shifts = torch.arange(8, dtype=torch.uint8, device=device) ++ expanded = ((packed.unsqueeze(1) >> shifts) & 1).reshape(-1) ++ return expanded[:numel].to(dtype=torch.bool) + def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]): params_dict = dict(model.named_parameters()) diff --git a/python/sglang/srt/models/glm4v_moe.py b/python/sglang/srt/models/glm4v_moe.py -index 2f0074924d..1f991932c6 100644 +index 2f00749..1f99193 100644 --- a/python/sglang/srt/models/glm4v_moe.py +++ b/python/sglang/srt/models/glm4v_moe.py @@ -52,11 +52,31 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): @@ -1795,7 +2213,7 @@ index 2f0074924d..1f991932c6 100644 continue diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py -index 912891b6a7..fd67a7b580 100644 +index 912891b..fd67a7b 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -325,7 +325,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): @@ -1808,7 +2226,7 @@ index 912891b6a7..fd67a7b580 100644 if ( diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py -index 7746b24459..57b65fe06f 100644 +index 7746b24..57b65fe 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -1005,14 +1005,19 @@ class Qwen3LLMModel(Qwen3Model): @@ -1836,7 +2254,7 @@ index 7746b24459..57b65fe06f 100644 positions, hidden_states, diff --git a/python/sglang/srt/multimodal/processors/glm4v.py b/python/sglang/srt/multimodal/processors/glm4v.py -index a44f14b6ca..6d6c65ea49 100644 +index a44f14b..6d6c65e 100644 --- a/python/sglang/srt/multimodal/processors/glm4v.py +++ b/python/sglang/srt/multimodal/processors/glm4v.py @@ -1,7 +1,13 @@ @@ -1904,7 +2322,7 @@ index a44f14b6ca..6d6c65ea49 100644 image_grid_thw = None video_grid_thw = None diff --git a/python/sglang/srt/multimodal/processors/qwen_vl.py b/python/sglang/srt/multimodal/processors/qwen_vl.py -index 3f102567d0..6fb3899021 100644 +index 3f10256..6fb3899 100644 --- a/python/sglang/srt/multimodal/processors/qwen_vl.py +++ b/python/sglang/srt/multimodal/processors/qwen_vl.py @@ -499,7 +499,7 @@ class QwenVLImageProcessor(SGLangBaseProcessor): @@ -1917,7 +2335,7 @@ index 3f102567d0..6fb3899021 100644 image_data=image_data, video_data=request_obj.video_data, diff --git a/python/sglang/srt/observability/req_time_stats.py b/python/sglang/srt/observability/req_time_stats.py -index 8caf21c320..51d1edc584 100644 +index 8caf21c..51d1edc 100644 --- a/python/sglang/srt/observability/req_time_stats.py +++ b/python/sglang/srt/observability/req_time_stats.py @@ -21,7 +21,10 @@ import uuid @@ -2157,7 +2575,7 @@ index 8caf21c320..51d1edc584 100644 def format_duration(self, duration: float) -> str: diff --git a/python/sglang/srt/observability/scheduler_metrics_mixin.py b/python/sglang/srt/observability/scheduler_metrics_mixin.py -index ff5695ce2e..588379a85d 100644 +index ff5695c..588379a 100644 --- a/python/sglang/srt/observability/scheduler_metrics_mixin.py +++ b/python/sglang/srt/observability/scheduler_metrics_mixin.py @@ -883,12 +883,42 @@ class SchedulerMetricsMixin: @@ -2204,7 +2622,7 @@ index ff5695ce2e..588379a85d 100644 def get_loads(self: Scheduler, req: GetLoadsReqInput = None) -> GetLoadsReqOutput: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py -index d91ced805f..4c8774bb64 100644 +index d91ced8..023085d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -670,6 +670,7 @@ class ServerArgs: @@ -2215,7 +2633,15 @@ index d91ced805f..4c8774bb64 100644 enable_fused_qk_norm_rope: bool = False enable_precise_embedding_interpolation: bool = False enable_fused_moe_sum_all_reduce: bool = False -@@ -5659,6 +5660,12 @@ class ServerArgs: +@@ -711,6 +712,7 @@ class ServerArgs: + # For model weight update and weight loading + custom_weight_loader: Optional[List[str]] = None + weight_loader_disable_mmap: bool = False ++ update_weight_partial_chunk_bytes: int = 512 * 1024 * 1024 + remote_instance_weight_loader_seed_instance_ip: Optional[str] = None + remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None + remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None +@@ -5659,6 +5661,12 @@ class ServerArgs: help="Token splitting mode for the prefill phase of DeepSeek v3.2 under context parallelism. Optional values: 'round-robin-split'(default), 'in-seq-split' " "'round-robin-split' distributes tokens across ranks based on token_idx %% cp_size. It supports multi-batch prefill, fused MoE, and FP8 KV cache.", ) @@ -2228,8 +2654,26 @@ index d91ced805f..4c8774bb64 100644 parser.add_argument( "--enable-prefill-context-parallel", action="store_true", +@@ -5825,6 +5833,17 @@ class ServerArgs: + default=None, + help="The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func", + ) ++ parser.add_argument( ++ "--update-weight-partial-chunk-bytes", ++ type=int, ++ default=ServerArgs.update_weight_partial_chunk_bytes, ++ help=( ++ "Byte cap per model.load_weights call when applying a partial " ++ "weight update (delta or selective). Bigger amortizes per-call " ++ "cost (name resolution, MoE expert remap, fp8 scale repacking) " ++ "but raises peak HBM during decode." ++ ), ++ ) + parser.add_argument( + "--weight-loader-disable-mmap", + action="store_true", diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py -index 40e859b2d6..2604ae037c 100644 +index 40e859b..2604ae0 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -377,6 +377,10 @@ class EAGLEDraftCudaGraphRunner: @@ -2259,7 +2703,7 @@ index 40e859b2d6..2604ae037c 100644 buffers.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py -index dbb91f555e..a04caefc34 100644 +index dbb91f5..a04caef 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -776,6 +776,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): @@ -2302,7 +2746,7 @@ index dbb91f555e..a04caefc34 100644 @dataclass diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py -index b0be70d751..44a78d684e 100644 +index b0be70d..44a78d6 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -2157,6 +2157,7 @@ class SafeUnpickler(pickle.Unpickler): @@ -2314,7 +2758,7 @@ index b0be70d751..44a78d684e 100644 DENY_CLASSES = { diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py -index 3be16446e0..1b2371c839 100644 +index 3be1644..1b2371c 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -69,6 +69,9 @@ def _check_tensors( diff --git a/docs/en/advanced/partial-weight-sync.md b/docs/en/advanced/partial-weight-sync.md new file mode 100644 index 0000000000..596fe2d70d --- /dev/null +++ b/docs/en/advanced/partial-weight-sync.md @@ -0,0 +1,114 @@ +# Partial Weight Sync (Selective / Delta) + +- [Overview](#overview) +- [Quick Start](#quick-start) +- [Modes: selective vs delta](#modes-selective-vs-delta) +- [How it Works](#how-it-works) +- [Choosing the Wire Encoding](#choosing-the-wire-encoding) +- [Precision Behaviour](#precision-behaviour) +- [Periodic Base Sync](#periodic-base-sync) +- [Why Not Colocated](#why-not-colocated) + +## Overview + +For **non-colocated** runs, slime's default weight sync broadcasts every parameter on every training step. The full broadcast scales linearly with model size and dominates the sync phase even when only a small fraction of weights actually change between steps. Partial-update modes keep a pinned-CPU snapshot of the last sync's weights and broadcast only the changed-position payload; the SGLang receiver applies it without re-touching unchanged params. During typical RL fine-tuning at conservative learning rates the per-step diff is sparse — a few percent of weights — so the wire shrinks proportionally. + +**Inspiration / prior art.** `selective` is inspired by [arXiv:2509.19128](https://arxiv.org/abs/2509.19128). `delta` is informed by the additive-update approach in [Cursor Composer 2](https://cursor.com/resources/Composer2.pdf) and [Fireworks AI — Frontier RL Is Cheaper Than You Think](https://fireworks.ai/blog/frontier-rl-is-cheaper-than-you-think). + +## Quick Start + +Enable a partial mode on the trainer side: + +```bash +--update-weight-mode selective # 'selective' / 'delta' / 'full' (default) +--update-weight-partial-encoding sparse_indices +--update-weight-delta-dtype fp32 # delta mode only +--update-weight-base-sync-interval 9999 # default. Both partial modes are lossless under + # their defaults (selective by construction, delta + # with fp32 math), so 9999 effectively disables + # periodic base syncs. Set lower (e.g. 30) to + # verify against periodic full broadcasts, or + # if your workload has a custom base-sync need. +``` + +And one knob on the SGLang side (auto-mirrored by slime as `--sglang-update-weight-partial-chunk-bytes`): + +```bash +--sglang-update-weight-partial-chunk-bytes $((2 * 1024 * 1024 * 1024)) +``` + +See [examples/partial_weight_sync/run-glm4.7-355B-A32B-partial.sh](../../../examples/partial_weight_sync/run-glm4.7-355B-A32B-partial.sh) for a complete non-colocated launcher. + +## Modes: selective vs delta + +Both modes share the same sender pipeline (snapshot, mask determination, sparse encoding, bucketed broadcast) and the same wire format. They differ only in what the values mean and how the receiver applies them: + +| | `--update-weight-mode selective` | `--update-weight-mode delta` | +|---|---|---| +| Values on wire | new param values at changed positions, in the snapshot's dtype | `(current − snapshot)` cast to `--update-weight-delta-dtype` (default fp32) | +| "Unchanged" signal at receiver | NaN sentinel in the decoded dense tensor | implicit (zero delta at unchanged positions) | +| Receiver apply | `param[~isnan(src)] = src[~isnan(src)]` (selective overwrite) | `param += delta` (in-place add, auto-promotes for fp32 math, casts back to param dtype) | +| Wire bytes (values portion) | 2×nnz @ bf16 (½× delta) | 4×nnz @ fp32 | +| Lossless? | yes by construction (no arithmetic) | yes when `delta-dtype` > param dtype | + +Pick `selective` when you want the smaller wire and don't need fp32 arithmetic margin; pick `delta` when you'd rather keep the arithmetic path for compatibility or want to amplify sub-bf16 deltas via the fp32 subtraction. + +## How it Works + +Per sync, on the trainer (PP-source rank only): + +1. **Compute the payload**: for selective, take the bf16 mask `current != snapshot` and emit new values with NaN at unchanged positions; for delta, lift current weights and pinned-CPU snapshot to delta_dtype and subtract. +2. **Encode**: sparse-encode active positions into two flat packed tensors (`__packed_keys__`, `__packed_values__`) plus a per-param manifest (`PartialWeightSpec.params`). +3. **Bucket and broadcast**: pack multiple parameters per NCCL broadcast (`--update-weight-buffer-size` controls the bucket cap). +4. **Snapshot new prev**: D2H copy of the just-sent weights onto a side stream so it overlaps with downstream broadcast/encode work. + +On the SGLang receiver: + +1. **Broadcast**: receive the two packed tensors per bucket. +2. **Decode lazily**: yield one decoded dense tensor per parameter; unchanged positions are filled with the mode's sentinel (NaN for selective, 0 for delta). The consumer's `chunk_byte_cap` bounds peak HBM during decode (`encoded_buffers + in-flight chunk`). +3. **Apply**: route the decoded tensors through the model's normal `load_weights` path, but with `Tensor.copy_` / `fill_` rewired by a context manager: + - For `selective`: `_selective_load_context` redirects writes that target param storage to a masked overwrite (`param[~isnan(src)] = src[~isnan(src)]`), leaving NaN positions untouched. + - For `delta`: `_additive_load_context` redirects writes that target param storage to `add_` (PyTorch auto-promotes for fp32 math and casts back on store, so deltas keep fp32 precision). + +Auxiliary writes (scratch buffers, dtype temporaries, `post_load_weights` for fp8-scale recompute or MoE bias materialization) keep their normal overwriting semantics in both contexts. + +The wire protocol — `PartialWeightSpec` (encoding + per-param manifest), and per-param `PartialWeightParam` (name, dtype, shape, key/value slice ranges) — is defined in `sglang.srt.managers.io_struct` (added by the slime SGLang patch). + +## Choosing the Wire Encoding + +`--update-weight-partial-encoding` accepts three values: + +| value | wire layout | when to pick | +|---|---|---| +| `sparse_indices` | int32 active offsets + values | low change density (< ~3%) | +| `sparse_bitmask` | 1 bit per element + values | moderate change density (> ~3%) | +| `dense` | identity, one tensor per param | debugging the apply path | + +The break-even density between the two sparse encodings is independent of the value dtype. With `n = numel`, `k = nnz`, `v = value bytes`: + +``` +sparse_indices wire = k * (4 + v) +sparse_bitmask wire = ceil(n / 8) + k * v +``` + +Equal when `4k = n/8`, i.e. `k/n = 1/32 ≈ 3.125%`. Below that, indices is smaller; above, bitmask is smaller. For typical RL fine-tuning at moderate learning rates, `sparse_indices` wins; for early-training high-LR phases where most weights move every step, switch to `sparse_bitmask`. + +## Precision Behaviour + +For `delta` mode, `--update-weight-delta-dtype` is the *math* dtype, not just the wire dtype. The subtraction is performed at `delta_dtype` on both operands (after promoting from the param dtype), and the receiver's `param.data.add_(fp32_delta)` lets PyTorch do the addition at the common dtype (fp32) before casting the result back into the bf16 param. This recovers small-magnitude deltas that would otherwise round to zero through a bf16 subtraction. + +For `selective` mode there is no arithmetic — the receiver overwrites changed positions with the trainer's exact bf16 values — so precision is bit-perfect regardless of `--update-weight-delta-dtype` (the flag is silently ignored). + +The CPU snapshot occupies only the param dtype's bytes in both modes (no fp32 inflation of pinned memory). + +## Periodic Base Sync + +The first sync of every job is always a *base sync* (a full broadcast that re-establishes the snapshot). After that, slime sends partial syncs until `committed_syncs % --update-weight-base-sync-interval == 0`, at which point a base sync runs again. + +In selective mode or with `--update-weight-delta-dtype fp32` (delta mode), the partial apply is **lossless**: every bf16 value is exactly representable in fp32, the subtraction `current_fp32 − snapshot_fp32` produces the exact difference between the two stored bf16 values, and the receiver's in-place `bf16_param.add_(fp32_delta)` reconstructs the trainer's bf16 state bit-for-bit when the fp32 result is rounded back to bf16. Selective is lossless by construction (direct overwrite). Because no error accumulates across partial syncs, receiver state never drifts from a base-sync reference no matter how many partial syncs elapse — periodic base sync is not needed for correctness. The default `--update-weight-base-sync-interval 9999` effectively disables it and is the recommended setting; set lower (e.g. `30`) if you want periodic full broadcasts to verify correctness or your workload has a custom base-sync requirement. + +The only operational reason to keep an occasional base sync is recovery — e.g. a rollout engine that joins mid-training and needs a complete state before it can apply partial updates. If you set `--update-weight-delta-dtype bf16` (delta only, not higher than the param dtype) to save wire bytes, the delta apply is no longer lossless and a finite interval starts to matter. + +## Why Not Colocated + +Colocated weight sync uses CUDA IPC: the engine maps the trainer's parameter storage directly into its own process. There is no NCCL broadcast, and "wire size" is one IPC handle per param (~64 B). Partial encoding's `bytes saved on the wire` benefit is zero, while the partial-update bookkeeping (snapshot + subtract/mask + sparse encode) is pure overhead. Slime rejects `--update-weight-mode selective --colocate` and `--update-weight-mode delta --colocate` at argparse time. diff --git a/docs/en/index.rst b/docs/en/index.rst index d5ca098e06..1811be8a9c 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -44,6 +44,7 @@ slime is the RL-framework behind GLM-4.7, GLM-4.6 and GLM-4.5. Apart from models advanced/on-policy-distillation.md advanced/speculative-decoding.md advanced/low-precision.md + advanced/partial-weight-sync.md advanced/reproducibility.md advanced/fault-tolerance.md advanced/pd-disaggregation.md diff --git a/docs/zh/advanced/partial-weight-sync.md b/docs/zh/advanced/partial-weight-sync.md new file mode 100644 index 0000000000..22f9f674f5 --- /dev/null +++ b/docs/zh/advanced/partial-weight-sync.md @@ -0,0 +1,113 @@ +# 增量权重同步(Selective / Delta) + +- [概述](#概述) +- [快速开始](#快速开始) +- [两种 partial 模式:selective 与 delta](#两种-partial-模式selective-与-delta) +- [工作原理](#工作原理) +- [选择 wire 编码](#选择-wire-编码) +- [精度行为](#精度行为) +- [周期性 Base Sync](#周期性-base-sync) +- [为什么 colocate 模式不需要](#为什么-colocate-模式不需要) + +## 概述 + +在**非 colocate**(non-colocated)模式下,slime 默认会在每一步训练时把所有参数完整地广播给 SGLang。完整广播的体积随模型规模线性增长,即使两步之间实际变化的权重比例很小,broadcast 仍然主导整个权重同步阶段。Partial-update 模式会把上一次同步时的权重在 pinned CPU 内存里保留一份 snapshot,每步只广播变化位置的数据,SGLang 接收端只更新这些位置。在 RL fine-tuning 阶段、学习率不大的常见设置里,每步 diff 都很稀疏(只有百分之几的权重发生变化),wire 体积也按比例减少。 + +**参考资料 / 先验工作。** `selective` 模式的灵感来自 [arXiv:2509.19128](https://arxiv.org/abs/2509.19128)。`delta` 模式的加性更新思路参考了 [Cursor Composer 2](https://cursor.com/resources/Composer2.pdf) 和 [Fireworks AI — Frontier RL Is Cheaper Than You Think](https://fireworks.ai/blog/frontier-rl-is-cheaper-than-you-think)。 + +## 快速开始 + +训练端开关与传输编码: + +```bash +--update-weight-mode selective # 'selective' / 'delta' / 'full'(默认) +--update-weight-partial-encoding sparse_indices +--update-weight-delta-dtype fp32 # 仅 delta 模式生效 +--update-weight-base-sync-interval 9999 # 默认值。两种 partial 模式在默认设置下都是 lossless 的 + # (selective 按构造,delta 配 fp32 算术),所以 9999 + # 实际上关闭了周期性 base sync。若想用周期性全量广播 + # 来验证正确性、或有自定义的 base sync 需求,可调小 + # (例如 30)。 +``` + +SGLang 端唯一的旋钮(slime 通过 `--sglang-update-weight-partial-chunk-bytes` 自动转发): + +```bash +--sglang-update-weight-partial-chunk-bytes $((2 * 1024 * 1024 * 1024)) +``` + +完整非 colocate 启动脚本见 [examples/partial_weight_sync/run-glm4.7-355B-A32B-partial.sh](../../../examples/partial_weight_sync/run-glm4.7-355B-A32B-partial.sh)。 + +## 两种 partial 模式:selective 与 delta + +两种模式共用 sender 流水线(snapshot、mask 计算、稀疏编码、桶式广播)和 wire 格式,区别只在 values 的语义以及 receiver 的 apply 方式: + +| | `--update-weight-mode selective` | `--update-weight-mode delta` | +|---|---|---| +| wire 上的 values | 变化位置的新权重,dtype 同 snapshot | `(current − snapshot)`,cast 到 `--update-weight-delta-dtype`(默认 fp32) | +| 接收端"未变化"信号 | 解码后的稠密张量在未变化位置填 NaN | 隐式(delta 在未变化位置为 0) | +| 接收端 apply | `param[~isnan(src)] = src[~isnan(src)]`(selective overwrite) | `param += delta`(in-place add,自动提升到 fp32 计算后再 cast 回 param dtype) | +| values 部分 wire 字节 | 2×nnz @ bf16(½× delta) | 4×nnz @ fp32 | +| 是否 lossless | 永远 lossless(无算术) | 当 `delta-dtype` 高于 param dtype 时 lossless | + +当你想要更小的 wire、不需要 fp32 算术余量时选 `selective`;当你需要 fp32 减法去保住 sub-bf16 级别的小 delta 时选 `delta`。 + +## 工作原理 + +每次同步,训练端(仅 PP-source rank): + +1. **计算 payload**:selective 模式下,先在 bf16 上取 `current != snapshot` 的 mask,再生成新权重值并在 unchanged 位置填 NaN;delta 模式下,将当前权重与 pinned-CPU snapshot 同时提升到 delta_dtype 然后相减。 +2. **编码**:将 active 位置稀疏编码为两条扁平张量(`__packed_keys__`、`__packed_values__`)和一份 per-param manifest(`PartialWeightSpec.params`)。 +3. **分桶广播**:多个参数共享一次 NCCL 广播,桶大小由 `--update-weight-buffer-size` 控制。 +4. **异步刷新 snapshot**:把当前权重通过独立 CUDA stream 拷贝到 pinned CPU,与下一轮的广播、编码计算重叠。 + +SGLang 接收端: + +1. **接收**:每个桶接收两条 packed 张量。 +2. **懒解码**:以生成器逐参数 yield 解码后的稠密张量;unchanged 位置按模式填入 sentinel(selective 模式填 NaN,delta 模式填 0)。下游 chunking 的 `chunk_byte_cap` 同时为 decode 阶段的峰值 HBM 设上限(`encoded_buffers + in-flight chunk`)。 +3. **加性写入**:仍走模型 `load_weights` 主路径,但通过一个 context manager 重写 `Tensor.copy_` / `fill_`: + - `selective` 模式下 `_selective_load_context` 把落入 param storage 的 copy_ 重写为 mask-overwrite(`param[~isnan(src)] = src[~isnan(src)]`),unchanged 位置保持不动。 + - `delta` 模式下 `_additive_load_context` 把落入 param storage 的 copy_ 重写为 `add_`(PyTorch 自动提升到 fp32 完成加法、再 cast 回 store,保留 fp32 精度)。 + +非 param 的写入(scratch buffer、dtype 转换、`post_load_weights` 中的 FP8 scale 重计算 / MoE bias 物化等)在两种 context 下都保持原始覆盖语义。 + +Wire protocol —— `PartialWeightSpec`(encoding + per-param manifest)和 `PartialWeightParam`(name、dtype、shape、keys/values slice)—— 定义在 `sglang.srt.managers.io_struct`(由 slime 的 SGLang patch 注入)。 + +## 选择 wire 编码 + +`--update-weight-partial-encoding` 接受三个值: + +| 值 | wire 排布 | 适用场景 | +|---|---|---| +| `sparse_indices` | int32 active 下标 + 值 | 低变化率(< ~3%) | +| `sparse_bitmask` | 每元素 1 bit 的 mask + 值 | 中等变化率(> ~3%) | +| `dense` | 每参数一条张量 | 调试 apply 路径 | + +两种稀疏编码的等价点和值 dtype 无关。令 `n = numel`,`k = nnz`,`v = 值字节数`: + +``` +sparse_indices wire = k * (4 + v) +sparse_bitmask wire = ceil(n / 8) + k * v +``` + +二者相等时 `4k = n/8`,即 `k/n = 1/32 ≈ 3.125%`。低于该 density 选 indices,高于则选 bitmask。常见的小学习率 RL fine-tuning 阶段 `sparse_indices` 更省,训练早期大 LR 阶段几乎所有权重都在动时换 `sparse_bitmask`。 + +## 精度行为 + +`delta` 模式下 `--update-weight-delta-dtype` 控制的是**计算 dtype**,不仅仅是 wire dtype。减法在两个操作数都被提升到 `delta_dtype` 之后进行;接收端的 `param.data.add_(fp32_delta)` 让 PyTorch 内部以共同 dtype(fp32)做加法,然后再 cast 回 bf16 写入 param。这样可以保留那些在 bf16 减法下会直接舍入为零的小幅度 delta。 + +`selective` 模式下没有算术,接收端直接把 trainer 的精确 bf16 值写回 param,因此精度天然 bit-perfect,与 `--update-weight-delta-dtype` 无关(该 flag 在 selective 模式下被静默忽略)。 + +CPU snapshot 在两种模式下都只占用 param dtype 的字节数(不会因此膨胀到 fp32 的存储)。 + +## 周期性 Base Sync + +每次任务的第一次同步永远是 *base sync*(一次完整广播,重建 snapshot)。之后每当 `committed_syncs % --update-weight-base-sync-interval == 0` 再触发一次 base sync。 + +在 selective 模式或 `--update-weight-delta-dtype fp32`(delta 模式)下,partial apply 都是**无损(lossless)**的:selective 模式因为直接覆盖而天然无损;delta 模式下每个 bf16 值都可以精确表示为 fp32,`current_fp32 − snapshot_fp32` 得到两个 bf16 值的精确差,接收端的 `bf16_param.add_(fp32_delta)` 在自动提升到 fp32 完成加法、再 cast 回 bf16 之后,会逐比特地复现 trainer 的 bf16 状态。因为不会有误差累积,无论中间累积了多少次 partial 同步,接收端的状态都不会偏离对应的 base sync 结果,从正确性角度并不需要周期性 base sync。默认 `--update-weight-base-sync-interval 9999` 实际上已关闭周期性 base sync,是推荐设置;若希望用周期性全量广播来验证正确性或有自定义需求,可设成较小的值(例如 `30`)。 + +保留少量 base sync 的运营性理由主要是恢复点——例如一个中途加入的 rollout engine 需要先拿到完整状态才能应用后续 partial 更新。如果你为了进一步压缩 wire 体积而把 `--update-weight-delta-dtype` 设为 `bf16`(不高于 param dtype 的精度,仅对 delta 模式有意义),apply 就不再 lossless,这时 interval 才需要给一个合理的有限值。 + +## 为什么 colocate 模式不需要 + +Colocate 模式的权重同步走的是 CUDA IPC:SGLang 直接把 trainer 进程的参数 storage 映射到自己进程,wire 上只交换一个 IPC handle(~64 B),完全没有 NCCL 广播。Partial 编码的「wire 体积」优势归零,而 partial 更新的额外开销(snapshot 维护、减法/取 mask、稀疏编码)反而是纯开销。所以 slime 在 argparse 阶段就拒绝 `--update-weight-mode selective --colocate` 和 `--update-weight-mode delta --colocate` 的组合。 diff --git a/docs/zh/index.rst b/docs/zh/index.rst index 36d3d79eb2..5b286d3402 100644 --- a/docs/zh/index.rst +++ b/docs/zh/index.rst @@ -44,6 +44,7 @@ slime 是 GLM-4.7、GLM-4.6、GLM-4.5 背后的 RL 训练框架。除此之外 advanced/on-policy-distillation.md advanced/speculative-decoding.md advanced/low-precision.md + advanced/partial-weight-sync.md advanced/reproducibility.md advanced/fault-tolerance.md advanced/pd-disaggregation.md diff --git a/examples/README.md b/examples/README.md index a1e3bd251e..1b417ff932 100644 --- a/examples/README.md +++ b/examples/README.md @@ -4,6 +4,7 @@ These examples provide concrete examples to leverage slime in your own RL workfl ## Directory Structure +- **[partial_weight_sync](./partial_weight_sync)**: Non-colocated weight sync that broadcasts sparse-encoded partial updates (selective or delta) instead of full weights. - **[eval_multi_task](./eval_multi_task)**: Example for supporting evaluation multiple tasks with different configs. - **[fully_async](./fully_async)**: Demonstrates fully asynchronous rollout generation for higher efficiency. - **[geo3k_vlm](./geo3k_vlm)**: Training VLMs on a single-turn reasoning task using GRPO on the GEO3K dataset. diff --git a/examples/partial_weight_sync/README.md b/examples/partial_weight_sync/README.md new file mode 100644 index 0000000000..0c8588e7fc --- /dev/null +++ b/examples/partial_weight_sync/README.md @@ -0,0 +1,121 @@ +# Partial Weight Sync (selective / delta) + +This example demonstrates non-colocated weight sync with **partial-update modes**: instead of broadcasting every parameter on every sync, slime broadcasts only the changed-position payload, and the SGLang receiver applies it without rebroadcasting the unchanged majority of the weights. Two sub-modes: + +- **`selective`** — broadcast new values at changed positions only (with NaN as the "unchanged" sentinel); receiver overwrites those positions, leaves others alone. Lossless by construction (no arithmetic), wire ~½ the size of fp32 delta. Inspired by [arXiv:2509.19128](https://arxiv.org/abs/2509.19128). +- **`delta`** — broadcast `(current − snapshot)` sparse-encoded; receiver applies additively (`param += delta`). Inspired by [Cursor Composer 2](https://cursor.com/resources/Composer2.pdf) and [Fireworks AI — Frontier RL Is Cheaper Than You Think](https://fireworks.ai/blog/frontier-rl-is-cheaper-than-you-think). + +For non-colocated runs the wire shrinks roughly in proportion to the change density, which is typically a few percent during RL fine-tuning at conservative learning rates. The broadcast that previously dominated the sync phase becomes a small fraction of it. Colocated runs share GPU memory via CUDA IPC and have no wire — partial-update modes buy nothing there and are rejected at argparse time. + +## Files + +- `run-glm4.7-355B-A32B-partial.sh`: 16-node (8 actor + 8 rollout) GLM-4.7-355B-A32B launcher with partial-update flags set. + +## Usage + +Set up the same checkpoint and dataset paths as a standard non-colocated GLM-4.7 run (see [docs/en/examples/glm4.7-355B-A32B.md](../../docs/en/examples/glm4.7-355B-A32B.md)), then launch: + +```bash +bash examples/partial_weight_sync/run-glm4.7-355B-A32B-partial.sh +``` + +The script has two pre-built `PARTIAL_ARGS` blocks; the selective block is active by default and the delta block is commented out. Comment one out to switch. + +**Selective mode:** + +```bash +PARTIAL_ARGS=( + --update-weight-mode selective + --update-weight-partial-encoding sparse_indices + --update-weight-base-sync-interval 9999 +) +``` + +**Delta mode:** + +```bash +PARTIAL_ARGS=( + --update-weight-mode delta + --update-weight-partial-encoding sparse_indices + --update-weight-delta-dtype fp32 + --update-weight-base-sync-interval 9999 +) +``` + +Notes: +- `--update-weight-delta-dtype` is delta-only (silently ignored in selective mode — no arithmetic happens there). +- `--update-weight-base-sync-interval` defaults to `9999` — effectively disables periodic base syncs because both modes are lossless under their defaults (selective by construction, delta with fp32 math). Set lower (e.g. `30`) if you want to verify correctness against periodic full broadcasts, or if your workload has a custom base-sync requirement. +- `--update-weight-partial-encoding` accepts `sparse_indices` / `sparse_bitmask` / `dense`. + +And one receiver-side flag in `SGLANG_ARGS`: + +```bash +--sglang-update-weight-partial-chunk-bytes $((2 * 1024 * 1024 * 1024)) +``` + +See [docs/en/advanced/partial-weight-sync.md](../../docs/en/advanced/partial-weight-sync.md) for the wire protocol, encoding choice, and precision behaviour. + +## Results + +### Selective mode + +W&B traces comparing `selective` mode against the full-sync baseline. + + + +*Placeholder — selective experiment numbers and traces pending.* + +![Update weights density](./update_weights_density.png) + +*Per-sync change density (`perf/update_weights_density`) — fraction of weight positions that moved between consecutive syncs. Step 0 is omitted: it's always the warmup base sync with density = 1.0, which would compress the y-axis and hide the partial-sync values.* + +### Delta mode + +W&B traces comparing `delta` mode against the full-sync baseline on the run above. + +![Raw reward](./raw_reward.png) + +*Raw reward over training steps — delta and full match.* + +![Train/rollout logprob abs diff](./train_rollout_logprob_abs_diff.png) + +*Absolute logprob difference between train and rollout — delta and full match.* + +![Update weights time](./update_weights_time.png) + +*Per-step weight-update wall-clock — delta is substantially faster.* + +## Reading the curves + +The reward / logprob-diff curves track each other closely between modes, but they don't sit pixel-on-pixel. That divergence is **not** evidence that partial-update modes lose information — both modes shipped here are mathematically lossless under their respective recipes: + +- **`selective`**: no arithmetic at all — the receiver overwrites changed positions with the trainer's exact bf16 values. Lossless by construction. +- **`delta` with `--update-weight-delta-dtype fp32`**: every bf16 value is exactly representable in fp32; the subtraction is exact within fp32; the receiver's in-place `bf16 += fp32` add casts back identically to what a full sync would store. Receiver state matches a full-sync reference bit-for-bit per step. + +The small curve-to-curve divergence comes from **non-determinism elsewhere in the training/rollout stack** (cuBLAS reductions, FlashAttention split-K, NCCL all-reduce ordering, dynamic-batch token assignment). Two identically-configured *full*-sync runs would diverge the same way. What's "matching" between partial and full here is the trajectory, not the bits. + +## Why `sparse_indices` for this run + +Per-sync weight-change density during RL fine-tuning at conservative learning rates is typically a few percent — see for instance [arXiv:2602.03839](https://arxiv.org/pdf/2602.03839), which reports that only on the order of 1% of weights change per RL update. Our own logs on the GLM-4.7-355B run measured roughly **2–3% density per sync**. + +The break-even density between the two sparse encodings is independent of the value dtype. With `n = numel`, `k = nnz`, `v = value-dtype bytes`: + +``` +sparse_indices wire = k * (4 + v) (int32 indices + values) +sparse_bitmask wire = ceil(n / 8) + k * v (1 bit per element + values) +``` + +Equal when `4k = n/8`, i.e. `k/n = 1/32 ≈ 3.125%`. Below 3.125% `sparse_indices` is smaller; above, `sparse_bitmask` wins. Our 2–3% observed density sits below the break-even — hence `sparse_indices` is the right pick for this workload. (`dense` is the no-compression option, kept around for debugging the additive / selective apply path independently of the sparse encoding.) + +## Composes with any communication optimization in slime + +This feature only changes *what bytes get shipped*; it does not touch the NCCL broadcast itself, the Ray lock around it, the bucket scheduling, or any send/receive layer. So any future slime improvement to the weight-update communication path — better compute/broadcast overlap, NIC-level optimizations, pipeline-parallel sends, deduplicated metadata — stacks additively on top of the speedups shown above. Both `selective` and `delta` inherit those gains for free. + +## Two modes, one feature + +`selective` and `delta` are both *lossless* partial-update modes. They differ only in what they put on the wire and how the receiver applies it: + +- **`selective`** carries new values directly (~½ the values-wire at bf16) and applies them by selective overwrite. No arithmetic on either side, so the receiver is bit-exact with the trainer regardless of dtype. Pick this when wire size is the binding constraint. +- **`delta`** keeps an arithmetic path (`receiver += sender's delta`). Wire-values portion is 4 bytes/element at fp32. Pick this when you want the arithmetic semantics (e.g. for compatibility with future ideas that compose with additive apply). + +Both are exposed so you can pick the trade-off that fits your run. diff --git a/examples/partial_weight_sync/run-glm4.7-355B-A32B-partial.sh b/examples/partial_weight_sync/run-glm4.7-355B-A32B-partial.sh new file mode 100755 index 0000000000..b11d42e5da --- /dev/null +++ b/examples/partial_weight_sync/run-glm4.7-355B-A32B-partial.sh @@ -0,0 +1,204 @@ +#!/bin/bash + +# Non-colocated GLM-4.7-355B-A32B with partial weight sync. +# 8 actor nodes (TP=8, PP=4, EP=16) + 64 rollout GPUs (8 H100 nodes worth), +# 16 nodes total. Two modes available — see PARTIAL_ARGS below; the default +# block is `selective`, the alternate `delta` block is commented out. + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 + +unset http_proxy https_proxy HTTP_PROXY HTTPS_PROXY + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +source "/root/slime/scripts/models/glm4.5-355B-A32B.sh" + +CKPT_ARGS=( + --hf-checkpoint /root/GLM-4.7-355B-A32B + --ref-load /root/GLM-4.7-355B-A32B_torch_dist/ +) + +ROLLOUT_ARGS=( + --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type deepscaler + --num-rollout 3000 + --rollout-batch-size 64 + --n-samples-per-prompt 8 + --rollout-max-response-len 8192 + --rollout-temperature 1 + + --num-steps-per-rollout 4 + --balance-data + --rollout-stop-token-ids 151329 151336 151338 +) + +EVAL_ARGS=( + --eval-interval 20 + --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl + --n-samples-per-eval-prompt 8 + --eval-max-response-len 8192 + --eval-top-p 1 +) + +PERF_ARGS=( + --tensor-model-parallel-size 8 + --sequence-parallel + --pipeline-model-parallel-size 4 + --context-parallel-size 2 + --expert-model-parallel-size 16 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + --use-dynamic-batch-size + --max-tokens-per-gpu 16384 +) + +GRPO_ARGS=( + --advantage-estimator gspo + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --kl-coef 0.00 + --entropy-coef 0.00 + --eps-clip 1e-4 + --eps-clip-high 2e-4 + --use-tis +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 + + --optimizer-cpu-offload + --overlap-cpu-optimizer-d2h-h2d + --use-precision-aware-optimizer +) + +WANDB_ARGS=( + # --use-wandb + # --wandb-project slime-delta + # --wandb-group glm4.7-355B-delta +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 32 + --sglang-mem-fraction-static 0.7 + --sglang-enable-dp-attention + --sglang-dp-size 4 + --sglang-ep-size 32 + --sglang-enable-dp-lm-head + --sglang-moe-dense-tp-size 1 + + # Receiver batches up to this many bytes per model.load_weights call. Bigger + # amortizes per-call cost (name resolution, MoE expert remap) but raises + # peak HBM during decode. + --sglang-update-weight-partial-chunk-bytes $((2 * 1024 * 1024 * 1024)) + + # mtp + --sglang-speculative-algorithm EAGLE + --sglang-speculative-num-steps 3 + --sglang-speculative-eagle-topk 1 + --sglang-speculative-num-draft-tokens 4 +) + +# Partial weight sync (sender side). Pick one of the two blocks below. +# +# `--update-weight-base-sync-interval` defaults to 9999 — effectively disables +# periodic base syncs because both modes are lossless under their defaults +# (selective by construction, delta with fp32 math). Set lower (e.g. 30) if +# you want to verify against periodic full broadcasts, or if your workload +# has a custom base-sync requirement. + +# ── Mode 1: selective — broadcast new values at changed positions ───────── +# `--update-weight-delta-dtype` is silently ignored here (no arithmetic; apply +# is lossless by construction). +PARTIAL_ARGS=( + --update-weight-mode selective + --update-weight-partial-encoding sparse_indices + --update-weight-base-sync-interval 9999 +) + +# ── Mode 2: delta — broadcast (current − snapshot), receiver += delta ────── +# Uncomment to run delta instead (and comment out the selective block above). +# PARTIAL_ARGS=( +# --update-weight-mode delta +# --update-weight-partial-encoding sparse_indices +# --update-weight-delta-dtype fp32 +# --update-weight-base-sync-interval 9999 +# ) + +MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash + --moe-token-dispatcher-type flex + --moe-enable-deepep + --update-weight-buffer-size $((2 * 1024 * 1024 * 1024)) +) + +# launch the master node of ray +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +# Build the runtime environment JSON +RUNTIME_ENV_JSON=$(cat < None: diff --git a/slime/backends/megatron_utils/data.py b/slime/backends/megatron_utils/data.py index 19db1f475a..d25b1d93e9 100644 --- a/slime/backends/megatron_utils/data.py +++ b/slime/backends/megatron_utils/data.py @@ -595,7 +595,7 @@ def log_passrate(rollout_id: int, args: Namespace, rollout_data: RolloutBatch) - gather_log_data("passrate", args, rollout_id, log_dict) -def log_perf_data(rollout_id: int, args: Namespace) -> None: +def log_perf_data(rollout_id: int, args: Namespace, extra_metrics: dict | None = None) -> None: train_metric_utils.log_perf_data_raw( rollout_id=rollout_id, args=args, @@ -607,6 +607,7 @@ def log_perf_data(rollout_id: int, args: Namespace) -> None: compute_total_fwd_flops=lambda seq_lens: calculate_fwd_flops(seqlens=seq_lens, args=args) / dist.get_world_size() / 1e12, + extra_metrics=extra_metrics, ) diff --git a/slime/backends/megatron_utils/sglang.py b/slime/backends/megatron_utils/sglang.py index 97c82a31cd..5c586e990b 100644 --- a/slime/backends/megatron_utils/sglang.py +++ b/slime/backends/megatron_utils/sglang.py @@ -13,6 +13,7 @@ from sglang.srt.patch_torch import monkey_patch_torch_reductions +from sglang.srt.managers.io_struct import PartialWeightEncoding, PartialWeightParam, PartialWeightSpec from sglang.srt.utils import MultiprocessingSerializer @@ -28,4 +29,7 @@ "monkey_patch_torch_reductions", "MultiprocessingSerializer", "FlattenedTensorBucket", + "PartialWeightEncoding", + "PartialWeightParam", + "PartialWeightSpec", ] diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py index 822b801776..5836da64ec 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py @@ -1,7 +1,7 @@ import socket import time from argparse import Namespace -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Iterator, Mapping, Sequence import ray import torch @@ -14,6 +14,7 @@ from slime.utils.distributed_utils import get_gloo_group, init_process_group from ..megatron_to_hf import convert_to_hf +from ..sglang import PartialWeightSpec from .common import all_gather_param, named_params_and_buffers @@ -21,6 +22,7 @@ class UpdateWeightFromDistributed: """ Update distributed engines via NCCL. Each PP rank: group "slime-pp_{pp_rank}", only DP=TP=0 broadcasts. Non-expert (TP) and expert (EP) params separate. + Subclasses override ``_send_weights`` to inject per-mode behaviour (e.g. delta). """ def __init__( @@ -41,6 +43,15 @@ def __init__( self.quantization_config = quantization_config self.weight_version = 0 self._model_update_groups = None + self.update_weight_metrics: dict[str, float] = {} + + def pop_metrics(self) -> dict[str, float]: + """ + Return and clear ``update_weight_metrics``. Drained by the actor onto + the rollout/step flush each step. + """ + out, self.update_weight_metrics = self.update_weight_metrics, {} + return out def connect_rollout_engines( self, @@ -89,7 +100,7 @@ def disconnect_rollout_engines(self) -> None: @torch.no_grad() def update_weights(self) -> None: """ - Pause → flush → non-expert (TP) → expert (EP) → continue. Progress on PP source. + Pause → flush → _send_weights → continue. Progress on PP source. """ self.weight_version += 1 @@ -106,34 +117,8 @@ def update_weights(self) -> None: ) dist.barrier(group=get_gloo_group()) - buffer_size = 0 - converted_named_tensors = [] - # non expert params pbar = tqdm(desc=f"[{self._group_name}] Update weights", total=0) if self._is_pp_src_rank else None - - for name, param in named_params_and_buffers(self.args, self.model): - if ".experts." in name: - continue - buffer_size = self._update_weight_from_distributed( - name, param, converted_named_tensors, buffer_size, pbar=pbar - ) - - if converted_named_tensors: - self._update_bucket_weights_from_distributed(converted_named_tensors, pbar=pbar) - - dist.barrier(group=get_gloo_group()) - - buffer_size = 0 - named_tensors = [] - for name, param in named_params_and_buffers(self.args, self.model): - if ".experts." not in name: - continue - buffer_size = self._update_expert_weight_from_distributed( - name, param, named_tensors, buffer_size, pbar=pbar - ) - - if named_tensors: - self._update_expert_bucket_weights_from_distributed(named_tensors, pbar=pbar) + self._send_weights(pbar) dist.barrier(group=get_gloo_group()) if dist.get_rank() == 0: @@ -147,59 +132,83 @@ def update_weights(self) -> None: ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) dist.barrier(group=get_gloo_group()) - def _update_weight_from_distributed( - self, - name: str, - param: torch.nn.Parameter, - converted_named_tensors: list[tuple[str, torch.Tensor]], - buffer_size: int, - pbar: tqdm | None = None, - ) -> int | None: + def _send_weights(self, pbar: tqdm | None) -> None: """ - Non-expert: gather TP → rm pad → HF → buffer (flush if full). All gather, PP source buffers. - Returns updated bytes on source, None on non-source. + Non-expert (TP) loop → barrier → expert (EP) loop. Subclasses should + override ``_on_chunk`` to extend; override this only when the bucketing + strategy differs. """ - param = all_gather_param(name, param) - if not self._is_pp_src_rank: - return + buffer_size = 0 + buffer: list[tuple[str, torch.Tensor]] = [] + for hf_chunk in self._iter_non_expert_chunks(): + self._on_chunk(hf_chunk) + chunk_bytes = sum(t.numel() * t.element_size() for _, t in hf_chunk) + if buffer_size + chunk_bytes > self.args.update_weight_buffer_size: + self._update_bucket_weights_from_distributed(buffer, pbar=pbar) + buffer = [] + buffer_size = 0 + buffer.extend(hf_chunk) + buffer_size += chunk_bytes + if buffer: + self._update_bucket_weights_from_distributed(buffer, pbar=pbar) - param_size = param.numel() * param.element_size() - if buffer_size + param_size > self.args.update_weight_buffer_size: - self._update_bucket_weights_from_distributed(converted_named_tensors, pbar=pbar) - buffer_size = 0 - converted_named_tensors += convert_to_hf(self.args, self.model_name, name, param, self.quantization_config) - buffer_size += param_size - return buffer_size + dist.barrier(group=get_gloo_group()) - def _update_expert_weight_from_distributed( - self, - name: str, - param: torch.nn.Parameter, - named_tensors: list[tuple[str, torch.Tensor]], - buffer_size: int, - pbar: tqdm | None = None, - ) -> int: + for hf_chunk in self._iter_expert_chunks(): + self._on_chunk(hf_chunk) + self._update_bucket_weights_from_distributed(hf_chunk, pbar=pbar) + + def _on_chunk(self, hf_chunk: list[tuple[str, torch.Tensor]]) -> None: """ - Expert: gather TP → rm pad → buffer. EP gather + HF deferred. Threshold × EP size. + Hook for each HF chunk in ``_send_weights`` before its broadcast. No-op by default. """ - param = all_gather_param(name, param) - - param_size = param.numel() * param.element_size() - if ( - buffer_size + param_size - ) * mpu.get_expert_model_parallel_world_size() > self.args.update_weight_buffer_size: - self._update_expert_bucket_weights_from_distributed(named_tensors, pbar=pbar) - buffer_size = 0 - named_tensors.append((name, param)) - buffer_size += param_size - return buffer_size + def _iter_non_expert_chunks(self) -> Iterator[list[tuple[str, torch.Tensor]]]: + """ + Yield one HF chunk per non-expert param after TP all-gather + HF convert. + Empty generator on non-PP-src ranks (they still participate in all_gather_param). + """ + for name, param in named_params_and_buffers(self.args, self.model): + if ".experts." in name: + continue + param = all_gather_param(name, param) + if not self._is_pp_src_rank: + continue + yield convert_to_hf(self.args, self.model_name, name, param, self.quantization_config) - def _update_expert_bucket_weights_from_distributed( - self, named_tensors: list[tuple[str, torch.Tensor]], pbar: tqdm | None = None - ) -> None: + def _iter_expert_chunks(self) -> Iterator[list[tuple[str, torch.Tensor]]]: + """ + Yield one HF chunk per EP-weighted flush of expert params: TP gather + buffer + until threshold, then EP gather + HF convert. All ranks participate in the EP + collective; only PP source yields a non-empty chunk and only on those ranks + does the generator advance. + """ + buffer_size = 0 + batch: list[tuple[str, torch.Tensor]] = [] + for name, param in named_params_and_buffers(self.args, self.model): + if ".experts." not in name: + continue + param = all_gather_param(name, param) + param_size = param.numel() * param.element_size() + if ( + buffer_size + param_size + ) * mpu.get_expert_model_parallel_world_size() > self.args.update_weight_buffer_size: + hf_chunk = self._ep_gather_and_convert(batch) + if hf_chunk: + yield hf_chunk + batch = [] + buffer_size = 0 + batch.append((name, param)) + buffer_size += param_size + if batch: + hf_chunk = self._ep_gather_and_convert(batch) + if hf_chunk: + yield hf_chunk + + def _ep_gather_and_convert(self, named_tensors: list[tuple[str, torch.Tensor]]) -> list[tuple[str, torch.Tensor]]: """ - Gather EP → HF → broadcast. Clears buffer. + EP all-gather a buffered batch + HF convert on PP source. Clears ``named_tensors``. + Returns HF tensors on PP source, [] elsewhere. """ names = [name for name, _ in named_tensors] all_names = [None] * mpu.get_expert_model_parallel_world_size() @@ -224,20 +233,24 @@ def _update_expert_bucket_weights_from_distributed( named_tensors.clear() if not self._is_pp_src_rank: - return + return [] all_gathered_params = sum(all_gathered_params, []) converted_hf_tensors = [] for name, param in all_gathered_params: converted_hf_tensors += convert_to_hf(self.args, self.model_name, name, param, self.quantization_config) - - self._update_bucket_weights_from_distributed(converted_hf_tensors, pbar) + return converted_hf_tensors def _update_bucket_weights_from_distributed( - self, converted_named_tensors: list[tuple[str, torch.Tensor]], pbar: tqdm | None = None + self, + converted_named_tensors: list[tuple[str, torch.Tensor]], + pbar: tqdm | None = None, + load_format: str | None = None, + partial: PartialWeightSpec | None = None, ) -> None: """ Lock → broadcast → clear → unlock → pbar++. Lock prevents NCCL deadlock. + Partial-update modes (selective/delta) pass ``load_format`` and a ``PartialWeightSpec``. """ # lock the rollout engines to prevent dead lock on broadcast. while not ray.get(self.rollout_engine_lock.acquire.remote()): @@ -249,6 +262,8 @@ def _update_bucket_weights_from_distributed( self.weight_version, self.rollout_engines, converted_named_tensors, + load_format=load_format, + partial=partial, ) ray.get(refs) @@ -321,9 +336,13 @@ def update_weights_from_distributed( weight_version: int, rollout_engines: Sequence[ActorHandle], converted_named_tensors: Sequence[tuple[str, torch.Tensor]], + load_format: str | None = None, + partial: PartialWeightSpec | None = None, ) -> list[ObjectRef]: """ Send metadata (Ray), broadcast tensors (NCCL rank 0 → engines). + Partial-update modes pass ``load_format`` (``"selective"`` / ``"delta"``) + and ``partial`` (PartialWeightSpec). """ refs = [ engine.update_weights_from_distributed.remote( @@ -332,6 +351,8 @@ def update_weights_from_distributed( shapes=[param.shape for _, param in converted_named_tensors], group_name=group_name, weight_version=str(weight_version), + load_format=load_format, + partial=partial, ) for engine in rollout_engines ] diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_partial.py b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_partial.py new file mode 100644 index 0000000000..794ab79652 --- /dev/null +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_partial.py @@ -0,0 +1,508 @@ +from __future__ import annotations + +import os +import threading +from argparse import Namespace +from collections.abc import Callable, Mapping, Sequence +from dataclasses import dataclass, field, replace +from queue import Queue + +import torch +import torch.distributed as dist +from safetensors.torch import save, save_file +from tqdm import tqdm + +from slime.utils.distributed_utils import get_gloo_group + +from ..sglang import PartialWeightEncoding, PartialWeightParam, PartialWeightSpec +from .update_weight_from_distributed import UpdateWeightFromDistributed + +try: + import zstandard +except ImportError: + zstandard = None + + +_DELTA_DTYPE_MAP = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, +} + + +@dataclass +class PartialChunk: + """ + One encoded chunk awaiting PartialSendBucket. + dense → tensors=[(name, payload)], params=None. + sparse_* → tensors=[("__packed_keys__", ...), ("__packed_values__", ...)], + params holds per-param decoding with chunk-local offsets. + nnz is the active-position count across all params (used for density). + """ + + tensors: list[tuple[str, torch.Tensor]] + params: list[PartialWeightParam] | None + byte_size: int + nnz: int = 0 + + +@dataclass +class PartialPayload: + """ + Per-param compute output flowing into the encoder. ``payload`` carries the + full-size values at every position (deltas for ``delta``, new param values + for ``selective``); ``mask`` is a bool tensor marking active positions. + The encoder consumes (payload, mask) directly — no mode-specific predicate + re-derives the mask, and no sentinel value is materialized on the sender. + """ + + name: str + payload: torch.Tensor + mask: torch.Tensor + + +def encode_partial( + named_payloads: list[PartialPayload], + encoding: PartialWeightEncoding, + mode: str, +) -> PartialChunk: + """ + Encode partial-update payloads per wire encoding. ``mode`` is needed only + by the dense encoder, which has to materialize the receiver-side sentinel. + Sparse paths read the mask directly from each PartialPayload. + """ + if encoding is PartialWeightEncoding.DENSE: + return _encode_dense(named_payloads, mode) + if encoding is PartialWeightEncoding.SPARSE_INDICES: + return _encode_sparse(named_payloads, _indices_kv) + if encoding is PartialWeightEncoding.SPARSE_BITMASK: + return _encode_sparse(named_payloads, _bitmask_kv) + raise ValueError(f"unknown partial-update encoding: {encoding!r}") + + +def _encode_dense(named_payloads: list[PartialPayload], mode: str) -> PartialChunk: + """ + Dense wire: send a full-size tensor per param with a receiver-side + sentinel at unchanged positions. Delta's payload already has 0 at + unchanged (current − snapshot is 0 there); selective re-materializes a + NaN-marked tensor here — lazy, since dense is the debug-only encoding. + """ + tensors: list[tuple[str, torch.Tensor]] = [] + nnz = 0 + for pp in named_payloads: + nnz += int(pp.mask.sum()) + if mode == "selective": + nan = torch.full_like(pp.payload, float("nan")) + tensors.append((pp.name, torch.where(pp.mask, pp.payload, nan))) + else: # "delta" + tensors.append((pp.name, pp.payload)) + size = sum(t.numel() * t.element_size() for _, t in tensors) + return PartialChunk(tensors=tensors, params=None, byte_size=size, nnz=nnz) + + +def _encode_sparse( + named_payloads: list[PartialPayload], + kv_fn: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]], +) -> PartialChunk: + """ + Walk named_payloads, ask kv_fn for (keys, values) per param, pack into a + single (packed_keys, packed_values) PartialChunk with a per-param manifest. + Params with zero active positions are skipped. + """ + keys_chunks: list[torch.Tensor] = [] + values_chunks: list[torch.Tensor] = [] + params: list[PartialWeightParam] = [] + keys_off = values_off = 0 + for pp in named_payloads: + flat_payload = pp.payload.contiguous().view(-1) + flat_mask = pp.mask.contiguous().view(-1) + keys, values = kv_fn(flat_payload, flat_mask) + nnz = int(values.numel()) + if nnz == 0: + continue + keys_count = int(keys.numel()) + params.append( + PartialWeightParam( + name=pp.name, + dtype=str(pp.payload.dtype).replace("torch.", ""), + shape=list(pp.payload.shape), + keys_start=keys_off, + keys_end=keys_off + keys_count, + values_start=values_off, + values_end=values_off + nnz, + ) + ) + keys_chunks.append(keys) + values_chunks.append(values) + keys_off += keys_count + values_off += nnz + if not params: + return PartialChunk(tensors=[], params=[], byte_size=0) + packed_keys = torch.cat(keys_chunks, dim=0) + packed_values = torch.cat(values_chunks, dim=0) + size = packed_keys.numel() * packed_keys.element_size() + packed_values.numel() * packed_values.element_size() + return PartialChunk( + tensors=[("__packed_keys__", packed_keys), ("__packed_values__", packed_values)], + params=params, + byte_size=size, + nnz=values_off, + ) + + +def _indices_kv(flat_payload: torch.Tensor, flat_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + idx_long = flat_mask.nonzero(as_tuple=False).view(-1) + return idx_long.to(dtype=torch.int32), flat_payload[idx_long] + + +def _bitmask_kv(flat_payload: torch.Tensor, flat_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + values = flat_payload[flat_mask] + if flat_mask.numel() == 0: + return torch.empty(0, dtype=torch.uint8, device=flat_mask.device), values + mask_u8 = flat_mask.to(dtype=torch.uint8) + pad = (-mask_u8.numel()) % 8 + if pad: + mask_u8 = torch.cat([mask_u8, torch.zeros(pad, dtype=torch.uint8, device=mask_u8.device)]) + bits = mask_u8.view(-1, 8) + weights = (2 ** torch.arange(8, dtype=torch.uint8, device=mask_u8.device)).view(1, 8) + return torch.sum(bits * weights, dim=1, dtype=torch.uint8), values + + +@dataclass +class PartialSendBucket: + """ + Accumulates PartialChunks for one batched broadcast. Sparse params are + eagerly shifted into merged-buffer coordinates on add(). + """ + + tensors: list[tuple[str, torch.Tensor]] = field(default_factory=list) + params: list[PartialWeightParam] = field(default_factory=list) + byte_size: int = 0 + keys_total: int = 0 + values_total: int = 0 + + @property + def has_updates(self) -> bool: + return bool(self.tensors) + + def should_flush_before_add(self, update: PartialChunk, byte_limit: int) -> bool: + return self.has_updates and self.byte_size + update.byte_size > byte_limit + + def add(self, update: PartialChunk) -> None: + if update.params is not None: + for p in update.params: + self.params.append( + replace( + p, + keys_start=p.keys_start + self.keys_total, + keys_end=p.keys_end + self.keys_total, + values_start=p.values_start + self.values_total, + values_end=p.values_end + self.values_total, + ) + ) + self.keys_total += update.tensors[0][1].numel() + self.values_total += update.tensors[1][1].numel() + self.tensors.extend(update.tensors) + self.byte_size += update.byte_size + + def flush_payload( + self, + ) -> tuple[list[tuple[str, torch.Tensor]], list[PartialWeightParam] | None]: + """ + Sparse: concat per-chunk packed tensors into one pair. Dense: tensors as-is. + """ + if not self.params: + return list(self.tensors), None + keys = [t for n, t in self.tensors if n == "__packed_keys__"] + values = [t for n, t in self.tensors if n == "__packed_values__"] + merged = [ + ("__packed_keys__", torch.cat(keys, dim=0)), + ("__packed_values__", torch.cat(values, dim=0)), + ] + return merged, list(self.params) + + def clear(self) -> None: + self.tensors.clear() + self.params.clear() + self.byte_size = 0 + self.keys_total = 0 + self.values_total = 0 + + +class PartialSync: + """ + Owns pinned-CPU snapshot of last broadcast's tensors and the base-vs-partial + decision. PP-source-rank only. ``compute_payload`` produces per-param + PartialPayloads for either mode against the same snapshot. + """ + + def __init__(self, args: Namespace) -> None: + self.delta_dtype = _DELTA_DTYPE_MAP[args.update_weight_delta_dtype] + self.base_sync_interval = args.update_weight_base_sync_interval + if self.base_sync_interval < 1: + raise ValueError("--update-weight-base-sync-interval must be >= 1") + self.snapshot: dict[str, torch.Tensor] = {} + self.committed_syncs = 0 + self.d2h_stream: torch.cuda.Stream | None = None + self.snapshot_dirty = False + + def should_send_base(self) -> bool: + return self.committed_syncs == 0 or self.committed_syncs % self.base_sync_interval == 0 + + def compute_payload(self, named_tensors: list[tuple[str, torch.Tensor]], mode: str) -> list[PartialPayload]: + """ + For each param produce a PartialPayload. Both modes share the snapshot + preamble (wait for in-flight D2H, batch H2D the pinned snapshot, sync); + they differ only in what counts as the payload and how the mask is + derived: + + delta — payload = (new − snapshot) at delta_dtype; mask = payload != 0 + selective — payload = new (reference, no copy); mask = new != snapshot + + Caller advances snapshot after. + """ + if mode == "delta": + + def per_param(name, tensor, prev): + payload = tensor.to(self.delta_dtype) - prev.to(self.delta_dtype) + return payload, payload != 0 + + elif mode == "selective": + + def per_param(name, tensor, prev): + if not tensor.dtype.is_floating_point: + raise TypeError(f"selective mode requires float param dtype; got {tensor.dtype} for {name!r}") + return tensor, tensor != prev + + else: + raise ValueError(f"unknown partial-update mode: {mode!r}") + self.flush_snapshot() + prev_gpu = [] + for name, tensor in named_tensors: + if name not in self.snapshot: + raise KeyError(f"missing snapshot for {name!r}; need a base sync first") + prev_gpu.append(self.snapshot[name].to(device=tensor.device, non_blocking=True)) + torch.cuda.synchronize() + result: list[PartialPayload] = [] + for (name, tensor), prev in zip(named_tensors, prev_gpu, strict=True): + payload, mask = per_param(name, tensor, prev) + result.append(PartialPayload(name=name, payload=payload, mask=mask)) + del prev + return result + + def update_snapshot_async(self, named_tensors: list[tuple[str, torch.Tensor]]) -> None: + """ + D2H snapshot copy on a side stream so it overlaps downstream broadcast/encode. + """ + if self.d2h_stream is None: + self.d2h_stream = torch.cuda.Stream() + event = torch.cuda.current_stream().record_event() + with torch.cuda.stream(self.d2h_stream): + self.d2h_stream.wait_event(event) + for name, tensor in named_tensors: + if name not in self.snapshot: + self.snapshot[name] = torch.empty_like(tensor, device=torch.device("cpu"), pin_memory=True) + self.snapshot[name].copy_(tensor.detach(), non_blocking=True) + self.snapshot_dirty = True + + def flush_snapshot(self) -> None: + if self.snapshot_dirty: + if self.d2h_stream is not None: + self.d2h_stream.synchronize() + else: + torch.cuda.synchronize() + self.snapshot_dirty = False + + def on_sync_succeeded(self) -> None: + self.flush_snapshot() + self.committed_syncs += 1 + + +class PartialArtifactWriter: + """ + Async background writer for per-chunk partial-update artifacts. Active iff + ``--update-weight-partial-artifact-dir`` is set. Output is per-chunk safetensors + (zstd-wrapped if ``zstandard`` is installed). + """ + + def __init__(self, artifact_dir: str) -> None: + self.artifact_dir = artifact_dir + os.makedirs(artifact_dir, exist_ok=True) + self.work_queue: Queue = Queue() + self.compressor = zstandard.ZstdCompressor() if zstandard is not None else None + self.thread = threading.Thread(target=self._run, name="partial-artifact-writer", daemon=True) + self.thread.start() + + def enqueue( + self, + weight_version: int, + chunk_idx: int, + named_tensors: list[tuple[str, torch.Tensor]], + ) -> None: + self.work_queue.put((weight_version, chunk_idx, {name: t.contiguous() for name, t in named_tensors})) + + def _run(self) -> None: + while True: + item = self.work_queue.get() + if item is None: + return + weight_version, chunk_idx, tensors = item + stem = f"weight_version_{weight_version:06d}_chunk_{chunk_idx:06d}.safetensors" + if self.compressor is None: + save_file(tensors, os.path.join(self.artifact_dir, stem)) + continue + payload = save(tensors) + with open(os.path.join(self.artifact_dir, f"{stem}.zst"), "wb") as f: + f.write(self.compressor.compress(payload)) + + +class UpdateWeightFromDistributedPartial(UpdateWeightFromDistributed): + """ + Partial-update variant. Sends a sparse-encoded payload per named tensor and + has SGLang apply it. Two sub-modes, selected by ``--update-weight-mode``: + + * ``selective``: payload values are the new param values at changed positions, + with NaN as the "unchanged" sentinel; receiver overwrites the non-NaN + positions only. + * ``delta``: payload values are ``(current − snapshot)`` cast to delta_dtype; + receiver applies additively (``param += delta``). + + Periodic base syncs (full broadcasts) refresh the snapshot. The ``_on_chunk`` + hook on the base class is used to keep the snapshot in lockstep during base + syncs. + """ + + def __init__( + self, + args: Namespace, + model: Sequence[torch.nn.Module], + weights_getter: Callable[[], Mapping[str, torch.Tensor]], + *, + model_name: str, + quantization_config: dict[str, int | str | list[str]] | None, + ) -> None: + super().__init__( + args, + model, + weights_getter, + model_name=model_name, + quantization_config=quantization_config, + ) + self.mode = args.update_weight_mode # "selective" or "delta" + self.partial_sync = PartialSync(args) + self.artifact_writer = ( + PartialArtifactWriter(args.update_weight_partial_artifact_dir) + if args.update_weight_partial_artifact_dir is not None + else None + ) + self.artifact_chunk_idx = 0 + self.pending_artifacts: list[list[tuple[str, torch.Tensor]]] = [] + self.density_nnz = 0 + self.density_numel = 0 + self.wire_bytes = 0 + + def _send_weights(self, pbar: tqdm | None) -> None: + is_base = self.partial_sync.should_send_base() + self.density_nnz = 0 + self.density_numel = 0 + self.wire_bytes = 0 + if is_base: + super()._send_weights(pbar) + else: + self._send_partial_weights(pbar) + # Increment on all ranks so should_send_base() stays in lockstep across + # the PP group. flush_snapshot() is a no-op on non-PP-src ranks. + self.partial_sync.on_sync_succeeded() + self._record_metrics(is_base) + + def _on_chunk(self, hf_chunk: list[tuple[str, torch.Tensor]]) -> None: + """ + Base-sync hook: snapshot this chunk so the next partial sync has prev to diff against. + Also count dense wire bytes so base syncs share the wire_bytes metric axis. + """ + self.partial_sync.update_snapshot_async(hf_chunk) + self.wire_bytes += sum(t.numel() * t.element_size() for _, t in hf_chunk) + + def _send_partial_weights(self, pbar: tqdm | None) -> None: + """ + non-expert (TP) loop → barrier → expert (EP) loop. Each HF chunk is + converted to a partial-update payload (selective or delta) and bucketed. + """ + encoding = PartialWeightEncoding(self.args.update_weight_partial_encoding) + bucket = PartialSendBucket() + for hf_chunk in self._iter_non_expert_chunks(): + self._enqueue_partial_chunk(hf_chunk, encoding, bucket, pbar) + self._flush_partial_bucket(bucket, encoding, pbar) + + dist.barrier(group=get_gloo_group()) + + for hf_chunk in self._iter_expert_chunks(): + self._enqueue_partial_chunk(hf_chunk, encoding, bucket, pbar) + self._flush_partial_bucket(bucket, encoding, pbar) + + def _enqueue_partial_chunk( + self, + hf_chunk: list[tuple[str, torch.Tensor]], + encoding: PartialWeightEncoding, + bucket: PartialSendBucket, + pbar: tqdm | None, + ) -> None: + """ + compute payloads (mode-specific) → snapshot new prev → encode → bucket.add. + """ + if not hf_chunk: + return + payloads = self.partial_sync.compute_payload(hf_chunk, self.mode) + self.partial_sync.update_snapshot_async(hf_chunk) + chunk = encode_partial(payloads, encoding, self.mode) + # numel from input payload so zero-nnz params still count — otherwise + # the ratio biases toward params that did change. + self.density_numel += sum(pp.payload.numel() for pp in payloads) + self.density_nnz += chunk.nnz + self.wire_bytes += chunk.byte_size + if not chunk.tensors: + return + if bucket.should_flush_before_add(chunk, self.args.update_weight_buffer_size): + self._flush_partial_bucket(bucket, encoding, pbar) + # Append AFTER the flush check so this chunk's artifact lands in the + # same flush as its broadcast (and not at all if its encoding skipped). + if self.artifact_writer is not None: + self.pending_artifacts.append([(pp.name, pp.payload.cpu()) for pp in payloads]) + bucket.add(chunk) + + def _flush_partial_bucket( + self, + bucket: PartialSendBucket, + encoding: PartialWeightEncoding, + pbar: tqdm | None, + ) -> None: + """ + Lock → broadcast (with PartialWeightSpec) → unlock → pbar++. Drains + pending artifacts to the async writer once the broadcast lands. + load_format is "selective" or "delta" per self.mode. + """ + if not bucket.has_updates: + return + wire_tensors, params = bucket.flush_payload() + spec = PartialWeightSpec(encoding=encoding, params=params) + bucket.clear() + self._update_bucket_weights_from_distributed(wire_tensors, pbar=pbar, load_format=self.mode, partial=spec) + if self.artifact_writer is not None: + for artifact in self.pending_artifacts: + self.artifact_writer.enqueue(self.weight_version, self.artifact_chunk_idx, artifact) + self.artifact_chunk_idx += 1 + self.pending_artifacts.clear() + + def _record_metrics(self, is_base: bool) -> None: + """ + Base sync sends every position → density 1.0 by definition. + """ + counts = torch.tensor( + [self.density_nnz, self.density_numel, self.wire_bytes], + dtype=torch.int64, + device=torch.cuda.current_device(), + ) + dist.all_reduce(counts) + nnz, numel, wire_bytes = counts.tolist() + self.update_weight_metrics["perf/update_weights_is_base_sync"] = 1.0 if is_base else 0.0 + self.update_weight_metrics["perf/update_weights_density"] = 1.0 if is_base else nnz / max(numel, 1) + self.update_weight_metrics["perf/update_weights_wire_bytes"] = wire_bytes diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_tensor.py b/slime/backends/megatron_utils/update_weight/update_weight_from_tensor.py index dbba6aeb57..986d5e69a6 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_tensor.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_tensor.py @@ -48,6 +48,7 @@ def __init__( self.model_name = model_name self.quantization_config = quantization_config self.weight_version = 0 + self.update_weight_metrics: dict[str, float] = {} self._hf_weight_iterator = HfWeightIteratorBase.create( args=args, model=model, model_name=model_name, quantization_config=quantization_config @@ -134,6 +135,14 @@ def connect_rollout_engines( if start <= dist.get_rank() < end: self._ipc_engine = engine + def pop_metrics(self) -> dict[str, float]: + """ + Return and clear ``update_weight_metrics``. Empty under colocate (no + per-sync metrics today); kept symmetric with UpdateWeightFromDistributed. + """ + out, self.update_weight_metrics = self.update_weight_metrics, {} + return out + @torch.no_grad() def update_weights(self) -> None: """ diff --git a/slime/backends/sglang_utils/sglang_engine.py b/slime/backends/sglang_utils/sglang_engine.py index c28e13d5f9..035ea7152b 100644 --- a/slime/backends/sglang_utils/sglang_engine.py +++ b/slime/backends/sglang_utils/sglang_engine.py @@ -404,7 +404,15 @@ def destroy_weights_update_group(self, group_name): pass def update_weights_from_distributed( - self, names, dtypes, shapes, group_name, flush_cache=False, weight_version: str | None = None + self, + names, + dtypes, + shapes, + group_name, + flush_cache=False, + weight_version: str | None = None, + load_format: str | None = None, + partial=None, ): payload = { "names": names, @@ -415,6 +423,10 @@ def update_weights_from_distributed( } if weight_version is not None: payload["weight_version"] = weight_version + if load_format is not None: + payload["load_format"] = load_format + if partial is not None: + payload["partial"] = dataclasses.asdict(partial) return self._make_request( "update_weights_from_distributed", payload, diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index e8a1730782..ea77c8932f 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -136,6 +136,67 @@ def add_train_arguments(parser): default="raw", help="The method to convert megatron weights to hugging face weights for SGLang.", ) + # Partial weight sync (selective + delta modes). Receiver chunking is mirrored + # automatically to SGLang as --sglang-update-weight-partial-chunk-bytes. + parser.add_argument( + "--update-weight-mode", + choices=["full", "selective", "delta"], + default="full", + help=( + "Weight sync strategy. 'full' broadcasts every parameter on every " + "sync. 'selective' broadcasts the new values at the changed positions " + "only and the receiver overwrites those positions (unchanged positions " + "are signaled by NaN in the wire payload). 'delta' broadcasts " + "(new − snapshot) and the receiver applies it additively. " + "The first sync is always full regardless." + ), + ) + parser.add_argument( + "--update-weight-delta-dtype", + choices=["fp16", "bf16", "fp32"], + default="fp32", + help=( + "Math dtype for the delta subtraction and additive apply (delta mode " + "only; ignored otherwise). Higher than the param dtype makes the " + "apply lossless." + ), + ) + parser.add_argument( + "--update-weight-partial-encoding", + choices=["sparse_indices", "sparse_bitmask", "dense"], + default="sparse_indices", + help=( + "Wire encoding for partial broadcasts (selective + delta modes). " + "'sparse_indices' sends (indices, values) for active entries; " + "'sparse_bitmask' sends a packed bitmask + values; 'dense' sends " + "the full per-param tensor." + ), + ) + parser.add_argument( + "--update-weight-base-sync-interval", + type=int, + default=9999, + help=( + "Run a base sync (a full broadcast that re-establishes the snapshot) " + "every N successful partial syncs (selective + delta modes). The " + "first sync is always a base sync. Both modes are lossless under " + "their default settings (selective by construction, delta with fp32 " + "math), so the default 9999 effectively disables periodic " + "base syncs — receiver state doesn't drift from a base-sync " + "reference no matter how many partial syncs elapse. Set lower " + "(e.g. 30) to verify correctness against periodic full broadcasts, " + "or if your workload has a custom base-sync requirement." + ), + ) + parser.add_argument( + "--update-weight-partial-artifact-dir", + type=str, + default=None, + help=( + "Optional directory for asynchronously saving per-broadcast partial-" + "update artifacts (selective + delta modes). Off by default." + ), + ) parser.add_argument( "--custom-model-provider-path", type=str, @@ -1842,3 +1903,13 @@ def slime_validate_args(args): if args.only_train_params_name_list and args.freeze_params_name_list: raise ValueError("You can only specify ONE of: --only-train-params-name-list, or --freeze-params-name-list.") + + if args.update_weight_mode in ("selective", "delta") and args.colocate: + raise ValueError( + f"--update-weight-mode={args.update_weight_mode} is not supported with " + "--colocate. Colocate transfers weights via CUDA IPC: only a memory " + "handle (~64 B) crosses processes, never bytes. Partial-update modes " + "shrink bytes on the wire, of which there are none here, so the partial-" + "update bookkeeping (snapshot + subtract/mask + sparse encode) is pure " + "overhead." + ) diff --git a/slime/utils/train_metric_utils.py b/slime/utils/train_metric_utils.py index 9bec049d15..8ecefcfbf3 100644 --- a/slime/utils/train_metric_utils.py +++ b/slime/utils/train_metric_utils.py @@ -11,7 +11,11 @@ def log_perf_data_raw( - rollout_id: int, args: Namespace, is_primary_rank: bool, compute_total_fwd_flops: Callable + rollout_id: int, + args: Namespace, + is_primary_rank: bool, + compute_total_fwd_flops: Callable, + extra_metrics: dict | None = None, ) -> None: timer_instance = Timer() log_dict_raw = deepcopy(timer_instance.log_dict()) @@ -21,6 +25,8 @@ def log_perf_data_raw( return log_dict = {f"perf/{key}_time": val for key, val in log_dict_raw.items()} + if extra_metrics: + log_dict.update(extra_metrics) if ("perf/actor_train_time" in log_dict) and (compute_total_fwd_flops is not None): total_fwd_flops = compute_total_fwd_flops(seq_lens=timer_instance.seq_lens)