Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions docs/reliability.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,31 @@ training step count. By default, checkpointing is disabled if
`checkpoint_root_directory` is not specified. Users can further customize
checkpointing behavior via `checkpointing_options` in the config.

Users customize background preservation behavior granularly using components
defined inside `checkpoint_options`:

* **Save Decision Policies**: Dictates when to initiate a checkpoint based on
defined steps or intervals. Supported configurations include
`FixedIntervalPolicy` and `ContinuousCheckpointingPolicy`. The default is
`ContinuousCheckpointingPolicy(minimum_interval_secs=180)` (saves every 180
seconds). See Orbax v1 [`save_decision_policies.py`](https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/experimental/v1/_src/training/save_decision_policies.py)
for the complete interface contracts.
* **Preservation Policies**: Sets specifications regarding tracking
checkpoints over bounded timelines (e.g., `LatestN`). The default is
`LatestN(n=3)` (keeps the latest 3 checkpoints). See Orbax v1
[`preservation_policies.py`](https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/experimental/v1/_src/training/preservation_policies.py)
for the complete interface contracts.
* **Step Name Format**: Defines the representation of directory names for step
checkpoints. The default is `ocp.path.step.standard_name_format()` (uses
simple integer step names).
* **Asynchronous Processing**: Manage asynchronous behavior by specifying:
* `enable_async_checkpointing`: Whether to use async checkpointing.
Defaults to `True`. **It is recommended to keep this enabled** to
prevent the main thread from blocking during training runs while
checkpoints are written to storage.
* `timeout_secs`: The timeout for asynchronous operations.
Defaults to `1200` seconds.

## Fault Tolerance

Tunix ensures fault tolerance primarily through its checkpointing mechanism,
Expand Down
219 changes: 209 additions & 10 deletions tests/sft/checkpoint_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import os
import tempfile
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
from etils import epath
Expand All @@ -28,6 +29,7 @@
import optax
import qwix
from tunix.sft import checkpoint_manager
from tunix.sft import checkpoint_options

os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'

Expand Down Expand Up @@ -110,20 +112,34 @@ def test_empty_root_directory(self):
def test_checkpoint_manager_options_none_sets_default(self):
cp_path = f'{self.temp_path}/{self.id()}'
cp_manager = checkpoint_manager.CheckpointManager(cp_path, options=None)
self.assertIsNotNone(cp_manager._checkpoint_manager)
self.assertIsNotNone(cp_manager._checkpointer)
self.assertEqual(
cp_manager._checkpoint_manager._options, # pytype: disable=attribute-error
checkpoint_manager._DEFAULT_CHECKPOINTING_OPTIONS,
cp_manager._options,
checkpoint_options.DEFAULT_CHECKPOINTING_OPTIONS,
)

def test_context_property(self):
cp_path = f'{self.temp_path}/{self.id()}'
cp_manager = checkpoint_manager.CheckpointManager(cp_path)
self.assertIsNotNone(cp_manager._context)

def test_context_property_with_pathways(self):
with mock.patch.dict(os.environ, {'JAX_PLATFORMS': 'proxy'}):
cp_path = f'{self.temp_path}/{self.id()}'
cp_manager = checkpoint_manager.CheckpointManager(cp_path)
self.assertIsNotNone(cp_manager._context)
self.assertFalse(cp_manager._context.array_options.saving.use_ocdbt)
self.assertFalse(cp_manager._context.array_options.saving.use_zarr3)

def test_save(self):
cp_path = f'{self.temp_path}/{self.id()}'
cp_manager = checkpoint_manager.CheckpointManager(cp_path)
model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh)

# Save the model state.
self.assertTrue(cp_manager.save(1, model))
cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
assert cp_manager._checkpointer is not None
cp_manager._checkpointer.wait()
self.assertEqual(cp_manager.latest_step(), 1)

cp_manager.close()
Expand All @@ -139,7 +155,8 @@ def test_restore(self):

# Save the model params.
self.assertTrue(cp_manager.save(1, model))
cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
assert cp_manager._checkpointer is not None
cp_manager._checkpointer.wait()

# Change the model state.
changed_state = jax.tree.map(lambda x: x + 1, nnx.state(model))
Expand All @@ -162,7 +179,8 @@ def test_restore_different_sharding(self):

# Save the model params.
self.assertTrue(cp_manager.save(1, unsharded_model))
cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
assert cp_manager._checkpointer is not None
cp_manager._checkpointer.wait()

# Restore the model without shardings.
self.assertEqual(cp_manager.maybe_restore(unsharded_model), (1, {}))
Expand Down Expand Up @@ -211,7 +229,8 @@ def test_restore_with_lora(self):

# Save the model params.
self.assertTrue(cp_manager.save(1, model, save_only_lora_params=True))
cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
assert cp_manager._checkpointer is not None
cp_manager._checkpointer.wait()

