Source code for evorl.algorithms.erl.cemrl_td3.cemrl_openes

  1import math
  2from omegaconf import DictConfig
  3
  4import chex
  5import jax
  6import jax.numpy as jnp
  7import jax.tree_util as jtu
  8import optax
  9
 10from evorl.replay_buffers import ReplayBuffer
 11from evorl.metrics import MetricBase
 12from evorl.types import State, PyTreeDict
 13from evorl.evaluators import Evaluator, EpisodeCollector
 14from evorl.agent import AgentState
 15from evorl.types import Params
 16from evorl.envs import create_env, AutoresetMode
 17from evorl.recorders import get_1d_array_statistics, add_prefix
 18from evorl.ec.optimizers import OpenES, ExponentialScheduleSpec, ECState
 19from evorl.utils.jax_utils import tree_set, tree_get
 20from evorl.algorithms.td3 import make_mlp_td3_agent, TD3NetworkParams
 21from evorl.algorithms.offpolicy_utils import skip_replay_buffer_state
 22
 23from .cemrl_td3_workflow import (
 24    create_dummy_td3_trainmetric,
 25    cemrl_replace_td3_actor_params,
 26    CEMRLTD3WorkflowTemplate,
 27)
 28from ..cemrl_workflow import CEMRLTrainMetric
 29
 30
