Skip to content

[Feature] Add Learner primitive (LocalLearner, FSDP2Learner)#3926

Open
theap06 wants to merge 3 commits into
pytorch:mainfrom
theap06:learner-primitive-clean
Open

[Feature] Add Learner primitive (LocalLearner, FSDP2Learner)#3926
theap06 wants to merge 3 commits into
pytorch:mainfrom
theap06:learner-primitive-clean

Conversation

@theap06

@theap06 theap06 commented Jul 2, 2026

Copy link
Copy Markdown
Contributor

Summary

Introduces torchrl.trainers.Learner: a backend-agnostic entry point for
taking one optimization step on a tensordict batch with a given
LossModule. LocalLearner is the single-process reference implementation;
FSDP2Learner shards the same model with fully_shard and reuses the
training step unchanged.

Design

  • Learner.update() is concrete, in the base class, and touches only
    self.model / self.optimizer / self.clip_grad_norm /
    self.grad_accum_steps: zero_grad -> forward -> sum the loss_module's
    "loss"-prefixed output keys -> backward -> optional grad-norm clip ->
    optimizer step. This is what lets FSDP2Learner reuse the exact same step
    as LocalLearner -- sharded training only changes model construction and
    weight gathering, not the step itself.
  • get_weights() is the one place sharding is not transparent: it must
    return plain tensors (for WeightSyncScheme.send, which already accepts a
    TensorDictBase), even when the learner's parameters are sharded.
    FSDP2Learner.get_weights() gathers every DTensor leaf via
    full_tensor().
  • FSDP2Learner does not decide sharding granularity or device mesh -- it
    accepts a model the caller has already wrapped with fully_shard, exactly
    as bare FSDP2 usage works. Keeping that decision in caller code avoids
    FSDP2Learner becoming a second place those choices are made.

Planned follow-ups (not in this PR)

  1. Multi-rank FSDP2Learner verification on real multi-GPU hardware --
    the one thing I could not test here.
  2. Refactor an existing recipe's hand-rolled training loop (e.g. the
    reward-model recipe, once [Feature] Add RewardModelLoss objective for RLHF reward-model training #3922 lands) onto LocalLearner, as the
    first real consumer.
  3. A RemoteLearner design writeup scoping one external backend
    (TorchTitan is the most PyTorch-native of the candidates and the
    most plausible first integration target) before any implementation.

theap06 added 2 commits July 2, 2026 01:22
Introduces torchrl.trainers.Learner: a backend-agnostic entry point for
taking one optimization step on a tensordict batch with a given LossModule.
LocalLearner is the single-process reference implementation.

Mirrors the role Collector plays for data collection and LLMWrapperBase
plays for generation/scoring, so training placement (local / sharded /
remote) becomes a swappable backend behind a fixed contract instead of a
hand-rolled loop per recipe.

get_weights() returns a TensorDictBase, matching what
WeightSyncScheme.send() already accepts, so this composes with the
existing weight-sync path unchanged.
Refactors Learner.update() to live in the base class as concrete, generic
step logic (zero_grad -> forward -> sum loss_* keys -> backward -> clip ->
step), operating only on self.model/self.optimizer/self.clip_grad_norm/
self.grad_accum_steps. This is what lets FSDP2Learner reuse the exact same
training step as LocalLearner: sharded training only changes how the model
is constructed (wrapped with fully_shard by the caller) and how
get_weights() gathers the result.

FSDP2Learner.get_weights() gathers every DTensor leaf via full_tensor()
into a plain tensor, so its output is consumable by WeightSyncScheme
exactly like LocalLearner's, with no changes on the receiving side.

Verified on a single-rank (world_size=1) gloo process group, which
exercises the real fully_shard()/DTensor code path without a cluster:
forward/backward/clip_grad_norm_/optimizer.step() dispatch correctly
through DTensor, TensorDict.from_module()/.apply() handle DTensor leaves
transparently, and FSDP2Learner produces bit-exact losses and gathered
weights vs LocalLearner given the same seed/data/lr.
@pytorch-bot

pytorch-bot Bot commented Jul 2, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3926

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 2, 2026
@github-actions github-actions Bot added Feature New feature Documentation Improvements or additions to documentation Trainers labels Jul 2, 2026
…d-sync

Fixes three real gaps in FSDP2Learner identified after the initial PR:

1. get_weights() previously gathered every DTensor leaf to EVERY rank via
   full_tensor(), replicating the whole model in every rank's memory for
   no benefit -- does not scale to large sharded models. Now uses
   torch.distributed.checkpoint.state_dict.get_model_state_dict with
   StateDictOptions(full_state_dict=True, cpu_offload=True), which
   gathers to rank 0 only (other ranks get an empty tensordict) by
   documented DCP semantics. Verified: correct nested key shape via
   unflatten_keys('.'), matches the prior full_tensor()-based output.

2. No sharded-checkpoint path existed. Learner (base) gains real
   state_dict()/load_state_dict() covering model + optimizer state
   (a bare Optimizer is not an nn.Module, so plain nn.Module.state_dict()
   silently drops it -- a real, previously-latent bug for LocalLearner
   too). Both overrides clone their tensors: nn.Module.state_dict() and
   Optimizer.state_dict() return views onto live tensors, not copies, so
   without cloning, further training after checkpointing would silently
   corrupt the saved checkpoint (caught by round-trip tests). FSDP2Learner
   overrides these two methods again with get_state_dict/set_state_dict
   (DCP-aware, handles DTensor optimizer state, rank0-only via
   cpu_offload). Verified end to end: model weights AND Adam/SGD-momentum
   optimizer state round-trip correctly through save/perturb/load.

3. grad_accum_steps>1 was untested on FSDP2Learner, and update() did no
   FSDP2-specific optimization during accumulation. Learner.update() now
   toggles model.set_requires_gradient_sync(...) when the model exposes
   it (FSDP2-wrapped models do; LocalLearner's plain model doesn't, so
   this is a no-op there), deferring the cross-rank gradient reduction
   until the last micro-batch of an accumulation window instead of
   reducing on every micro-batch. Verified against a non-sharded
   reference: the accumulated, synced gradient after 2 microbatches
   matches a plain model accumulating the same 2 microbatches exactly.

Still unverified (unchanged from the original PR): all of the above is
tested at world_size=1 (single-rank gloo), which exercises the real
fully_shard()/DTensor/DCP code paths but not actual cross-rank
communication or memory behavior.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Documentation Improvements or additions to documentation Feature New feature Trainers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant