Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs_nnx/api_reference/flax.nnx/training/ema.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
EMA
------------------------

.. automodule:: flax.nnx
.. currentmodule:: flax.nnx

.. autoclass:: EMA
:members: __init__, update
1 change: 1 addition & 0 deletions docs_nnx/api_reference/flax.nnx/training/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ Experimental API. See the `NNX page <https://flax.readthedocs.io/en/latest/nnx/i

metrics
optimizer
ema

332 changes: 79 additions & 253 deletions docs_nnx/guides/optimization_cookbook.ipynb

Large diffs are not rendered by default.

89 changes: 20 additions & 69 deletions docs_nnx/guides/optimization_cookbook.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,75 +46,29 @@ y = rngs.normal((32, 8))

# Exponential Moving Average

Neural network see increased robustness when, rather than using only the weights available at the end of training, we use an exponential moving average of the weights produced throughout training. To modify the standard Flax training loop to accomodate calculating exponential moving averages, we can introduce a new `nnx.Variable` subclass for EMA Params, along with a function that converts all variables in a module to this subclass.
Neural networks see increased robustness when, rather than using only the weights available at the end of training, we use an exponential moving average of the weights produced throughout training. NNX provides `nnx.EMA` to make this easy. Simply create an `nnx.EMA` from your model and call `ema.update` after each optimizer step. The averaged parameters are stored in `ema.params`.

```python
class EmaParam(nnx.Variable):
@classmethod
def from_variable(cls, var):
return cls.from_metadata(jnp.copy(var.get_value()), var.get_metadata())

def as_ema_params(node):
return jax.tree.map(
EmaParam.from_variable,
node,
is_leaf=lambda x: isinstance(x, nnx.Variable),
)
```

Now, we'll add a method to update the EMA params based on current model values.

```python
class EMA(nnx.Pytree):
def __init__(self, params, decay: float, *, only: nnx.filterlib.Filter = ...):
self.decay = decay
self.filter = only
self.ema_params = nnx.data(as_ema_params(nnx.state(params, only)))

def update(self, new_params):
def _update(ema_param, new_param):
ema_param[...] = (
self.decay * ema_param[...] + (1.0 - self.decay) * new_param[...]
)
jax.tree.map(
_update,
self.ema_params,
nnx.state(new_params, self.filter),
is_leaf=lambda x: isinstance(x, nnx.Variable),
)
```


The training loop proceeds as normal, but calls `ema.update` after each optimizer step.

```python
# initialization
model = make_model(rngs)
ema = EMA(model, decay=0.9)

# simulate parameter update
def double(param):
param[...] *= 2.0
jax.tree.map(double, model, is_leaf=lambda x: isinstance(x, nnx.Variable))
ema.update(model)
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
ema = nnx.EMA(model, decay=0.9)
ema_model = ema.apply_to(model)

@nnx.jit
def train_step(model, optimizer, ema, x, y):
loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
optimizer.update(model, grads)
ema.update(model)
return loss

optimizer = nnx.Optimizer(
model,
tx=optax.adam(1e-3),
wrt=nnx.Param,
)
losses = []
@nnx.jit
def eval_step(model, x, y):
return loss_fn(model, x, y)

for _ in range(50):
loss = train_step(model, optimizer, ema, x, y)
losses.append(loss)
plt.plot(losses);
train_step(model, optimizer, ema, x, y)

loss = eval_step(ema_model, x, y)
print(f"final eval loss: {loss}")
```

# Low Rank Adaptation
Expand All @@ -129,23 +83,21 @@ def add_rank2_lora(path, node):
return node

base_model = make_model(rngs)
lora_model = nnx.recursive_map(add_rank2_lora, base_model)
model = nnx.recursive_map(add_rank2_lora, base_model)
nnx.display(model)
```

The training loop is the same as before, but we pass `wrt=nnx.LoRAParam` to the optimizer so that only the low-rank adaptation parameters are updated while the base model weights remain frozen.

