evorl.utils.orbax_utils

Module Contents

Classes

CheckpointManager

DummyCheckpointManager

A dummy checkpoint manager that does nothing.

Functions

filter_zero_size_arrays_on_restore

Filter out zero-size arrays from the pytree.

filter_zero_size_arrays_on_save

Filter out zero-size arrays from the pytree.

load

Load state from a file.

save

Save state to a file.

setup_checkpoint_manager

Setup checkpoint manager.

API

class evorl.utils.orbax_utils.CheckpointManager(directory: etils.epath.PathLike, checkpointers: Optional[Union[orbax.checkpoint.checkpoint_manager.AbstractCheckpointer, orbax.checkpoint.checkpoint_manager.CheckpointersDict]] = None, options: Optional[orbax.checkpoint.checkpoint_manager.CheckpointManagerOptions] = None, metadata: Optional[Mapping[str, Any]] = None, item_names: Optional[Sequence[str]] = None, item_handlers: Optional[Union[orbax.checkpoint.checkpoint_manager.CheckpointHandler, orbax.checkpoint.checkpoint_manager.CheckpointHandlersDict]] = None, logger: Optional[orbax.checkpoint._src.logging.abstract_logger.AbstractLogger] = None, handler_registry: Optional[orbax.checkpoint.checkpoint_manager.CheckpointHandlerRegistry] = None)[source]

Bases: orbax.checkpoint.CheckpointManager

restore(step, items, **kwargs)[source]
save(step, items, **kwargs)[source]
class evorl.utils.orbax_utils.DummyCheckpointManager[source]

Bases: orbax.checkpoint.AbstractCheckpointManager

A dummy checkpoint manager that does nothing.

all_steps(read: bool = False) collections.abc.Sequence[int][source]
best_step() int | None[source]
check_for_errors()[source]
close()[source]
delete(step: int)[source]
directory()[source]
item_metadata(step: int)[source]
latest_step() int | None[source]
metadata() collections.abc.Mapping[str, Any][source]
metrics(step: int) Any | None[source]
reached_preemption(step: int) bool[source]
reload()[source]
abstract restore(step, items, **kwargs)[source]
save(step, items, **kwargs)[source]
should_save(step: int) bool[source]
wait_until_finished()[source]
evorl.utils.orbax_utils.filter_zero_size_arrays_on_restore(src_tree: chex.ArrayTree, dst_tree: chex.ArrayTree) chex.ArrayTree[source]

Filter out zero-size arrays from the pytree.

evorl.utils.orbax_utils.filter_zero_size_arrays_on_save(tree: chex.ArrayTree) chex.ArrayTree[source]

Filter out zero-size arrays from the pytree.

evorl.utils.orbax_utils.load(path, state: chex.ArrayTree) chex.ArrayTree[source]

Load state from a file.

Parameters:
  • path – Checkpoint path

  • state – The same structure as the saved state for restore. Can be a dummy state or its abstract_state by jtu.tree_map(ocp.utils.to_shape_dtype_struct, state)

Returns:

The loaded state.

evorl.utils.orbax_utils.save(path, state: chex.ArrayTree)[source]

Save state to a file.

Parameters:
  • path – Checkpoint path.

  • state – The state to be saved.

evorl.utils.orbax_utils.setup_checkpoint_manager(config: omegaconf.DictConfig) orbax.checkpoint.CheckpointManager[source]

Setup checkpoint manager.