Source code for evorl.algorithms.contrib.ars_linear

  1import logging
  2from omegaconf import DictConfig
  3from collections.abc import Sequence
  4from typing_extensions import Self  # pytype: disable=not-supported-yet]
  5
  6import jax
  7import jax.numpy as jnp
  8import jax.tree_util as jtu
  9import flax.linen as nn
 10
 11from evorl.types import State, Params
 12from evorl.envs import AutoresetMode, create_env, Space, Box
 13from evorl.evaluators import Evaluator, EpisodeObsCollector
 14from evorl.sample_batch import SampleBatch
 15from evorl.agent import AgentState
 16from evorl.ec.optimizers import ARS, ECState
 17from evorl.utils import running_statistics
 18from evorl.networks import make_mlp, ActivationFn
 19
 20from evorl.algorithms.ec.so.es_workflow import ESWorkflowTemplate
 21from evorl.algorithms.ec.obs_utils import init_obs_preprocessor
 22from evorl.algorithms.ec.ec_agent import DeterministicECAgent
 23
 24
 25logger = logging.getLogger(__name__)
 26
 27
[docs] 28def make_policy_network( 29 action_size: int, 30 hidden_layer_sizes: Sequence[int] = (256, 256), 31 use_bias: bool = True, 32 activation: ActivationFn = nn.relu, 33 activation_final: ActivationFn | None = None, 34 norm_layer_type: str = "none", 35) -> nn.Module: 36 """Creates a policy network.""" 37 38 class Policy(nn.Module): 39 @nn.compact 40 def __call__(self, x): 41 policy_model = make_mlp( 42 layer_sizes=tuple(hidden_layer_sizes) + (action_size,), 43 activation=activation, 44 # kernel_init=jax.nn.initializers.lecun_uniform(), 45 kernel_init=jax.nn.initializers.zeros, 46 activation_final=activation_final, 47 use_bias=use_bias, 48 norm_layer_type=norm_layer_type, 49 ) 50 actions = policy_model(x) 51 return jnp.clip(actions, -1.0, 1.0) 52 53 return Policy()
54 55
[docs] 56def make_deterministic_ec_agent( 57 action_space: Space, 58 actor_hidden_layer_sizes: tuple[int] = (256, 256), 59 use_bias: bool = True, 60 norm_layer_type: str = "none", 61 normalize_obs: bool = False, 62): 63 assert isinstance(action_space, Box), "Only continue action space is supported." 64 65 action_size = action_space.shape[0] 66 67 policy_network = make_policy_network( 68 action_size=action_size, 69 hidden_layer_sizes=actor_hidden_layer_sizes, 70 use_bias=use_bias, 71 activation_final=None, 72 norm_layer_type=norm_layer_type, 73 ) 74 75 if normalize_obs: 76 obs_preprocessor = running_statistics.normalize 77 else: 78 obs_preprocessor = None 79 80 return DeterministicECAgent( 81 policy_network=policy_network, 82 obs_preprocessor=obs_preprocessor, 83 )
84 85
[docs] 86class ARSWorkflow(ESWorkflowTemplate):
[docs] 87 @classmethod 88 def name(cls): 89 return "ARS"
90 91 @classmethod 92 def _rescale_config(cls, config: DictConfig) -> None: 93 super()._rescale_config(config) 94 95 num_devices = jax.device_count() 96 if config.random_timesteps % num_devices != 0: 97 logger.warning( 98 f"When enable_multi_devices=True, pop_size ({config.random_timesteps}) should be divisible by num_devices ({num_devices})," 99 ) 100 101 config.random_timesteps = (config.random_timesteps // num_devices) * num_devices 102 103 @classmethod 104 def _build_from_config(cls, config: DictConfig) -> Self: 105 env = create_env( 106 config.env, 107 episode_length=config.env.max_episode_steps, 108 parallel=config.num_envs, 109 autoreset_mode=AutoresetMode.DISABLED, 110 ) 111 112 agent = make_deterministic_ec_agent( 113 action_space=env.action_space, 114 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes, 115 use_bias=config.agent_network.use_bias, 116 normalize_obs=config.normalize_obs, 117 norm_layer_type=config.agent_network.norm_layer_type, 118 ) 119 120 ec_optimizer = ARS( 121 pop_size=config.pop_size, 122 num_elites=config.num_elites, 123 lr=config.lr, 124 noise_std=config.noise_std, 125 optimizer_name=config.optimizer_name, 126 ) 127 128 if config.explore: 129 action_fn = agent.compute_actions 130 else: 131 action_fn = agent.evaluate_actions 132 133 assert config.normalize_obs_mode in ["VBN", "RS", "Global"] 134 if config.normalize_obs_mode == "VBN": 135 ec_evaluator = Evaluator( 136 env=env, 137 action_fn=action_fn, 138 max_episode_steps=config.env.max_episode_steps, 139 discount=config.discount, 140 ) 141 else: 142 ec_evaluator = EpisodeObsCollector( 143 env=env, 144 action_fn=action_fn, 145 max_episode_steps=config.env.max_episode_steps, 146 discount=config.discount, 147 ) 148 149 # to evaluate the pop-mean actor 150 eval_env = create_env( 151 config.env, 152 episode_length=config.env.max_episode_steps, 153 parallel=config.num_eval_envs, 154 autoreset_mode=AutoresetMode.DISABLED, 155 ) 156 157 evaluator = Evaluator( 158 env=eval_env, 159 action_fn=agent.evaluate_actions, 160 max_episode_steps=config.env.max_episode_steps, 161 ) 162 163 agent_state_vmap_axes = AgentState( 164 params=0, 165 obs_preprocessor_state=None, 166 ) 167 168 return cls( 169 config=config, 170 env=env, 171 agent=agent, 172 ec_optimizer=ec_optimizer, 173 ec_evaluator=ec_evaluator, 174 evaluator=evaluator, 175 agent_state_vmap_axes=agent_state_vmap_axes, 176 ) 177 178 def _setup_agent_and_optimizer(self, key: jax.Array) -> tuple[AgentState, ECState]: 179 agent_key, ec_key = jax.random.split(key) 180 agent_state = self.agent.init( 181 self.env.obs_space, self.env.action_space, agent_key 182 ) 183 184 init_actor_params = agent_state.params.policy_params 185 ec_opt_state = self.ec_optimizer.init(init_actor_params, ec_key) 186 187 # remove params 188 agent_state = self._replace_actor_params(agent_state, params=None) 189 190 return agent_state, ec_opt_state 191 192 def _postsetup(self, state: State) -> State: 193 # setup obs_preprocessor_state 194 if self.config.normalize_obs: 195 key, obs_key = jax.random.split(state.key, 2) 196 agent_state = init_obs_preprocessor( 197 agent_state=state.agent_state, 198 config=self.config, 199 key=obs_key, 200 dp_axis_name=self.dp_axis_name, 201 ) 202 203 # Note: we don't count these random timesteps in state.metrics 204 return state.replace( 205 agent_state=agent_state, 206 key=key, 207 ) 208 else: 209 return state 210 211 def _replace_actor_params( 212 self, agent_state: AgentState, params: Params 213 ) -> AgentState: 214 return agent_state.replace( 215 params=agent_state.params.replace(policy_params=params) 216 ) 217 218 def _get_pop_center(self, state: State) -> AgentState: 219 pop_center = state.ec_opt_state.mean 220 221 return self._replace_actor_params(state.agent_state, pop_center) 222 223 def _update_obs_preprocessor( 224 self, agent_state: AgentState, trajectory: SampleBatch 225 ) -> AgentState: 226 if self.config.normalize_obs_mode == "Global": 227 obs_preprocessor_state = running_statistics.update( 228 agent_state.obs_preprocessor_state, 229 trajectory.obs, 230 weights=1 - trajectory.dones, 231 dp_axis_name=self.dp_axis_name, 232 ) 233 234 elif self.config.normalize_obs_mode == "RS": 235 dummy_obs = jtu.tree_map(lambda x: x[0, 0], trajectory.obs) 236 obs_preprocessor_state = running_statistics.init_state(dummy_obs) 237 obs_preprocessor_state = running_statistics.update( 238 obs_preprocessor_state, 239 trajectory.obs, 240 weights=1 - trajectory.dones, 241 dp_axis_name=self.dp_axis_name, 242 ) 243 else: 244 obs_preprocessor_state = agent_state.obs_preprocessor_state 245 246 return agent_state.replace(obs_preprocessor_state=obs_preprocessor_state)