```python
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.LoRAParam)

@nnx.jit
def train_step(model, optimizer, x, y):
loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
optimizer.update(model, grads)
return loss

optimizer = nnx.Optimizer(
model,
tx=optax.adam(1e-3),
wrt=nnx.LoRAParam,
)

losses = []
for _ in range(50):
Expand Down Expand Up @@ -174,10 +126,7 @@ def train_step(model, optimizer, x, y):
return loss

model = make_model(rngs)
optimizer = nnx.Optimizer(
model,
tx=optax.lbfgs(1e-3),
wrt=nnx.Param)
optimizer = nnx.Optimizer(model, optax.lbfgs(1e-3), wrt=nnx.Param)

losses = []
for _ in range(50):
Expand Down Expand Up @@ -221,7 +170,8 @@ Gradient accumulation in Flax is easy: just use the `optax.MultiSteps` optimizer

```python
model = make_model(rngs)
optimizer = nnx.Optimizer(model, tx=optax.MultiSteps(optax.adam(1e-3), every_k_schedule=3), wrt=nnx.Param)
tx = optax.MultiSteps(optax.adam(1e-3), every_k_schedule=3)
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)

@nnx.jit
def train_step(model, optimizer, x, y):
Expand Down Expand Up @@ -276,3 +226,4 @@ But the model is not:
```python
jax.typeof(model.layers[0].kernel[...])
```

1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@
from .training.optimizer import Optimizer as Optimizer
from .training.optimizer import ModelAndOptimizer as ModelAndOptimizer
from .training.optimizer import OptState as OptState
from .training.ema import EMA as EMA
from .transforms.autodiff import DiffState as DiffState
from .transforms.autodiff import grad as grad
from .transforms.autodiff import value_and_grad as value_and_grad
Expand Down
179 changes: 179 additions & 0 deletions flax/nnx/training/ema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import typing as tp

from flax.nnx import filterlib
from flax.nnx import graphlib
from flax.nnx import pytreelib
from flax.nnx import statelib
from flax.nnx import variablelib
import jax
import jax.numpy as jnp

A = tp.TypeVar('A')


def _to_ema_param(node: tp.Any):
def ema_param(path, x):
if not isinstance(x, variablelib.Variable):
path_str = '/'.join(str(k) for k in path)
raise TypeError(
f"EMA only supports Variable leaves, got {type(x).__name__} at "
f"path '{path_str}'. Use the `only` filter to select Variable leaves."
)
ema_metadata = x.get_metadata()
value = jnp.copy(x.get_value())
return type(x)(value, **ema_metadata)

return jax.tree.map_with_path(
ema_param, node, is_leaf=lambda x: isinstance(x, variablelib.Variable)
)


class EMA(pytreelib.Pytree):
"""Exponential Moving Average (EMA) of parameters.

Maintains a shadow copy of model Variables that is updated as an
exponentially weighted moving average on each call to :meth:`update`.
This is commonly used to stabilize training and improve evaluation
performance by applying the averaged parameters at inference time.

Example usage::

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> import optax
...
>>> model = nnx.Linear(2, 2, rngs=nnx.Rngs(0))
>>> optimizer = nnx.Optimizer(model, optax.sgd(0.1), wrt=nnx.Param)
>>> ema = nnx.EMA(model, decay=0.9)
>>> ema_model = ema.apply_to(model)
...
>>> def loss_fn(model, x, y):
... return jnp.mean((model(x) - y) ** 2)
...
>>> @nnx.jit
... def train_step(model, optimizer, ema, x, y):
... grads = nnx.grad(loss_fn)(model, x, y)
... optimizer.update(model, grads)
... ema.update(model)
...
>>> @nnx.jit
... def eval_step(model, x, y):
... return loss_fn(model, x, y)
...
>>> x, y = jnp.ones((1, 2)), jnp.ones((1, 2))
>>> train_step(model, optimizer, ema, x, y)
>>> loss = eval_step(ema_model, x, y)

In this example, ``ema.update`` computes the moving average and updates
the internal state of ``ema``. ``ema.apply_to`` creates a new model
instance (``ema_model``) that shares its Variables with ``ema``.
Therefore, ``ema_model`` will automatically reflect the updates performed by
``ema.update`` and can be used directly in ``eval_step``.

Attributes:
decay: The decay rate for the exponential moving average.
filter: The filter used to select which variables to track.
params: A pytree of variables holding the current
moving average values.
"""

