[Feature] Add Learner primitive (LocalLearner, FSDP2Learner)#3926
Open
theap06 wants to merge 3 commits into
Open
[Feature] Add Learner primitive (LocalLearner, FSDP2Learner)#3926theap06 wants to merge 3 commits into
theap06 wants to merge 3 commits into
Conversation
Introduces torchrl.trainers.Learner: a backend-agnostic entry point for taking one optimization step on a tensordict batch with a given LossModule. LocalLearner is the single-process reference implementation. Mirrors the role Collector plays for data collection and LLMWrapperBase plays for generation/scoring, so training placement (local / sharded / remote) becomes a swappable backend behind a fixed contract instead of a hand-rolled loop per recipe. get_weights() returns a TensorDictBase, matching what WeightSyncScheme.send() already accepts, so this composes with the existing weight-sync path unchanged.
Refactors Learner.update() to live in the base class as concrete, generic step logic (zero_grad -> forward -> sum loss_* keys -> backward -> clip -> step), operating only on self.model/self.optimizer/self.clip_grad_norm/ self.grad_accum_steps. This is what lets FSDP2Learner reuse the exact same training step as LocalLearner: sharded training only changes how the model is constructed (wrapped with fully_shard by the caller) and how get_weights() gathers the result. FSDP2Learner.get_weights() gathers every DTensor leaf via full_tensor() into a plain tensor, so its output is consumable by WeightSyncScheme exactly like LocalLearner's, with no changes on the receiving side. Verified on a single-rank (world_size=1) gloo process group, which exercises the real fully_shard()/DTensor code path without a cluster: forward/backward/clip_grad_norm_/optimizer.step() dispatch correctly through DTensor, TensorDict.from_module()/.apply() handle DTensor leaves transparently, and FSDP2Learner produces bit-exact losses and gathered weights vs LocalLearner given the same seed/data/lr.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3926
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…d-sync
Fixes three real gaps in FSDP2Learner identified after the initial PR:
1. get_weights() previously gathered every DTensor leaf to EVERY rank via
full_tensor(), replicating the whole model in every rank's memory for
no benefit -- does not scale to large sharded models. Now uses
torch.distributed.checkpoint.state_dict.get_model_state_dict with
StateDictOptions(full_state_dict=True, cpu_offload=True), which
gathers to rank 0 only (other ranks get an empty tensordict) by
documented DCP semantics. Verified: correct nested key shape via
unflatten_keys('.'), matches the prior full_tensor()-based output.
2. No sharded-checkpoint path existed. Learner (base) gains real
state_dict()/load_state_dict() covering model + optimizer state
(a bare Optimizer is not an nn.Module, so plain nn.Module.state_dict()
silently drops it -- a real, previously-latent bug for LocalLearner
too). Both overrides clone their tensors: nn.Module.state_dict() and
Optimizer.state_dict() return views onto live tensors, not copies, so
without cloning, further training after checkpointing would silently
corrupt the saved checkpoint (caught by round-trip tests). FSDP2Learner
overrides these two methods again with get_state_dict/set_state_dict
(DCP-aware, handles DTensor optimizer state, rank0-only via
cpu_offload). Verified end to end: model weights AND Adam/SGD-momentum
optimizer state round-trip correctly through save/perturb/load.
3. grad_accum_steps>1 was untested on FSDP2Learner, and update() did no
FSDP2-specific optimization during accumulation. Learner.update() now
toggles model.set_requires_gradient_sync(...) when the model exposes
it (FSDP2-wrapped models do; LocalLearner's plain model doesn't, so
this is a no-op there), deferring the cross-rank gradient reduction
until the last micro-batch of an accumulation window instead of
reducing on every micro-batch. Verified against a non-sharded
reference: the accumulated, synced gradient after 2 microbatches
matches a plain model accumulating the same 2 microbatches exactly.
Still unverified (unchanged from the original PR): all of the above is
tested at world_size=1 (single-rank gloo), which exercises the real
fully_shard()/DTensor/DCP code paths but not actual cross-rank
communication or memory behavior.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Introduces
torchrl.trainers.Learner: a backend-agnostic entry point fortaking one optimization step on a tensordict batch with a given
LossModule.LocalLearneris the single-process reference implementation;FSDP2Learnershards the same model withfully_shardand reuses thetraining step unchanged.
Design
Learner.update()is concrete, in the base class, and touches onlyself.model/self.optimizer/self.clip_grad_norm/self.grad_accum_steps: zero_grad -> forward -> sum the loss_module's"loss"-prefixed output keys -> backward -> optional grad-norm clip ->optimizer step. This is what lets
FSDP2Learnerreuse the exact same stepas
LocalLearner-- sharded training only changes model construction andweight gathering, not the step itself.
get_weights()is the one place sharding is not transparent: it mustreturn plain tensors (for
WeightSyncScheme.send, which already accepts aTensorDictBase), even when the learner's parameters are sharded.FSDP2Learner.get_weights()gathers everyDTensorleaf viafull_tensor().FSDP2Learnerdoes not decide sharding granularity or device mesh -- itaccepts a model the caller has already wrapped with
fully_shard, exactlyas bare FSDP2 usage works. Keeping that decision in caller code avoids
FSDP2Learnerbecoming a second place those choices are made.Planned follow-ups (not in this PR)
FSDP2Learnerverification on real multi-GPU hardware --the one thing I could not test here.
reward-model recipe, once [Feature] Add RewardModelLoss objective for RLHF reward-model training #3922 lands) onto
LocalLearner, as thefirst real consumer.
RemoteLearnerdesign writeup scoping one external backend(TorchTitan is the most PyTorch-native of the candidates and the
most plausible first integration target) before any implementation.