Source code for evorl.workflows.rl_workflow

  1import copy
  2import logging
  3
  4import chex
  5import jax
  6import optax
  7from omegaconf import DictConfig, OmegaConf
  8from typing_extensions import Self  # pytype: disable=not-supported-yet
  9
 10from evorl.replay_buffers import AbstractReplayBuffer, ReplayBufferState
 11from evorl.agent import Agent, AgentState
 12from evorl.distributed import DP_AXIS_NAME, split_key_to_devices, shmap_vmap
 13from evorl.envs import Env
 14from evorl.evaluators import Evaluator
 15from evorl.metrics import EvaluateMetric, MetricBase, WorkflowMetric
 16from evorl.types import State
 17
 18from .workflow import Workflow
 19
 20logger = logging.getLogger(__name__)
 21
 22
[docs] 23class RLWorkflow(Workflow): 24 """Base Workflow for RL algorithms.""" 25 26 def __init__(self, config: DictConfig): 27 """Initialize a RLWorkflow instance. 28 29 Args: 30 config: the config object. 31 """ 32 super().__init__(config) 33 34 self.dp_axis_name = None 35 self.devices = jax.local_devices()[:1] 36 37 @property 38 def enable_multi_devices(self) -> bool: 39 """Whether multi-devices training is enabled.""" 40 return self.dp_axis_name is not None 41
[docs] 42 @classmethod 43 def build_from_config( 44 cls, 45 config: DictConfig, 46 enable_multi_devices: bool = False, 47 enable_jit: bool = True, 48 ) -> Self: 49 """Build the rl workflow instance from the config. 50 51 Args: 52 config: Config of the workflow. 53 enable_multi_devices: Whether multi-devices training is enabled. 54 enable_jit: Whether jit is enabled. 55 """ 56 config = copy.deepcopy(config) # avoid in-place modification 57 58 devices = jax.local_devices() 59 60 if enable_multi_devices: 61 cls.enable_shmap(DP_AXIS_NAME) 62 OmegaConf.set_readonly(config, False) 63 cls._rescale_config(config) 64 elif enable_jit: 65 cls.enable_jit() 66 67 OmegaConf.set_readonly(config, True) 68 69 workflow = cls._build_from_config(config) 70 if enable_multi_devices: 71 workflow.dp_axis_name = DP_AXIS_NAME 72 workflow.devices = devices 73 74 return workflow
75 76 @classmethod 77 def _build_from_config(cls, config: DictConfig) -> Self: 78 """Customize the process of building the workflow instance from the config. 79 80 Args: 81 config: Config of the workflow. 82 83 Returns: 84 The created workflow instance. 85 """ 86 raise NotImplementedError 87 88 @classmethod 89 def _rescale_config(cls, config: DictConfig) -> None: 90 """Customize the logic of rescaling the config settings when multi-devices training is enabled. 91 92 When enable_multi_devices=True, rescale config settings in-place to match multi-devices. 93 94 Args: 95 config: Config of the workflow. 96 """ 97 pass 98
[docs] 99 def step(self, state: State) -> tuple[MetricBase, State]: 100 """Customize the training logic of one iteration. 101 102 Args: 103 state: State of the workflow. 104 105 Returns: 106 Tuple of (metrics, state). 107 """ 108 raise NotImplementedError
109
[docs] 110 def evaluate(self, state: State) -> tuple[MetricBase, State]: 111 """Customize the evaluation logic for the workflow. 112 113 Args: 114 state: State of the workflow. 115 """ 116 raise NotImplementedError
117
[docs] 118 @classmethod 119 def enable_jit(cls) -> None: 120 """Define which methods should be jitted. 121 122 By default, the workflow's `step()` and `evaluate()` methods are jitted. 123 """ 124 cls.evaluate = jax.jit(cls.evaluate, static_argnums=(0,)) 125 cls.step = jax.jit(cls.step, static_argnums=(0,))
126
[docs] 127 @classmethod 128 def enable_shmap(cls, axis_name: str) -> None: 129 """Define which methods should be shmaped. 130 131 This method defines the multi-device behavior. By default, the workflow's `step()` and `evaluate()` methods are shmaped. 132 133 Args: 134 axis_name: The axis_name for shmap. 135 """ 136 # We wrap the methods dynamically to inject sharding context. 137 # This will be constructed and called with shmap_vmap later in execution. 138 # So we leave the unbound methods as is but mark them to be executed via shard_map. 139 # Note: In the refactored approach, passing a mesh handles parallel execution. 140 pass
141 142
[docs] 143class OnPolicyWorkflow(RLWorkflow): 144 """Workflow template for On-Policy RL algorithms. 145 146 This class constructs the template for On-Policy RL algorithms, providing the general `setup()` and `evaluate()` methods. 147 """ 148 149 def __init__( 150 self, 151 env: Env, 152 agent: Agent, 153 optimizer: optax.GradientTransformation, 154 evaluator: Evaluator, 155 config: DictConfig, 156 ): 157 """Initialize an OnPolicyWorkflow instance. 158 159 Args: 160 env: Environment object. 161 agent: Workflow-sepecific agent object. 162 optimizer: Optimizer of the agent. 163 evaluator: Evaluator object used in self.evaluation(). 164 config: Config of the workflow. 165 """ 166 super().__init__(config) 167 168 self.env = env 169 self.agent = agent 170 self.optimizer = optimizer 171 self.evaluator = evaluator 172 173 def _setup_agent_and_optimizer( 174 self, key: chex.PRNGKey 175 ) -> tuple[AgentState, chex.ArrayTree]: 176 """Setup Agent and Optimizer states. 177 178 Args: 179 key: JAX PRNGKey. 180 181 Returns: 182 Tuple of (agent_state, opt_state) 183 """ 184 agent_state = self.agent.init(self.env.obs_space, self.env.action_space, key) 185 opt_state = self.optimizer.init(agent_state.params) 186 return agent_state, opt_state 187 188 def _setup_workflow_metrics(self) -> MetricBase: 189 """Define Workflow metrics.""" 190 return WorkflowMetric() 191
[docs] 192 def setup(self, key: chex.PRNGKey) -> State: 193 key, agent_key, env_key = jax.random.split(key, 3) 194 195 agent_state, opt_state = self._setup_agent_and_optimizer(agent_key) 196 workflow_metrics = self._setup_workflow_metrics() 197 198 if self.enable_multi_devices: 199 sharding = jax.sharding.NamedSharding( 200 jax.sharding.Mesh(self.devices, (self.dp_axis_name,)), 201 jax.sharding.PartitionSpec(), 202 ) 203 workflow_metrics, agent_state, opt_state = jax.device_put( 204 (workflow_metrics, agent_state, opt_state), sharding 205 ) 206 207 # key and env_state should be different over devices 208 key = split_key_to_devices(key, self.devices) 209 210 env_key = split_key_to_devices(env_key, self.devices) 211 212 env_reset_fn = shmap_vmap( 213 self.env.reset, 214 mesh=jax.sharding.Mesh(self.devices, (self.dp_axis_name,)), 215 in_specs=jax.sharding.PartitionSpec(self.dp_axis_name), 216 out_specs=jax.sharding.PartitionSpec(self.dp_axis_name), 217 check_rep=False, 218 ) 219 env_state = env_reset_fn(env_key) 220 else: 221 env_state = self.env.reset(env_key) 222 223 return State( 224 key=key, 225 metrics=workflow_metrics, 226 agent_state=agent_state, 227 env_state=env_state, 228 opt_state=opt_state, 229 )
230
[docs] 231 def evaluate(self, state: State) -> tuple[MetricBase, State]: 232 key, eval_key = jax.random.split(state.key, num=2) 233 234 # [#episodes] 235 raw_eval_metrics = self.evaluator.evaluate( 236 state.agent_state, eval_key, num_episodes=self.config.eval_episodes 237 ) 238 239 eval_metrics = EvaluateMetric( 240 episode_returns=raw_eval_metrics.episode_returns.mean(), 241 episode_lengths=raw_eval_metrics.episode_lengths.mean(), 242 ).all_reduce(dp_axis_name=self.dp_axis_name) 243 244 state = state.replace(key=key) 245 return eval_metrics, state
246 247
[docs] 248class OffPolicyWorkflow(RLWorkflow): 249 """Workflow template for Off-Policy RL algorithms. 250 251 This class constructs the template for Off-Policy RL algorithms, providing the general `setup()` and `evaluate()` methods. 252 """ 253 254 def __init__( 255 self, 256 env: Env, 257 agent: Agent, 258 optimizer: optax.GradientTransformation, 259 evaluator: Evaluator, 260 replay_buffer: AbstractReplayBuffer, 261 config: DictConfig, 262 ): 263 """Initialize an OffPolicyWorkflow instance. 264 265 Args: 266 env: Environment object. 267 agent: Workflow-sepecific agent object. 268 optimizer: Optimizer of the agent. 269 evaluator: Evaluator object used in self.evaluation(). 270 replay_buffer: ReplayBuffer object. 271 config: Config of the workflow. 272 """ 273 super().__init__(config) 274 275 self.env = env 276 self.agent = agent 277 self.optimizer = optimizer 278 self.evaluator = evaluator 279 self.replay_buffer = replay_buffer 280 281 def _setup_agent_and_optimizer( 282 self, key: chex.PRNGKey 283 ) -> tuple[AgentState, chex.ArrayTree]: 284 """Setup Agent and Optimizer states. 285 286 Args: 287 key: JAX PRNGKey. 288 289 Returns: 290 Tuple of (agent_state, opt_state). 291 """ 292 agent_state = self.agent.init(self.env.obs_space, self.env.action_space, key) 293 opt_state = self.optimizer.init(agent_state.params) 294 return agent_state, opt_state 295 296 def _setup_workflow_metrics(self) -> MetricBase: 297 """Define Workflow metrics.""" 298 return WorkflowMetric() 299 300 def _setup_replaybuffer(self, key: chex.PRNGKey) -> ReplayBufferState: 301 """Setup ReplayBuffer state.""" 302 raise NotImplementedError 303 304 def _postsetup_replaybuffer(self, state: State) -> State: 305 """Post-setup ReplayBuffer state before training.""" 306 return state 307
[docs] 308 def setup(self, key: chex.PRNGKey) -> State: 309 key, agent_key, env_key, rb_key = jax.random.split(key, 4) 310 311 agent_state, opt_state = self._setup_agent_and_optimizer(agent_key) 312 workflow_metrics = self._setup_workflow_metrics() 313 314 if self.enable_multi_devices: 315 sharding = jax.sharding.NamedSharding( 316 jax.sharding.Mesh(self.devices, (self.dp_axis_name,)), 317 jax.sharding.PartitionSpec(), 318 ) 319 workflow_metrics, agent_state, opt_state = jax.device_put( 320 (workflow_metrics, agent_state, opt_state), sharding 321 ) 322 323 # key and env_state should be different over devices 324 key = split_key_to_devices(key, self.devices) 325 326 env_key = split_key_to_devices(env_key, self.devices) 327 mesh = jax.sharding.Mesh(self.devices, (self.dp_axis_name,)) 328 spec = jax.sharding.PartitionSpec(self.dp_axis_name) 329 330 env_reset_fn = shmap_vmap( 331 self.env.reset, 332 mesh=mesh, 333 in_specs=spec, 334 out_specs=spec, 335 check_rep=False, 336 ) 337 env_state = env_reset_fn(env_key) 338 339 rb_key = split_key_to_devices(rb_key, self.devices) 340 setup_rb_fn = shmap_vmap( 341 self._setup_replaybuffer, 342 mesh=mesh, 343 in_specs=spec, 344 out_specs=spec, 345 check_rep=False, 346 ) 347 replay_buffer_state = setup_rb_fn(rb_key) 348 else: 349 env_state = self.env.reset(env_key) 350 replay_buffer_state = self._setup_replaybuffer(rb_key) 351 352 state = State( 353 key=key, 354 metrics=workflow_metrics, 355 agent_state=agent_state, 356 env_state=env_state, 357 opt_state=opt_state, 358 replay_buffer_state=replay_buffer_state, 359 ) 360 361 logger.info("Start replay buffer post-setup") 362 if self.enable_multi_devices: 363 mesh = jax.sharding.Mesh(self.devices, (self.dp_axis_name,)) 364 spec = jax.sharding.PartitionSpec(self.dp_axis_name) 365 post_setup_fn = shmap_vmap( 366 self._postsetup_replaybuffer, 367 mesh=mesh, 368 in_specs=spec, 369 out_specs=spec, 370 check_rep=False, 371 ) 372 state = post_setup_fn(state) 373 else: 374 state = self._postsetup_replaybuffer(state) 375 376 logger.info("Complete replay buffer post-setup") 377 378 return state
379
[docs] 380 def evaluate(self, state: State) -> tuple[MetricBase, State]: 381 key, eval_key = jax.random.split(state.key, num=2) 382 383 # [#episodes] 384 raw_eval_metrics = self.evaluator.evaluate( 385 state.agent_state, eval_key, num_episodes=self.config.eval_episodes 386 ) 387 388 eval_metrics = EvaluateMetric( 389 episode_returns=raw_eval_metrics.episode_returns.mean(), 390 episode_lengths=raw_eval_metrics.episode_lengths.mean(), 391 ).all_reduce(dp_axis_name=self.dp_axis_name) 392 393 state = state.replace(key=key) 394 return eval_metrics, state