# Change the model state.
changed_state = jax.tree.map(lambda x: x + 1, nnx.state(model))
Expand All @@ -235,13 +254,86 @@ def test_restore_with_lora(self):
nnx.state(model, nnx.filterlib.Not(nnx.LoRAParam)),
)

def test_restore_only_lora_params(self):
cp_path = f'{self.temp_path}/{self.id()}'
cp_manager = checkpoint_manager.CheckpointManager(cp_path)
model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh)
lora_provider = qwix.LoraProvider(
module_path='.*w1',
rank=4,
alpha=2.0,
)
dummy_model_input = {
'x': jnp.ones(2, dtype=jnp.int32),
}
model = qwix.apply_lora_to_model(model, lora_provider, **dummy_model_input)
expected_lora_state = nnx.clone(nnx.state(model, nnx.LoRAParam))
changed_non_lora_state = jax.tree.map(
lambda x: x + 2, nnx.state(model, (nnx.filterlib.Not(nnx.LoRAParam)))
)

# Save the model params (entire model).
self.assertTrue(cp_manager.save(1, model, save_only_lora_params=False))
assert cp_manager._checkpointer is not None
cp_manager._checkpointer.wait()

# Change the model state.
nnx.update(
model, jax.tree.map(lambda x: x + 1, nnx.state(model, nnx.LoRAParam))
)
nnx.update(model, changed_non_lora_state)

# Restore only the model lora params.
self.assertEqual(
cp_manager.maybe_restore(model, restore_only_lora_params=True),
(1, {}),
)
# Check the model lora params are restored correctly.
jax.tree.map_with_path(
assert_close,
expected_lora_state,
nnx.state(model, nnx.LoRAParam),
)
# Check the rest of the params are not restored.
jax.tree.map_with_path(
assert_close,
changed_non_lora_state,
nnx.state(model, nnx.filterlib.Not(nnx.LoRAParam)),
)

def test_restore_full_from_lora_only_checkpoint_fails(self):
cp_path = f'{self.temp_path}/{self.id()}'
cp_manager = checkpoint_manager.CheckpointManager(cp_path)
model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh)
lora_provider = qwix.LoraProvider(
module_path='.*w1',
rank=4,
alpha=2.0,
)
dummy_model_input = {
'x': jnp.ones(2, dtype=jnp.int32),
}
model = qwix.apply_lora_to_model(model, lora_provider, **dummy_model_input)

# Save only the lora params.
self.assertTrue(cp_manager.save(1, model, save_only_lora_params=True))
assert cp_manager._checkpointer is not None
cp_manager._checkpointer.wait()

# Try to restore full model, expect failure.
with self.assertRaisesRegex(
ValueError, 'If this checkpoint only contains LoRA parameters'
):
cp_manager.maybe_restore(model, restore_only_lora_params=False)

def test_save_and_restore_with_custom_metadata(self):
cp_path = f'{self.temp_path}/{self.id()}'
ckpt_manager = checkpoint_manager.CheckpointManager(cp_path)
model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh)
custom_metadata = {'foo': 1, 'bar': 2}
ckpt_manager.save(1, model, custom_metadata=custom_metadata)
ckpt_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
assert ckpt_manager._checkpointer is not None
ckpt_manager._checkpointer.wait()
restored_step, restored_metadata = ckpt_manager.maybe_restore(model)
self.assertEqual(restored_step, 1)
self.assertEqual(restored_metadata, custom_metadata)
Expand All @@ -257,7 +349,8 @@ def test_save_and_restore_with_optimizer_state(self):
)
custom_metadata = {'foo': 1, 'bar': 2}
ckpt_manager.save(1, model, optimizer, custom_metadata=custom_metadata)
ckpt_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
assert ckpt_manager._checkpointer is not None
ckpt_manager._checkpointer.wait()

new_optimizer = nnx.Optimizer(
model,
Expand All @@ -281,6 +374,68 @@ def test_save_and_restore_with_optimizer_state(self):
new_optimizer.opt_state.hyperparams['learning_rate'].value, 1e-3
)

def test_save_and_restore_with_forced_single_device_sharding(self):
cp_path = f'{self.temp_path}/{self.id()}'
ckpt_manager = checkpoint_manager.CheckpointManager(cp_path)
model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh)
optimizer = nnx.Optimizer(
model,
optax.inject_hyperparams(optax.adamw)(learning_rate=1e-3),
wrt=nnx.Param,
)
custom_metadata = {'foo': 1, 'bar': 2}
ckpt_manager.save(1, model, optimizer, custom_metadata=custom_metadata)
assert ckpt_manager._checkpointer is not None
ckpt_manager._checkpointer.wait()

new_optimizer = nnx.Optimizer(
model,
optax.inject_hyperparams(optax.adamw)(learning_rate=1e-5),
wrt=nnx.Param,
)

new_optimizer.opt_state.hyperparams['learning_rate'].value = jax.device_put(
new_optimizer.opt_state.hyperparams['learning_rate'].value,
jax.devices()[0],
)

