Source code for evorl.algorithms.ec.so.sepcem

  1from omegaconf import DictConfig
  2from typing_extensions import Self  # pytype: disable=not-supported-yet]
  3
  4import jax
  5
  6from evorl.types import State, Params
  7from evorl.envs import AutoresetMode, create_env
  8from evorl.evaluators import Evaluator
  9from evorl.agent import AgentState
 10from evorl.ec.optimizers import SepCEM, ExponentialScheduleSpec, ECState
 11
 12from .es_workflow import ESWorkflowTemplate
 13from ..obs_utils import init_obs_preprocessor
 14from ..ec_agent import make_deterministic_ec_agent
 15
 16
[docs] 17class SepCEMWorkflow(ESWorkflowTemplate):
[docs] 18 @classmethod 19 def name(cls): 20 return "SepCEM"
21 22 @classmethod 23 def _build_from_config(cls, config: DictConfig) -> Self: 24 env = create_env( 25 config.env, 26 episode_length=config.env.max_episode_steps, 27 parallel=config.num_envs, 28 autoreset_mode=AutoresetMode.DISABLED, 29 ) 30 31 agent = make_deterministic_ec_agent( 32 action_space=env.action_space, 33 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes, 34 use_bias=config.agent_network.use_bias, 35 normalize_obs=config.normalize_obs, 36 norm_layer_type=config.agent_network.norm_layer_type, 37 policy_obs_key=config.agent_network.policy_obs_key, 38 ) 39 40 ec_optimizer = SepCEM( 41 pop_size=config.pop_size, 42 num_elites=config.num_elites, 43 cov_eps_schedule=ExponentialScheduleSpec(**config.cov_eps), 44 weighted_update=config.weighted_update, 45 rank_weight_shift=config.rank_weight_shift, 46 mirror_sampling=config.mirror_sampling, 47 ) 48 49 if config.explore: 50 action_fn = agent.compute_actions 51 else: 52 action_fn = agent.evaluate_actions 53 54 ec_evaluator = Evaluator( 55 env=env, 56 action_fn=action_fn, 57 max_episode_steps=config.env.max_episode_steps, 58 discount=config.discount, 59 ) 60 61 # to evaluate the pop-mean actor 62 eval_env = create_env( 63 config.env, 64 episode_length=config.env.max_episode_steps, 65 parallel=config.num_eval_envs, 66 autoreset_mode=AutoresetMode.DISABLED, 67 ) 68 69 evaluator = Evaluator( 70 env=eval_env, 71 action_fn=agent.evaluate_actions, 72 max_episode_steps=config.env.max_episode_steps, 73 ) 74 75 agent_state_vmap_axes = AgentState( 76 params=0, 77 obs_preprocessor_state=None, 78 ) 79 80 return cls( 81 config=config, 82 env=env, 83 agent=agent, 84 ec_optimizer=ec_optimizer, 85 ec_evaluator=ec_evaluator, 86 evaluator=evaluator, 87 agent_state_vmap_axes=agent_state_vmap_axes, 88 ) 89 90 def _setup_agent_and_optimizer(self, key: jax.Array) -> tuple[AgentState, ECState]: 91 agent_key, ec_key = jax.random.split(key) 92 agent_state = self.agent.init( 93 self.env.obs_space, self.env.action_space, agent_key 94 ) 95 96 init_actor_params = agent_state.params.policy_params 97 ec_opt_state = self.ec_optimizer.init(init_actor_params, ec_key) 98 99 # remove params 100 agent_state = self._replace_actor_params(agent_state, params=None) 101 102 return agent_state, ec_opt_state 103 104 def _postsetup(self, state: State) -> State: 105 # setup obs_preprocessor_state 106 if self.config.normalize_obs: 107 key, obs_key = jax.random.split(state.key, 2) 108 agent_state = init_obs_preprocessor( 109 agent_state=state.agent_state, 110 config=self.config, 111 key=obs_key, 112 dp_axis_name=self.dp_axis_name, 113 ) 114 115 # Note: we don't count these random timesteps in state.metrics 116 return state.replace( 117 agent_state=agent_state, 118 key=key, 119 ) 120 else: 121 return state 122 123 def _replace_actor_params( 124 self, agent_state: AgentState, params: Params 125 ) -> AgentState: 126 return agent_state.replace( 127 params=agent_state.params.replace(policy_params=params) 128 ) 129 130 def _get_pop_center(self, state: State) -> AgentState: 131 pop_center = state.ec_opt_state.mean 132 133 return self._replace_actor_params(state.agent_state, pop_center)