Source code for evorl.workflows.workflow

  1from abc import ABC, abstractmethod
  2from typing import Any
  3
  4import chex
  5from omegaconf import DictConfig
  6from typing_extensions import Self  # pytype: disable=not-supported-yet
  7
  8from evorl.recorders import ChainRecorder, Recorder
  9from evorl.types import State
 10from evorl.utils.orbax_utils import setup_checkpoint_manager
 11
 12# TODO: remove it when evox is updated
 13
 14
[docs] 15class AbstractWorkflow(ABC): 16 """A Workflow Interface for EvoRL training pipelines.""" 17
[docs] 18 @abstractmethod 19 def init(self, key: chex.PRNGKey) -> State: 20 """Initialize the workflow's state. 21 22 Args: 23 key: JAX PRNGKey 24 25 Returns: 26 state: the state of the workflow 27 """ 28 raise NotImplementedError
29
[docs] 30 @abstractmethod 31 def step(self, state: State) -> tuple[Any, State]: 32 """Define the logic of one training iteration.""" 33 raise NotImplementedError
34
[docs] 35 @classmethod 36 def name(cls) -> str: 37 """Define the name of the workflow (eg. PPO, PSO, etc.). 38 39 Default workflow name is its class name. 40 """ 41 return cls.__name__
42 43
[docs] 44class Workflow(AbstractWorkflow): 45 """The base class for all Workflows. 46 47 All workflow classes are inherit from this class, and customize by implementing 48 """ 49 50 def __init__(self, config: DictConfig): 51 """Initialize a RLWorkflow instance. 52 53 Args: 54 config: the config object. 55 """ 56 self.config = config 57 self.recorder = ChainRecorder([]) 58 self.checkpoint_manager = setup_checkpoint_manager(config) 59
[docs] 60 @classmethod 61 def build_from_config(cls, config: DictConfig, *args, **kwargs) -> Self: 62 """Build the workflow instance from the config. 63 64 This is the public API to call for instantiating a new workflow object from config. Normally, it will call __init__() and do some pre- and post-processing. 65 66 Args: 67 config: config object 68 69 Returns: 70 A workflow instance 71 72 """ 73 raise NotImplementedError
74
[docs] 75 def init(self, key: chex.PRNGKey) -> State: 76 """Initialize the state of the . 77 78 This is the public API to call for instance state initialization. 79 """ 80 self.recorder.init() 81 state = self.setup(key) 82 return state
83
[docs] 84 def setup(self, key: chex.PRNGKey) -> State: 85 raise NotImplementedError
86
[docs] 87 def add_recorders(self, recorders: Recorder) -> None: 88 for recorder in recorders: 89 self.recorder.add_recorder(recorder)
90
[docs] 91 def close(self) -> None: 92 """Close the workflow's components.""" 93 self.recorder.close() 94 self.checkpoint_manager.close()
95
[docs] 96 def learn(self, state: State) -> State: 97 """Run the complete learning process. 98 99 The learning process includes: 100 101 - call multiple times of step() 102 - record the metrics 103 - save checkpoints 104 """ 105 raise NotImplementedError