1import copy
2import logging
3
4import chex
5import jax
6import optax
7from omegaconf import DictConfig, OmegaConf
8from typing_extensions import Self # pytype: disable=not-supported-yet
9
10from evorl.replay_buffers import AbstractReplayBuffer, ReplayBufferState
11from evorl.agent import Agent, AgentState
12from evorl.distributed import DP_AXIS_NAME, split_key_to_devices, shmap_vmap
13from evorl.envs import Env
14from evorl.evaluators import Evaluator
15from evorl.metrics import EvaluateMetric, MetricBase, WorkflowMetric
16from evorl.types import State
17
18from .workflow import Workflow
19
20logger = logging.getLogger(__name__)
21
22
[docs]
23class RLWorkflow(Workflow):
24 """Base Workflow for RL algorithms."""
25
26 def __init__(self, config: DictConfig):
27 """Initialize a RLWorkflow instance.
28
29 Args:
30 config: the config object.
31 """
32 super().__init__(config)
33
34 self.dp_axis_name = None
35 self.devices = jax.local_devices()[:1]
36
37 @property
38 def enable_multi_devices(self) -> bool:
39 """Whether multi-devices training is enabled."""
40 return self.dp_axis_name is not None
41
[docs]
42 @classmethod
43 def build_from_config(
44 cls,
45 config: DictConfig,
46 enable_multi_devices: bool = False,
47 enable_jit: bool = True,
48 ) -> Self:
49 """Build the rl workflow instance from the config.
50
51 Args:
52 config: Config of the workflow.
53 enable_multi_devices: Whether multi-devices training is enabled.
54 enable_jit: Whether jit is enabled.
55 """
56 config = copy.deepcopy(config) # avoid in-place modification
57
58 devices = jax.local_devices()
59
60 if enable_multi_devices:
61 cls.enable_shmap(DP_AXIS_NAME)
62 OmegaConf.set_readonly(config, False)
63 cls._rescale_config(config)
64 elif enable_jit:
65 cls.enable_jit()
66
67 OmegaConf.set_readonly(config, True)
68
69 workflow = cls._build_from_config(config)
70 if enable_multi_devices:
71 workflow.dp_axis_name = DP_AXIS_NAME
72 workflow.devices = devices
73
74 return workflow
75
76 @classmethod
77 def _build_from_config(cls, config: DictConfig) -> Self:
78 """Customize the process of building the workflow instance from the config.
79
80 Args:
81 config: Config of the workflow.
82
83 Returns:
84 The created workflow instance.
85 """
86 raise NotImplementedError
87
88 @classmethod
89 def _rescale_config(cls, config: DictConfig) -> None:
90 """Customize the logic of rescaling the config settings when multi-devices training is enabled.
91
92 When enable_multi_devices=True, rescale config settings in-place to match multi-devices.
93
94 Args:
95 config: Config of the workflow.
96 """
97 pass
98
[docs]
99 def step(self, state: State) -> tuple[MetricBase, State]:
100 """Customize the training logic of one iteration.
101
102 Args:
103 state: State of the workflow.
104
105 Returns:
106 Tuple of (metrics, state).
107 """
108 raise NotImplementedError
109
[docs]
110 def evaluate(self, state: State) -> tuple[MetricBase, State]:
111 """Customize the evaluation logic for the workflow.
112
113 Args:
114 state: State of the workflow.
115 """
116 raise NotImplementedError
117
[docs]
118 @classmethod
119 def enable_jit(cls) -> None:
120 """Define which methods should be jitted.
121
122 By default, the workflow's `step()` and `evaluate()` methods are jitted.
123 """
124 cls.evaluate = jax.jit(cls.evaluate, static_argnums=(0,))
125 cls.step = jax.jit(cls.step, static_argnums=(0,))
126
[docs]
127 @classmethod
128 def enable_shmap(cls, axis_name: str) -> None:
129 """Define which methods should be shmaped.
130
131 This method defines the multi-device behavior. By default, the workflow's `step()` and `evaluate()` methods are shmaped.
132
133 Args:
134 axis_name: The axis_name for shmap.
135 """
136 # We wrap the methods dynamically to inject sharding context.
137 # This will be constructed and called with shmap_vmap later in execution.
138 # So we leave the unbound methods as is but mark them to be executed via shard_map.
139 # Note: In the refactored approach, passing a mesh handles parallel execution.
140 pass
141
142
[docs]
143class OnPolicyWorkflow(RLWorkflow):
144 """Workflow template for On-Policy RL algorithms.
145
146 This class constructs the template for On-Policy RL algorithms, providing the general `setup()` and `evaluate()` methods.
147 """
148
149 def __init__(
150 self,
151 env: Env,
152 agent: Agent,
153 optimizer: optax.GradientTransformation,
154 evaluator: Evaluator,
155 config: DictConfig,
156 ):
157 """Initialize an OnPolicyWorkflow instance.
158
159 Args:
160 env: Environment object.
161 agent: Workflow-sepecific agent object.
162 optimizer: Optimizer of the agent.
163 evaluator: Evaluator object used in self.evaluation().
164 config: Config of the workflow.
165 """
166 super().__init__(config)
167
168 self.env = env
169 self.agent = agent
170 self.optimizer = optimizer
171 self.evaluator = evaluator
172
173 def _setup_agent_and_optimizer(
174 self, key: chex.PRNGKey
175 ) -> tuple[AgentState, chex.ArrayTree]:
176 """Setup Agent and Optimizer states.
177
178 Args:
179 key: JAX PRNGKey.
180
181 Returns:
182 Tuple of (agent_state, opt_state)
183 """
184 agent_state = self.agent.init(self.env.obs_space, self.env.action_space, key)
185 opt_state = self.optimizer.init(agent_state.params)
186 return agent_state, opt_state
187
188 def _setup_workflow_metrics(self) -> MetricBase:
189 """Define Workflow metrics."""
190 return WorkflowMetric()
191
[docs]
192 def setup(self, key: chex.PRNGKey) -> State:
193 key, agent_key, env_key = jax.random.split(key, 3)
194
195 agent_state, opt_state = self._setup_agent_and_optimizer(agent_key)
196 workflow_metrics = self._setup_workflow_metrics()
197
198 if self.enable_multi_devices:
199 sharding = jax.sharding.NamedSharding(
200 jax.sharding.Mesh(self.devices, (self.dp_axis_name,)),
201 jax.sharding.PartitionSpec(),
202 )
203 workflow_metrics, agent_state, opt_state = jax.device_put(
204 (workflow_metrics, agent_state, opt_state), sharding
205 )
206
207 # key and env_state should be different over devices
208 key = split_key_to_devices(key, self.devices)
209
210 env_key = split_key_to_devices(env_key, self.devices)
211
212 env_reset_fn = shmap_vmap(
213 self.env.reset,
214 mesh=jax.sharding.Mesh(self.devices, (self.dp_axis_name,)),
215 in_specs=jax.sharding.PartitionSpec(self.dp_axis_name),
216 out_specs=jax.sharding.PartitionSpec(self.dp_axis_name),
217 check_rep=False,
218 )
219 env_state = env_reset_fn(env_key)
220 else:
221 env_state = self.env.reset(env_key)
222
223 return State(
224 key=key,
225 metrics=workflow_metrics,
226 agent_state=agent_state,
227 env_state=env_state,
228 opt_state=opt_state,
229 )
230
[docs]
231 def evaluate(self, state: State) -> tuple[MetricBase, State]:
232 key, eval_key = jax.random.split(state.key, num=2)
233
234 # [#episodes]
235 raw_eval_metrics = self.evaluator.evaluate(
236 state.agent_state, eval_key, num_episodes=self.config.eval_episodes
237 )
238
239 eval_metrics = EvaluateMetric(
240 episode_returns=raw_eval_metrics.episode_returns.mean(),
241 episode_lengths=raw_eval_metrics.episode_lengths.mean(),
242 ).all_reduce(dp_axis_name=self.dp_axis_name)
243
244 state = state.replace(key=key)
245 return eval_metrics, state
246
247
[docs]
248class OffPolicyWorkflow(RLWorkflow):
249 """Workflow template for Off-Policy RL algorithms.
250
251 This class constructs the template for Off-Policy RL algorithms, providing the general `setup()` and `evaluate()` methods.
252 """
253
254 def __init__(
255 self,
256 env: Env,
257 agent: Agent,
258 optimizer: optax.GradientTransformation,
259 evaluator: Evaluator,
260 replay_buffer: AbstractReplayBuffer,
261 config: DictConfig,
262 ):
263 """Initialize an OffPolicyWorkflow instance.
264
265 Args:
266 env: Environment object.
267 agent: Workflow-sepecific agent object.
268 optimizer: Optimizer of the agent.
269 evaluator: Evaluator object used in self.evaluation().
270 replay_buffer: ReplayBuffer object.
271 config: Config of the workflow.
272 """
273 super().__init__(config)
274
275 self.env = env
276 self.agent = agent
277 self.optimizer = optimizer
278 self.evaluator = evaluator
279 self.replay_buffer = replay_buffer
280
281 def _setup_agent_and_optimizer(
282 self, key: chex.PRNGKey
283 ) -> tuple[AgentState, chex.ArrayTree]:
284 """Setup Agent and Optimizer states.
285
286 Args:
287 key: JAX PRNGKey.
288
289 Returns:
290 Tuple of (agent_state, opt_state).
291 """
292 agent_state = self.agent.init(self.env.obs_space, self.env.action_space, key)
293 opt_state = self.optimizer.init(agent_state.params)
294 return agent_state, opt_state
295
296 def _setup_workflow_metrics(self) -> MetricBase:
297 """Define Workflow metrics."""
298 return WorkflowMetric()
299
300 def _setup_replaybuffer(self, key: chex.PRNGKey) -> ReplayBufferState:
301 """Setup ReplayBuffer state."""
302 raise NotImplementedError
303
304 def _postsetup_replaybuffer(self, state: State) -> State:
305 """Post-setup ReplayBuffer state before training."""
306 return state
307
[docs]
308 def setup(self, key: chex.PRNGKey) -> State:
309 key, agent_key, env_key, rb_key = jax.random.split(key, 4)
310
311 agent_state, opt_state = self._setup_agent_and_optimizer(agent_key)
312 workflow_metrics = self._setup_workflow_metrics()
313
314 if self.enable_multi_devices:
315 sharding = jax.sharding.NamedSharding(
316 jax.sharding.Mesh(self.devices, (self.dp_axis_name,)),
317 jax.sharding.PartitionSpec(),
318 )
319 workflow_metrics, agent_state, opt_state = jax.device_put(
320 (workflow_metrics, agent_state, opt_state), sharding
321 )
322
323 # key and env_state should be different over devices
324 key = split_key_to_devices(key, self.devices)
325
326 env_key = split_key_to_devices(env_key, self.devices)
327 mesh = jax.sharding.Mesh(self.devices, (self.dp_axis_name,))
328 spec = jax.sharding.PartitionSpec(self.dp_axis_name)
329
330 env_reset_fn = shmap_vmap(
331 self.env.reset,
332 mesh=mesh,
333 in_specs=spec,
334 out_specs=spec,
335 check_rep=False,
336 )
337 env_state = env_reset_fn(env_key)
338
339 rb_key = split_key_to_devices(rb_key, self.devices)
340 setup_rb_fn = shmap_vmap(
341 self._setup_replaybuffer,
342 mesh=mesh,
343 in_specs=spec,
344 out_specs=spec,
345 check_rep=False,
346 )
347 replay_buffer_state = setup_rb_fn(rb_key)
348 else:
349 env_state = self.env.reset(env_key)
350 replay_buffer_state = self._setup_replaybuffer(rb_key)
351
352 state = State(
353 key=key,
354 metrics=workflow_metrics,
355 agent_state=agent_state,
356 env_state=env_state,
357 opt_state=opt_state,
358 replay_buffer_state=replay_buffer_state,
359 )
360
361 logger.info("Start replay buffer post-setup")
362 if self.enable_multi_devices:
363 mesh = jax.sharding.Mesh(self.devices, (self.dp_axis_name,))
364 spec = jax.sharding.PartitionSpec(self.dp_axis_name)
365 post_setup_fn = shmap_vmap(
366 self._postsetup_replaybuffer,
367 mesh=mesh,
368 in_specs=spec,
369 out_specs=spec,
370 check_rep=False,
371 )
372 state = post_setup_fn(state)
373 else:
374 state = self._postsetup_replaybuffer(state)
375
376 logger.info("Complete replay buffer post-setup")
377
378 return state
379
[docs]
380 def evaluate(self, state: State) -> tuple[MetricBase, State]:
381 key, eval_key = jax.random.split(state.key, num=2)
382
383 # [#episodes]
384 raw_eval_metrics = self.evaluator.evaluate(
385 state.agent_state, eval_key, num_episodes=self.config.eval_episodes
386 )
387
388 eval_metrics = EvaluateMetric(
389 episode_returns=raw_eval_metrics.episode_returns.mean(),
390 episode_lengths=raw_eval_metrics.episode_lengths.mean(),
391 ).all_reduce(dp_axis_name=self.dp_axis_name)
392
393 state = state.replace(key=key)
394 return eval_metrics, state