Source code for evorl.algorithms.ec.so.cmaes

  1import logging
  2import numpy as np
  3from omegaconf import DictConfig
  4from typing_extensions import Self  # pytype: disable=not-supported-yet]
  5
  6import jax
  7import jax.numpy as jnp
  8import jax.tree_util as jtu
  9
 10from evorl.types import State, Params
 11from evorl.envs import AutoresetMode, create_env
 12from evorl.evaluators import Evaluator
 13from evorl.agent import AgentState
 14from evorl.ec.optimizers.ec_optimizer import ECState
 15from evorl.utils.ec_utils import ParamVectorSpec
 16
 17from .es_workflow import ESWorkflowTemplate
 18from ..obs_utils import init_obs_preprocessor
 19from ..ec_agent import make_deterministic_ec_agent
 20
 21logger = logging.getLogger(__name__)
 22
 23
[docs] 24class CMAESWorkflow(ESWorkflowTemplate):
[docs] 25 @classmethod 26 def name(cls): 27 return "CMAES"
28 29 @classmethod 30 def _rescale_config(cls, config: DictConfig) -> None: 31 super()._rescale_config(config) 32 33 num_devices = jax.device_count() 34 if config.random_timesteps % num_devices != 0: 35 logging.warning( 36 f"When enable_multi_devices=True, pop_size ({config.random_timesteps}) should be divisible by num_devices ({num_devices})," 37 ) 38 39 config.random_timesteps = (config.random_timesteps // num_devices) * num_devices 40 41 @classmethod 42 def _build_from_config(cls, config: DictConfig) -> Self: 43 env = create_env( 44 config.env, 45 episode_length=config.env.max_episode_steps, 46 parallel=config.num_envs, 47 autoreset_mode=AutoresetMode.DISABLED, 48 ) 49 50 agent = make_deterministic_ec_agent( 51 action_space=env.action_space, 52 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes, 53 use_bias=config.agent_network.use_bias, 54 normalize_obs=config.normalize_obs, 55 norm_layer_type=config.agent_network.norm_layer_type, 56 policy_obs_key=config.agent_network.policy_obs_key, 57 ) 58 59 # dummy agent_state 60 agent_key = jax.random.PRNGKey(config.seed) 61 agent_state = agent.init(env.obs_space, env.action_space, agent_key) 62 param_vec_spec = ParamVectorSpec(agent_state.params.policy_params) 63 64 from evorl.ec.optimizers.evox_wrapper import EvoXAlgorithmAdapter 65 from evorl.ec.evox_algorithm import CMAES 66 67 ec_optimizer = EvoXAlgorithmAdapter( 68 algorithm=CMAES( 69 center_init=param_vec_spec.to_vector(agent_state.params.policy_params), 70 init_stdev=config.init_stdev, 71 pop_size=config.pop_size, 72 mu=config.num_elites, 73 ), 74 param_vec_spec=param_vec_spec, 75 ) 76 77 if config.explore: 78 action_fn = agent.compute_actions 79 else: 80 action_fn = agent.evaluate_actions 81 82 ec_evaluator = Evaluator( 83 env=env, 84 action_fn=action_fn, 85 max_episode_steps=config.env.max_episode_steps, 86 discount=config.discount, 87 ) 88 89 # to evaluate the pop-mean actor 90 eval_env = create_env( 91 config.env, 92 episode_length=config.env.max_episode_steps, 93 parallel=config.num_eval_envs, 94 autoreset_mode=AutoresetMode.DISABLED, 95 ) 96 97 evaluator = Evaluator( 98 env=eval_env, 99 action_fn=agent.evaluate_actions, 100 max_episode_steps=config.env.max_episode_steps, 101 ) 102 103 agent_state_vmap_axes = AgentState( 104 params=0, 105 obs_preprocessor_state=None, 106 ) 107 108 return cls( 109 config=config, 110 env=env, 111 agent=agent, 112 ec_optimizer=ec_optimizer, 113 ec_evaluator=ec_evaluator, 114 evaluator=evaluator, 115 agent_state_vmap_axes=agent_state_vmap_axes, 116 ) 117 118 def _setup_agent_and_optimizer(self, key: jax.Array) -> tuple[AgentState, ECState]: 119 agent_key, ec_key = jax.random.split(key) 120 agent_state = self.agent.init( 121 self.env.obs_space, self.env.action_space, agent_key 122 ) 123 124 ec_opt_state = self.ec_optimizer.init(ec_key) 125 126 # remove params 127 agent_state = self._replace_actor_params(agent_state, params=None) 128 129 return agent_state, ec_opt_state 130 131 def _postsetup(self, state: State) -> State: 132 # setup obs_preprocessor_state 133 if self.config.normalize_obs: 134 key, obs_key = jax.random.split(state.key, 2) 135 agent_state = init_obs_preprocessor( 136 agent_state=state.agent_state, 137 config=self.config, 138 key=obs_key, 139 dp_axis_name=self.dp_axis_name, 140 ) 141 142 # Note: we don't count these random timesteps in state.metrics 143 return state.replace( 144 agent_state=agent_state, 145 key=key, 146 ) 147 else: 148 return state 149 150 def _replace_actor_params( 151 self, agent_state: AgentState, params: Params 152 ) -> AgentState: 153 return agent_state.replace( 154 params=agent_state.params.replace(policy_params=params) 155 ) 156 157 def _get_pop_center(self, state: State) -> AgentState: 158 flat_pop_center = state.ec_opt_state.algo_state.mean 159 160 pop_center = self.ec_optimizer.param_vec_spec.to_tree(flat_pop_center) 161 162 return self._replace_actor_params(state.agent_state, pop_center) 163 164 def _record_callback(self, state: State, iters: int) -> None: 165 algo_state = state.ec_opt_state.algo_state 166 C = algo_state.C 167 std = jnp.sqrt(jnp.diagonal(C)) * algo_state.sigma 168 169 std = self.ec_optimizer.param_vec_spec.to_tree(std) 170 std_statistics = _get_std_statistics(std) 171 self.recorder.write({"ec/std": std_statistics}, iters) 172 self.recorder.write({"ec/sigma": algo_state.sigma.tolist()}, iters)
173 174
[docs] 175class SepCMAESWorkflow(CMAESWorkflow):
[docs] 176 @classmethod 177 def name(cls): 178 return "SepCMAES"
179 180 @classmethod 181 def _build_from_config(cls, config: DictConfig) -> Self: 182 env = create_env( 183 config.env, 184 episode_length=config.env.max_episode_steps, 185 parallel=config.num_envs, 186 autoreset_mode=AutoresetMode.DISABLED, 187 ) 188 189 agent = make_deterministic_ec_agent( 190 action_space=env.action_space, 191 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes, 192 use_bias=config.agent_network.use_bias, 193 normalize_obs=config.normalize_obs, 194 norm_layer_type=config.agent_network.norm_layer_type, 195 policy_obs_key=config.agent_network.policy_obs_key, 196 ) 197 198 # dummy agent_state 199 agent_key = jax.random.PRNGKey(config.seed) 200 agent_state = agent.init(env.obs_space, env.action_space, agent_key) 201 param_vec_spec = ParamVectorSpec(agent_state.params.policy_params) 202 203 from evorl.ec.optimizers.evox_wrapper import EvoXAlgorithmAdapter 204 from evorl.ec.evox_algorithm import SepCMAES 205 206 ec_optimizer = EvoXAlgorithmAdapter( 207 algorithm=SepCMAES( 208 center_init=param_vec_spec.to_vector(agent_state.params.policy_params), 209 init_stdev=config.init_stdev, 210 pop_size=config.pop_size, 211 mu=config.num_elites, 212 ), 213 param_vec_spec=param_vec_spec, 214 ) 215 216 if config.explore: 217 action_fn = agent.compute_actions 218 else: 219 action_fn = agent.evaluate_actions 220 221 ec_evaluator = Evaluator( 222 env=env, 223 action_fn=action_fn, 224 max_episode_steps=config.env.max_episode_steps, 225 discount=config.discount, 226 ) 227 228 # to evaluate the pop-mean actor 229 eval_env = create_env( 230 config.env, 231 episode_length=config.env.max_episode_steps, 232 parallel=config.num_eval_envs, 233 autoreset_mode=AutoresetMode.DISABLED, 234 ) 235 236 evaluator = Evaluator( 237 env=eval_env, 238 action_fn=agent.evaluate_actions, 239 max_episode_steps=config.env.max_episode_steps, 240 ) 241 242 agent_state_vmap_axes = AgentState( 243 params=0, 244 obs_preprocessor_state=None, 245 ) 246 247 return cls( 248 config=config, 249 env=env, 250 agent=agent, 251 ec_optimizer=ec_optimizer, 252 ec_evaluator=ec_evaluator, 253 evaluator=evaluator, 254 agent_state_vmap_axes=agent_state_vmap_axes, 255 ) 256 257 def _record_callback(self, state: State, iters: int) -> None: 258 algo_state = state.ec_opt_state.algo_state 259 C = algo_state.C 260 std = jnp.sqrt(C) * algo_state.sigma 261 262 std = self.ec_optimizer.param_vec_spec.to_tree(std) 263 std_statistics = _get_std_statistics(std) 264 self.recorder.write({"ec/std": std_statistics}, iters) 265 self.recorder.write({"ec/sigma": algo_state.sigma.tolist()}, iters)
266 267 268def _get_std_statistics(variance): 269 def _get_stats(x): 270 x = np.asarray(x) 271 return dict( 272 min=np.min(x).tolist(), 273 max=np.max(x).tolist(), 274 mean=np.mean(x).tolist(), 275 ) 276 277 return jtu.tree_map(_get_stats, variance)