1import os
2import logging
3from collections.abc import Mapping, Sequence
4from typing import Any
5
6import jax
7import jax.tree_util as jtu
8import chex
9import orbax.checkpoint as ocp
10from omegaconf import DictConfig, OmegaConf
11
12logger = logging.getLogger(__name__)
13
14
[docs]
15def save(path, state: chex.ArrayTree):
16 """Save state to a file.
17
18 Args:
19 path: Checkpoint path.
20 state: The state to be saved.
21 """
22 path = os.path.abspath(os.path.expanduser(path))
23
24 state = filter_zero_size_arrays_on_save(state)
25
26 with ocp.StandardCheckpointer() as ckpt:
27 ckpt.save(path, state)
28
29
[docs]
30def load(path, state: chex.ArrayTree) -> chex.ArrayTree:
31 """Load state from a file.
32
33 Args:
34 path: Checkpoint path
35 state: The same structure as the saved state for restore. Can be a dummy state or its abstract_state by `jtu.tree_map(ocp.utils.to_shape_dtype_struct, state)`
36
37 Returns:
38 The loaded state.
39 """
40 path = os.path.abspath(os.path.expanduser(path))
41 abstract_state = jtu.tree_map(ocp.utils.to_shape_dtype_struct, state)
42
43 with ocp.StandardCheckpointer() as ckpt:
44 new_state = ckpt.restore(path, abstract_state)
45
46 new_state = filter_zero_size_arrays_on_restore(state, new_state)
47
48 return new_state
49
50
[docs]
51def filter_zero_size_arrays_on_save(
52 tree: chex.ArrayTree,
53) -> chex.ArrayTree:
54 """Filter out zero-size arrays from the pytree."""
55
56 def f(x):
57 if isinstance(x, jax.Array) and x.size == 0:
58 return None
59 else:
60 return x
61
62 return jtu.tree_map(f, tree)
63
64
[docs]
65def filter_zero_size_arrays_on_restore(
66 src_tree: chex.ArrayTree, dst_tree: chex.ArrayTree
67) -> chex.ArrayTree:
68 """Filter out zero-size arrays from the pytree."""
69
70 def f(src, dst):
71 if isinstance(src, jax.Array) and src.size == 0:
72 return src
73 else:
74 return dst
75
76 return jtu.tree_map(f, src_tree, dst_tree)
77
78
[docs]
79class DummyCheckpointManager(ocp.AbstractCheckpointManager):
80 """A dummy checkpoint manager that does nothing."""
81
[docs]
82 def directory(self):
83 return "UwU"
84
[docs]
85 def all_steps(self, read: bool = False) -> Sequence[int]:
86 return []
87
[docs]
88 def latest_step(self) -> int | None:
89 return None
90
[docs]
91 def best_step(self) -> int | None:
92 return None
93
[docs]
94 def reload(self):
95 pass
96
[docs]
97 def reached_preemption(self, step: int) -> bool:
98 return True
99
[docs]
100 def should_save(self, step: int) -> bool:
101 return False
102
[docs]
103 def delete(self, step: int):
104 pass
105
108
111
[docs]
112 def metrics(self, step: int) -> Any | None:
113 return None
114
[docs]
115 def wait_until_finished(self):
116 pass
117
[docs]
118 def check_for_errors(self):
119 pass
120
[docs]
121 def save(self, step, items, **kwargs):
122 return True
123
[docs]
124 def restore(self, step, items, **kwargs):
125 raise NotImplementedError("UwU")
126
[docs]
127 def close(self):
128 pass
129
130
[docs]
131class CheckpointManager(ocp.CheckpointManager):
[docs]
132 def save(self, step, items, **kwargs):
133 args = ocp.args.StandardSave(filter_zero_size_arrays_on_save(items))
134 return super().save(step, args=args, **kwargs)
135
[docs]
136 def restore(self, step, items, **kwargs):
137 new_items = super().restore(
138 step, args=ocp.args.StandardRestore(items), **kwargs
139 )
140
141 new_items = filter_zero_size_arrays_on_restore(items, new_items)
142 return new_items
143
144
[docs]
145def setup_checkpoint_manager(config: DictConfig) -> ocp.CheckpointManager:
146 """Setup checkpoint manager."""
147 if config.checkpoint.enable:
148 ckpt_options = ocp.CheckpointManagerOptions(
149 save_interval_steps=config.checkpoint.save_interval_steps,
150 max_to_keep=config.checkpoint.max_to_keep,
151 )
152 # Note: orbax only supports absolute path
153 output_dir = os.path.abspath(os.path.expanduser(config.output_dir))
154 ckpt_path = os.path.join(output_dir, "checkpoints")
155 logger.info(f"set checkpoint path: {ckpt_path}")
156 checkpoint_manager = CheckpointManager(
157 ckpt_path,
158 options=ckpt_options,
159 metadata=OmegaConf.to_container(
160 config, resolve=True
161 ), # rescaled real config
162 )
163 else:
164 checkpoint_manager = DummyCheckpointManager()
165
166 return checkpoint_manager