Source code for evorl.algorithms.contrib.openes_noise_table

  1import logging
  2from omegaconf import DictConfig
  3from typing_extensions import Self  # pytype: disable=not-supported-yet]
  4
  5import jax
  6
  7from evorl.types import State, Params
  8from evorl.envs import AutoresetMode, create_env
  9from evorl.evaluators import Evaluator
 10from evorl.agent import AgentState
 11from evorl.ec.optimizers import OpenESNoiseTable, ExponentialScheduleSpec, ECState
 12
 13
 14from evorl.algorithms.ec.so.es_workflow import ESWorkflowTemplate
 15from evorl.algorithms.ec.obs_utils import init_obs_preprocessor
 16from evorl.algorithms.ec.ec_agent import make_deterministic_ec_agent
 17
 18
 19logger = logging.getLogger(__name__)
 20
 21
[docs] 22class OpenESWorkflow(ESWorkflowTemplate):
[docs] 23 @classmethod 24 def name(cls): 25 return "OpenES-NoiseTable"
26 27 @classmethod 28 def _rescale_config(cls, config: DictConfig) -> None: 29 super()._rescale_config(config) 30 31 num_devices = jax.device_count() 32 if config.random_timesteps % num_devices != 0: 33 logger.warning( 34 f"When enable_multi_devices=True, pop_size ({config.random_timesteps}) should be divisible by num_devices ({num_devices})," 35 ) 36 37 config.random_timesteps = (config.random_timesteps // num_devices) * num_devices 38 39 @classmethod 40 def _build_from_config(cls, config: DictConfig) -> Self: 41 env = create_env( 42 config.env, 43 episode_length=config.env.max_episode_steps, 44 parallel=config.num_envs, 45 autoreset_mode=AutoresetMode.DISABLED, 46 ) 47 48 agent = make_deterministic_ec_agent( 49 action_space=env.action_space, 50 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes, 51 use_bias=config.agent_network.use_bias, 52 normalize_obs=config.normalize_obs, 53 norm_layer_type=config.agent_network.norm_layer_type, 54 ) 55 56 ec_optimizer = OpenESNoiseTable( 57 pop_size=config.pop_size, 58 noise_table_size=config.noise_table_size, 59 lr_schedule=ExponentialScheduleSpec(**config.ec_lr), 60 noise_std_schedule=ExponentialScheduleSpec(**config.ec_noise_std), 61 mirror_sampling=config.mirror_sampling, 62 weight_decay=config.weight_decay, 63 optimizer_name=config.optimizer_name, 64 ) 65 66 if config.explore: 67 action_fn = agent.compute_actions 68 else: 69 action_fn = agent.evaluate_actions 70 71 ec_evaluator = Evaluator( 72 env=env, 73 action_fn=action_fn, 74 max_episode_steps=config.env.max_episode_steps, 75 discount=config.discount, 76 ) 77 78 # to evaluate the pop-mean actor 79 eval_env = create_env( 80 config.env, 81 episode_length=config.env.max_episode_steps, 82 parallel=config.num_eval_envs, 83 autoreset_mode=AutoresetMode.DISABLED, 84 ) 85 86 evaluator = Evaluator( 87 env=eval_env, 88 action_fn=agent.evaluate_actions, 89 max_episode_steps=config.env.max_episode_steps, 90 ) 91 92 agent_state_vmap_axes = AgentState( 93 params=0, 94 obs_preprocessor_state=None, 95 ) 96 97 return cls( 98 config=config, 99 env=env, 100 agent=agent, 101 ec_optimizer=ec_optimizer, 102 ec_evaluator=ec_evaluator, 103 evaluator=evaluator, 104 agent_state_vmap_axes=agent_state_vmap_axes, 105 ) 106 107 def _setup_agent_and_optimizer(self, key: jax.Array) -> tuple[AgentState, ECState]: 108 agent_key, ec_key = jax.random.split(key) 109 agent_state = self.agent.init( 110 self.env.obs_space, self.env.action_space, agent_key 111 ) 112 113 init_actor_params = agent_state.params.policy_params 114 ec_opt_state = self.ec_optimizer.init(init_actor_params, ec_key) 115 116 # remove params 117 agent_state = self._replace_actor_params(agent_state, params=None) 118 119 # add shared noise table 120 121 return agent_state, ec_opt_state 122 123 def _postsetup(self, state: State) -> State: 124 # setup obs_preprocessor_state 125 if self.config.normalize_obs: 126 key, obs_key = jax.random.split(state.key, 2) 127 agent_state = init_obs_preprocessor( 128 agent_state=state.agent_state, 129 config=self.config, 130 key=obs_key, 131 dp_axis_name=self.dp_axis_name, 132 ) 133 134 # Note: we don't count these random timesteps in state.metrics 135 return state.replace( 136 agent_state=agent_state, 137 key=key, 138 ) 139 else: 140 return state 141 142 def _replace_actor_params( 143 self, agent_state: AgentState, params: Params 144 ) -> AgentState: 145 return agent_state.replace( 146 params=agent_state.params.replace(policy_params=params) 147 ) 148 149 def _get_pop_center(self, state: State) -> AgentState: 150 pop_center = state.ec_opt_state.mean 151 152 return self._replace_actor_params(state.agent_state, pop_center)