From 4243c8339a092dfc09b81a9002186d91c84d5d3a Mon Sep 17 00:00:00 2001 From: Achintya P Date: Thu, 2 Jul 2026 01:02:38 -0700 Subject: [PATCH 1/3] [Feature] Add Learner primitive (Learner, LocalLearner) Introduces torchrl.trainers.Learner: a backend-agnostic entry point for taking one optimization step on a tensordict batch with a given LossModule. LocalLearner is the single-process reference implementation. Mirrors the role Collector plays for data collection and LLMWrapperBase plays for generation/scoring, so training placement (local / sharded / remote) becomes a swappable backend behind a fixed contract instead of a hand-rolled loop per recipe. get_weights() returns a TensorDictBase, matching what WeightSyncScheme.send() already accepts, so this composes with the existing weight-sync path unchanged. --- docs/source/reference/trainers.rst | 1 + docs/source/reference/trainers_learners.rst | 30 +++++ test/test_trainer.py | 119 +++++++++++++++++++- torchrl/trainers/__init__.py | 4 + torchrl/trainers/learners/__init__.py | 14 +++ torchrl/trainers/learners/common.py | 93 +++++++++++++++ torchrl/trainers/learners/local.py | 116 +++++++++++++++++++ 7 files changed, 376 insertions(+), 1 deletion(-) create mode 100644 docs/source/reference/trainers_learners.rst create mode 100644 torchrl/trainers/learners/__init__.py create mode 100644 torchrl/trainers/learners/common.py create mode 100644 torchrl/trainers/learners/local.py diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst index 162c933ce20..f401f0a7543 100644 --- a/docs/source/reference/trainers.rst +++ b/docs/source/reference/trainers.rst @@ -49,3 +49,4 @@ Documentation Sections trainers_basics trainers_loggers trainers_hooks + trainers_learners diff --git a/docs/source/reference/trainers_learners.rst b/docs/source/reference/trainers_learners.rst new file mode 100644 index 00000000000..615ce208ba8 --- /dev/null +++ b/docs/source/reference/trainers_learners.rst @@ -0,0 +1,30 @@ +.. currentmodule:: torchrl.trainers + +Learners +======== + +.. _ref_learners: + +A :class:`~torchrl.trainers.Learner` owns a trainable model and exposes a +single, backend-agnostic entry point -- :meth:`~torchrl.trainers.Learner.update` +-- for taking one optimization step on a tensordict batch with a given +:class:`~torchrl.objectives.common.LossModule`. It plays the same role for +training that :class:`~torchrl.collectors.Collector` plays for data collection +and :class:`~torchrl.modules.llm.LLMWrapperBase` plays for generation/scoring: +a fixed contract with interchangeable backends, so algorithm code does not +need to know whether the update runs on one device, under sharded training, or +on a remote training process. + +:class:`~torchrl.trainers.LocalLearner` is the single-process reference +implementation. Its :meth:`~torchrl.trainers.Learner.get_weights` output is +accepted as-is by :class:`~torchrl.weight_update.WeightSyncScheme`, so a +``Learner`` composes with the existing weight-sync path without changes on +either side. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + Learner + LearnerCapabilities + LocalLearner diff --git a/test/test_trainer.py b/test/test_trainer.py index e520c9ee773..ef0993cb152 100644 --- a/test/test_trainer.py +++ b/test/test_trainer.py @@ -29,8 +29,9 @@ TensorDictReplayBuffer, ) from torchrl.envs.libs.gym import _has_gym +from torchrl.objectives.common import LossModule from torchrl.testing import PONG_VERSIONED -from torchrl.trainers import LogValidationReward, Trainer +from torchrl.trainers import Learner, LocalLearner, LogValidationReward, Trainer from torchrl.trainers.algorithms.cql import CQLTrainer from torchrl.trainers.algorithms.ddpg import DDPGTrainer from torchrl.trainers.algorithms.dqn import DQNTrainer @@ -1562,6 +1563,122 @@ def test_train_stops_early(self): assert collector.shutdown_calls == 1 +class _ToyRegressionLoss(LossModule): + """Minimal LossModule: MSE loss plus a non-loss logging metric.""" + + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + def forward(self, batch: TensorDict) -> TensorDict: + pred = self.model(batch["x"]) + loss = (pred - batch["y"]).pow(2).mean() + with torch.no_grad(): + metric = pred.mean() # a non-"loss"-prefixed field, for logging only + return TensorDict({"loss_mse": loss, "metric": metric}) + + +class _NoLossKeyModule(LossModule): + """A LossModule whose output has no 'loss'-prefixed key (misconfigured).""" + + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + def forward(self, batch: TensorDict) -> TensorDict: + pred = self.model(batch["x"]) + return TensorDict({"mse": (pred - batch["y"]).pow(2).mean()}) + + +class TestLocalLearner: + @staticmethod + def _make(clip_grad_norm=None, grad_accum_steps=1, lr=0.1): + torch.manual_seed(0) + model = nn.Linear(4, 1) + optimizer = torch.optim.SGD(model.parameters(), lr=lr) + learner = LocalLearner( + model, + optimizer, + clip_grad_norm=clip_grad_norm, + grad_accum_steps=grad_accum_steps, + ) + loss_module = _ToyRegressionLoss(model) + return learner, loss_module, model + + @staticmethod + def _batch(n=8): + return TensorDict( + {"x": torch.randn(n, 4), "y": torch.randn(n, 1)}, batch_size=[n] + ) + + def test_update_returns_loss_and_metric(self): + learner, loss_module, _ = self._make() + out = learner.update(self._batch(), loss_module) + assert "loss_mse" in out.keys() + assert "metric" in out.keys() + assert out.get("loss_mse").shape == () + + def test_update_steps_the_optimizer(self): + learner, loss_module, model = self._make(lr=0.5) + before = model.weight.clone() + learner.update(self._batch(), loss_module) + assert not torch.equal(before, model.weight) + + def test_loss_decreases_over_steps(self): + learner, loss_module, _ = self._make(lr=0.1) + torch.manual_seed(1) + batch = self._batch(64) + first = learner.update(batch, loss_module).get("loss_mse").item() + for _ in range(20): + last = learner.update(batch, loss_module).get("loss_mse").item() + assert last < first + + def test_clip_grad_norm_writes_grad_norm(self): + learner, loss_module, _ = self._make(clip_grad_norm=0.5) + out = learner.update(self._batch(), loss_module) + assert "grad_norm" in out.keys() + assert out.get("grad_norm") >= 0 + + def test_no_clip_grad_norm_absent(self): + learner, loss_module, _ = self._make(clip_grad_norm=None) + out = learner.update(self._batch(), loss_module) + assert "grad_norm" not in out.keys() + + def test_missing_loss_key_raises(self): + learner, _, model = self._make() + with pytest.raises(ValueError, match="no keys starting with 'loss'"): + learner.update(self._batch(), _NoLossKeyModule(model)) + + def test_grad_accumulation_defers_step(self): + learner, loss_module, model = self._make(grad_accum_steps=2, lr=1.0) + before = model.weight.clone() + learner.update(self._batch(), loss_module) + # optimizer has not stepped yet: weights unchanged after the 1st of 2 accum steps + assert torch.equal(before, model.weight) + learner.update(self._batch(), loss_module) + # after the 2nd accum step, the optimizer has stepped + assert not torch.equal(before, model.weight) + + def test_get_weights_matches_model_params(self): + learner, _, model = self._make() + weights = learner.get_weights() + torch.testing.assert_close(weights["weight"], model.weight) + torch.testing.assert_close(weights["bias"], model.bias) + + def test_invalid_grad_accum_steps_raises(self): + model = nn.Linear(4, 1) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + with pytest.raises(ValueError, match="grad_accum_steps must be"): + LocalLearner(model, optimizer, grad_accum_steps=0) + + def test_base_learner_methods_are_abstract(self): + learner = Learner() + with pytest.raises(NotImplementedError): + learner.update(self._batch(), _ToyRegressionLoss(nn.Linear(4, 1))) + with pytest.raises(NotImplementedError): + learner.get_weights() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/trainers/__init__.py b/torchrl/trainers/__init__.py index 6f4456e23bf..2203bfdb730 100644 --- a/torchrl/trainers/__init__.py +++ b/torchrl/trainers/__init__.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from .learners import Learner, LearnerCapabilities, LocalLearner from .trainers import ( BatchSubSampler, ClearCudaCache, @@ -31,6 +32,9 @@ "CountFramesLog", "DefaultOptimizationStepper", "EarlyStopping", + "Learner", + "LearnerCapabilities", + "LocalLearner", "LogScalar", "LogTiming", "LogValidationReward", diff --git a/torchrl/trainers/learners/__init__.py b/torchrl/trainers/learners/__init__.py new file mode 100644 index 00000000000..c08ab9cb5c4 --- /dev/null +++ b/torchrl/trainers/learners/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from torchrl.trainers.learners.common import Learner, LearnerCapabilities +from torchrl.trainers.learners.local import LocalLearner + +__all__ = [ + "Learner", + "LearnerCapabilities", + "LocalLearner", +] diff --git a/torchrl/trainers/learners/common.py b/torchrl/trainers/learners/common.py new file mode 100644 index 00000000000..457538b2fb5 --- /dev/null +++ b/torchrl/trainers/learners/common.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from dataclasses import dataclass + +from tensordict import TensorDictBase +from torch import nn + +from torchrl.objectives.common import LossModule + + +@dataclass +class LearnerCapabilities: + """Declares what a :class:`Learner` implementation supports. + + Algorithm code is not expected to branch on these (that would defeat the + point of the abstraction); they exist for logging, validation, and for + orchestration code that needs to know, e.g., whether a learner can be + checkpointed independently on every rank. + + Attributes: + sharded (bool): whether the learner's parameters are sharded across + multiple devices/processes (e.g. FSDP2). Defaults to ``False``. + remote (bool): whether :meth:`~Learner.update` dispatches to a + separate process rather than running in-line. Defaults to + ``False``. + """ + + sharded: bool = False + remote: bool = False + + +class Learner(nn.Module): + """Base class for the trainable-policy role. + + A :class:`Learner` owns a trainable model and exposes a single, + backend-agnostic entry point, :meth:`update`, for taking one optimization + step on a :class:`~tensordict.TensorDictBase` batch with a given + :class:`~torchrl.objectives.common.LossModule`. Algorithm code calls + ``learner.update(batch, loss_module)`` without knowing whether the update + runs locally on one device, under sharded (e.g. FSDP2) training, or on a + separate remote training process -- that placement is the ``Learner`` + subclass's responsibility, not the algorithm's. + + This mirrors the role :class:`~torchrl.collectors.Collector` plays for + data collection and :class:`~torchrl.modules.llm.LLMWrapperBase` plays for + generation/scoring: a fixed, TensorDict-native contract with multiple + interchangeable backends. + + .. note:: + ``update`` requires ``loss_module.forward`` to follow the + :class:`~torchrl.objectives.common.LossModule` convention: every + differentiable loss term is returned under a key starting with + ``"loss"`` (these are summed and used for the backward pass); any + other returned entry (e.g. ``accuracy``, a KL for logging) is treated + as a non-differentiable metric and left untouched. + + .. seealso:: :class:`~torchrl.weight_update.WeightSyncScheme` consumes + :meth:`get_weights` to synchronize a learner's parameters to remote + inference workers, so a ``Learner`` composes with the existing + weight-sync machinery without changes on either side. + """ + + capabilities: LearnerCapabilities = LearnerCapabilities() + + def update(self, batch: TensorDictBase, loss_module: LossModule) -> TensorDictBase: + """Take one optimization step on ``batch`` using ``loss_module``. + + Args: + batch (TensorDictBase): a batch, in the format expected by + ``loss_module``. + loss_module (LossModule): computes the loss(es) for ``batch``. + Its output's ``"loss"``-prefixed keys are summed and + backpropagated; other keys are passed through for logging. + + Returns: + TensorDictBase: the tensordict returned by ``loss_module``, + augmented with a ``"grad_norm"`` entry when gradient clipping is + enabled. + """ + raise NotImplementedError + + def get_weights(self) -> TensorDictBase: + """Return the learner's current parameters as a tensordict. + + The returned tensordict is accepted as-is by + :meth:`~torchrl.weight_update.WeightSyncScheme.send`, so it is the + seam between the training role and the weight-sync/inference roles. + """ + raise NotImplementedError diff --git a/torchrl/trainers/learners/local.py b/torchrl/trainers/learners/local.py new file mode 100644 index 00000000000..b8d93c53bc7 --- /dev/null +++ b/torchrl/trainers/learners/local.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import torch +from tensordict import is_tensorclass, TensorDict, TensorDictBase +from torch import nn + +from torchrl.objectives.common import LossModule +from torchrl.trainers.learners.common import Learner, LearnerCapabilities + + +class LocalLearner(Learner): + """A single-process, single-device :class:`~torchrl.trainers.learners.Learner`. + + Wraps a model and an optimizer and performs the update in-process: forward + through ``loss_module``, sum the outputs' ``"loss"``-prefixed entries, + backward, optionally clip gradients, and step the optimizer. This is the + reference implementation the :class:`~torchrl.trainers.learners.Learner` + contract is designed around -- a sharded (e.g. FSDP2-backed) or remote + learner implements the same two methods, :meth:`update` and + :meth:`get_weights`, so that algorithm code written against + ``LocalLearner`` runs unchanged against those backends. + + Args: + model (torch.nn.Module): the trainable module. Also the source for + :meth:`get_weights`. + optimizer (torch.optim.Optimizer): the optimizer stepping ``model``'s + parameters. + + Keyword Args: + clip_grad_norm (float, optional): if set, gradients are clipped to + this max norm via :func:`torch.nn.utils.clip_grad_norm_` before + the optimizer step, and the resulting norm is written to the + output tensordict under ``"grad_norm"``. Defaults to ``None``. + grad_accum_steps (int, optional): number of :meth:`update` calls to + accumulate gradients over before stepping the optimizer and + zeroing gradients. Defaults to ``1`` (step on every call). + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> from torch import nn + >>> from torchrl.trainers.learners import LocalLearner + >>> from torchrl.objectives.common import LossModule + >>> + >>> class ToyLoss(LossModule): + ... def forward(self, batch): + ... pred = self.actor(batch["x"]) + ... return TensorDict({"loss_mse": (pred - batch["y"]).pow(2).mean()}) + >>> + >>> model = nn.Linear(4, 1) + >>> loss_module = ToyLoss() + >>> loss_module.actor = model + >>> learner = LocalLearner(model, torch.optim.Adam(model.parameters())) + >>> batch = TensorDict({"x": torch.randn(8, 4), "y": torch.randn(8, 1)}, [8]) + >>> out = learner.update(batch, loss_module) + >>> out.get("loss_mse").item() > 0 + True + """ + + def __init__( + self, + model: nn.Module, + optimizer: torch.optim.Optimizer, + *, + clip_grad_norm: float | None = None, + grad_accum_steps: int = 1, + ) -> None: + super().__init__() + self.model = model + self.optimizer = optimizer + self.clip_grad_norm = clip_grad_norm + if grad_accum_steps < 1: + raise ValueError(f"grad_accum_steps must be >= 1, got {grad_accum_steps}.") + self.grad_accum_steps = grad_accum_steps + self._accum_step = 0 + self.capabilities = LearnerCapabilities(sharded=False, remote=False) + + def update(self, batch: TensorDictBase, loss_module: LossModule) -> TensorDictBase: + if self._accum_step == 0: + self.optimizer.zero_grad(set_to_none=True) + + loss_td = loss_module(batch) + loss_keys = [k for k in loss_td.keys() if k.startswith("loss")] + if not loss_keys: + raise ValueError( + "loss_module returned no keys starting with 'loss': " + f"{list(loss_td.keys())}. LossModule.forward must return at " + "least one 'loss'-prefixed entry." + ) + total_loss = sum(loss_td.get(k) for k in loss_keys) / self.grad_accum_steps + total_loss.backward() + + self._accum_step += 1 + if self._accum_step < self.grad_accum_steps: + return loss_td + self._accum_step = 0 + + if self.clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.clip_grad_norm + ) + # loss_module may return a strict TensorClass (e.g. RewardModelLossOutput) + # that rejects undeclared keys; convert to a writable TensorDict first. + if is_tensorclass(loss_td): + loss_td = loss_td.to_tensordict() + loss_td.set("grad_norm", grad_norm) + + self.optimizer.step() + return loss_td + + def get_weights(self) -> TensorDictBase: + return TensorDict.from_module(self.model) From d5a0c876a59f9e0f65324c58c1f29e044275942e Mon Sep 17 00:00:00 2001 From: Achintya P Date: Thu, 2 Jul 2026 01:18:12 -0700 Subject: [PATCH 2/3] [Feature] Add FSDP2Learner; make Learner.update backend-agnostic Refactors Learner.update() to live in the base class as concrete, generic step logic (zero_grad -> forward -> sum loss_* keys -> backward -> clip -> step), operating only on self.model/self.optimizer/self.clip_grad_norm/ self.grad_accum_steps. This is what lets FSDP2Learner reuse the exact same training step as LocalLearner: sharded training only changes how the model is constructed (wrapped with fully_shard by the caller) and how get_weights() gathers the result. FSDP2Learner.get_weights() gathers every DTensor leaf via full_tensor() into a plain tensor, so its output is consumable by WeightSyncScheme exactly like LocalLearner's, with no changes on the receiving side. Verified on a single-rank (world_size=1) gloo process group, which exercises the real fully_shard()/DTensor code path without a cluster: forward/backward/clip_grad_norm_/optimizer.step() dispatch correctly through DTensor, TensorDict.from_module()/.apply() handle DTensor leaves transparently, and FSDP2Learner produces bit-exact losses and gathered weights vs LocalLearner given the same seed/data/lr. --- docs/source/reference/trainers_learners.rst | 14 ++- test/test_trainer.py | 115 +++++++++++++++++- torchrl/trainers/__init__.py | 3 +- torchrl/trainers/learners/__init__.py | 2 + torchrl/trainers/learners/common.py | 67 ++++++++-- torchrl/trainers/learners/fsdp2.py | 128 ++++++++++++++++++++ torchrl/trainers/learners/local.py | 49 ++------ 7 files changed, 316 insertions(+), 62 deletions(-) create mode 100644 torchrl/trainers/learners/fsdp2.py diff --git a/docs/source/reference/trainers_learners.rst b/docs/source/reference/trainers_learners.rst index 615ce208ba8..7e97c187245 100644 --- a/docs/source/reference/trainers_learners.rst +++ b/docs/source/reference/trainers_learners.rst @@ -16,10 +16,15 @@ need to know whether the update runs on one device, under sharded training, or on a remote training process. :class:`~torchrl.trainers.LocalLearner` is the single-process reference -implementation. Its :meth:`~torchrl.trainers.Learner.get_weights` output is -accepted as-is by :class:`~torchrl.weight_update.WeightSyncScheme`, so a -``Learner`` composes with the existing weight-sync path without changes on -either side. +implementation. :class:`~torchrl.trainers.FSDP2Learner` shards the same model +with :func:`torch.distributed._composable.fsdp.fully_shard` and reuses +:meth:`~torchrl.trainers.Learner.update` unchanged -- FSDP2's sharding is +transparent to the training step; only construction (the caller wraps the +model before handing it to the learner) and :meth:`~torchrl.trainers.Learner.get_weights` +(which gathers sharded parameters into plain tensors) differ. Either +learner's :meth:`~torchrl.trainers.Learner.get_weights` output is accepted +as-is by :class:`~torchrl.weight_update.WeightSyncScheme`, so a ``Learner`` +composes with the existing weight-sync path without changes on either side. .. autosummary:: :toctree: generated/ @@ -28,3 +33,4 @@ either side. Learner LearnerCapabilities LocalLearner + FSDP2Learner diff --git a/test/test_trainer.py b/test/test_trainer.py index ef0993cb152..4690a625756 100644 --- a/test/test_trainer.py +++ b/test/test_trainer.py @@ -31,7 +31,13 @@ from torchrl.envs.libs.gym import _has_gym from torchrl.objectives.common import LossModule from torchrl.testing import PONG_VERSIONED -from torchrl.trainers import Learner, LocalLearner, LogValidationReward, Trainer +from torchrl.trainers import ( + FSDP2Learner, + Learner, + LocalLearner, + LogValidationReward, + Trainer, +) from torchrl.trainers.algorithms.cql import CQLTrainer from torchrl.trainers.algorithms.ddpg import DDPGTrainer from torchrl.trainers.algorithms.dqn import DQNTrainer @@ -1671,12 +1677,109 @@ def test_invalid_grad_accum_steps_raises(self): with pytest.raises(ValueError, match="grad_accum_steps must be"): LocalLearner(model, optimizer, grad_accum_steps=0) - def test_base_learner_methods_are_abstract(self): - learner = Learner() - with pytest.raises(NotImplementedError): - learner.update(self._batch(), _ToyRegressionLoss(nn.Linear(4, 1))) + def test_base_learner_get_weights_is_abstract(self): + # update() is concrete on Learner (shared by LocalLearner/FSDP2Learner); + # only get_weights() -- the backend-specific gather -- remains abstract. with pytest.raises(NotImplementedError): - learner.get_weights() + Learner().get_weights() + + +@pytest.mark.skipif( + not torch.distributed.is_available(), reason="torch.distributed required" +) +class TestFSDP2Learner: + """FSDP2Learner exercised on a single-rank process group (gloo/CPU). + + world_size=1 does not exercise cross-rank sharding, but it runs the exact + fully_shard()/DTensor code path FSDP2Learner is built on, so these tests + catch real API breakage, not just interface stubs. + """ + + _PORT = "29601" + + @pytest.fixture(autouse=True) + def _process_group(self): + import torch.distributed as dist + + os.environ.setdefault("MASTER_ADDR", "localhost") + os.environ["MASTER_PORT"] = self._PORT + dist.init_process_group(backend="gloo", rank=0, world_size=1) + try: + yield + finally: + dist.destroy_process_group() + + @staticmethod + def _make(clip_grad_norm=None, grad_accum_steps=1, lr=0.1, seed=0): + from torch.distributed._composable.fsdp import fully_shard + from torch.distributed.device_mesh import init_device_mesh + + torch.manual_seed(seed) + model = nn.Linear(4, 1) + mesh = init_device_mesh("cpu", (1,)) + fully_shard(model, mesh=mesh) + optimizer = torch.optim.SGD(model.parameters(), lr=lr) + learner = FSDP2Learner( + model, + optimizer, + clip_grad_norm=clip_grad_norm, + grad_accum_steps=grad_accum_steps, + ) + loss_module = _ToyRegressionLoss(model) + return learner, loss_module, model + + @staticmethod + def _batch(n=8, seed=None): + if seed is not None: + torch.manual_seed(seed) + return TensorDict( + {"x": torch.randn(n, 4), "y": torch.randn(n, 1)}, batch_size=[n] + ) + + def test_update_returns_loss_and_metric(self): + learner, loss_module, _ = self._make() + out = learner.update(self._batch(seed=1), loss_module) + assert "loss_mse" in out.keys() + assert "metric" in out.keys() + + def test_clip_grad_norm_writes_grad_norm(self): + learner, loss_module, _ = self._make(clip_grad_norm=0.5) + out = learner.update(self._batch(seed=1), loss_module) + assert "grad_norm" in out.keys() + assert out.get("grad_norm") >= 0 + + def test_get_weights_returns_plain_tensors(self): + learner, _, _ = self._make() + weights = learner.get_weights() + for leaf in weights.values(True, True): + assert not isinstance(leaf, torch.distributed.tensor.DTensor) + + def test_capabilities_report_sharded(self): + learner, _, _ = self._make() + assert learner.capabilities.sharded + assert not learner.capabilities.remote + + def test_matches_local_learner_bit_exact(self): + """The whole point of the abstraction: same update() logic, same + result, whether the model is sharded or not.""" + batch = self._batch(n=8, seed=42) + + fsdp2_learner, fsdp2_loss, fsdp2_model = self._make( + clip_grad_norm=1.0, lr=0.5, seed=0 + ) + fsdp2_out = fsdp2_learner.update(batch, fsdp2_loss) + + torch.manual_seed(0) + local_model = nn.Linear(4, 1) + local_optimizer = torch.optim.SGD(local_model.parameters(), lr=0.5) + local_learner = LocalLearner(local_model, local_optimizer, clip_grad_norm=1.0) + local_loss = _ToyRegressionLoss(local_model) + local_out = local_learner.update(batch, local_loss) + + assert fsdp2_out.get("loss_mse").item() == local_out.get("loss_mse").item() + gathered = fsdp2_learner.get_weights() + torch.testing.assert_close(gathered["weight"], local_model.weight) + torch.testing.assert_close(gathered["bias"], local_model.bias) if __name__ == "__main__": diff --git a/torchrl/trainers/__init__.py b/torchrl/trainers/__init__.py index 2203bfdb730..2b111d7ff82 100644 --- a/torchrl/trainers/__init__.py +++ b/torchrl/trainers/__init__.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .learners import Learner, LearnerCapabilities, LocalLearner +from .learners import FSDP2Learner, Learner, LearnerCapabilities, LocalLearner from .trainers import ( BatchSubSampler, ClearCudaCache, @@ -32,6 +32,7 @@ "CountFramesLog", "DefaultOptimizationStepper", "EarlyStopping", + "FSDP2Learner", "Learner", "LearnerCapabilities", "LocalLearner", diff --git a/torchrl/trainers/learners/__init__.py b/torchrl/trainers/learners/__init__.py index c08ab9cb5c4..b44117bbccf 100644 --- a/torchrl/trainers/learners/__init__.py +++ b/torchrl/trainers/learners/__init__.py @@ -5,9 +5,11 @@ from __future__ import annotations from torchrl.trainers.learners.common import Learner, LearnerCapabilities +from torchrl.trainers.learners.fsdp2 import FSDP2Learner from torchrl.trainers.learners.local import LocalLearner __all__ = [ + "FSDP2Learner", "Learner", "LearnerCapabilities", "LocalLearner", diff --git a/torchrl/trainers/learners/common.py b/torchrl/trainers/learners/common.py index 457538b2fb5..0cd1b8b062a 100644 --- a/torchrl/trainers/learners/common.py +++ b/torchrl/trainers/learners/common.py @@ -6,7 +6,8 @@ from dataclasses import dataclass -from tensordict import TensorDictBase +import torch +from tensordict import is_tensorclass, TensorDictBase from torch import nn from torchrl.objectives.common import LossModule @@ -36,10 +37,10 @@ class LearnerCapabilities: class Learner(nn.Module): """Base class for the trainable-policy role. - A :class:`Learner` owns a trainable model and exposes a single, - backend-agnostic entry point, :meth:`update`, for taking one optimization - step on a :class:`~tensordict.TensorDictBase` batch with a given - :class:`~torchrl.objectives.common.LossModule`. Algorithm code calls + A :class:`Learner` owns a trainable model and an optimizer and exposes a + single, backend-agnostic entry point, :meth:`update`, for taking one + optimization step on a :class:`~tensordict.TensorDictBase` batch with a + given :class:`~torchrl.objectives.common.LossModule`. Algorithm code calls ``learner.update(batch, loss_module)`` without knowing whether the update runs locally on one device, under sharded (e.g. FSDP2) training, or on a separate remote training process -- that placement is the ``Learner`` @@ -50,6 +51,16 @@ class Learner(nn.Module): generation/scoring: a fixed, TensorDict-native contract with multiple interchangeable backends. + :meth:`update` is implemented once, here, and is intentionally backend + agnostic: it only touches ``self.model``, ``self.optimizer``, + ``self.clip_grad_norm``, and ``self.grad_accum_steps``, all of which a + subclass sets in its constructor. This is what lets + :class:`~torchrl.trainers.learners.FSDP2Learner` reuse the exact same + training step as :class:`~torchrl.trainers.learners.LocalLearner`: sharded + training only changes how the model is constructed (wrapped with + ``fully_shard``) and how :meth:`get_weights` gathers the result, not how a + step is taken. + .. note:: ``update`` requires ``loss_module.forward`` to follow the :class:`~torchrl.objectives.common.LossModule` convention: every @@ -66,6 +77,12 @@ class Learner(nn.Module): capabilities: LearnerCapabilities = LearnerCapabilities() + model: nn.Module + optimizer: torch.optim.Optimizer + clip_grad_norm: float | None + grad_accum_steps: int + _accum_step: int + def update(self, batch: TensorDictBase, loss_module: LossModule) -> TensorDictBase: """Take one optimization step on ``batch`` using ``loss_module``. @@ -81,13 +98,45 @@ def update(self, batch: TensorDictBase, loss_module: LossModule) -> TensorDictBa augmented with a ``"grad_norm"`` entry when gradient clipping is enabled. """ - raise NotImplementedError + if self._accum_step == 0: + self.optimizer.zero_grad(set_to_none=True) + + loss_td = loss_module(batch) + loss_keys = [k for k in loss_td.keys() if k.startswith("loss")] + if not loss_keys: + raise ValueError( + "loss_module returned no keys starting with 'loss': " + f"{list(loss_td.keys())}. LossModule.forward must return at " + "least one 'loss'-prefixed entry." + ) + total_loss = sum(loss_td.get(k) for k in loss_keys) / self.grad_accum_steps + total_loss.backward() + + self._accum_step += 1 + if self._accum_step < self.grad_accum_steps: + return loss_td + self._accum_step = 0 + + if self.clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.clip_grad_norm + ) + # loss_module may return a strict TensorClass (e.g. RewardModelLossOutput) + # that rejects undeclared keys; convert to a writable TensorDict first. + if is_tensorclass(loss_td): + loss_td = loss_td.to_tensordict() + loss_td.set("grad_norm", grad_norm) + + self.optimizer.step() + return loss_td def get_weights(self) -> TensorDictBase: """Return the learner's current parameters as a tensordict. - The returned tensordict is accepted as-is by - :meth:`~torchrl.weight_update.WeightSyncScheme.send`, so it is the - seam between the training role and the weight-sync/inference roles. + The returned tensordict holds plain (fully materialized) tensors, even + when the learner's parameters are internally sharded, so it is + accepted as-is by :meth:`~torchrl.weight_update.WeightSyncScheme.send`. + This is the seam between the training role and the weight-sync / + inference roles. """ raise NotImplementedError diff --git a/torchrl/trainers/learners/fsdp2.py b/torchrl/trainers/learners/fsdp2.py new file mode 100644 index 00000000000..8fc4b6a2a42 --- /dev/null +++ b/torchrl/trainers/learners/fsdp2.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import torch +from tensordict import TensorDict, TensorDictBase +from torch import nn + +from torchrl.trainers.learners.common import Learner, LearnerCapabilities + +try: + from torch.distributed.tensor import DTensor + + _has_dtensor = True +except ImportError: # pragma: no cover - torch without distributed.tensor + _has_dtensor = False + + +class FSDP2Learner(Learner): + """A :class:`~torchrl.trainers.learners.Learner` for FSDP2-sharded models. + + Accepts a model that the caller has already wrapped with + :func:`torch.distributed._composable.fsdp.fully_shard` (typically + per-submodule, then on the root module), and an optimizer built on that + (sharded) model's parameters. ``fully_shard`` must be applied *before* the + optimizer is constructed, so the optimizer holds the sharded + (:class:`~torch.distributed.tensor.DTensor`) parameters rather than the + original ones. + + :meth:`~torchrl.trainers.learners.Learner.update` is inherited unchanged + from :class:`~torchrl.trainers.learners.Learner`: FSDP2's sharding is + transparent to the training step (forward/backward/optimizer-step all + dispatch through ``DTensor`` the same way they would through a regular + tensor). Only two things differ from + :class:`~torchrl.trainers.learners.LocalLearner`: how the model arrives + (already sharded, by the caller) and how :meth:`get_weights` reports it. + + :meth:`get_weights` gathers every sharded leaf into a plain tensor via + ``DTensor.full_tensor()``, so the returned tensordict holds regular + tensors and is consumable by + :meth:`~torchrl.weight_update.WeightSyncScheme.send` exactly like + :class:`~torchrl.trainers.learners.LocalLearner`'s, with no changes needed + on the receiving (inference) side. This gather is the seam between + sharded training and (typically replicated) inference, and is the one + place sharding is not transparent. + + Args: + model (torch.nn.Module): a model already wrapped with ``fully_shard``. + optimizer (torch.optim.Optimizer): an optimizer constructed on + ``model``'s (already-sharded) parameters. + + Keyword Args: + clip_grad_norm (float, optional): as in + :class:`~torchrl.trainers.learners.LocalLearner`. Gradient-norm + clipping dispatches through ``DTensor`` transparently. Defaults to + ``None``. + grad_accum_steps (int, optional): as in + :class:`~torchrl.trainers.learners.LocalLearner`. Defaults to + ``1``. + + .. warning:: + ``FSDP2Learner`` does not itself decide what to shard, at what + granularity, or on what device mesh -- those are model-specific + choices that belong in the caller's model-construction code, exactly + as they would with bare ``fully_shard``. This keeps + ``FSDP2Learner`` a thin adapter rather than a second place those + decisions are made. + + Examples: + >>> import torch + >>> import torch.distributed as dist + >>> from torch import nn + >>> from torch.distributed._composable.fsdp import fully_shard + >>> from torch.distributed.device_mesh import init_device_mesh + >>> from tensordict import TensorDict + >>> from torchrl.objectives.common import LossModule + >>> from torchrl.trainers.learners import FSDP2Learner + >>> + >>> dist.init_process_group(backend="gloo", rank=0, world_size=1) + >>> mesh = init_device_mesh("cpu", (1,)) + >>> model = nn.Sequential(nn.Linear(4, 4), nn.Linear(4, 1)) + >>> for layer in model: + ... fully_shard(layer, mesh=mesh) + >>> fully_shard(model, mesh=mesh) + >>> optimizer = torch.optim.Adam(model.parameters()) + >>> + >>> class ToyLoss(LossModule): + ... def forward(self, batch): + ... pred = self.actor(batch["x"]) + ... return TensorDict({"loss_mse": (pred - batch["y"]).pow(2).mean()}) + >>> loss_module = ToyLoss() + >>> loss_module.actor = model + >>> + >>> learner = FSDP2Learner(model, optimizer) + >>> batch = TensorDict({"x": torch.randn(8, 4), "y": torch.randn(8, 1)}, [8]) + >>> out = learner.update(batch, loss_module) + >>> weights = learner.get_weights() # gathered, plain tensors + >>> dist.destroy_process_group() + """ + + def __init__( + self, + model: nn.Module, + optimizer: torch.optim.Optimizer, + *, + clip_grad_norm: float | None = None, + grad_accum_steps: int = 1, + ) -> None: + if not _has_dtensor: + raise RuntimeError( + "FSDP2Learner requires torch.distributed.tensor (DTensor), " + "which is not available in this torch build." + ) + super().__init__() + self.model = model + self.optimizer = optimizer + self.clip_grad_norm = clip_grad_norm + if grad_accum_steps < 1: + raise ValueError(f"grad_accum_steps must be >= 1, got {grad_accum_steps}.") + self.grad_accum_steps = grad_accum_steps + self._accum_step = 0 + self.capabilities = LearnerCapabilities(sharded=True, remote=False) + + def get_weights(self) -> TensorDictBase: + td = TensorDict.from_module(self.model) + return td.apply(lambda t: t.full_tensor() if isinstance(t, DTensor) else t) diff --git a/torchrl/trainers/learners/local.py b/torchrl/trainers/learners/local.py index b8d93c53bc7..6c4703d606e 100644 --- a/torchrl/trainers/learners/local.py +++ b/torchrl/trainers/learners/local.py @@ -5,24 +5,22 @@ from __future__ import annotations import torch -from tensordict import is_tensorclass, TensorDict, TensorDictBase +from tensordict import TensorDict, TensorDictBase from torch import nn -from torchrl.objectives.common import LossModule from torchrl.trainers.learners.common import Learner, LearnerCapabilities class LocalLearner(Learner): """A single-process, single-device :class:`~torchrl.trainers.learners.Learner`. - Wraps a model and an optimizer and performs the update in-process: forward - through ``loss_module``, sum the outputs' ``"loss"``-prefixed entries, - backward, optionally clip gradients, and step the optimizer. This is the + Holds a model and an optimizer, unwrapped and unsharded. This is the reference implementation the :class:`~torchrl.trainers.learners.Learner` - contract is designed around -- a sharded (e.g. FSDP2-backed) or remote - learner implements the same two methods, :meth:`update` and - :meth:`get_weights`, so that algorithm code written against - ``LocalLearner`` runs unchanged against those backends. + contract is designed around -- + :class:`~torchrl.trainers.learners.FSDP2Learner` shards the same model + with ``fully_shard`` and reuses :meth:`~torchrl.trainers.learners.Learner.update` + unchanged, so algorithm code written against ``LocalLearner`` runs + unmodified under sharded training. Args: model (torch.nn.Module): the trainable module. Also the source for @@ -79,38 +77,5 @@ def __init__( self._accum_step = 0 self.capabilities = LearnerCapabilities(sharded=False, remote=False) - def update(self, batch: TensorDictBase, loss_module: LossModule) -> TensorDictBase: - if self._accum_step == 0: - self.optimizer.zero_grad(set_to_none=True) - - loss_td = loss_module(batch) - loss_keys = [k for k in loss_td.keys() if k.startswith("loss")] - if not loss_keys: - raise ValueError( - "loss_module returned no keys starting with 'loss': " - f"{list(loss_td.keys())}. LossModule.forward must return at " - "least one 'loss'-prefixed entry." - ) - total_loss = sum(loss_td.get(k) for k in loss_keys) / self.grad_accum_steps - total_loss.backward() - - self._accum_step += 1 - if self._accum_step < self.grad_accum_steps: - return loss_td - self._accum_step = 0 - - if self.clip_grad_norm is not None: - grad_norm = torch.nn.utils.clip_grad_norm_( - self.model.parameters(), self.clip_grad_norm - ) - # loss_module may return a strict TensorClass (e.g. RewardModelLossOutput) - # that rejects undeclared keys; convert to a writable TensorDict first. - if is_tensorclass(loss_td): - loss_td = loss_td.to_tensordict() - loss_td.set("grad_norm", grad_norm) - - self.optimizer.step() - return loss_td - def get_weights(self) -> TensorDictBase: return TensorDict.from_module(self.model) From 070409dcbd6595a81526a74b9aefe5b2b98b7bfe Mon Sep 17 00:00:00 2001 From: Achintya P Date: Thu, 2 Jul 2026 01:49:48 -0700 Subject: [PATCH 3/3] [BugFix] Fix FSDP2Learner: rank0-only gather, real checkpointing, grad-sync Fixes three real gaps in FSDP2Learner identified after the initial PR: 1. get_weights() previously gathered every DTensor leaf to EVERY rank via full_tensor(), replicating the whole model in every rank's memory for no benefit -- does not scale to large sharded models. Now uses torch.distributed.checkpoint.state_dict.get_model_state_dict with StateDictOptions(full_state_dict=True, cpu_offload=True), which gathers to rank 0 only (other ranks get an empty tensordict) by documented DCP semantics. Verified: correct nested key shape via unflatten_keys('.'), matches the prior full_tensor()-based output. 2. No sharded-checkpoint path existed. Learner (base) gains real state_dict()/load_state_dict() covering model + optimizer state (a bare Optimizer is not an nn.Module, so plain nn.Module.state_dict() silently drops it -- a real, previously-latent bug for LocalLearner too). Both overrides clone their tensors: nn.Module.state_dict() and Optimizer.state_dict() return views onto live tensors, not copies, so without cloning, further training after checkpointing would silently corrupt the saved checkpoint (caught by round-trip tests). FSDP2Learner overrides these two methods again with get_state_dict/set_state_dict (DCP-aware, handles DTensor optimizer state, rank0-only via cpu_offload). Verified end to end: model weights AND Adam/SGD-momentum optimizer state round-trip correctly through save/perturb/load. 3. grad_accum_steps>1 was untested on FSDP2Learner, and update() did no FSDP2-specific optimization during accumulation. Learner.update() now toggles model.set_requires_gradient_sync(...) when the model exposes it (FSDP2-wrapped models do; LocalLearner's plain model doesn't, so this is a no-op there), deferring the cross-rank gradient reduction until the last micro-batch of an accumulation window instead of reducing on every micro-batch. Verified against a non-sharded reference: the accumulated, synced gradient after 2 microbatches matches a plain model accumulating the same 2 microbatches exactly. Still unverified (unchanged from the original PR): all of the above is tested at world_size=1 (single-rank gloo), which exercises the real fully_shard()/DTensor/DCP code paths but not actual cross-rank communication or memory behavior. --- test/test_trainer.py | 111 ++++++++++++++++++++++++++++ torchrl/trainers/learners/common.py | 69 +++++++++++++++++ torchrl/trainers/learners/fsdp2.py | 78 ++++++++++++++++--- 3 files changed, 246 insertions(+), 12 deletions(-) diff --git a/test/test_trainer.py b/test/test_trainer.py index 4690a625756..cdcc8808423 100644 --- a/test/test_trainer.py +++ b/test/test_trainer.py @@ -1683,6 +1683,37 @@ def test_base_learner_get_weights_is_abstract(self): with pytest.raises(NotImplementedError): Learner().get_weights() + def test_state_dict_round_trip_includes_optimizer(self): + # plain nn.Module.state_dict() would silently drop the optimizer's + # state (Optimizer is not an nn.Module) -- this is the regression test + # for that gap. Momentum must be nonzero, or SGD never populates a + # momentum_buffer in its state to begin with. + torch.manual_seed(0) + model = nn.Linear(4, 1) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + learner = LocalLearner(model, optimizer) + loss_module = _ToyRegressionLoss(model) + torch.manual_seed(2) + for _ in range(3): + learner.update(self._batch(), loss_module) + checkpoint = learner.state_dict() + saved_weight = model.weight.clone() + saved_momentum = next(iter(learner.optimizer.state.values()))[ + "momentum_buffer" + ].clone() + + with torch.no_grad(): + model.weight.zero_() + for state in learner.optimizer.state.values(): + state["momentum_buffer"].zero_() + + learner.load_state_dict(checkpoint) + torch.testing.assert_close(model.weight, saved_weight) + restored_momentum = next(iter(learner.optimizer.state.values()))[ + "momentum_buffer" + ] + torch.testing.assert_close(restored_momentum, saved_momentum) + @pytest.mark.skipif( not torch.distributed.is_available(), reason="torch.distributed required" @@ -1781,6 +1812,86 @@ def test_matches_local_learner_bit_exact(self): torch.testing.assert_close(gathered["weight"], local_model.weight) torch.testing.assert_close(gathered["bias"], local_model.bias) + def test_grad_accumulation_defers_step(self): + learner, loss_module, model = self._make(grad_accum_steps=2, lr=1.0) + before = model.weight.full_tensor().clone() + learner.update(self._batch(seed=1), loss_module) + assert torch.equal(before, model.weight.full_tensor()) + learner.update(self._batch(seed=2), loss_module) + assert not torch.equal(before, model.weight.full_tensor()) + + def test_grad_accumulation_matches_non_sharded_reference(self): + """Regression test for the set_requires_gradient_sync(False) toggle + update() uses to skip communication on non-final accumulation steps: + the accumulated gradient must still equal a plain (non-sharded) + reference that accumulates the same two micro-batches.""" + from torch.distributed._composable.fsdp import fully_shard + from torch.distributed.device_mesh import init_device_mesh + + batch1, batch2 = self._batch(seed=10), self._batch(seed=11) + + # update() divides each micro-batch's loss by grad_accum_steps before + # backward (accumulation averages, rather than sums, the microbatches), + # so the reference must apply the same scaling. + torch.manual_seed(0) + ref_model = nn.Linear(4, 1) + ref_loss = _ToyRegressionLoss(ref_model) + ref_model.zero_grad() + (ref_loss(batch1).get("loss_mse") / 2).backward() + (ref_loss(batch2).get("loss_mse") / 2).backward() + ref_grad = ref_model.weight.grad.clone() + + torch.manual_seed(0) + model = nn.Linear(4, 1) + mesh = init_device_mesh("cpu", (1,)) + fully_shard(model, mesh=mesh) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + learner = FSDP2Learner(model, optimizer, grad_accum_steps=2) + loss_module = _ToyRegressionLoss(model) + learner.update(batch1, loss_module) # sync disabled; .grad stays None + learner.update(batch2, loss_module) # sync re-enabled; optimizer.step() + # optimizer.step() does not clear .grad (only zero_grad() does), so the + # fully-accumulated, synced gradient is still readable here. + accumulated_grad = model.weight.grad.full_tensor() + torch.testing.assert_close(accumulated_grad, ref_grad) + + def test_state_dict_round_trip_includes_optimizer(self): + from torch.distributed._composable.fsdp import fully_shard + from torch.distributed.device_mesh import init_device_mesh + + # momentum must be nonzero, or SGD never populates optimizer state to + # begin with, and the DCP loader has nothing meaningful to restore. + torch.manual_seed(0) + model = nn.Linear(4, 1) + mesh = init_device_mesh("cpu", (1,)) + fully_shard(model, mesh=mesh) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + learner = FSDP2Learner(model, optimizer) + loss_module = _ToyRegressionLoss(model) + + torch.manual_seed(3) + for _ in range(3): + learner.update(self._batch(), loss_module) + checkpoint = learner.state_dict() + saved_weight = model.weight.full_tensor().clone() + saved_momentum = ( + next(iter(optimizer.state.values()))["momentum_buffer"] + .full_tensor() + .clone() + ) + + with torch.no_grad(): + model.weight.to_local().zero_() + for state in optimizer.state.values(): + state["momentum_buffer"].to_local().zero_() + + learner.load_state_dict(checkpoint) + torch.testing.assert_close(model.weight.full_tensor(), saved_weight) + restored_momentum = next(iter(optimizer.state.values()))[ + "momentum_buffer" + ].full_tensor() + torch.testing.assert_close(restored_momentum, saved_momentum) + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/trainers/learners/common.py b/torchrl/trainers/learners/common.py index 0cd1b8b062a..0bdb6a03536 100644 --- a/torchrl/trainers/learners/common.py +++ b/torchrl/trainers/learners/common.py @@ -13,6 +13,25 @@ from torchrl.objectives.common import LossModule +def _clone_tensors(obj): + """Recursively clone every tensor in a (possibly nested) state-dict-like structure. + + ``nn.Module.state_dict()`` and ``Optimizer.state_dict()`` both return + views onto their live tensors, not independent copies, so holding onto + their output as a "checkpoint" while training continues silently mutates + that checkpoint too. + """ + if isinstance(obj, torch.Tensor): + return obj.clone() + if isinstance(obj, dict): + return {k: _clone_tensors(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_clone_tensors(v) for v in obj] + if isinstance(obj, tuple): + return tuple(_clone_tensors(v) for v in obj) + return obj + + @dataclass class LearnerCapabilities: """Declares what a :class:`Learner` implementation supports. @@ -61,6 +80,16 @@ class Learner(nn.Module): ``fully_shard``) and how :meth:`get_weights` gathers the result, not how a step is taken. + During gradient accumulation, if ``self.model`` exposes + ``set_requires_gradient_sync`` (as an FSDP2 ``fully_shard``-wrapped module + does), :meth:`update` disables cross-rank gradient synchronization on + every accumulation step except the last: gradients still accumulate + correctly (verified against a non-sharded reference), but the + communication (reduce-scatter) only happens once per accumulation window + instead of once per micro-batch. This is a no-op for + :class:`~torchrl.trainers.learners.LocalLearner`, whose model has no such + method. + .. note:: ``update`` requires ``loss_module.forward`` to follow the :class:`~torchrl.objectives.common.LossModule` convention: every @@ -73,6 +102,17 @@ class Learner(nn.Module): :meth:`get_weights` to synchronize a learner's parameters to remote inference workers, so a ``Learner`` composes with the existing weight-sync machinery without changes on either side. + + .. warning:: + :meth:`state_dict` / :meth:`load_state_dict` are overridden (relative + to plain :class:`torch.nn.Module`) to checkpoint the optimizer state + alongside the model: a bare :class:`~torch.optim.Optimizer` is not an + ``nn.Module``, so the default ``nn.Module.state_dict()`` would + silently omit it, corrupting a training resume (Adam's moments, etc., + would reset). Subclasses whose model/optimizer need + sharding-aware (DTensor) checkpointing, e.g. + :class:`~torchrl.trainers.learners.FSDP2Learner`, override these two + methods again with a DTensor-aware implementation. """ capabilities: LearnerCapabilities = LearnerCapabilities() @@ -101,6 +141,13 @@ def update(self, batch: TensorDictBase, loss_module: LossModule) -> TensorDictBa if self._accum_step == 0: self.optimizer.zero_grad(set_to_none=True) + # Sharded (e.g. FSDP2) models can defer the cross-rank gradient + # reduction until the last micro-batch of an accumulation window; + # LocalLearner's plain model has no such method, so this is a no-op. + set_grad_sync = getattr(self.model, "set_requires_gradient_sync", None) + if set_grad_sync is not None: + set_grad_sync(self._accum_step == self.grad_accum_steps - 1) + loss_td = loss_module(batch) loss_keys = [k for k in loss_td.keys() if k.startswith("loss")] if not loss_keys: @@ -140,3 +187,25 @@ def get_weights(self) -> TensorDictBase: inference roles. """ raise NotImplementedError + + def state_dict(self, *args, **kwargs) -> dict: # noqa: D417 + """Return a checkpoint covering the model, optimizer, and accumulation state. + + See the class-level warning: this intentionally does not reuse plain + :meth:`torch.nn.Module.state_dict`, which would silently drop the + optimizer's state. The returned tensors are independent clones (both + ``nn.Module.state_dict()`` and ``Optimizer.state_dict()`` otherwise + return views onto the live parameters/state, so an in-place update + after checkpointing would silently corrupt the "saved" checkpoint too). + """ + return { + "model": _clone_tensors(self.model.state_dict()), + "optimizer": _clone_tensors(self.optimizer.state_dict()), + "accum_step": self._accum_step, + } + + def load_state_dict(self, state_dict: dict, *args, **kwargs) -> None: # noqa: D417 + """Restore a checkpoint produced by :meth:`state_dict`.""" + self.model.load_state_dict(state_dict["model"]) + self.optimizer.load_state_dict(state_dict["optimizer"]) + self._accum_step = state_dict.get("accum_step", 0) diff --git a/torchrl/trainers/learners/fsdp2.py b/torchrl/trainers/learners/fsdp2.py index 8fc4b6a2a42..a6101060f53 100644 --- a/torchrl/trainers/learners/fsdp2.py +++ b/torchrl/trainers/learners/fsdp2.py @@ -11,11 +11,16 @@ from torchrl.trainers.learners.common import Learner, LearnerCapabilities try: - from torch.distributed.tensor import DTensor + from torch.distributed.checkpoint.state_dict import ( + get_model_state_dict, + get_state_dict, + set_state_dict, + StateDictOptions, + ) - _has_dtensor = True -except ImportError: # pragma: no cover - torch without distributed.tensor - _has_dtensor = False + _has_dist_checkpoint = True +except ImportError: # pragma: no cover - torch without distributed.checkpoint + _has_dist_checkpoint = False class FSDP2Learner(Learner): @@ -38,13 +43,17 @@ class FSDP2Learner(Learner): (already sharded, by the caller) and how :meth:`get_weights` reports it. :meth:`get_weights` gathers every sharded leaf into a plain tensor via - ``DTensor.full_tensor()``, so the returned tensordict holds regular - tensors and is consumable by + :func:`torch.distributed.checkpoint.state_dict.get_model_state_dict`, so + the returned tensordict holds regular tensors and is consumable by :meth:`~torchrl.weight_update.WeightSyncScheme.send` exactly like :class:`~torchrl.trainers.learners.LocalLearner`'s, with no changes needed on the receiving (inference) side. This gather is the seam between sharded training and (typically replicated) inference, and is the one - place sharding is not transparent. + place sharding is not transparent. By default the gather targets rank 0 + only (other ranks receive an empty tensordict): gathering the full model + to *every* rank, as a naive per-leaf ``DTensor.full_tensor()`` would, + replicates the whole model in every rank's memory for no benefit, which + does not scale to large sharded models. Args: model (torch.nn.Module): a model already wrapped with ``fully_shard``. @@ -108,9 +117,9 @@ def __init__( clip_grad_norm: float | None = None, grad_accum_steps: int = 1, ) -> None: - if not _has_dtensor: + if not _has_dist_checkpoint: raise RuntimeError( - "FSDP2Learner requires torch.distributed.tensor (DTensor), " + "FSDP2Learner requires torch.distributed.checkpoint.state_dict, " "which is not available in this torch build." ) super().__init__() @@ -123,6 +132,51 @@ def __init__( self._accum_step = 0 self.capabilities = LearnerCapabilities(sharded=True, remote=False) - def get_weights(self) -> TensorDictBase: - td = TensorDict.from_module(self.model) - return td.apply(lambda t: t.full_tensor() if isinstance(t, DTensor) else t) + def get_weights(self, *, cpu_offload: bool = True) -> TensorDictBase: + """Gather the sharded model into a plain-tensor tensordict. + + Keyword Args: + cpu_offload (bool, optional): if ``True`` (the default), the + gathered weights are returned only on rank 0 (other ranks get + an empty tensordict) and moved to CPU, avoiding an all-rank + replication of the full model. Set to ``False`` to instead + gather full GPU-resident copies onto every rank. + """ + options = StateDictOptions(full_state_dict=True, cpu_offload=cpu_offload) + state_dict = get_model_state_dict(self.model, options=options) + return TensorDict(state_dict).unflatten_keys(".") + + def state_dict(self, *args, **kwargs) -> dict: + """DTensor-aware checkpoint: gathers model + optimizer state to rank 0. + + Overrides :meth:`~torchrl.trainers.learners.Learner.state_dict`, which + uses plain (non-distributed) ``state_dict()`` calls that would return + raw, per-rank ``DTensor`` shards rather than a portable checkpoint. + """ + options = StateDictOptions(full_state_dict=True, cpu_offload=True) + model_state_dict, optim_state_dict = get_state_dict( + self.model, self.optimizer, options=options + ) + return { + "model": model_state_dict, + "optimizer": optim_state_dict, + "accum_step": self._accum_step, + } + + def load_state_dict(self, state_dict: dict, *args, **kwargs) -> None: + """Restore a checkpoint produced by :meth:`state_dict`. + + ``broadcast_from_rank0=True`` lets rank 0 hold the full checkpoint + (as produced by :meth:`state_dict`) while every rank reshards it + according to its local shards -- the counterpart of the + ``cpu_offload``-gathered save. + """ + options = StateDictOptions(full_state_dict=True, broadcast_from_rank0=True) + set_state_dict( + self.model, + self.optimizer, + model_state_dict=state_dict["model"], + optim_state_dict=state_dict["optimizer"], + options=options, + ) + self._accum_step = state_dict.get("accum_step", 0)