Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/reference/trainers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,4 @@ Documentation Sections
trainers_basics
trainers_loggers
trainers_hooks
trainers_learners
36 changes: 36 additions & 0 deletions docs/source/reference/trainers_learners.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
.. currentmodule:: torchrl.trainers

Learners
========

.. _ref_learners:

A :class:`~torchrl.trainers.Learner` owns a trainable model and exposes a
single, backend-agnostic entry point -- :meth:`~torchrl.trainers.Learner.update`
-- for taking one optimization step on a tensordict batch with a given
:class:`~torchrl.objectives.common.LossModule`. It plays the same role for
training that :class:`~torchrl.collectors.Collector` plays for data collection
and :class:`~torchrl.modules.llm.LLMWrapperBase` plays for generation/scoring:
a fixed contract with interchangeable backends, so algorithm code does not
need to know whether the update runs on one device, under sharded training, or
on a remote training process.

:class:`~torchrl.trainers.LocalLearner` is the single-process reference
implementation. :class:`~torchrl.trainers.FSDP2Learner` shards the same model
with :func:`torch.distributed._composable.fsdp.fully_shard` and reuses
:meth:`~torchrl.trainers.Learner.update` unchanged -- FSDP2's sharding is
transparent to the training step; only construction (the caller wraps the
model before handing it to the learner) and :meth:`~torchrl.trainers.Learner.get_weights`
(which gathers sharded parameters into plain tensors) differ. Either
learner's :meth:`~torchrl.trainers.Learner.get_weights` output is accepted
as-is by :class:`~torchrl.weight_update.WeightSyncScheme`, so a ``Learner``
composes with the existing weight-sync path without changes on either side.

.. autosummary::
:toctree: generated/
:template: rl_template.rst

Learner
LearnerCapabilities
LocalLearner
FSDP2Learner
333 changes: 332 additions & 1 deletion test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,15 @@
TensorDictReplayBuffer,
)
from torchrl.envs.libs.gym import _has_gym
from torchrl.objectives.common import LossModule
from torchrl.testing import PONG_VERSIONED
from torchrl.trainers import LogValidationReward, Trainer
from torchrl.trainers import (
FSDP2Learner,
Learner,
LocalLearner,
LogValidationReward,
Trainer,
)
from torchrl.trainers.algorithms.cql import CQLTrainer
from torchrl.trainers.algorithms.ddpg import DDPGTrainer
from torchrl.trainers.algorithms.dqn import DQNTrainer
Expand Down Expand Up @@ -1562,6 +1569,330 @@ def test_train_stops_early(self):
assert collector.shutdown_calls == 1


class _ToyRegressionLoss(LossModule):
"""Minimal LossModule: MSE loss plus a non-loss logging metric."""

def __init__(self, model: nn.Module):
super().__init__()
self.model = model

def forward(self, batch: TensorDict) -> TensorDict:
pred = self.model(batch["x"])
loss = (pred - batch["y"]).pow(2).mean()
with torch.no_grad():
metric = pred.mean() # a non-"loss"-prefixed field, for logging only
return TensorDict({"loss_mse": loss, "metric": metric})


class _NoLossKeyModule(LossModule):
"""A LossModule whose output has no 'loss'-prefixed key (misconfigured)."""

def __init__(self, model: nn.Module):
super().__init__()
self.model = model

def forward(self, batch: TensorDict) -> TensorDict:
pred = self.model(batch["x"])
return TensorDict({"mse": (pred - batch["y"]).pow(2).mean()})


