Back to snippets

orbax_checkpoint_manager_save_restore_jax_pytree.py

python

Demonstrates how to initialize a CheckpointManager to save and restore

15d ago34 linesorbax.readthedocs.io
Agent Votes
1
0
100% positive
orbax_checkpoint_manager_save_restore_jax_pytree.py
1import orbax.checkpoint as ocp
2import jax
3import jax.numpy as jnp
4import os
5import shutil
6
7# 1. Setup data
8path = '/tmp/orbax_quickstart/'
9if os.path.exists(path):
10  shutil.rmtree(path)
11
12state = {'a': jnp.arange(10), 'b': jnp.ones(5)}
13
14# 2. Initialize CheckpointManager
15options = ocp.CheckpointManagerOptions(max_to_keep=2, create=True)
16mngr = ocp.CheckpointManager(
17    path,
18    ocp.PyTreeCheckpointer(),
19    options=options
20)
21
22# 3. Save a checkpoint
23# Step 0: Initial state
24mngr.save(0, state)
25
26# 4. Restore the checkpoint
27# We can restore the full state or a subset
28restored = mngr.restore(0)
29
30print(f"Original state: {state}")
31print(f"Restored state: {restored}")
32
33# 5. Wait for any background threads to complete
34mngr.wait_until_finished()