def __init__(
self,
params: tp.Any,
decay: float,
*,
only: filterlib.Filter = ...,
graph: bool | None = None,
):
"""Initializes the EMA module.

Args:
params: Any object, typically an NNX module/node, whose parameters
will be tracked.
decay: The decay rate for the moving average.
only: A filter indicating which variables should be included in the
EMA tracking. Defaults to matching everything. Note that EMA only
tracks ``nnx.Variable`` instances.
graph: If ``True``, uses graph-mode which supports the full NNX
feature set including shared references. If ``False``, uses
tree-mode which treats Modules as regular JAX pytrees, avoiding
the overhead of the graph protocol. If ``None`` (default), the
value is determined by the current ``nnx.set_graph_mode`` context.
"""
self.graph = graph
self.decay = decay
self.filter = only
self.params: graphlib.State = pytreelib.data(
_to_ema_param(graphlib.state(params, only, graph=graph))
)

def update(self, updates: tp.Any) -> None:
"""Updates the EMA parameters towards the given new parameters.

The update rule for each parameter is::

ema = decay * ema + (1 - decay) * update

Args:
updates: The new parameters or module to blend into the current EMA.
This should have the same structure as the ``params`` object passed
during initialization.
"""
def _update_ema(ema: variablelib.Variable, update: tp.Any) -> tp.Any:
ema[...] = self.decay * ema + (1.0 - self.decay) * update

jax.tree.map(
_update_ema,
self.params,
graphlib.state(updates, self.filter, graph=self.graph),
is_leaf=lambda x: isinstance(x, variablelib.Variable),
)

def apply_to(self, model: A) -> A:
"""Returns a view of the model using the EMA parameters.

Constructs a new model instance with the same structure as ``model``
but whose tracked parameters are replaced by their exponential moving
average values. Non-tracked state (e.g. variables excluded by the
``only`` filter) is preserved from the original ``model``.

This is typically used at evaluation time to obtain a model whose
parameters reflect the smoothed training trajectory.

Example usage::

>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> model = nnx.Linear(2, 2, use_bias=False, rngs=nnx.Rngs(0))
>>> ema = nnx.EMA(model, decay=0.9)
>>> ema_model = ema.apply_to(model)
>>> assert ema_model.kernel is ema.params.kernel

Args:
model: A model instance whose graph structure is used to build
the output. The model should have the same structure as the
``params`` originally passed to :class:`EMA`.

Returns:
A new model of the same type as ``model`` with tracked parameters
replaced by the current EMA values.
"""
graphdef, state = graphlib.split(model, graph=self.graph)
merged_state = statelib.merge_state(state, self.params)
return graphlib.merge(graphdef, merged_state)
2 changes: 1 addition & 1 deletion flax/traceback_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# Whether to filter flax frames from traceback.
_flax_filter_tracebacks = config.flax_filter_frames
# Flax specific set of paths to exclude from tracebacks.
_flax_exclusions = set()
_flax_exclusions: set[str] = set()


# re-import JAX symbol for convenience.
Expand Down
2 changes: 1 addition & 1 deletion flax/training/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,7 +1163,7 @@ def read_chunk(i):
else:
checkpoint_contents = fp.read()

state_dict = serialization.msgpack_restore(checkpoint_contents)
state_dict = serialization.msgpack_restore(checkpoint_contents) # type: ignore[arg-type]
state_dict = _restore_mpas(
state_dict,
target,
Expand Down
Loading
Loading