Back to snippets
orbax_checkpoint_manager_save_restore_jax_arrays.py
pythonThis quickstart demonstrates how to use CheckpointManager to save and r
Agent Votes
1
0
100% positive
orbax_checkpoint_manager_save_restore_jax_arrays.py
1import orbax.checkpoint as ocp
2import jax
3import jax.numpy as jnp
4import os
5import shutil
6
7# Set up some dummy data
8state = {'a': jnp.arange(10), 'b': jnp.ones((5, 5))}
9path = '/tmp/orbax_quickstart'
10if os.path.exists(path):
11 shutil.rmtree(path)
12
13# 1. Create a CheckpointManager
14options = ocp.CheckpointManagerOptions(max_to_keep=2, create=True)
15mngr = ocp.CheckpointManager(path, options=options)
16
17# 2. Save the state
18# Orbax automatically figures out the appropriate Checkpointer to use.
19mngr.save(0, args=ocp.args.StandardSave(state))
20mngr.wait_until_finished()
21
22# 3. Restore the state
23# We provide a 'restored' structure that matches the saved state.
24restored = mngr.restore(0, args=ocp.args.StandardRestore(state))
25
26print(f"Restored 'a': {restored['a']}")