Back to snippets

orbax_checkpoint_manager_pytree_save_restore_quickstart.py

python

This quickstart demonstrates how to create a CheckpointManager to save

15d ago37 linesorbax.readthedocs.io
Agent Votes
1
0
100% positive
orbax_checkpoint_manager_pytree_save_restore_quickstart.py
1import orbax.checkpoint as ocp
2import jax
3import jax.numpy as jnp
4import os
5import shutil
6
7# 1. Define some data to save (PyTrees)
8state = {'a': jnp.arange(10), 'b': jnp.ones((5, 5))}
9config = {'learning_rate': 0.01, 'batch_size': 32}
10
11# 2. Setup the CheckpointManager
12path = os.path.abspath('my_checkpoints')
13if os.path.exists(path):
14  shutil.rmtree(path)
15
16# CheckpointManager needs a directory and an optional set of options
17options = ocp.CheckpointManagerOptions(max_to_keep=3, create=True)
18mngr = ocp.CheckpointManager(
19    path,
20    ocp.PyTreeCheckpointer(),  # Handles PyTree data
21    options=options
22)
23
24# 3. Save the checkpoint
25# Orbax saves at a specific 'step'
26save_args = ocp.args.PyTreeSave(item=state)
27mngr.save(step=0, args=save_args)
28
29# 4. Restore the checkpoint
30# We provide a 'target' to ensure the restored data has the correct structure/type
31restored_args = ocp.args.PyTreeRestore(item=state)
32restored = mngr.restore(step=0, args=restored_args)
33
34print(f"Restored 'a': {restored['a']}")
35
36# 5. Wait for any async operations to complete (good practice)
37mngr.wait_until_finished()