diff --git a/docs/reliability.md b/docs/reliability.md index 1a2d3d184..dbbf5cf45 100644 --- a/docs/reliability.md +++ b/docs/reliability.md @@ -21,6 +21,31 @@ training step count. By default, checkpointing is disabled if `checkpoint_root_directory` is not specified. Users can further customize checkpointing behavior via `checkpointing_options` in the config. +Users customize background preservation behavior granularly using components +defined inside `checkpoint_options`: + +* **Save Decision Policies**: Dictates when to initiate a checkpoint based on + defined steps or intervals. Supported configurations include + `FixedIntervalPolicy` and `ContinuousCheckpointingPolicy`. The default is + `ContinuousCheckpointingPolicy(minimum_interval_secs=180)` (saves every 180 + seconds). See Orbax v1 [`save_decision_policies.py`](https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/experimental/v1/_src/training/save_decision_policies.py) + for the complete interface contracts. +* **Preservation Policies**: Sets specifications regarding tracking + checkpoints over bounded timelines (e.g., `LatestN`). The default is + `LatestN(n=3)` (keeps the latest 3 checkpoints). See Orbax v1 + [`preservation_policies.py`](https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/experimental/v1/_src/training/preservation_policies.py) + for the complete interface contracts. +* **Step Name Format**: Defines the representation of directory names for step + checkpoints. The default is `ocp.path.step.standard_name_format()` (uses + simple integer step names). +* **Asynchronous Processing**: Manage asynchronous behavior by specifying: + * `enable_async_checkpointing`: Whether to use async checkpointing. + Defaults to `True`. **It is recommended to keep this enabled** to + prevent the main thread from blocking during training runs while + checkpoints are written to storage. + * `timeout_secs`: The timeout for asynchronous operations. + Defaults to `1200` seconds. + ## Fault Tolerance Tunix ensures fault tolerance primarily through its checkpointing mechanism, diff --git a/tests/sft/checkpoint_manager_test.py b/tests/sft/checkpoint_manager_test.py index 63d120744..8c80c592a 100644 --- a/tests/sft/checkpoint_manager_test.py +++ b/tests/sft/checkpoint_manager_test.py @@ -16,6 +16,7 @@ import os import tempfile +from unittest import mock from absl.testing import absltest from absl.testing import parameterized from etils import epath @@ -28,6 +29,7 @@ import optax import qwix from tunix.sft import checkpoint_manager +from tunix.sft import checkpoint_options os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4' @@ -110,12 +112,25 @@ def test_empty_root_directory(self): def test_checkpoint_manager_options_none_sets_default(self): cp_path = f'{self.temp_path}/{self.id()}' cp_manager = checkpoint_manager.CheckpointManager(cp_path, options=None) - self.assertIsNotNone(cp_manager._checkpoint_manager) + self.assertIsNotNone(cp_manager._checkpointer) self.assertEqual( - cp_manager._checkpoint_manager._options, # pytype: disable=attribute-error - checkpoint_manager._DEFAULT_CHECKPOINTING_OPTIONS, + cp_manager._options, + checkpoint_options.DEFAULT_CHECKPOINTING_OPTIONS, ) + def test_context_property(self): + cp_path = f'{self.temp_path}/{self.id()}' + cp_manager = checkpoint_manager.CheckpointManager(cp_path) + self.assertIsNotNone(cp_manager._context) + + def test_context_property_with_pathways(self): + with mock.patch.dict(os.environ, {'JAX_PLATFORMS': 'proxy'}): + cp_path = f'{self.temp_path}/{self.id()}' + cp_manager = checkpoint_manager.CheckpointManager(cp_path) + self.assertIsNotNone(cp_manager._context) + self.assertFalse(cp_manager._context.array_options.saving.use_ocdbt) + self.assertFalse(cp_manager._context.array_options.saving.use_zarr3) + def test_save(self): cp_path = f'{self.temp_path}/{self.id()}' cp_manager = checkpoint_manager.CheckpointManager(cp_path) @@ -123,7 +138,8 @@ def test_save(self): # Save the model state. self.assertTrue(cp_manager.save(1, model)) - cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error + assert cp_manager._checkpointer is not None + cp_manager._checkpointer.wait() self.assertEqual(cp_manager.latest_step(), 1) cp_manager.close() @@ -139,7 +155,8 @@ def test_restore(self): # Save the model params. self.assertTrue(cp_manager.save(1, model)) - cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error + assert cp_manager._checkpointer is not None + cp_manager._checkpointer.wait() # Change the model state. changed_state = jax.tree.map(lambda x: x + 1, nnx.state(model)) @@ -162,7 +179,8 @@ def test_restore_different_sharding(self): # Save the model params. self.assertTrue(cp_manager.save(1, unsharded_model)) - cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error + assert cp_manager._checkpointer is not None + cp_manager._checkpointer.wait() # Restore the model without shardings. self.assertEqual(cp_manager.maybe_restore(unsharded_model), (1, {})) @@ -211,7 +229,8 @@ def test_restore_with_lora(self): # Save the model params. self.assertTrue(cp_manager.save(1, model, save_only_lora_params=True)) - cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error + assert cp_manager._checkpointer is not None + cp_manager._checkpointer.wait() # Change the model state. changed_state = jax.tree.map(lambda x: x + 1, nnx.state(model)) @@ -235,13 +254,86 @@ def test_restore_with_lora(self): nnx.state(model, nnx.filterlib.Not(nnx.LoRAParam)), ) + def test_restore_only_lora_params(self): + cp_path = f'{self.temp_path}/{self.id()}' + cp_manager = checkpoint_manager.CheckpointManager(cp_path) + model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh) + lora_provider = qwix.LoraProvider( + module_path='.*w1', + rank=4, + alpha=2.0, + ) + dummy_model_input = { + 'x': jnp.ones(2, dtype=jnp.int32), + } + model = qwix.apply_lora_to_model(model, lora_provider, **dummy_model_input) + expected_lora_state = nnx.clone(nnx.state(model, nnx.LoRAParam)) + changed_non_lora_state = jax.tree.map( + lambda x: x + 2, nnx.state(model, (nnx.filterlib.Not(nnx.LoRAParam))) + ) + + # Save the model params (entire model). + self.assertTrue(cp_manager.save(1, model, save_only_lora_params=False)) + assert cp_manager._checkpointer is not None + cp_manager._checkpointer.wait() + + # Change the model state. + nnx.update( + model, jax.tree.map(lambda x: x + 1, nnx.state(model, nnx.LoRAParam)) + ) + nnx.update(model, changed_non_lora_state) + + # Restore only the model lora params. + self.assertEqual( + cp_manager.maybe_restore(model, restore_only_lora_params=True), + (1, {}), + ) + # Check the model lora params are restored correctly. + jax.tree.map_with_path( + assert_close, + expected_lora_state, + nnx.state(model, nnx.LoRAParam), + ) + # Check the rest of the params are not restored. + jax.tree.map_with_path( + assert_close, + changed_non_lora_state, + nnx.state(model, nnx.filterlib.Not(nnx.LoRAParam)), + ) + + def test_restore_full_from_lora_only_checkpoint_fails(self): + cp_path = f'{self.temp_path}/{self.id()}' + cp_manager = checkpoint_manager.CheckpointManager(cp_path) + model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh) + lora_provider = qwix.LoraProvider( + module_path='.*w1', + rank=4, + alpha=2.0, + ) + dummy_model_input = { + 'x': jnp.ones(2, dtype=jnp.int32), + } + model = qwix.apply_lora_to_model(model, lora_provider, **dummy_model_input) + + # Save only the lora params. + self.assertTrue(cp_manager.save(1, model, save_only_lora_params=True)) + assert cp_manager._checkpointer is not None + cp_manager._checkpointer.wait() + + # Try to restore full model, expect failure. + with self.assertRaisesRegex( + ValueError, 'If this checkpoint only contains LoRA parameters' + ): + cp_manager.maybe_restore(model, restore_only_lora_params=False) + def test_save_and_restore_with_custom_metadata(self): cp_path = f'{self.temp_path}/{self.id()}' ckpt_manager = checkpoint_manager.CheckpointManager(cp_path) model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh) custom_metadata = {'foo': 1, 'bar': 2} ckpt_manager.save(1, model, custom_metadata=custom_metadata) - ckpt_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error + assert ckpt_manager._checkpointer is not None + ckpt_manager._checkpointer.wait() restored_step, restored_metadata = ckpt_manager.maybe_restore(model) self.assertEqual(restored_step, 1) self.assertEqual(restored_metadata, custom_metadata) @@ -257,7 +349,8 @@ def test_save_and_restore_with_optimizer_state(self): ) custom_metadata = {'foo': 1, 'bar': 2} ckpt_manager.save(1, model, optimizer, custom_metadata=custom_metadata) - ckpt_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error + assert ckpt_manager._checkpointer is not None + ckpt_manager._checkpointer.wait() new_optimizer = nnx.Optimizer( model, @@ -281,6 +374,68 @@ def test_save_and_restore_with_optimizer_state(self): new_optimizer.opt_state.hyperparams['learning_rate'].value, 1e-3 ) + def test_save_and_restore_with_forced_single_device_sharding(self): + cp_path = f'{self.temp_path}/{self.id()}' + ckpt_manager = checkpoint_manager.CheckpointManager(cp_path) + model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh) + optimizer = nnx.Optimizer( + model, + optax.inject_hyperparams(optax.adamw)(learning_rate=1e-3), + wrt=nnx.Param, + ) + custom_metadata = {'foo': 1, 'bar': 2} + ckpt_manager.save(1, model, optimizer, custom_metadata=custom_metadata) + assert ckpt_manager._checkpointer is not None + ckpt_manager._checkpointer.wait() + + new_optimizer = nnx.Optimizer( + model, + optax.inject_hyperparams(optax.adamw)(learning_rate=1e-5), + wrt=nnx.Param, + ) + + new_optimizer.opt_state.hyperparams['learning_rate'].value = jax.device_put( + new_optimizer.opt_state.hyperparams['learning_rate'].value, + jax.devices()[0], + ) + + self.assertIsInstance( + new_optimizer.opt_state.hyperparams['learning_rate'].value.sharding, + jax.sharding.SingleDeviceSharding, + ) + + restored_step, _ = ckpt_manager.maybe_restore( + model, new_optimizer + ) + self.assertEqual(restored_step, 1) + + errors = [] + def assert_named_sharding(path, x): + if hasattr(x, 'sharding'): + try: + self.assertIsInstance( + x.sharding, + jax.sharding.NamedSharding, + f'Variable at {path} is not NamedSharding', + ) + except AssertionError as e: + errors.append(str(e)) + return + + path_str = str(path) + if 'hyperparams' in path_str: + try: + self.assertEqual(x.sharding.spec, jax.sharding.PartitionSpec()) + except AssertionError as e: + errors.append(str(e)) + + jax.tree.map_with_path( + assert_named_sharding, + nnx.state(new_optimizer, nnx.optimizer.OptState), + ) + if errors: + self.fail(f'Found sharding mismatches:\n{"\n".join(errors)}') + def test_restore_without_optimizer(self): cp_path = f'{self.temp_path}/{self.id()}' ckpt_manager = checkpoint_manager.CheckpointManager(cp_path) @@ -291,7 +446,8 @@ def test_restore_without_optimizer(self): wrt=nnx.Param, ) ckpt_manager.save(1, model, optimizer) - ckpt_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error + assert ckpt_manager._checkpointer is not None + ckpt_manager._checkpointer.wait() ckpt_manager.maybe_restore(model) @parameterized.parameters(['test_data/checkpoints']) @@ -317,6 +473,49 @@ def test_restore_with_backward_compatibility(self, ckpt_path): nnx.state(model), ) + @parameterized.parameters(True, False) + def test_save_aligns_with_policy(self, enable_async): + cp_path = f'{self.temp_path}/{self.id()}_{enable_async}' + options = checkpoint_options.TunixCheckpointingOptions( + save_decision_policy=( + checkpoint_manager.ocp.training.save_decision_policies.FixedIntervalPolicy( + 2 + ) + ), + enable_async_checkpointing=enable_async, + ) + cp_manager = checkpoint_manager.CheckpointManager(cp_path, options=options) + model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh) + + # Step 1 should be skipped by FixedIntervalPolicy(2). + self.assertFalse(cp_manager.save(1, model)) + + # Step 2 should be saved. + self.assertTrue(cp_manager.save(2, model)) + assert cp_manager._checkpointer is not None + cp_manager._checkpointer.wait() + self.assertEqual(cp_manager.latest_step(), 2) + + def test_save_force_true_overrides_policy(self): + cp_path = f'{self.temp_path}/{self.id()}' + options = checkpoint_options.TunixCheckpointingOptions( + save_decision_policy=( + checkpoint_manager.ocp.training.save_decision_policies.FixedIntervalPolicy( + 2 + ) + ), + enable_async_checkpointing=True, + ) + cp_manager = checkpoint_manager.CheckpointManager(cp_path, options=options) + model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh) + + # Step 1 would normally be skipped by FixedIntervalPolicy(2), but force=True + # should force the save. + self.assertTrue(cp_manager.save(1, model, force=True)) + assert cp_manager._checkpointer is not None + cp_manager._checkpointer.wait() + self.assertEqual(cp_manager.latest_step(), 1) + if __name__ == '__main__': absltest.main() diff --git a/tests/sft/peft_trainer_test.py b/tests/sft/peft_trainer_test.py index 831cd101b..5365d61e5 100644 --- a/tests/sft/peft_trainer_test.py +++ b/tests/sft/peft_trainer_test.py @@ -29,8 +29,8 @@ import jax.sharding as shd import numpy as np import optax -import orbax.checkpoint as ocp from tunix.sft import checkpoint_manager +from tunix.sft import checkpoint_options from tunix.sft import hooks from tunix.sft import peft_trainer from tunix.sft import profiler @@ -547,13 +547,13 @@ def test_checkpointing( mock_checkpoint_manager.latest_step.return_value = ( expected_save_steps[-1] - 1 ) # force save at close - checkpoint_options = ocp.CheckpointManagerOptions() + checkpointing_options = checkpoint_options.create_checkpointing_options() config = peft_trainer.TrainingConfig( eval_every_n_steps=2, max_steps=100, gradient_accumulation_steps=grad_accu, checkpoint_root_directory='/tmp/checkpoint', - checkpointing_options=checkpoint_options, + checkpointing_options=checkpointing_options, ) rngs = nnx.Rngs(0) model = tc.get_lora_model( @@ -566,7 +566,7 @@ def test_checkpointing( trainer.train(train_ds, eval_ds) mock_checkpoint_manager_init.assert_called_once_with( - root_directory='/tmp/checkpoint', options=checkpoint_options + root_directory='/tmp/checkpoint', options=checkpointing_options ) # Assert that the checkpoint manager is called with the correct arguments # and does not have any unexpected calls. diff --git a/tunix/sft/checkpoint_manager.py b/tunix/sft/checkpoint_manager.py index d069a7706..a06385d38 100644 --- a/tunix/sft/checkpoint_manager.py +++ b/tunix/sft/checkpoint_manager.py @@ -14,21 +14,17 @@ """Checkpoint manager for PEFT.""" +from collections.abc import Mapping +import functools import os import time -from typing import Any, Tuple +from typing import Any from absl import logging from flax import nnx import jax -import orbax.checkpoint as ocp - -_DEFAULT_CHECKPOINTING_OPTIONS = ocp.CheckpointManagerOptions( - save_decision_policy=ocp.checkpoint_managers.ContinuousCheckpointingPolicy( - minimum_interval_secs=180, - ), - max_to_keep=3, -) +from orbax.checkpoint import v1 as ocp +from tunix.sft import checkpoint_options class CheckpointManager: @@ -37,7 +33,7 @@ class CheckpointManager: def __init__( self, root_directory: str | None = None, - options: ocp.CheckpointManagerOptions | None = None, + options: checkpoint_options.CheckpointingOptions | None = None, ): """Initializes the checkpoint manager. @@ -46,47 +42,76 @@ def __init__( the checkpoint manager will be disabled. options: The options for the checkpoint manager. """ - self._checkpoint_manager: ocp.CheckpointManager | None = None + self._checkpointer: ocp.training.Checkpointer | None = None + self._options = checkpoint_options.resolve_checkpointing_defaults( + options + ) if root_directory is not None: - # When using Pathways, the checkpoint manager only supports persistence - # APIs now. - if 'proxy' in os.getenv('JAX_PLATFORMS', ''): - item_handlers = { - 'model_params': ocp.PyTreeCheckpointHandler( - use_ocdbt=False, - use_zarr3=False, - ), - 'optimizer_state': ocp.PyTreeCheckpointHandler( - use_ocdbt=False, - use_zarr3=False, - ), - } - if os.getenv('ENABLE_PATHWAYS_PERSISTENCE', ''): - logging.info( - 'Using persistence API for checkpointing with Pathways.' - ) - else: - logging.warning( - 'Checkpointing without the persistence API, be aware of potential' - ' OOM.' - ) - else: - item_handlers = { - 'model_params': ocp.PyTreeCheckpointHandler(), - 'optimizer_state': ocp.PyTreeCheckpointHandler(), - } - item_handlers['custom_metadata'] = ocp.JsonCheckpointHandler() - self._checkpoint_manager = ocp.CheckpointManager( + self._checkpointer = ocp.training.Checkpointer( root_directory, - item_handlers=item_handlers, - options=options or _DEFAULT_CHECKPOINTING_OPTIONS, + context=self._context, + save_decision_policy=self._options.save_decision_policy, + preservation_policy=self._options.preservation_policy, + step_name_format=self._options.step_name_format, + ) + + @functools.cached_property + def _context(self) -> ocp.Context: + """The orbax context applied to every checkpointer operation.""" + ctx = ocp.Context() + if ( + self._options.async_options is not None + and self._options.async_options.timeout_secs is not None + ): + ctx.asynchronous.timeout_secs = self._options.async_options.timeout_secs + # When using Pathways, the checkpoint manager only supports persistence + # APIs now. + if 'proxy' in os.getenv('JAX_PLATFORMS', ''): + if os.getenv('ENABLE_PATHWAYS_PERSISTENCE', ''): + logging.info( + 'Using persistence API for checkpointing with Pathways.' + ) + else: + logging.warning( + 'Checkpointing without the persistence API, be aware of potential' + ' OOM.' + ) + ctx.array.saving.use_ocdbt = False + ctx.array.saving.use_zarr3 = False + return ctx + + def _save_checkpointables( + self, + step: int, + checkpointables: dict[str, Any], + force: bool, + custom_metadata: Mapping[str, Any] | None, + ) -> bool: + """Internal helper to dispatch and report whether a save happened.""" + if self._checkpointer is None: + return False + if self._options.enable_async_checkpointing: + # `save_checkpointables_async` returns an `AsyncResponse` when a save is + # initiated, or `None` when the save is skipped by the save policy. + response = self._checkpointer.save_checkpointables_async( + step, + checkpointables, + force=force, + custom_metadata=custom_metadata, ) + return response is not None + return self._checkpointer.save_checkpointables( + step, + checkpointables, + force=force, + custom_metadata=custom_metadata, + ) def latest_step(self) -> int | None: """Returns the latest step.""" - if self._checkpoint_manager is None: + if self._checkpointer is None or self._checkpointer.latest is None: return None - return self._checkpoint_manager.latest_step() + return self._checkpointer.latest.step def save( self, @@ -95,7 +120,7 @@ def save( optimizer: nnx.Optimizer | None = None, save_only_lora_params: bool = False, force: bool = False, - custom_metadata: dict[str, Any] | None = None, + custom_metadata: Mapping[str, Any] | None = None, ) -> bool: """Saves the params for the given step. @@ -110,36 +135,28 @@ def save( custom_metadata: Custom metadata to save with the checkpoint. Returns: - Whether the checkpoint was saved. + Whether the checkpoint save operation was successful if synchronous, + otherwise whether the save operation was initiated. """ - if self._checkpoint_manager is None: - return False - if not force and not self._checkpoint_manager.should_save(step): + if self._checkpointer is None: return False if save_only_lora_params: params = nnx.state(model, nnx.LoRAParam) else: params = nnx.state(model) - model_cp_args = ocp.args.PyTreeSave( - item=params, save_args=jax.tree.map(lambda _: ocp.SaveArgs(), params) - ) - - cp_save_args = { - 'model_params': model_cp_args, - } if optimizer is not None: - optimizer_state = nnx.state(optimizer, nnx.optimizer.OptState) - optimizer_cp_args = ocp.args.PyTreeSave( - item=optimizer_state, - save_args=jax.tree.map(lambda _: ocp.SaveArgs(), optimizer_state), - ) - cp_save_args['optimizer_state'] = optimizer_cp_args - return self._checkpoint_manager.save( - step, - args=ocp.args.Composite(**cp_save_args), - custom_metadata=custom_metadata or {}, - force=force, + checkpointables = { + 'model_params': params, + 'optimizer_state': nnx.state(optimizer, nnx.optimizer.OptState), + } + else: + checkpointables = { + 'model_params': params, + } + + return self._save_checkpointables( + step, checkpointables, force, custom_metadata ) def maybe_restore( @@ -148,7 +165,7 @@ def maybe_restore( optimizer: nnx.Optimizer | None = None, step: int | None = None, restore_only_lora_params: bool = False, - ) -> Tuple[int, dict[str, Any]]: + ) -> tuple[int, Any]: """Restores the params from the latest checkpoint if available and updates the model provided. Args: @@ -161,76 +178,96 @@ def maybe_restore( restore_only_lora_params: Whether to restore only the LoRA params. Returns: - The step of the restored checkpoint or 0 if no checkpoint is available. + A tuple (step, custom_metadata), where step is the step of the restored + checkpoint or 0 if no checkpoint is available, and the custom_metadata. Raises: RuntimeError: If the checkpoint cannot be restored. """ restore_start = time.time() - if self._checkpoint_manager is None: + if self._checkpointer is None: return 0, {} if step is None: - step = self._checkpoint_manager.latest_step() + step = self.latest_step() # If no checkpoint is available, return 0. if step is None: return 0, {} - metadata = self._checkpoint_manager.metadata(step) + metadata = self._checkpointer.checkpointables_metadata(step) - # Load the params from the checkpoint. if restore_only_lora_params: - abstract_params = nnx.state(model, nnx.LoRAParam) + model_params_state = nnx.state(model, nnx.LoRAParam) + # Partial (LoRA) restore is the one path that overrides the persistent + # context to enable partial loading. + load_ctx = ocp.Context(self._context) + load_ctx.pytree.loading.partial_load = True else: - abstract_params = nnx.state(model) - - model_cp_args = ocp.args.PyTreeRestore( - item=abstract_params, - restore_args=ocp.checkpoint_utils.construct_restore_args( - target=abstract_params - ), - ) + model_params_state = nnx.state(model) + load_ctx = self._context + abstract_checkpointables = {'model_params': model_params_state} def fix_sharding(state): # Scalar values in optimizer states like step and count is initialized as # SingleDeviceSharding, which will fail if optimizer is sharded. To fix # it, we will replicate the scalar values. - shardings = jax.tree_util.tree_map(lambda x: x.sharding, state) - try: - named_sharding = next( - s - for s in jax.tree_util.tree_leaves(shardings) - if isinstance(s, jax.sharding.NamedSharding) - ) - return nnx.get_named_sharding(optimizer_state, named_sharding.mesh) - except StopIteration: - return shardings - - if optimizer is not None and 'optimizer_state' in metadata.item_metadata: - optimizer_state = nnx.state(optimizer, nnx.optimizer.OptState) - fixed_sharding = fix_sharding(optimizer_state) - optimizer_cp_args = ocp.args.PyTreeRestore( - item=optimizer_state, - restore_args=ocp.checkpoint_utils.construct_restore_args( - target=optimizer_state, sharding_tree=fixed_sharding + mesh = next( + ( + x.sharding.mesh + for x in jax.tree_util.tree_leaves(state) + if getattr(x, 'sharding', None) + and isinstance(x.sharding, jax.sharding.NamedSharding) ), + None, ) - ckpt = self._checkpoint_manager.restore( - step, - args=ocp.args.Composite( - model_params=model_cp_args, - optimizer_state=optimizer_cp_args, + + if mesh is None: + logging.info( + 'Optimizer state contains no NamedSharding. Skipping sharding' + ' replication.' + ) + return state + + target_shardings = nnx.get_named_sharding(state, mesh) + return jax.tree_util.tree_map( + lambda x, shd: jax.ShapeDtypeStruct( + getattr(x, 'shape', ()), + getattr(x, 'dtype', jax.numpy.asarray(x).dtype), + sharding=shd, ), + state, + target_shardings, ) - nnx.update(optimizer, ckpt.optimizer_state) - else: - ckpt = self._checkpoint_manager.restore( - step, - args=ocp.args.Composite( - model_params=model_cp_args, - ), + + if ( + optimizer is not None + and metadata is not None + and 'optimizer_state' in metadata.metadata + ): + optimizer_state = nnx.state(optimizer, nnx.optimizer.OptState) + abstract_checkpointables['optimizer_state'] = fix_sharding( + optimizer_state ) + + try: + with load_ctx: + restored_checkpointables = self._checkpointer.load_checkpointables( + step, + abstract_checkpointables, + ) + except KeyError as e: + if not restore_only_lora_params: + raise ValueError( + f'Failed to restore from step {step}. If this checkpoint only' + ' contains LoRA parameters, please set' + ' `restore_only_lora_params=True`.' + ) from e + raise e + + if optimizer is not None and 'optimizer_state' in restored_checkpointables: + nnx.update(optimizer, restored_checkpointables['optimizer_state']) + # Update the model state with params from the restored checkpoint. - nnx.update(model, ckpt.model_params) + nnx.update(model, restored_checkpointables['model_params']) logging.info( 'Restored params from step: %d in %.3f seconds', step, @@ -239,8 +276,8 @@ def fix_sharding(state): custom_metadata = metadata.custom_metadata if metadata else {} return step, custom_metadata - def close(self): + def close(self) -> None: """Closes the checkpoint manager.""" - if self._checkpoint_manager is None: + if self._checkpointer is None: return - self._checkpoint_manager.close() + self._checkpointer.close() diff --git a/tunix/sft/peft_trainer.py b/tunix/sft/peft_trainer.py index 501e5739c..4c5dc0eaa 100644 --- a/tunix/sft/peft_trainer.py +++ b/tunix/sft/peft_trainer.py @@ -31,12 +31,12 @@ from jax.typing import ArrayLike # pylint: disable=g-importing-member import numpy as np import optax -import orbax.checkpoint as ocp from tunix.perf import metrics as perf_metrics from tunix.perf import trace as perf_trace from tunix.perf.experimental import constants as perf_constants from tunix.perf.experimental import tracer as perf_tracer_lib from tunix.sft import checkpoint_manager +from tunix.sft import checkpoint_options from tunix.sft import hooks from tunix.sft import inflight_throttler from tunix.sft import metrics_logger as sft_metrics_logger @@ -63,7 +63,10 @@ class TrainingConfig: # contains the model params and the train data iterator state. checkpoint_root_directory: str | None = None # Checkpoint configurations. If None, the default options will be used. - checkpointing_options: ocp.CheckpointManagerOptions | None = None + checkpointing_options: ( + checkpoint_options.CheckpointingOptions + | None + ) = None # Configs for the metrics logger. metrics_logging_options: MetricsLoggerOptions | None = None