Source code for evorl.utils.orbax_utils

  1import os
  2import logging
  3from collections.abc import Mapping, Sequence
  4from typing import Any
  5
  6import jax
  7import jax.tree_util as jtu
  8import chex
  9import orbax.checkpoint as ocp
 10from omegaconf import DictConfig, OmegaConf
 11
 12logger = logging.getLogger(__name__)
 13
 14
[docs] 15def save(path, state: chex.ArrayTree): 16 """Save state to a file. 17 18 Args: 19 path: Checkpoint path. 20 state: The state to be saved. 21 """ 22 path = os.path.abspath(os.path.expanduser(path)) 23 24 state = filter_zero_size_arrays_on_save(state) 25 26 with ocp.StandardCheckpointer() as ckpt: 27 ckpt.save(path, state)
28 29
[docs] 30def load(path, state: chex.ArrayTree) -> chex.ArrayTree: 31 """Load state from a file. 32 33 Args: 34 path: Checkpoint path 35 state: The same structure as the saved state for restore. Can be a dummy state or its abstract_state by `jtu.tree_map(ocp.utils.to_shape_dtype_struct, state)` 36 37 Returns: 38 The loaded state. 39 """ 40 path = os.path.abspath(os.path.expanduser(path)) 41 abstract_state = jtu.tree_map(ocp.utils.to_shape_dtype_struct, state) 42 43 with ocp.StandardCheckpointer() as ckpt: 44 new_state = ckpt.restore(path, abstract_state) 45 46 new_state = filter_zero_size_arrays_on_restore(state, new_state) 47 48 return new_state
49 50
[docs] 51def filter_zero_size_arrays_on_save( 52 tree: chex.ArrayTree, 53) -> chex.ArrayTree: 54 """Filter out zero-size arrays from the pytree.""" 55 56 def f(x): 57 if isinstance(x, jax.Array) and x.size == 0: 58 return None 59 else: 60 return x 61 62 return jtu.tree_map(f, tree)
63 64
[docs] 65def filter_zero_size_arrays_on_restore( 66 src_tree: chex.ArrayTree, dst_tree: chex.ArrayTree 67) -> chex.ArrayTree: 68 """Filter out zero-size arrays from the pytree.""" 69 70 def f(src, dst): 71 if isinstance(src, jax.Array) and src.size == 0: 72 return src 73 else: 74 return dst 75 76 return jtu.tree_map(f, src_tree, dst_tree)
77 78
[docs] 79class DummyCheckpointManager(ocp.AbstractCheckpointManager): 80 """A dummy checkpoint manager that does nothing.""" 81
[docs] 82 def directory(self): 83 return "UwU"
84
[docs] 85 def all_steps(self, read: bool = False) -> Sequence[int]: 86 return []
87
[docs] 88 def latest_step(self) -> int | None: 89 return None
90
[docs] 91 def best_step(self) -> int | None: 92 return None
93
[docs] 94 def reload(self): 95 pass
96
[docs] 97 def reached_preemption(self, step: int) -> bool: 98 return True
99
[docs] 100 def should_save(self, step: int) -> bool: 101 return False
102
[docs] 103 def delete(self, step: int): 104 pass
105
[docs] 106 def item_metadata(self, step: int): 107 return None
108
[docs] 109 def metadata(self) -> Mapping[str, Any]: 110 return {}
111
[docs] 112 def metrics(self, step: int) -> Any | None: 113 return None
114
[docs] 115 def wait_until_finished(self): 116 pass
117
[docs] 118 def check_for_errors(self): 119 pass
120
[docs] 121 def save(self, step, items, **kwargs): 122 return True
123
[docs] 124 def restore(self, step, items, **kwargs): 125 raise NotImplementedError("UwU")
126
[docs] 127 def close(self): 128 pass
129 130
[docs] 131class CheckpointManager(ocp.CheckpointManager):
[docs] 132 def save(self, step, items, **kwargs): 133 args = ocp.args.StandardSave(filter_zero_size_arrays_on_save(items)) 134 return super().save(step, args=args, **kwargs)
135
[docs] 136 def restore(self, step, items, **kwargs): 137 new_items = super().restore( 138 step, args=ocp.args.StandardRestore(items), **kwargs 139 ) 140 141 new_items = filter_zero_size_arrays_on_restore(items, new_items) 142 return new_items
143 144
[docs] 145def setup_checkpoint_manager(config: DictConfig) -> ocp.CheckpointManager: 146 """Setup checkpoint manager.""" 147 if config.checkpoint.enable: 148 ckpt_options = ocp.CheckpointManagerOptions( 149 save_interval_steps=config.checkpoint.save_interval_steps, 150 max_to_keep=config.checkpoint.max_to_keep, 151 ) 152 # Note: orbax only supports absolute path 153 output_dir = os.path.abspath(os.path.expanduser(config.output_dir)) 154 ckpt_path = os.path.join(output_dir, "checkpoints") 155 logger.info(f"set checkpoint path: {ckpt_path}") 156 checkpoint_manager = CheckpointManager( 157 ckpt_path, 158 options=ckpt_options, 159 metadata=OmegaConf.to_container( 160 config, resolve=True 161 ), # rescaled real config 162 ) 163 else: 164 checkpoint_manager = DummyCheckpointManager() 165 166 return checkpoint_manager