diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst index 162c933ce20..f401f0a7543 100644 --- a/docs/source/reference/trainers.rst +++ b/docs/source/reference/trainers.rst @@ -49,3 +49,4 @@ Documentation Sections trainers_basics trainers_loggers trainers_hooks + trainers_learners diff --git a/docs/source/reference/trainers_learners.rst b/docs/source/reference/trainers_learners.rst new file mode 100644 index 00000000000..7e97c187245 --- /dev/null +++ b/docs/source/reference/trainers_learners.rst @@ -0,0 +1,36 @@ +.. currentmodule:: torchrl.trainers + +Learners +======== + +.. _ref_learners: + +A :class:`~torchrl.trainers.Learner` owns a trainable model and exposes a +single, backend-agnostic entry point -- :meth:`~torchrl.trainers.Learner.update` +-- for taking one optimization step on a tensordict batch with a given +:class:`~torchrl.objectives.common.LossModule`. It plays the same role for +training that :class:`~torchrl.collectors.Collector` plays for data collection +and :class:`~torchrl.modules.llm.LLMWrapperBase` plays for generation/scoring: +a fixed contract with interchangeable backends, so algorithm code does not +need to know whether the update runs on one device, under sharded training, or +on a remote training process. + +:class:`~torchrl.trainers.LocalLearner` is the single-process reference +implementation. :class:`~torchrl.trainers.FSDP2Learner` shards the same model +with :func:`torch.distributed._composable.fsdp.fully_shard` and reuses +:meth:`~torchrl.trainers.Learner.update` unchanged -- FSDP2's sharding is +transparent to the training step; only construction (the caller wraps the +model before handing it to the learner) and :meth:`~torchrl.trainers.Learner.get_weights` +(which gathers sharded parameters into plain tensors) differ. Either +learner's :meth:`~torchrl.trainers.Learner.get_weights` output is accepted +as-is by :class:`~torchrl.weight_update.WeightSyncScheme`, so a ``Learner`` +composes with the existing weight-sync path without changes on either side. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + Learner + LearnerCapabilities + LocalLearner + FSDP2Learner diff --git a/test/test_trainer.py b/test/test_trainer.py index e520c9ee773..cdcc8808423 100644 --- a/test/test_trainer.py +++ b/test/test_trainer.py @@ -29,8 +29,15 @@ TensorDictReplayBuffer, ) from torchrl.envs.libs.gym import _has_gym +from torchrl.objectives.common import LossModule from torchrl.testing import PONG_VERSIONED -from torchrl.trainers import LogValidationReward, Trainer +from torchrl.trainers import ( + FSDP2Learner, + Learner, + LocalLearner, + LogValidationReward, + Trainer, +) from torchrl.trainers.algorithms.cql import CQLTrainer from torchrl.trainers.algorithms.ddpg import DDPGTrainer from torchrl.trainers.algorithms.dqn import DQNTrainer @@ -1562,6 +1569,330 @@ def test_train_stops_early(self): assert collector.shutdown_calls == 1 +class _ToyRegressionLoss(LossModule): + """Minimal LossModule: MSE loss plus a non-loss logging metric.""" + + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + def forward(self, batch: TensorDict) -> TensorDict: + pred = self.model(batch["x"]) + loss = (pred - batch["y"]).pow(2).mean() + with torch.no_grad(): + metric = pred.mean() # a non-"loss"-prefixed field, for logging only + return TensorDict({"loss_mse": loss, "metric": metric}) + + +class _NoLossKeyModule(LossModule): + """A LossModule whose output has no 'loss'-prefixed key (misconfigured).""" + + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + def forward(self, batch: TensorDict) -> TensorDict: + pred = self.model(batch["x"]) + return TensorDict({"mse": (pred - batch["y"]).pow(2).mean()}) + + +class TestLocalLearner: + @staticmethod + def _make(clip_grad_norm=None, grad_accum_steps=1, lr=0.1): + torch.manual_seed(0) + model = nn.Linear(4, 1) + optimizer = torch.optim.SGD(model.parameters(), lr=lr) + learner = LocalLearner( + model, + optimizer, + clip_grad_norm=clip_grad_norm, + grad_accum_steps=grad_accum_steps, + ) + loss_module = _ToyRegressionLoss(model) + return learner, loss_module, model + + @staticmethod + def _batch(n=8): + return TensorDict( + {"x": torch.randn(n, 4), "y": torch.randn(n, 1)}, batch_size=[n] + ) + + def test_update_returns_loss_and_metric(self): + learner, loss_module, _ = self._make() + out = learner.update(self._batch(), loss_module) + assert "loss_mse" in out.keys() + assert "metric" in out.keys() + assert out.get("loss_mse").shape == () + + def test_update_steps_the_optimizer(self): + learner, loss_module, model = self._make(lr=0.5) + before = model.weight.clone() + learner.update(self._batch(), loss_module) + assert not torch.equal(before, model.weight) + + def test_loss_decreases_over_steps(self): + learner, loss_module, _ = self._make(lr=0.1) + torch.manual_seed(1) + batch = self._batch(64) + first = learner.update(batch, loss_module).get("loss_mse").item() + for _ in range(20): + last = learner.update(batch, loss_module).get("loss_mse").item() + assert last < first + + def test_clip_grad_norm_writes_grad_norm(self): + learner, loss_module, _ = self._make(clip_grad_norm=0.5) + out = learner.update(self._batch(), loss_module) + assert "grad_norm" in out.keys() + assert out.get("grad_norm") >= 0 + + def test_no_clip_grad_norm_absent(self): + learner, loss_module, _ = self._make(clip_grad_norm=None) + out = learner.update(self._batch(), loss_module) + assert "grad_norm" not in out.keys() + + def test_missing_loss_key_raises(self): + learner, _, model = self._make() + with pytest.raises(ValueError, match="no keys starting with 'loss'"): + learner.update(self._batch(), _NoLossKeyModule(model)) + + def test_grad_accumulation_defers_step(self): + learner, loss_module, model = self._make(grad_accum_steps=2, lr=1.0) + before = model.weight.clone() + learner.update(self._batch(), loss_module) + # optimizer has not stepped yet: weights unchanged after the 1st of 2 accum steps + assert torch.equal(before, model.weight) + learner.update(self._batch(), loss_module) + # after the 2nd accum step, the optimizer has stepped + assert not torch.equal(before, model.weight) + + def test_get_weights_matches_model_params(self): + learner, _, model = self._make() + weights = learner.get_weights() + torch.testing.assert_close(weights["weight"], model.weight) + torch.testing.assert_close(weights["bias"], model.bias) + + def test_invalid_grad_accum_steps_raises(self): + model = nn.Linear(4, 1) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + with pytest.raises(ValueError, match="grad_accum_steps must be"): + LocalLearner(model, optimizer, grad_accum_steps=0) + + def test_base_learner_get_weights_is_abstract(self): + # update() is concrete on Learner (shared by LocalLearner/FSDP2Learner); + # only get_weights() -- the backend-specific gather -- remains abstract. + with pytest.raises(NotImplementedError): + Learner().get_weights() + + def test_state_dict_round_trip_includes_optimizer(self): + # plain nn.Module.state_dict() would silently drop the optimizer's + # state (Optimizer is not an nn.Module) -- this is the regression test + # for that gap. Momentum must be nonzero, or SGD never populates a + # momentum_buffer in its state to begin with. + torch.manual_seed(0) + model = nn.Linear(4, 1) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + learner = LocalLearner(model, optimizer) + loss_module = _ToyRegressionLoss(model) + torch.manual_seed(2) + for _ in range(3): + learner.update(self._batch(), loss_module) + checkpoint = learner.state_dict() + saved_weight = model.weight.clone() + saved_momentum = next(iter(learner.optimizer.state.values()))[ + "momentum_buffer" + ].clone() + + with torch.no_grad(): + model.weight.zero_() + for state in learner.optimizer.state.values(): + state["momentum_buffer"].zero_() + + learner.load_state_dict(checkpoint) + torch.testing.assert_close(model.weight, saved_weight) + restored_momentum = next(iter(learner.optimizer.state.values()))[ + "momentum_buffer" + ] + torch.testing.assert_close(restored_momentum, saved_momentum) + + +@pytest.mark.skipif( + not torch.distributed.is_available(), reason="torch.distributed required" +) +class TestFSDP2Learner: + """FSDP2Learner exercised on a single-rank process group (gloo/CPU). + + world_size=1 does not exercise cross-rank sharding, but it runs the exact + fully_shard()/DTensor code path FSDP2Learner is built on, so these tests + catch real API breakage, not just interface stubs. + """ + + _PORT = "29601" + + @pytest.fixture(autouse=True) + def _process_group(self): + import torch.distributed as dist + + os.environ.setdefault("MASTER_ADDR", "localhost") + os.environ["MASTER_PORT"] = self._PORT + dist.init_process_group(backend="gloo", rank=0, world_size=1) + try: + yield + finally: + dist.destroy_process_group() + + @staticmethod + def _make(clip_grad_norm=None, grad_accum_steps=1, lr=0.1, seed=0): + from torch.distributed._composable.fsdp import fully_shard + from torch.distributed.device_mesh import init_device_mesh + + torch.manual_seed(seed) + model = nn.Linear(4, 1) + mesh = init_device_mesh("cpu", (1,)) + fully_shard(model, mesh=mesh) + optimizer = torch.optim.SGD(model.parameters(), lr=lr) + learner = FSDP2Learner( + model, + optimizer, + clip_grad_norm=clip_grad_norm, + grad_accum_steps=grad_accum_steps, + ) + loss_module = _ToyRegressionLoss(model) + return learner, loss_module, model + + @staticmethod + def _batch(n=8, seed=None): + if seed is not None: + torch.manual_seed(seed) + return TensorDict( + {"x": torch.randn(n, 4), "y": torch.randn(n, 1)}, batch_size=[n] + ) + + def test_update_returns_loss_and_metric(self): + learner, loss_module, _ = self._make() + out = learner.update(self._batch(seed=1), loss_module) + assert "loss_mse" in out.keys() + assert "metric" in out.keys() + + def test_clip_grad_norm_writes_grad_norm(self): + learner, loss_module, _ = self._make(clip_grad_norm=0.5) + out = learner.update(self._batch(seed=1), loss_module) + assert "grad_norm" in out.keys() + assert out.get("grad_norm") >= 0 + + def test_get_weights_returns_plain_tensors(self): + learner, _, _ = self._make() + weights = learner.get_weights() + for leaf in weights.values(True, True): + assert not isinstance(leaf, torch.distributed.tensor.DTensor) + + def test_capabilities_report_sharded(self): + learner, _, _ = self._make() + assert learner.capabilities.sharded + assert not learner.capabilities.remote + + def test_matches_local_learner_bit_exact(self): + """The whole point of the abstraction: same update() logic, same + result, whether the model is sharded or not.""" + batch = self._batch(n=8, seed=42) + + fsdp2_learner, fsdp2_loss, fsdp2_model = self._make( + clip_grad_norm=1.0, lr=0.5, seed=0 + ) + fsdp2_out = fsdp2_learner.update(batch, fsdp2_loss) + + torch.manual_seed(0) + local_model = nn.Linear(4, 1) + local_optimizer = torch.optim.SGD(local_model.parameters(), lr=0.5) + local_learner = LocalLearner(local_model, local_optimizer, clip_grad_norm=1.0) + local_loss = _ToyRegressionLoss(local_model) + local_out = local_learner.update(batch, local_loss) + + assert fsdp2_out.get("loss_mse").item() == local_out.get("loss_mse").item() + gathered = fsdp2_learner.get_weights() + torch.testing.assert_close(gathered["weight"], local_model.weight) + torch.testing.assert_close(gathered["bias"], local_model.bias) + + def test_grad_accumulation_defers_step(self): + learner, loss_module, model = self._make(grad_accum_steps=2, lr=1.0) + before = model.weight.full_tensor().clone() + learner.update(self._batch(seed=1), loss_module) + assert torch.equal(before, model.weight.full_tensor()) + learner.update(self._batch(seed=2), loss_module) + assert not torch.equal(before, model.weight.full_tensor()) + + def test_grad_accumulation_matches_non_sharded_reference(self): + """Regression test for the set_requires_gradient_sync(False) toggle + update() uses to skip communication on non-final accumulation steps: + the accumulated gradient must still equal a plain (non-sharded) + reference that accumulates the same two micro-batches.""" + from torch.distributed._composable.fsdp import fully_shard + from torch.distributed.device_mesh import init_device_mesh + + batch1, batch2 = self._batch(seed=10), self._batch(seed=11) + + # update() divides each micro-batch's loss by grad_accum_steps before + # backward (accumulation averages, rather than sums, the microbatches), + # so the reference must apply the same scaling. + torch.manual_seed(0) + ref_model = nn.Linear(4, 1) + ref_loss = _ToyRegressionLoss(ref_model) + ref_model.zero_grad() + (ref_loss(batch1).get("loss_mse") / 2).backward() + (ref_loss(batch2).get("loss_mse") / 2).backward() + ref_grad = ref_model.weight.grad.clone() + + torch.manual_seed(0) + model = nn.Linear(4, 1) + mesh = init_device_mesh("cpu", (1,)) + fully_shard(model, mesh=mesh) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + learner = FSDP2Learner(model, optimizer, grad_accum_steps=2) + loss_module = _ToyRegressionLoss(model) + learner.update(batch1, loss_module) # sync disabled; .grad stays None + learner.update(batch2, loss_module) # sync re-enabled; optimizer.step() + # optimizer.step() does not clear .grad (only zero_grad() does), so the + # fully-accumulated, synced gradient is still readable here. + accumulated_grad = model.weight.grad.full_tensor() + torch.testing.assert_close(accumulated_grad, ref_grad) + + def test_state_dict_round_trip_includes_optimizer(self): + from torch.distributed._composable.fsdp import fully_shard + from torch.distributed.device_mesh import init_device_mesh + + # momentum must be nonzero, or SGD never populates optimizer state to + # begin with, and the DCP loader has nothing meaningful to restore. + torch.manual_seed(0) + model = nn.Linear(4, 1) + mesh = init_device_mesh("cpu", (1,)) + fully_shard(model, mesh=mesh) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + learner = FSDP2Learner(model, optimizer) + loss_module = _ToyRegressionLoss(model) + + torch.manual_seed(3) + for _ in range(3): + learner.update(self._batch(), loss_module) + checkpoint = learner.state_dict() + saved_weight = model.weight.full_tensor().clone() + saved_momentum = ( + next(iter(optimizer.state.values()))["momentum_buffer"] + .full_tensor() + .clone() + ) + + with torch.no_grad(): + model.weight.to_local().zero_() + for state in optimizer.state.values(): + state["momentum_buffer"].to_local().zero_() + + learner.load_state_dict(checkpoint) + torch.testing.assert_close(model.weight.full_tensor(), saved_weight) + restored_momentum = next(iter(optimizer.state.values()))[ + "momentum_buffer" + ].full_tensor() + torch.testing.assert_close(restored_momentum, saved_momentum) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/trainers/__init__.py b/torchrl/trainers/__init__.py index 6f4456e23bf..2b111d7ff82 100644 --- a/torchrl/trainers/__init__.py +++ b/torchrl/trainers/__init__.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from .learners import FSDP2Learner, Learner, LearnerCapabilities, LocalLearner from .trainers import ( BatchSubSampler, ClearCudaCache, @@ -31,6 +32,10 @@ "CountFramesLog", "DefaultOptimizationStepper", "EarlyStopping", + "FSDP2Learner", + "Learner", + "LearnerCapabilities", + "LocalLearner", "LogScalar", "LogTiming", "LogValidationReward", diff --git a/torchrl/trainers/learners/__init__.py b/torchrl/trainers/learners/__init__.py new file mode 100644 index 00000000000..b44117bbccf --- /dev/null +++ b/torchrl/trainers/learners/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from torchrl.trainers.learners.common import Learner, LearnerCapabilities +from torchrl.trainers.learners.fsdp2 import FSDP2Learner +from torchrl.trainers.learners.local import LocalLearner + +__all__ = [ + "FSDP2Learner", + "Learner", + "LearnerCapabilities", + "LocalLearner", +] diff --git a/torchrl/trainers/learners/common.py b/torchrl/trainers/learners/common.py new file mode 100644 index 00000000000..0bdb6a03536 --- /dev/null +++ b/torchrl/trainers/learners/common.py @@ -0,0 +1,211 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from dataclasses import dataclass + +import torch +from tensordict import is_tensorclass, TensorDictBase +from torch import nn + +from torchrl.objectives.common import LossModule + + +def _clone_tensors(obj): + """Recursively clone every tensor in a (possibly nested) state-dict-like structure. + + ``nn.Module.state_dict()`` and ``Optimizer.state_dict()`` both return + views onto their live tensors, not independent copies, so holding onto + their output as a "checkpoint" while training continues silently mutates + that checkpoint too. + """ + if isinstance(obj, torch.Tensor): + return obj.clone() + if isinstance(obj, dict): + return {k: _clone_tensors(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_clone_tensors(v) for v in obj] + if isinstance(obj, tuple): + return tuple(_clone_tensors(v) for v in obj) + return obj + + +@dataclass +class LearnerCapabilities: + """Declares what a :class:`Learner` implementation supports. + + Algorithm code is not expected to branch on these (that would defeat the + point of the abstraction); they exist for logging, validation, and for + orchestration code that needs to know, e.g., whether a learner can be + checkpointed independently on every rank. + + Attributes: + sharded (bool): whether the learner's parameters are sharded across + multiple devices/processes (e.g. FSDP2). Defaults to ``False``. + remote (bool): whether :meth:`~Learner.update` dispatches to a + separate process rather than running in-line. Defaults to + ``False``. + """ + + sharded: bool = False + remote: bool = False + + +class Learner(nn.Module): + """Base class for the trainable-policy role. + + A :class:`Learner` owns a trainable model and an optimizer and exposes a + single, backend-agnostic entry point, :meth:`update`, for taking one + optimization step on a :class:`~tensordict.TensorDictBase` batch with a + given :class:`~torchrl.objectives.common.LossModule`. Algorithm code calls + ``learner.update(batch, loss_module)`` without knowing whether the update + runs locally on one device, under sharded (e.g. FSDP2) training, or on a + separate remote training process -- that placement is the ``Learner`` + subclass's responsibility, not the algorithm's. + + This mirrors the role :class:`~torchrl.collectors.Collector` plays for + data collection and :class:`~torchrl.modules.llm.LLMWrapperBase` plays for + generation/scoring: a fixed, TensorDict-native contract with multiple + interchangeable backends. + + :meth:`update` is implemented once, here, and is intentionally backend + agnostic: it only touches ``self.model``, ``self.optimizer``, + ``self.clip_grad_norm``, and ``self.grad_accum_steps``, all of which a + subclass sets in its constructor. This is what lets + :class:`~torchrl.trainers.learners.FSDP2Learner` reuse the exact same + training step as :class:`~torchrl.trainers.learners.LocalLearner`: sharded + training only changes how the model is constructed (wrapped with + ``fully_shard``) and how :meth:`get_weights` gathers the result, not how a + step is taken. + + During gradient accumulation, if ``self.model`` exposes + ``set_requires_gradient_sync`` (as an FSDP2 ``fully_shard``-wrapped module + does), :meth:`update` disables cross-rank gradient synchronization on + every accumulation step except the last: gradients still accumulate + correctly (verified against a non-sharded reference), but the + communication (reduce-scatter) only happens once per accumulation window + instead of once per micro-batch. This is a no-op for + :class:`~torchrl.trainers.learners.LocalLearner`, whose model has no such + method. + + .. note:: + ``update`` requires ``loss_module.forward`` to follow the + :class:`~torchrl.objectives.common.LossModule` convention: every + differentiable loss term is returned under a key starting with + ``"loss"`` (these are summed and used for the backward pass); any + other returned entry (e.g. ``accuracy``, a KL for logging) is treated + as a non-differentiable metric and left untouched. + + .. seealso:: :class:`~torchrl.weight_update.WeightSyncScheme` consumes + :meth:`get_weights` to synchronize a learner's parameters to remote + inference workers, so a ``Learner`` composes with the existing + weight-sync machinery without changes on either side. + + .. warning:: + :meth:`state_dict` / :meth:`load_state_dict` are overridden (relative + to plain :class:`torch.nn.Module`) to checkpoint the optimizer state + alongside the model: a bare :class:`~torch.optim.Optimizer` is not an + ``nn.Module``, so the default ``nn.Module.state_dict()`` would + silently omit it, corrupting a training resume (Adam's moments, etc., + would reset). Subclasses whose model/optimizer need + sharding-aware (DTensor) checkpointing, e.g. + :class:`~torchrl.trainers.learners.FSDP2Learner`, override these two + methods again with a DTensor-aware implementation. + """ + + capabilities: LearnerCapabilities = LearnerCapabilities() + + model: nn.Module + optimizer: torch.optim.Optimizer + clip_grad_norm: float | None + grad_accum_steps: int + _accum_step: int + + def update(self, batch: TensorDictBase, loss_module: LossModule) -> TensorDictBase: + """Take one optimization step on ``batch`` using ``loss_module``. + + Args: + batch (TensorDictBase): a batch, in the format expected by + ``loss_module``. + loss_module (LossModule): computes the loss(es) for ``batch``. + Its output's ``"loss"``-prefixed keys are summed and + backpropagated; other keys are passed through for logging. + + Returns: + TensorDictBase: the tensordict returned by ``loss_module``, + augmented with a ``"grad_norm"`` entry when gradient clipping is + enabled. + """ + if self._accum_step == 0: + self.optimizer.zero_grad(set_to_none=True) + + # Sharded (e.g. FSDP2) models can defer the cross-rank gradient + # reduction until the last micro-batch of an accumulation window; + # LocalLearner's plain model has no such method, so this is a no-op. + set_grad_sync = getattr(self.model, "set_requires_gradient_sync", None) + if set_grad_sync is not None: + set_grad_sync(self._accum_step == self.grad_accum_steps - 1) + + loss_td = loss_module(batch) + loss_keys = [k for k in loss_td.keys() if k.startswith("loss")] + if not loss_keys: + raise ValueError( + "loss_module returned no keys starting with 'loss': " + f"{list(loss_td.keys())}. LossModule.forward must return at " + "least one 'loss'-prefixed entry." + ) + total_loss = sum(loss_td.get(k) for k in loss_keys) / self.grad_accum_steps + total_loss.backward() + + self._accum_step += 1 + if self._accum_step < self.grad_accum_steps: + return loss_td + self._accum_step = 0 + + if self.clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.clip_grad_norm + ) + # loss_module may return a strict TensorClass (e.g. RewardModelLossOutput) + # that rejects undeclared keys; convert to a writable TensorDict first. + if is_tensorclass(loss_td): + loss_td = loss_td.to_tensordict() + loss_td.set("grad_norm", grad_norm) + + self.optimizer.step() + return loss_td + + def get_weights(self) -> TensorDictBase: + """Return the learner's current parameters as a tensordict. + + The returned tensordict holds plain (fully materialized) tensors, even + when the learner's parameters are internally sharded, so it is + accepted as-is by :meth:`~torchrl.weight_update.WeightSyncScheme.send`. + This is the seam between the training role and the weight-sync / + inference roles. + """ + raise NotImplementedError + + def state_dict(self, *args, **kwargs) -> dict: # noqa: D417 + """Return a checkpoint covering the model, optimizer, and accumulation state. + + See the class-level warning: this intentionally does not reuse plain + :meth:`torch.nn.Module.state_dict`, which would silently drop the + optimizer's state. The returned tensors are independent clones (both + ``nn.Module.state_dict()`` and ``Optimizer.state_dict()`` otherwise + return views onto the live parameters/state, so an in-place update + after checkpointing would silently corrupt the "saved" checkpoint too). + """ + return { + "model": _clone_tensors(self.model.state_dict()), + "optimizer": _clone_tensors(self.optimizer.state_dict()), + "accum_step": self._accum_step, + } + + def load_state_dict(self, state_dict: dict, *args, **kwargs) -> None: # noqa: D417 + """Restore a checkpoint produced by :meth:`state_dict`.""" + self.model.load_state_dict(state_dict["model"]) + self.optimizer.load_state_dict(state_dict["optimizer"]) + self._accum_step = state_dict.get("accum_step", 0) diff --git a/torchrl/trainers/learners/fsdp2.py b/torchrl/trainers/learners/fsdp2.py new file mode 100644 index 00000000000..a6101060f53 --- /dev/null +++ b/torchrl/trainers/learners/fsdp2.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import torch +from tensordict import TensorDict, TensorDictBase +from torch import nn + +from torchrl.trainers.learners.common import Learner, LearnerCapabilities + +try: + from torch.distributed.checkpoint.state_dict import ( + get_model_state_dict, + get_state_dict, + set_state_dict, + StateDictOptions, + ) + + _has_dist_checkpoint = True +except ImportError: # pragma: no cover - torch without distributed.checkpoint + _has_dist_checkpoint = False + + +class FSDP2Learner(Learner): + """A :class:`~torchrl.trainers.learners.Learner` for FSDP2-sharded models. + + Accepts a model that the caller has already wrapped with + :func:`torch.distributed._composable.fsdp.fully_shard` (typically + per-submodule, then on the root module), and an optimizer built on that + (sharded) model's parameters. ``fully_shard`` must be applied *before* the + optimizer is constructed, so the optimizer holds the sharded + (:class:`~torch.distributed.tensor.DTensor`) parameters rather than the + original ones. + + :meth:`~torchrl.trainers.learners.Learner.update` is inherited unchanged + from :class:`~torchrl.trainers.learners.Learner`: FSDP2's sharding is + transparent to the training step (forward/backward/optimizer-step all + dispatch through ``DTensor`` the same way they would through a regular + tensor). Only two things differ from + :class:`~torchrl.trainers.learners.LocalLearner`: how the model arrives + (already sharded, by the caller) and how :meth:`get_weights` reports it. + + :meth:`get_weights` gathers every sharded leaf into a plain tensor via + :func:`torch.distributed.checkpoint.state_dict.get_model_state_dict`, so + the returned tensordict holds regular tensors and is consumable by + :meth:`~torchrl.weight_update.WeightSyncScheme.send` exactly like + :class:`~torchrl.trainers.learners.LocalLearner`'s, with no changes needed + on the receiving (inference) side. This gather is the seam between + sharded training and (typically replicated) inference, and is the one + place sharding is not transparent. By default the gather targets rank 0 + only (other ranks receive an empty tensordict): gathering the full model + to *every* rank, as a naive per-leaf ``DTensor.full_tensor()`` would, + replicates the whole model in every rank's memory for no benefit, which + does not scale to large sharded models. + + Args: + model (torch.nn.Module): a model already wrapped with ``fully_shard``. + optimizer (torch.optim.Optimizer): an optimizer constructed on + ``model``'s (already-sharded) parameters. + + Keyword Args: + clip_grad_norm (float, optional): as in + :class:`~torchrl.trainers.learners.LocalLearner`. Gradient-norm + clipping dispatches through ``DTensor`` transparently. Defaults to + ``None``. + grad_accum_steps (int, optional): as in + :class:`~torchrl.trainers.learners.LocalLearner`. Defaults to + ``1``. + + .. warning:: + ``FSDP2Learner`` does not itself decide what to shard, at what + granularity, or on what device mesh -- those are model-specific + choices that belong in the caller's model-construction code, exactly + as they would with bare ``fully_shard``. This keeps + ``FSDP2Learner`` a thin adapter rather than a second place those + decisions are made. + + Examples: + >>> import torch + >>> import torch.distributed as dist + >>> from torch import nn + >>> from torch.distributed._composable.fsdp import fully_shard + >>> from torch.distributed.device_mesh import init_device_mesh + >>> from tensordict import TensorDict + >>> from torchrl.objectives.common import LossModule + >>> from torchrl.trainers.learners import FSDP2Learner + >>> + >>> dist.init_process_group(backend="gloo", rank=0, world_size=1) + >>> mesh = init_device_mesh("cpu", (1,)) + >>> model = nn.Sequential(nn.Linear(4, 4), nn.Linear(4, 1)) + >>> for layer in model: + ... fully_shard(layer, mesh=mesh) + >>> fully_shard(model, mesh=mesh) + >>> optimizer = torch.optim.Adam(model.parameters()) + >>> + >>> class ToyLoss(LossModule): + ... def forward(self, batch): + ... pred = self.actor(batch["x"]) + ... return TensorDict({"loss_mse": (pred - batch["y"]).pow(2).mean()}) + >>> loss_module = ToyLoss() + >>> loss_module.actor = model + >>> + >>> learner = FSDP2Learner(model, optimizer) + >>> batch = TensorDict({"x": torch.randn(8, 4), "y": torch.randn(8, 1)}, [8]) + >>> out = learner.update(batch, loss_module) + >>> weights = learner.get_weights() # gathered, plain tensors + >>> dist.destroy_process_group() + """ + + def __init__( + self, + model: nn.Module, + optimizer: torch.optim.Optimizer, + *, + clip_grad_norm: float | None = None, + grad_accum_steps: int = 1, + ) -> None: + if not _has_dist_checkpoint: + raise RuntimeError( + "FSDP2Learner requires torch.distributed.checkpoint.state_dict, " + "which is not available in this torch build." + ) + super().__init__() + self.model = model + self.optimizer = optimizer + self.clip_grad_norm = clip_grad_norm + if grad_accum_steps < 1: + raise ValueError(f"grad_accum_steps must be >= 1, got {grad_accum_steps}.") + self.grad_accum_steps = grad_accum_steps + self._accum_step = 0 + self.capabilities = LearnerCapabilities(sharded=True, remote=False) + + def get_weights(self, *, cpu_offload: bool = True) -> TensorDictBase: + """Gather the sharded model into a plain-tensor tensordict. + + Keyword Args: + cpu_offload (bool, optional): if ``True`` (the default), the + gathered weights are returned only on rank 0 (other ranks get + an empty tensordict) and moved to CPU, avoiding an all-rank + replication of the full model. Set to ``False`` to instead + gather full GPU-resident copies onto every rank. + """ + options = StateDictOptions(full_state_dict=True, cpu_offload=cpu_offload) + state_dict = get_model_state_dict(self.model, options=options) + return TensorDict(state_dict).unflatten_keys(".") + + def state_dict(self, *args, **kwargs) -> dict: + """DTensor-aware checkpoint: gathers model + optimizer state to rank 0. + + Overrides :meth:`~torchrl.trainers.learners.Learner.state_dict`, which + uses plain (non-distributed) ``state_dict()`` calls that would return + raw, per-rank ``DTensor`` shards rather than a portable checkpoint. + """ + options = StateDictOptions(full_state_dict=True, cpu_offload=True) + model_state_dict, optim_state_dict = get_state_dict( + self.model, self.optimizer, options=options + ) + return { + "model": model_state_dict, + "optimizer": optim_state_dict, + "accum_step": self._accum_step, + } + + def load_state_dict(self, state_dict: dict, *args, **kwargs) -> None: + """Restore a checkpoint produced by :meth:`state_dict`. + + ``broadcast_from_rank0=True`` lets rank 0 hold the full checkpoint + (as produced by :meth:`state_dict`) while every rank reshards it + according to its local shards -- the counterpart of the + ``cpu_offload``-gathered save. + """ + options = StateDictOptions(full_state_dict=True, broadcast_from_rank0=True) + set_state_dict( + self.model, + self.optimizer, + model_state_dict=state_dict["model"], + optim_state_dict=state_dict["optimizer"], + options=options, + ) + self._accum_step = state_dict.get("accum_step", 0) diff --git a/torchrl/trainers/learners/local.py b/torchrl/trainers/learners/local.py new file mode 100644 index 00000000000..6c4703d606e --- /dev/null +++ b/torchrl/trainers/learners/local.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import torch +from tensordict import TensorDict, TensorDictBase +from torch import nn + +from torchrl.trainers.learners.common import Learner, LearnerCapabilities + + +class LocalLearner(Learner): + """A single-process, single-device :class:`~torchrl.trainers.learners.Learner`. + + Holds a model and an optimizer, unwrapped and unsharded. This is the + reference implementation the :class:`~torchrl.trainers.learners.Learner` + contract is designed around -- + :class:`~torchrl.trainers.learners.FSDP2Learner` shards the same model + with ``fully_shard`` and reuses :meth:`~torchrl.trainers.learners.Learner.update` + unchanged, so algorithm code written against ``LocalLearner`` runs + unmodified under sharded training. + + Args: + model (torch.nn.Module): the trainable module. Also the source for + :meth:`get_weights`. + optimizer (torch.optim.Optimizer): the optimizer stepping ``model``'s + parameters. + + Keyword Args: + clip_grad_norm (float, optional): if set, gradients are clipped to + this max norm via :func:`torch.nn.utils.clip_grad_norm_` before + the optimizer step, and the resulting norm is written to the + output tensordict under ``"grad_norm"``. Defaults to ``None``. + grad_accum_steps (int, optional): number of :meth:`update` calls to + accumulate gradients over before stepping the optimizer and + zeroing gradients. Defaults to ``1`` (step on every call). + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> from torch import nn + >>> from torchrl.trainers.learners import LocalLearner + >>> from torchrl.objectives.common import LossModule + >>> + >>> class ToyLoss(LossModule): + ... def forward(self, batch): + ... pred = self.actor(batch["x"]) + ... return TensorDict({"loss_mse": (pred - batch["y"]).pow(2).mean()}) + >>> + >>> model = nn.Linear(4, 1) + >>> loss_module = ToyLoss() + >>> loss_module.actor = model + >>> learner = LocalLearner(model, torch.optim.Adam(model.parameters())) + >>> batch = TensorDict({"x": torch.randn(8, 4), "y": torch.randn(8, 1)}, [8]) + >>> out = learner.update(batch, loss_module) + >>> out.get("loss_mse").item() > 0 + True + """ + + def __init__( + self, + model: nn.Module, + optimizer: torch.optim.Optimizer, + *, + clip_grad_norm: float | None = None, + grad_accum_steps: int = 1, + ) -> None: + super().__init__() + self.model = model + self.optimizer = optimizer + self.clip_grad_norm = clip_grad_norm + if grad_accum_steps < 1: + raise ValueError(f"grad_accum_steps must be >= 1, got {grad_accum_steps}.") + self.grad_accum_steps = grad_accum_steps + self._accum_step = 0 + self.capabilities = LearnerCapabilities(sharded=False, remote=False) + + def get_weights(self) -> TensorDictBase: + return TensorDict.from_module(self.model)