[docs] 31class CEMRLOpenESWorkflow(CEMRLTD3WorkflowTemplate): 32 """1 critic + n actors + 1 replay buffer. 33 34 We use shard_map to split and parallel the population. 35 """ 36
[docs] 37 @classmethod 38 def name(cls): 39 return "CEMRL-OpenES"
40 41 @classmethod 42 def _build_from_config(cls, config: DictConfig): 43 assert config.warmup_iters > 0 or config.random_timesteps > 0, ( 44 "Either warmup_iters or random_timesteps should be positive to pre-fill some data in the replay buffer" 45 ) 46 47 # env for one actor 48 env = create_env( 49 config.env, 50 episode_length=config.env.max_episode_steps, 51 parallel=config.num_envs, 52 autoreset_mode=AutoresetMode.DISABLED, 53 record_ori_obs=True, 54 ) 55 56 agent = make_mlp_td3_agent( 57 action_space=env.action_space, 58 norm_layer_type=config.agent_network.norm_layer_type, 59 num_critics=config.agent_network.num_critics, 60 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes, 61 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes, 62 discount=config.discount, 63 exploration_epsilon=config.exploration_epsilon, 64 policy_noise=config.policy_noise, 65 clip_policy_noise=config.clip_policy_noise, 66 critics_in_actor_loss=config.critics_in_actor_loss, 67 normalize_obs=config.normalize_obs, 68 ) 69 70 if ( 71 config.optimizer.grad_clip_norm is not None 72 and config.optimizer.grad_clip_norm > 0 73 ): 74 optimizer = optax.chain( 75 optax.clip_by_global_norm(config.optimizer.grad_clip_norm), 76 optax.adam(config.optimizer.lr), 77 ) 78 else: 79 optimizer = optax.adam(config.optimizer.lr) 80 81 ec_optimizer = OpenES( 82 pop_size=config.pop_size, 83 lr_schedule=ExponentialScheduleSpec(**config.ec_lr), 84 noise_std_schedule=ExponentialScheduleSpec(**config.ec_noise_std), 85 mirror_sampling=config.mirror_sampling, 86 ) 87 88 if config.fitness_with_exploration: 89 action_fn = agent.compute_actions 90 else: 91 action_fn = agent.evaluate_actions 92 93 collector = EpisodeCollector( 94 env=env, 95 action_fn=action_fn, 96 max_episode_steps=config.env.max_episode_steps, 97 env_extra_fields=("ori_obs", "termination"), 98 ) 99 100 replay_buffer = ReplayBuffer( 101 capacity=config.replay_buffer_capacity, 102 min_sample_timesteps=config.batch_size, 103 sample_batch_size=config.batch_size, 104 ) 105 106 # to evaluate the pop-mean actor 107 eval_env = create_env( 108 config.env, 109 episode_length=config.env.max_episode_steps, 110 parallel=config.num_eval_envs, 111 autoreset_mode=AutoresetMode.DISABLED, 112 ) 113 114 evaluator = Evaluator( 115 env=eval_env, 116 action_fn=agent.evaluate_actions, 117 max_episode_steps=config.env.max_episode_steps, 118 ) 119 120 agent_state_vmap_axes = AgentState( 121 params=TD3NetworkParams( 122 critic_params=None, 123 actor_params=0, 124 target_critic_params=None, 125 target_actor_params=0, 126 ), 127 obs_preprocessor_state=None, 128 ) 129 130 workflow = cls( 131 env=env, 132 agent=agent, 133 agent_state_vmap_axes=agent_state_vmap_axes, 134 optimizer=optimizer, 135 ec_optimizer=ec_optimizer, 136 collector=collector, 137 evaluator=evaluator, 138 replay_buffer=replay_buffer, 139 config=config, 140 ) 141 142 return workflow 143 144 def _setup_agent_and_optimizer( 145 self, key: chex.PRNGKey 146 ) -> tuple[AgentState, chex.ArrayTree, ECState]: 147 agent_key, ec_key = jax.random.split(key) 148 149 # one actor + one critic 150 agent_state = self.agent.init( 151 self.env.obs_space, self.env.action_space, agent_key 152 ) 153 154 init_actor_params = agent_state.params.actor_params 155 ec_opt_state = self.ec_optimizer.init(init_actor_params, ec_key) 156 157 agent_state = cemrl_replace_td3_actor_params(agent_state, pop_actor_params=None) 158 159 opt_state = PyTreeDict( 160 # Note: we create and drop the actors' opt_state at every step 161 critic=self.optimizer.init(agent_state.params.critic_params), 162 actor=None, 163 ) 164 165 return agent_state, opt_state, ec_opt_state 166 167 def _rl_injection( 168 self, ec_opt_state: ECState, pop: Params, external_indices 169 ) -> ECState: 170 external_noise = jtu.tree_map( 171 lambda x, m: (x - m) / ec_opt_state.noise_std, 172 tree_get(pop, external_indices), 173 ec_opt_state.mean, 174 ) 175 noise = tree_set( 176 ec_opt_state.noise, 177 external_noise, 178 external_indices, 179 unique_indices=True, 180 ) 181 182 return ec_opt_state.replace(noise=noise) 183
[docs] 184 def step(self, state: State) -> tuple[MetricBase, State]: 185 pop_size = self.config.pop_size 186 agent_state = state.agent_state 187 opt_state = state.opt_state 188 ec_opt_state = state.ec_opt_state 189 replay_buffer_state = state.replay_buffer_state 190 iterations = state.metrics.iterations + 1 191 192 pop_actor_params = agent_state.params.actor_params 193 194 key, rollout_key, perm_key, learn_key = jax.random.split(state.key, num=4) 195 196 # ======= CEM Sample ======== 197 pop_actor_params, ec_opt_state = self.ec_optimizer.ask(ec_opt_state) 198 199 # ======== RL update ======== 200 learning_actor_indices = jax.random.choice( 201 perm_key, 202 self.config.pop_size, 203 (self.config.num_learning_offspring,), 204 replace=False, 205 ) 206 207 def _rl_update(agent_state, opt_state, pop_actor_params): 208 learning_actor_params = tree_get(pop_actor_params, learning_actor_indices) 209 learning_agent_state = cemrl_replace_td3_actor_params( 210 agent_state, learning_actor_params 211 ) 212 213 # reset actors' opt_state 214 learning_opt_state = opt_state.replace( 215 actor=self.optimizer.init(learning_actor_params), 216 ) 217 218 td3_metrics, learning_agent_state, learning_opt_state = self._rl_update( 219 learning_agent_state, 220 learning_opt_state, 221 replay_buffer_state, 222 learn_key, 223 ) 224 225 pop_actor_params = tree_set( 226 pop_actor_params, 227 learning_agent_state.params.actor_params, 228 learning_actor_indices, 229 unique_indices=True, 230 ) 231 # drop the actors and their opt_state 232 agent_state = cemrl_replace_td3_actor_params( 233 learning_agent_state, pop_actor_params=None 234 ) 235 opt_state = learning_opt_state.replace(actor=None) 236 return td3_metrics, pop_actor_params, agent_state, opt_state 237 238 def _dummy_rl_update(agent_state, opt_state, pop_actor_params): 239 return ( 240 create_dummy_td3_trainmetric(self.config.num_learning_offspring), 241 pop_actor_params, 242 agent_state, 243 opt_state, 244 ) 245 246 td3_metrics, pop_actor_params, agent_state, opt_state = jax.lax.cond( 247 iterations > self.config.warmup_iters, 248 _rl_update, 249 _dummy_rl_update, 250 agent_state, 251 opt_state, 252 pop_actor_params, 253 ) 254 255 # ======== CEM update ======== 256 pop_agent_state = cemrl_replace_td3_actor_params(agent_state, pop_actor_params) 257 258 # the trajectory [T, #pop*B, ...] 259 # metrics: [#pop, B] 260 eval_metrics, trajectory, replay_buffer_state = self._rollout( 261 pop_agent_state, replay_buffer_state, rollout_key 262 ) 263 264 fitnesses = eval_metrics.episode_returns.mean(axis=-1) 265 266 ec_opt_state = jax.lax.cond( 267 iterations > self.config.warmup_iters, 268 self._rl_injection, 269 lambda ec_opt_state, pop, external_indices: ec_opt_state, 270 ec_opt_state, 271 pop_actor_params, 272 learning_actor_indices, 273 ) 274 ec_metrics, ec_opt_state = self.ec_optimizer.tell(ec_opt_state, fitnesses) 275 276 train_metrics = CEMRLTrainMetric( 277 rb_size=replay_buffer_state.buffer_size, 278 pop_episode_lengths=eval_metrics.episode_lengths.mean(-1), 279 pop_episode_returns=eval_metrics.episode_returns.mean(-1), 280 rl_metrics=td3_metrics, 281 ec_info=ec_metrics, 282 ) 283 284 # calculate the number of timestep 285 sampled_timesteps = eval_metrics.episode_lengths.sum().astype(jnp.uint32) 286 sampled_episodes = jnp.uint32(self.config.episodes_for_fitness * pop_size) 287 288 workflow_metrics = state.metrics.replace( 289 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps, 290 sampled_episodes=state.metrics.sampled_episodes + sampled_episodes, 291 iterations=iterations, 292 ) 293 294 state = state.replace( 295 key=key, 296 metrics=workflow_metrics, 297 agent_state=agent_state, 298 replay_buffer_state=replay_buffer_state, 299 ec_opt_state=ec_opt_state, 300 opt_state=opt_state, 301 ) 302 303 return train_metrics, state
304
[docs] 305 def learn(self, state: State) -> State: 306 num_iters = math.ceil( 307 (self.config.total_episodes - state.metrics.sampled_episodes) 308 / (self.config.episodes_for_fitness * self.config.pop_size) 309 ) 310 311 final_iteration = num_iters + state.metrics.iterations 312 for i in range(state.metrics.iterations, final_iteration): 313 iters = i + 1 314 train_metrics, state = self.step(state) 315 workflow_metrics = state.metrics 316 317 workflow_metrics_dict = workflow_metrics.to_local_dict() 318 self.recorder.write(workflow_metrics_dict, iters) 319 320 train_metrics_dict = train_metrics.to_local_dict() 321 train_metrics_dict["pop_episode_returns"] = get_1d_array_statistics( 322 train_metrics_dict["pop_episode_returns"], histogram=True 323 ) 324 325 train_metrics_dict["pop_episode_lengths"] = get_1d_array_statistics( 326 train_metrics_dict["pop_episode_lengths"], histogram=True 327 ) 328 329 if train_metrics_dict["rl_metrics"] is not None: 330 train_metrics_dict["rl_metrics"]["raw_loss_dict"] = jtu.tree_map( 331 get_1d_array_statistics, 332 train_metrics_dict["rl_metrics"]["raw_loss_dict"], 333 ) 334 335 self.recorder.write(train_metrics_dict, iters) 336 337 if iters % self.config.eval_interval == 0 or iters == final_iteration: 338 eval_metrics, state = self.evaluate(state) 339 340 self.recorder.write( 341 add_prefix(eval_metrics.to_local_dict(), "eval"), iters 342 ) 343 344 saved_state = state 345 if not self.config.save_replay_buffer: 346 saved_state = skip_replay_buffer_state(saved_state) 347 348 self.checkpoint_manager.save( 349 iters, 350 saved_state, 351 force=iters == final_iteration, 352 ) 353 354 return state