Source code for evorl.workflows.ec_workflow

  1import copy
  2import logging
  3from omegaconf import DictConfig, OmegaConf
  4from typing_extensions import Self
  5
  6import chex
  7import jax
  8import jax.numpy as jnp
  9import jax.tree_util as jtu
 10
 11from evorl.distributed import POP_AXIS_NAME, all_gather
 12from evorl.metrics import (
 13    MetricBase,
 14    ECWorkflowMetric,
 15    MultiObjectiveECWorkflowMetric,
 16    ECTrainMetric,
 17)
 18from evorl.ec.optimizers import EvoOptimizer, ECState
 19from evorl.envs import Env
 20from evorl.sample_batch import SampleBatch
 21from evorl.evaluators import Evaluator, EpisodeCollector
 22from evorl.agent import Agent, AgentState, AgentStateAxis
 23from evorl.distributed import get_global_ranks, psum, split_key_to_devices
 24from evorl.types import State, PyTreeData, pytree_field, Params, PyTreeDict
 25from evorl.utils.rl_toolkits import flatten_pop_rollout_episode
 26from evorl.utils.jax_utils import tree_stop_gradient
 27
 28from .workflow import Workflow
 29
 30
 31logger = logging.getLogger(__name__)
 32
 33
[docs] 34class DistributedInfo(PyTreeData): 35 """Distributed information for multi-devices training.""" 36 37 rank: int = jnp.zeros((), dtype=jnp.int32) 38 world_size: int = pytree_field(default=1, static=True)
39 40
[docs] 41class ECWorkflow(Workflow): 42 """Base Workflow for EC (Evolutionary Computation) algorithms.""" 43 44 def __init__(self, config: DictConfig): 45 """Initialize the ECWorkflow instance. 46 47 Args: 48 config: the config object 49 """ 50 super().__init__(config) 51 52 self.dp_axis_name = None 53 self.devices = jax.local_devices()[:1] 54 55 @property 56 def enable_multi_devices(self) -> bool: 57 """Whether multi-devices training is enabled.""" 58 return self.dp_axis_name is not None 59
[docs] 60 @classmethod 61 def build_from_config( 62 cls, 63 config: DictConfig, 64 enable_multi_devices: bool = False, 65 enable_jit: bool = True, 66 ): 67 """Build the ec workflow instance from the config. 68 69 Args: 70 config: Config of the workflow 71 enable_multi_devices: Whether multi-devices training is enabled 72 enable_jit: Whether jit is enabled 73 """ 74 config = copy.deepcopy(config) # avoid in-place modification 75 76 devices = jax.local_devices() 77 78 if enable_multi_devices: 79 cls.enable_shmap(POP_AXIS_NAME) 80 OmegaConf.set_readonly(config, False) 81 cls._rescale_config(config) 82 elif enable_jit: 83 cls.enable_jit() 84 85 OmegaConf.set_readonly(config, True) 86 87 workflow = cls._build_from_config(config) 88 if enable_multi_devices: 89 workflow.dp_axis_name = POP_AXIS_NAME 90 workflow.devices = devices 91 92 return workflow
93 94 @classmethod 95 def _build_from_config(cls, config: DictConfig) -> Self: 96 """Customize the process of building the workflow instance from the config. 97 98 Args: 99 config: Config of the workflow 100 101 Returns: 102 workflow: the created workflow instance 103 """ 104 raise NotImplementedError 105 106 @classmethod 107 def _rescale_config(cls, config: DictConfig) -> None: 108 """Customize the logic of rescaling the config settings when multi-devices training is enabled. 109 110 When enable_multi_devices=True, rescale config settings in-place to match multi-devices 111 112 Args: 113 config: Config of the workflow 114 """ 115 pass 116
[docs] 117 @classmethod 118 def enable_jit(cls) -> None: 119 """Define which methods should be jitted. 120 121 By default, the workflow's `step()` method is jitted. 122 """ 123 cls.step = jax.jit(cls.step, static_argnums=(0,))
124
[docs] 125 @classmethod 126 def enable_shmap(cls, axis_name) -> None: 127 """Define which methods should be shmaped. 128 129 This method defines the multi-device behavior. By default, the workflow's `step()` method is shmaped. 130 """ 131 pass
132 133
[docs] 134class ECWorkflowTemplate(ECWorkflow): 135 """Workflow template for EC algorithms. 136 137 Attributes: 138 env: Environment object. 139 agent: Workflow-sepecific agent object. 140 ec_optimizer: EC Optimizer of the agent. 141 ec_evaluator: Evaluator object used in `self.evaluation()`. 142 agent_state_vmap_axes: Vmap axis for the agent state. 143 config: Config of the workflow. 144 """ 145 146 def __init__( 147 self, 148 *, 149 env: Env, 150 agent: Agent, 151 ec_optimizer: EvoOptimizer, 152 ec_evaluator: Evaluator | EpisodeCollector, 153 agent_state_vmap_axes: AgentStateAxis = 0, 154 config: DictConfig, 155 ): 156 """Initialize the ECWorkflow instance. 157 158 Args: 159 env: Environment object. 160 agent: Workflow-sepecific agent object. 161 ec_optimizer: EC Optimizer of the agent. 162 ec_evaluator: Evaluator object used in `self.evaluation()`. 163 agent_state_vmap_axes: Vmap axis for the agent state. 164 config: Config of the workflow. 165 """ 166 super().__init__(config) 167 168 self.agent = agent 169 self.env = env 170 self.ec_optimizer = ec_optimizer 171 self.ec_evaluator = ec_evaluator 172 self.agent_state_vmap_axes = agent_state_vmap_axes 173 174 def _setup_agent_and_optimizer( 175 self, key: chex.PRNGKey 176 ) -> tuple[AgentState, ECState]: 177 """Setup Agent and ECOptimizer states. 178 179 Args: 180 key: JAX PRNGKey 181 182 Returns: 183 Tuple of (agent_state, ec_state) 184 """ 185 raise NotImplementedError 186 187 def _setup_workflow_metrics(self) -> MetricBase: 188 """Define Workflow metrics.""" 189 return ECWorkflowMetric(best_objective=jnp.finfo(jnp.float32).min) 190
[docs] 191 def setup(self, key: chex.PRNGKey) -> State: 192 key, agent_key = jax.random.split(key, 2) 193 194 # agent_state: store params not optimized by EC (eg: obs_preprocessor_state) 195 agent_state, ec_opt_state = self._setup_agent_and_optimizer(agent_key) 196 workflow_metrics = self._setup_workflow_metrics() 197 distributed_info = DistributedInfo() 198 199 if self.enable_multi_devices: 200 sharding = jax.sharding.NamedSharding( 201 jax.sharding.Mesh(self.devices, (self.dp_axis_name,)), 202 jax.sharding.PartitionSpec(), 203 ) 204 agent_state, ec_opt_state, workflow_metrics = jax.device_put( 205 (agent_state, ec_opt_state, workflow_metrics), sharding 206 ) 207 key = split_key_to_devices(key, self.devices) 208 209 distributed_info = DistributedInfo( 210 rank=get_global_ranks(), 211 world_size=jax.device_count(), 212 ) 213 214 state = State( 215 key=key, 216 agent_state=agent_state, 217 ec_opt_state=ec_opt_state, 218 metrics=workflow_metrics, 219 distributed_info=distributed_info, 220 ) 221 222 state = self._postsetup(state) 223 224 return state
225 226 def _postsetup(self, state: State) -> State: 227 """Post-setup state before training. 228 229 By default, no post-setup is applied 230 """ 231 return state 232 233 def _replace_actor_params( 234 self, agent_state: AgentState, params: Params 235 ) -> AgentState: 236 """Define how to replace the pop agent_state from the population params. 237 238 Args: 239 agent_state: State of the agent. 240 params: Population params. 241 242 Returns: 243 New agent_state with replaced population params. 244 """ 245 raise NotImplementedError 246 247 def _update_obs_preprocessor( 248 self, agent_state: AgentState, trajectory: SampleBatch 249 ) -> AgentState: 250 """Update the obs_preprocessor_state based on sampled trajectories. 251 252 By default, don't update obs_preprocessor_state. 253 254 Args: 255 agent_state: State of the agent 256 trajectory: Episodic trajectory (T, B, ...) 257 """ 258 return agent_state 259 260 def _metrics_to_fitnesses(self, metrics: MetricBase) -> chex.ArrayTree: 261 """Convert the rollout metrics to fitnesses. 262 263 By default, use the mean of episode_returns over multiple episodes as fitnesses. 264 265 Args: 266 metrics: Rollout metrics. 267 268 Returns: 269 Fitnesses of the population. 270 """ 271 return jnp.mean(metrics.episode_returns, axis=-1) 272
[docs] 273 def step(self, state: State) -> tuple[MetricBase, State]: 274 agent_state = state.agent_state 275 key, rollout_key = jax.random.split(state.key, 2) 276 277 pop, ec_opt_state = self.ec_optimizer.ask(state.ec_opt_state) 278 pop_size = jtu.tree_leaves(pop)[0].shape[0] 279 280 slice_size = pop_size // state.distributed_info.world_size 281 eval_pop = jtu.tree_map( 282 lambda x: jax.lax.dynamic_slice_in_dim( 283 x, state.distributed_info.rank * slice_size, slice_size, axis=0 284 ), 285 pop, 286 ) 287 288 pop_agent_state = self._replace_actor_params(agent_state, eval_pop) 289 290 if isinstance(self.ec_evaluator, EpisodeCollector): 291 # trajectory: [#pop, T, #episodes] 292 rollout_metrics, trajectory = jax.vmap( 293 self.ec_evaluator.rollout, 294 in_axes=(self.agent_state_vmap_axes, 0, None), 295 )( 296 pop_agent_state, 297 jax.random.split(rollout_key, num=slice_size), 298 self.config.episodes_for_fitness, 299 ) 300 # [#pop, T, B, ...] -> [T, #pop*B, ...] 301 trajectory = flatten_pop_rollout_episode(trajectory) 302 trajectory = tree_stop_gradient(trajectory) 303 agent_state = self._update_obs_preprocessor(agent_state, trajectory) 304 305 elif isinstance(self.ec_evaluator, Evaluator): 306 rollout_metrics = jax.vmap( 307 self.ec_evaluator.evaluate, 308 in_axes=(self.agent_state_vmap_axes, 0, None), 309 )( 310 pop_agent_state, 311 jax.random.split(rollout_key, num=slice_size), 312 self.config.episodes_for_fitness, 313 ) 314 315 fitnesses = self._metrics_to_fitnesses(rollout_metrics) 316 fitnesses = all_gather(fitnesses, self.dp_axis_name, axis=0, tiled=True) 317 318 ec_metrics, ec_opt_state = self.ec_optimizer.tell(ec_opt_state, fitnesses) 319 320 sampled_episodes = psum( 321 jnp.uint32(pop_size * self.config.episodes_for_fitness), 322 self.dp_axis_name, 323 ) 324 sampled_timesteps_m = ( 325 psum(rollout_metrics.episode_lengths.sum(), self.dp_axis_name) / 1e6 326 ) 327 328 workflow_metrics = state.metrics.replace( 329 sampled_episodes=state.metrics.sampled_episodes + sampled_episodes, 330 sampled_timesteps_m=state.metrics.sampled_timesteps_m + sampled_timesteps_m, 331 iterations=state.metrics.iterations + 1, 332 best_objective=jnp.maximum( 333 state.metrics.best_objective, jnp.max(rollout_metrics.episode_returns) 334 ), 335 ) 336 337 train_metrics = ECTrainMetric( 338 objectives=fitnesses, 339 ec_metrics=ec_metrics, 340 ) 341 342 return train_metrics, state.replace( 343 key=key, 344 agent_state=agent_state, 345 ec_opt_state=ec_opt_state, 346 metrics=workflow_metrics, 347 )
348
[docs] 349 @classmethod 350 def enable_jit(cls) -> None: 351 super().enable_jit() 352 cls._postsetup = jax.jit(cls._postsetup, static_argnums=(0,))
353
[docs] 354 @classmethod 355 def enable_shmap(cls, axis_name) -> None: 356 super().enable_shmap(axis_name) 357 # Note: In the new paradigm, we use shmap_vmap where needed instead of wrappers here. 358 pass
359 360
[docs] 361class MultiObjectiveECWorkflowTemplate(ECWorkflowTemplate): 362 """Workflow template for multi-objective EC algorithms.""" 363 364 def _metrics_to_fitnesses(self, metrics: MetricBase) -> chex.ArrayTree: 365 fitnesses = jnp.stack( 366 [jnp.mean(metrics[k], axis=-1) for k in self.config.metric_names], axis=-1 367 ) 368 if fitnesses.shape[-1] == 1: 369 fitnesses = fitnesses.squeeze(-1) 370 371 return fitnesses 372 373 def _setup_workflow_metrics(self) -> MetricBase: 374 return MultiObjectiveECWorkflowMetric() 375
[docs] 376 def step(self, state: State) -> tuple[MetricBase, State]: 377 agent_state = state.agent_state 378 key, rollout_key = jax.random.split(state.key, 2) 379 380 pop, ec_opt_state = self.ec_optimizer.ask(state.ec_opt_state) 381 pop_size = jtu.tree_leaves(pop)[0].shape[0] 382 383 slice_size = pop_size // state.distributed_info.world_size 384 eval_pop = jtu.tree_map( 385 lambda x: jax.lax.dynamic_slice_in_dim( 386 x, state.distributed_info.rank * slice_size, slice_size, axis=0 387 ), 388 pop, 389 ) 390 391 pop_agent_state = self._replace_actor_params(agent_state, eval_pop) 392 393 if isinstance(self.ec_evaluator, EpisodeCollector): 394 # trajectory: [#pop, T, #episodes] 395 rollout_metrics, trajectory = jax.vmap( 396 self.ec_evaluator.rollout, 397 in_axes=(self.agent_state_vmap_axes, 0, None), 398 )( 399 pop_agent_state, 400 jax.random.split(rollout_key, num=slice_size), 401 self.config.episodes_for_fitness, 402 ) 403 # [#pop, T, B, ...] -> [T, #pop*B, ...] 404 trajectory = flatten_pop_rollout_episode(trajectory) 405 trajectory = tree_stop_gradient(trajectory) 406 agent_state = self._update_obs_preprocessor(agent_state, trajectory) 407 408 elif isinstance(self.ec_evaluator, Evaluator): 409 rollout_metrics = jax.vmap( 410 self.ec_evaluator.evaluate, 411 in_axes=(self.agent_state_vmap_axes, 0, None), 412 )( 413 pop_agent_state, 414 jax.random.split(rollout_key, num=slice_size), 415 self.config.episodes_for_fitness, 416 ) 417 418 fitnesses = self._metrics_to_fitnesses(rollout_metrics) 419 fitnesses = all_gather(fitnesses, self.dp_axis_name, axis=0, tiled=True) 420 421 ec_metrics, ec_opt_state = self.ec_optimizer.tell(ec_opt_state, fitnesses) 422 423 sampled_episodes = psum( 424 jnp.uint32(pop_size * self.config.episodes_for_fitness), 425 self.dp_axis_name, 426 ) 427 sampled_timesteps_m = ( 428 psum(rollout_metrics.episode_lengths.sum(), self.dp_axis_name) / 1e6 429 ) 430 431 workflow_metrics = state.metrics.replace( 432 sampled_episodes=state.metrics.sampled_episodes + sampled_episodes, 433 sampled_timesteps_m=state.metrics.sampled_timesteps_m + sampled_timesteps_m, 434 iterations=state.metrics.iterations + 1, 435 ) 436 437 train_metrics = ECTrainMetric(objectives=fitnesses, ec_metrics=ec_metrics) 438 439 return train_metrics, state.replace( 440 key=key, 441 agent_state=agent_state, 442 ec_opt_state=ec_opt_state, 443 metrics=workflow_metrics, 444 )