Source code for evorl.algorithms.ec.so.ars

  1import logging
  2from omegaconf import DictConfig
  3from typing_extensions import Self  # pytype: disable=not-supported-yet]
  4
  5import jax
  6import jax.tree_util as jtu
  7
  8from evorl.types import State, Params
  9from evorl.envs import AutoresetMode, create_env
 10from evorl.evaluators import Evaluator, EpisodeObsCollector
 11from evorl.sample_batch import SampleBatch
 12from evorl.agent import AgentState
 13from evorl.ec.optimizers import ARS, ECState
 14from evorl.utils import running_statistics
 15
 16from .es_workflow import ESWorkflowTemplate
 17from ..obs_utils import init_obs_preprocessor
 18from ..ec_agent import make_deterministic_ec_agent
 19
 20
 21logger = logging.getLogger(__name__)
 22
 23
[docs] 24class ARSWorkflow(ESWorkflowTemplate):
[docs] 25 @classmethod 26 def name(cls): 27 return "ARS"
28 29 @classmethod 30 def _rescale_config(cls, config: DictConfig) -> None: 31 super()._rescale_config(config) 32 33 num_devices = jax.device_count() 34 if config.random_timesteps % num_devices != 0: 35 logging.warning( 36 f"When enable_multi_devices=True, pop_size ({config.random_timesteps}) should be divisible by num_devices ({num_devices})," 37 ) 38 39 config.random_timesteps = (config.random_timesteps // num_devices) * num_devices 40 41 @classmethod 42 def _build_from_config(cls, config: DictConfig) -> Self: 43 env = create_env( 44 config.env, 45 episode_length=config.env.max_episode_steps, 46 parallel=config.num_envs, 47 autoreset_mode=AutoresetMode.DISABLED, 48 ) 49 50 agent = make_deterministic_ec_agent( 51 action_space=env.action_space, 52 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes, 53 use_bias=config.agent_network.use_bias, 54 normalize_obs=config.normalize_obs, 55 norm_layer_type=config.agent_network.norm_layer_type, 56 policy_obs_key=config.agent_network.policy_obs_key, 57 ) 58 59 ec_optimizer = ARS( 60 pop_size=config.pop_size, 61 num_elites=config.num_elites, 62 lr=config.lr, 63 noise_std=config.noise_std, 64 optimizer_name=config.optimizer_name, 65 ) 66 67 if config.explore: 68 action_fn = agent.compute_actions 69 else: 70 action_fn = agent.evaluate_actions 71 72 assert config.normalize_obs_mode in ["VBN", "RS", "Global"] 73 if config.normalize_obs_mode == "VBN": 74 ec_evaluator = Evaluator( 75 env=env, 76 action_fn=action_fn, 77 max_episode_steps=config.env.max_episode_steps, 78 discount=config.discount, 79 ) 80 else: 81 ec_evaluator = EpisodeObsCollector( 82 env=env, 83 action_fn=action_fn, 84 max_episode_steps=config.env.max_episode_steps, 85 discount=config.discount, 86 ) 87 88 # to evaluate the pop-mean actor 89 eval_env = create_env( 90 config.env, 91 episode_length=config.env.max_episode_steps, 92 parallel=config.num_eval_envs, 93 autoreset_mode=AutoresetMode.DISABLED, 94 ) 95 96 evaluator = Evaluator( 97 env=eval_env, 98 action_fn=agent.evaluate_actions, 99 max_episode_steps=config.env.max_episode_steps, 100 ) 101 102 agent_state_vmap_axes = AgentState( 103 params=0, 104 obs_preprocessor_state=None, 105 ) 106 107 return cls( 108 config=config, 109 env=env, 110 agent=agent, 111 ec_optimizer=ec_optimizer, 112 ec_evaluator=ec_evaluator, 113 evaluator=evaluator, 114 agent_state_vmap_axes=agent_state_vmap_axes, 115 ) 116 117 def _setup_agent_and_optimizer(self, key: jax.Array) -> tuple[AgentState, ECState]: 118 agent_key, ec_key = jax.random.split(key) 119 agent_state = self.agent.init( 120 self.env.obs_space, self.env.action_space, agent_key 121 ) 122 123 init_actor_params = agent_state.params.policy_params 124 ec_opt_state = self.ec_optimizer.init(init_actor_params, ec_key) 125 126 # remove params 127 agent_state = self._replace_actor_params(agent_state, params=None) 128 129 return agent_state, ec_opt_state 130 131 def _postsetup(self, state: State) -> State: 132 # setup obs_preprocessor_state 133 if self.config.normalize_obs and self.config.normalize_obs_mode != "RS": 134 key, obs_key = jax.random.split(state.key, 2) 135 agent_state = init_obs_preprocessor( 136 agent_state=state.agent_state, 137 config=self.config, 138 key=obs_key, 139 dp_axis_name=self.dp_axis_name, 140 ) 141 142 # Note: we don't count these random timesteps in state.metrics 143 return state.replace( 144 agent_state=agent_state, 145 key=key, 146 ) 147 else: 148 return state 149 150 def _replace_actor_params( 151 self, agent_state: AgentState, params: Params 152 ) -> AgentState: 153 return agent_state.replace( 154 params=agent_state.params.replace(policy_params=params) 155 ) 156 157 def _get_pop_center(self, state: State) -> AgentState: 158 pop_center = state.ec_opt_state.mean 159 160 return self._replace_actor_params(state.agent_state, pop_center) 161 162 def _update_obs_preprocessor( 163 self, agent_state: AgentState, trajectory: SampleBatch 164 ) -> AgentState: 165 if self.config.normalize_obs_mode == "Global": 166 obs_preprocessor_state = running_statistics.update( 167 agent_state.obs_preprocessor_state, 168 trajectory.obs, 169 weights=1 - trajectory.dones, 170 dp_axis_name=self.dp_axis_name, 171 ) 172 173 elif self.config.normalize_obs_mode == "RS": 174 dummy_obs = jtu.tree_map(lambda x: x[0, 0], trajectory.obs) 175 obs_preprocessor_state = running_statistics.init_state(dummy_obs) 176 obs_preprocessor_state = running_statistics.update( 177 obs_preprocessor_state, 178 trajectory.obs, 179 weights=1 - trajectory.dones, 180 dp_axis_name=self.dp_axis_name, 181 ) 182 else: 183 obs_preprocessor_state = agent_state.obs_preprocessor_state 184 185 return agent_state.replace(obs_preprocessor_state=obs_preprocessor_state)