self.assertIsInstance(
new_optimizer.opt_state.hyperparams['learning_rate'].value.sharding,
jax.sharding.SingleDeviceSharding,
)

restored_step, _ = ckpt_manager.maybe_restore(
model, new_optimizer
)
self.assertEqual(restored_step, 1)

errors = []
def assert_named_sharding(path, x):
if hasattr(x, 'sharding'):
try:
self.assertIsInstance(
x.sharding,
jax.sharding.NamedSharding,
f'Variable at {path} is not NamedSharding',
)
except AssertionError as e:
errors.append(str(e))
return

path_str = str(path)
if 'hyperparams' in path_str:
try:
self.assertEqual(x.sharding.spec, jax.sharding.PartitionSpec())
except AssertionError as e:
errors.append(str(e))

jax.tree.map_with_path(
assert_named_sharding,
nnx.state(new_optimizer, nnx.optimizer.OptState),
)
if errors:
self.fail(f'Found sharding mismatches:\n{"\n".join(errors)}')

def test_restore_without_optimizer(self):
cp_path = f'{self.temp_path}/{self.id()}'
ckpt_manager = checkpoint_manager.CheckpointManager(cp_path)
Expand All @@ -291,7 +446,8 @@ def test_restore_without_optimizer(self):
wrt=nnx.Param,
)
ckpt_manager.save(1, model, optimizer)
ckpt_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
assert ckpt_manager._checkpointer is not None
ckpt_manager._checkpointer.wait()
ckpt_manager.maybe_restore(model)

@parameterized.parameters(['test_data/checkpoints'])
Expand All @@ -317,6 +473,49 @@ def test_restore_with_backward_compatibility(self, ckpt_path):
nnx.state(model),
)

@parameterized.parameters(True, False)
def test_save_aligns_with_policy(self, enable_async):
cp_path = f'{self.temp_path}/{self.id()}_{enable_async}'
options = checkpoint_options.TunixCheckpointingOptions(
save_decision_policy=(
checkpoint_manager.ocp.training.save_decision_policies.FixedIntervalPolicy(
2
)
),
enable_async_checkpointing=enable_async,
)
cp_manager = checkpoint_manager.CheckpointManager(cp_path, options=options)
model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh)

# Step 1 should be skipped by FixedIntervalPolicy(2).
self.assertFalse(cp_manager.save(1, model))

# Step 2 should be saved.
self.assertTrue(cp_manager.save(2, model))
assert cp_manager._checkpointer is not None
cp_manager._checkpointer.wait()
self.assertEqual(cp_manager.latest_step(), 2)

def test_save_force_true_overrides_policy(self):
cp_path = f'{self.temp_path}/{self.id()}'
options = checkpoint_options.TunixCheckpointingOptions(
save_decision_policy=(
checkpoint_manager.ocp.training.save_decision_policies.FixedIntervalPolicy(
2
)
),
enable_async_checkpointing=True,
)
cp_manager = checkpoint_manager.CheckpointManager(cp_path, options=options)
model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh)

# Step 1 would normally be skipped by FixedIntervalPolicy(2), but force=True
# should force the save.
self.assertTrue(cp_manager.save(1, model, force=True))
assert cp_manager._checkpointer is not None
cp_manager._checkpointer.wait()
self.assertEqual(cp_manager.latest_step(), 1)


if __name__ == '__main__':
absltest.main()
8 changes: 4 additions & 4 deletions tests/sft/peft_trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
import jax.sharding as shd
import numpy as np
import optax
import orbax.checkpoint as ocp
from tunix.sft import checkpoint_manager
from tunix.sft import checkpoint_options
from tunix.sft import hooks
from tunix.sft import peft_trainer
from tunix.sft import profiler
Expand Down Expand Up @@ -547,13 +547,13 @@ def test_checkpointing(
mock_checkpoint_manager.latest_step.return_value = (
expected_save_steps[-1] - 1
) # force save at close
checkpoint_options = ocp.CheckpointManagerOptions()
checkpointing_options = checkpoint_options.create_checkpointing_options()
config = peft_trainer.TrainingConfig(
eval_every_n_steps=2,
max_steps=100,
gradient_accumulation_steps=grad_accu,
checkpoint_root_directory='/tmp/checkpoint',
checkpointing_options=checkpoint_options,
checkpointing_options=checkpointing_options,
)
rngs = nnx.Rngs(0)
model = tc.get_lora_model(
Expand All @@ -566,7 +566,7 @@ def test_checkpointing(
trainer.train(train_ds, eval_ds)

mock_checkpoint_manager_init.assert_called_once_with(
root_directory='/tmp/checkpoint', options=checkpoint_options
root_directory='/tmp/checkpoint', options=checkpointing_options
)
# Assert that the checkpoint manager is called with the correct arguments
# and does not have any unexpected calls.
Expand Down
Loading
Loading