Source code for evorl.workflows.workflow
1from abc import ABC, abstractmethod
2from typing import Any
3
4import chex
5from omegaconf import DictConfig
6from typing_extensions import Self # pytype: disable=not-supported-yet
7
8from evorl.recorders import ChainRecorder, Recorder
9from evorl.types import State
10from evorl.utils.orbax_utils import setup_checkpoint_manager
11
12# TODO: remove it when evox is updated
13
14
[docs]
15class AbstractWorkflow(ABC):
16 """A Workflow Interface for EvoRL training pipelines."""
17
[docs]
18 @abstractmethod
19 def init(self, key: chex.PRNGKey) -> State:
20 """Initialize the workflow's state.
21
22 Args:
23 key: JAX PRNGKey
24
25 Returns:
26 state: the state of the workflow
27 """
28 raise NotImplementedError
29
[docs]
30 @abstractmethod
31 def step(self, state: State) -> tuple[Any, State]:
32 """Define the logic of one training iteration."""
33 raise NotImplementedError
34
[docs]
35 @classmethod
36 def name(cls) -> str:
37 """Define the name of the workflow (eg. PPO, PSO, etc.).
38
39 Default workflow name is its class name.
40 """
41 return cls.__name__
42
43
[docs]
44class Workflow(AbstractWorkflow):
45 """The base class for all Workflows.
46
47 All workflow classes are inherit from this class, and customize by implementing
48 """
49
50 def __init__(self, config: DictConfig):
51 """Initialize a RLWorkflow instance.
52
53 Args:
54 config: the config object.
55 """
56 self.config = config
57 self.recorder = ChainRecorder([])
58 self.checkpoint_manager = setup_checkpoint_manager(config)
59
[docs]
60 @classmethod
61 def build_from_config(cls, config: DictConfig, *args, **kwargs) -> Self:
62 """Build the workflow instance from the config.
63
64 This is the public API to call for instantiating a new workflow object from config. Normally, it will call __init__() and do some pre- and post-processing.
65
66 Args:
67 config: config object
68
69 Returns:
70 A workflow instance
71
72 """
73 raise NotImplementedError
74
[docs]
75 def init(self, key: chex.PRNGKey) -> State:
76 """Initialize the state of the .
77
78 This is the public API to call for instance state initialization.
79 """
80 self.recorder.init()
81 state = self.setup(key)
82 return state
83
[docs]
84 def setup(self, key: chex.PRNGKey) -> State:
85 raise NotImplementedError
86
[docs]
87 def add_recorders(self, recorders: Recorder) -> None:
88 for recorder in recorders:
89 self.recorder.add_recorder(recorder)
90
[docs]
91 def close(self) -> None:
92 """Close the workflow's components."""
93 self.recorder.close()
94 self.checkpoint_manager.close()
95
[docs]
96 def learn(self, state: State) -> State:
97 """Run the complete learning process.
98
99 The learning process includes:
100
101 - call multiple times of step()
102 - record the metrics
103 - save checkpoints
104 """
105 raise NotImplementedError