class TestLocalLearner:
@staticmethod
def _make(clip_grad_norm=None, grad_accum_steps=1, lr=0.1):
torch.manual_seed(0)
model = nn.Linear(4, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
learner = LocalLearner(
model,
optimizer,
clip_grad_norm=clip_grad_norm,
grad_accum_steps=grad_accum_steps,
)
loss_module = _ToyRegressionLoss(model)
return learner, loss_module, model

@staticmethod
def _batch(n=8):
return TensorDict(
{"x": torch.randn(n, 4), "y": torch.randn(n, 1)}, batch_size=[n]
)

def test_update_returns_loss_and_metric(self):
learner, loss_module, _ = self._make()
out = learner.update(self._batch(), loss_module)
assert "loss_mse" in out.keys()
assert "metric" in out.keys()
assert out.get("loss_mse").shape == ()

def test_update_steps_the_optimizer(self):
learner, loss_module, model = self._make(lr=0.5)
before = model.weight.clone()
learner.update(self._batch(), loss_module)
assert not torch.equal(before, model.weight)

def test_loss_decreases_over_steps(self):
learner, loss_module, _ = self._make(lr=0.1)
torch.manual_seed(1)
batch = self._batch(64)
first = learner.update(batch, loss_module).get("loss_mse").item()
for _ in range(20):
last = learner.update(batch, loss_module).get("loss_mse").item()
assert last < first

def test_clip_grad_norm_writes_grad_norm(self):
learner, loss_module, _ = self._make(clip_grad_norm=0.5)
out = learner.update(self._batch(), loss_module)
assert "grad_norm" in out.keys()
assert out.get("grad_norm") >= 0

def test_no_clip_grad_norm_absent(self):
learner, loss_module, _ = self._make(clip_grad_norm=None)
out = learner.update(self._batch(), loss_module)
assert "grad_norm" not in out.keys()

def test_missing_loss_key_raises(self):
learner, _, model = self._make()
with pytest.raises(ValueError, match="no keys starting with 'loss'"):
learner.update(self._batch(), _NoLossKeyModule(model))

def test_grad_accumulation_defers_step(self):
learner, loss_module, model = self._make(grad_accum_steps=2, lr=1.0)
before = model.weight.clone()
learner.update(self._batch(), loss_module)
# optimizer has not stepped yet: weights unchanged after the 1st of 2 accum steps
assert torch.equal(before, model.weight)
learner.update(self._batch(), loss_module)
# after the 2nd accum step, the optimizer has stepped
assert not torch.equal(before, model.weight)

def test_get_weights_matches_model_params(self):
learner, _, model = self._make()
weights = learner.get_weights()
torch.testing.assert_close(weights["weight"], model.weight)
torch.testing.assert_close(weights["bias"], model.bias)

def test_invalid_grad_accum_steps_raises(self):
model = nn.Linear(4, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
with pytest.raises(ValueError, match="grad_accum_steps must be"):
LocalLearner(model, optimizer, grad_accum_steps=0)

def test_base_learner_get_weights_is_abstract(self):
# update() is concrete on Learner (shared by LocalLearner/FSDP2Learner);
# only get_weights() -- the backend-specific gather -- remains abstract.
with pytest.raises(NotImplementedError):
Learner().get_weights()

def test_state_dict_round_trip_includes_optimizer(self):
# plain nn.Module.state_dict() would silently drop the optimizer's
# state (Optimizer is not an nn.Module) -- this is the regression test
# for that gap. Momentum must be nonzero, or SGD never populates a
# momentum_buffer in its state to begin with.
torch.manual_seed(0)
model = nn.Linear(4, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
learner = LocalLearner(model, optimizer)
loss_module = _ToyRegressionLoss(model)
torch.manual_seed(2)
for _ in range(3):
learner.update(self._batch(), loss_module)
checkpoint = learner.state_dict()
saved_weight = model.weight.clone()
saved_momentum = next(iter(learner.optimizer.state.values()))[
"momentum_buffer"
].clone()

with torch.no_grad():
model.weight.zero_()
for state in learner.optimizer.state.values():
state["momentum_buffer"].zero_()

learner.load_state_dict(checkpoint)
torch.testing.assert_close(model.weight, saved_weight)
restored_momentum = next(iter(learner.optimizer.state.values()))[
"momentum_buffer"
]
torch.testing.assert_close(restored_momentum, saved_momentum)


@pytest.mark.skipif(
not torch.distributed.is_available(), reason="torch.distributed required"
)
class TestFSDP2Learner:
"""FSDP2Learner exercised on a single-rank process group (gloo/CPU).

world_size=1 does not exercise cross-rank sharding, but it runs the exact
fully_shard()/DTensor code path FSDP2Learner is built on, so these tests
catch real API breakage, not just interface stubs.
"""

_PORT = "29601"

@pytest.fixture(autouse=True)
def _process_group(self):
import torch.distributed as dist

os.environ.setdefault("MASTER_ADDR", "localhost")
os.environ["MASTER_PORT"] = self._PORT
dist.init_process_group(backend="gloo", rank=0, world_size=1)
try:
yield
finally:
dist.destroy_process_group()

@staticmethod
def _make(clip_grad_norm=None, grad_accum_steps=1, lr=0.1, seed=0):
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.device_mesh import init_device_mesh

torch.manual_seed(seed)
model = nn.Linear(4, 1)
mesh = init_device_mesh("cpu", (1,))
fully_shard(model, mesh=mesh)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
learner = FSDP2Learner(
model,
optimizer,
clip_grad_norm=clip_grad_norm,
grad_accum_steps=grad_accum_steps,
)
loss_module = _ToyRegressionLoss(model)
return learner, loss_module, model

@staticmethod
def _batch(n=8, seed=None):
if seed is not None:
torch.manual_seed(seed)
return TensorDict(
{"x": torch.randn(n, 4), "y": torch.randn(n, 1)}, batch_size=[n]
)

def test_update_returns_loss_and_metric(self):
learner, loss_module, _ = self._make()
out = learner.update(self._batch(seed=1), loss_module)
assert "loss_mse" in out.keys()
assert "metric" in out.keys()

def test_clip_grad_norm_writes_grad_norm(self):
learner, loss_module, _ = self._make(clip_grad_norm=0.5)
out = learner.update(self._batch(seed=1), loss_module)
assert "grad_norm" in out.keys()
assert out.get("grad_norm") >= 0

def test_get_weights_returns_plain_tensors(self):
learner, _, _ = self._make()
weights = learner.get_weights()
for leaf in weights.values(True, True):
assert not isinstance(leaf, torch.distributed.tensor.DTensor)

def test_capabilities_report_sharded(self):
learner, _, _ = self._make()
assert learner.capabilities.sharded
assert not learner.capabilities.remote

def test_matches_local_learner_bit_exact(self):
"""The whole point of the abstraction: same update() logic, same
result, whether the model is sharded or not."""
batch = self._batch(n=8, seed=42)

fsdp2_learner, fsdp2_loss, fsdp2_model = self._make(
clip_grad_norm=1.0, lr=0.5, seed=0
)
fsdp2_out = fsdp2_learner.update(batch, fsdp2_loss)

torch.manual_seed(0)
local_model = nn.Linear(4, 1)
local_optimizer = torch.optim.SGD(local_model.parameters(), lr=0.5)
local_learner = LocalLearner(local_model, local_optimizer, clip_grad_norm=1.0)
local_loss = _ToyRegressionLoss(local_model)
local_out = local_learner.update(batch, local_loss)

assert fsdp2_out.get("loss_mse").item() == local_out.get("loss_mse").item()
gathered = fsdp2_learner.get_weights()
torch.testing.assert_close(gathered["weight"], local_model.weight)
torch.testing.assert_close(gathered["bias"], local_model.bias)

def test_grad_accumulation_defers_step(self):
learner, loss_module, model = self._make(grad_accum_steps=2, lr=1.0)
before = model.weight.full_tensor().clone()
learner.update(self._batch(seed=1), loss_module)
assert torch.equal(before, model.weight.full_tensor())
learner.update(self._batch(seed=2), loss_module)
assert not torch.equal(before, model.weight.full_tensor())

def test_grad_accumulation_matches_non_sharded_reference(self):
"""Regression test for the set_requires_gradient_sync(False) toggle
update() uses to skip communication on non-final accumulation steps:
the accumulated gradient must still equal a plain (non-sharded)
reference that accumulates the same two micro-batches."""
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.device_mesh import init_device_mesh

batch1, batch2 = self._batch(seed=10), self._batch(seed=11)

# update() divides each micro-batch's loss by grad_accum_steps before
# backward (accumulation averages, rather than sums, the microbatches),
# so the reference must apply the same scaling.
torch.manual_seed(0)
ref_model = nn.Linear(4, 1)
ref_loss = _ToyRegressionLoss(ref_model)
ref_model.zero_grad()
(ref_loss(batch1).get("loss_mse") / 2).backward()
(ref_loss(batch2).get("loss_mse") / 2).backward()
ref_grad = ref_model.weight.grad.clone()

torch.manual_seed(0)
model = nn.Linear(4, 1)
mesh = init_device_mesh("cpu", (1,))
fully_shard(model, mesh=mesh)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
learner = FSDP2Learner(model, optimizer, grad_accum_steps=2)
loss_module = _ToyRegressionLoss(model)
learner.update(batch1, loss_module) # sync disabled; .grad stays None
learner.update(batch2, loss_module) # sync re-enabled; optimizer.step()
# optimizer.step() does not clear .grad (only zero_grad() does), so the
# fully-accumulated, synced gradient is still readable here.
accumulated_grad = model.weight.grad.full_tensor()
torch.testing.assert_close(accumulated_grad, ref_grad)

def test_state_dict_round_trip_includes_optimizer(self):
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.device_mesh import init_device_mesh

# momentum must be nonzero, or SGD never populates optimizer state to
# begin with, and the DCP loader has nothing meaningful to restore.
torch.manual_seed(0)
model = nn.Linear(4, 1)
mesh = init_device_mesh("cpu", (1,))
fully_shard(model, mesh=mesh)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
learner = FSDP2Learner(model, optimizer)
loss_module = _ToyRegressionLoss(model)

torch.manual_seed(3)
for _ in range(3):
learner.update(self._batch(), loss_module)
checkpoint = learner.state_dict()
saved_weight = model.weight.full_tensor().clone()
saved_momentum = (
next(iter(optimizer.state.values()))["momentum_buffer"]
.full_tensor()
.clone()
)

with torch.no_grad():
model.weight.to_local().zero_()
for state in optimizer.state.values():
state["momentum_buffer"].to_local().zero_()

learner.load_state_dict(checkpoint)
torch.testing.assert_close(model.weight.full_tensor(), saved_weight)
restored_momentum = next(iter(optimizer.state.values()))[
"momentum_buffer"
].full_tensor()
torch.testing.assert_close(restored_momentum, saved_momentum)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
Loading
Loading