Source code for evorl.algorithms.ec.so.vanilla_es

  1import logging
  2from omegaconf import DictConfig
  3from typing_extensions import Self  # pytype: disable=not-supported-yet]
  4
  5import jax
  6
  7from evorl.types import State, Params
  8from evorl.envs import AutoresetMode, create_env
  9from evorl.evaluators import Evaluator
 10from evorl.agent import AgentState
 11from evorl.ec.optimizers import VanillaES, ExponentialScheduleSpec, ECState
 12
 13
 14from .es_workflow import ESWorkflowTemplate
 15from ..obs_utils import init_obs_preprocessor
 16from ..ec_agent import make_deterministic_ec_agent
 17
 18
 19logger = logging.getLogger(__name__)
 20
 21
[docs] 22class VanillaESWorkflow(ESWorkflowTemplate):
[docs] 23 @classmethod 24 def name(cls): 25 return "VanillaES"
26 27 @classmethod 28 def _rescale_config(cls, config: DictConfig) -> None: 29 super()._rescale_config(config) 30 31 num_devices = jax.device_count() 32 if config.random_timesteps % num_devices != 0: 33 logger.warning( 34 f"When enable_multi_devices=True, pop_size ({config.random_timesteps}) should be divisible by num_devices ({num_devices})," 35 ) 36 37 config.random_timesteps = (config.random_timesteps // num_devices) * num_devices 38 39 @classmethod 40 def _build_from_config(cls, config: DictConfig) -> Self: 41 env = create_env( 42 config.env, 43 episode_length=config.env.max_episode_steps, 44 parallel=config.num_envs, 45 autoreset_mode=AutoresetMode.DISABLED, 46 ) 47 48 agent = make_deterministic_ec_agent( 49 action_space=env.action_space, 50 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes, 51 use_bias=config.agent_network.use_bias, 52 normalize_obs=config.normalize_obs, 53 norm_layer_type=config.agent_network.norm_layer_type, 54 policy_obs_key=config.agent_network.policy_obs_key, 55 ) 56 57 ec_optimizer = VanillaES( 58 pop_size=config.pop_size, 59 num_elites=config.num_elites, 60 noise_std_schedule=ExponentialScheduleSpec(**config.ec_noise_std), 61 ) 62 63 if config.explore: 64 action_fn = agent.compute_actions 65 else: 66 action_fn = agent.evaluate_actions 67 68 ec_evaluator = Evaluator( 69 env=env, 70 action_fn=action_fn, 71 max_episode_steps=config.env.max_episode_steps, 72 discount=config.discount, 73 ) 74 75 # to evaluate the pop-mean actor 76 eval_env = create_env( 77 config.env, 78 episode_length=config.env.max_episode_steps, 79 parallel=config.num_eval_envs, 80 autoreset_mode=AutoresetMode.DISABLED, 81 ) 82 83 evaluator = Evaluator( 84 env=eval_env, 85 action_fn=agent.evaluate_actions, 86 max_episode_steps=config.env.max_episode_steps, 87 ) 88 89 agent_state_vmap_axes = AgentState( 90 params=0, 91 obs_preprocessor_state=None, 92 ) 93 94 return cls( 95 config=config, 96 env=env, 97 agent=agent, 98 ec_optimizer=ec_optimizer, 99 ec_evaluator=ec_evaluator, 100 evaluator=evaluator, 101 agent_state_vmap_axes=agent_state_vmap_axes, 102 ) 103 104 def _setup_agent_and_optimizer(self, key: jax.Array) -> tuple[AgentState, ECState]: 105 agent_key, ec_key = jax.random.split(key) 106 agent_state = self.agent.init( 107 self.env.obs_space, self.env.action_space, agent_key 108 ) 109 110 init_actor_params = agent_state.params.policy_params 111 ec_opt_state = self.ec_optimizer.init(init_actor_params, ec_key) 112 113 # remove params 114 agent_state = self._replace_actor_params(agent_state, params=None) 115 116 return agent_state, ec_opt_state 117 118 def _postsetup(self, state: State) -> State: 119 # setup obs_preprocessor_state 120 if self.config.normalize_obs: 121 key, obs_key = jax.random.split(state.key, 2) 122 agent_state = init_obs_preprocessor( 123 agent_state=state.agent_state, 124 config=self.config, 125 key=obs_key, 126 dp_axis_name=self.dp_axis_name, 127 ) 128 129 # Note: we don't count these random timesteps in state.metrics 130 return state.replace( 131 agent_state=agent_state, 132 key=key, 133 ) 134 else: 135 return state 136 137 def _replace_actor_params( 138 self, agent_state: AgentState, params: Params 139 ) -> AgentState: 140 return agent_state.replace( 141 params=agent_state.params.replace(policy_params=params) 142 ) 143 144 def _get_pop_center(self, state: State) -> AgentState: 145 pop_center = state.ec_opt_state.mean 146 147 return self._replace_actor_params(state.agent_state, pop_center)