1import logging
2from omegaconf import DictConfig
3from collections.abc import Sequence
4from typing_extensions import Self # pytype: disable=not-supported-yet]
5
6import jax
7import jax.numpy as jnp
8import jax.tree_util as jtu
9import flax.linen as nn
10
11from evorl.types import State, Params
12from evorl.envs import AutoresetMode, create_env, Space, Box
13from evorl.evaluators import Evaluator, EpisodeObsCollector
14from evorl.sample_batch import SampleBatch
15from evorl.agent import AgentState
16from evorl.ec.optimizers import ARS, ECState
17from evorl.utils import running_statistics
18from evorl.networks import make_mlp, ActivationFn
19
20from evorl.algorithms.ec.so.es_workflow import ESWorkflowTemplate
21from evorl.algorithms.ec.obs_utils import init_obs_preprocessor
22from evorl.algorithms.ec.ec_agent import DeterministicECAgent
23
24
25logger = logging.getLogger(__name__)
26
27
[docs]
28def make_policy_network(
29 action_size: int,
30 hidden_layer_sizes: Sequence[int] = (256, 256),
31 use_bias: bool = True,
32 activation: ActivationFn = nn.relu,
33 activation_final: ActivationFn | None = None,
34 norm_layer_type: str = "none",
35) -> nn.Module:
36 """Creates a policy network."""
37
38 class Policy(nn.Module):
39 @nn.compact
40 def __call__(self, x):
41 policy_model = make_mlp(
42 layer_sizes=tuple(hidden_layer_sizes) + (action_size,),
43 activation=activation,
44 # kernel_init=jax.nn.initializers.lecun_uniform(),
45 kernel_init=jax.nn.initializers.zeros,
46 activation_final=activation_final,
47 use_bias=use_bias,
48 norm_layer_type=norm_layer_type,
49 )
50 actions = policy_model(x)
51 return jnp.clip(actions, -1.0, 1.0)
52
53 return Policy()
54
55
[docs]
56def make_deterministic_ec_agent(
57 action_space: Space,
58 actor_hidden_layer_sizes: tuple[int] = (256, 256),
59 use_bias: bool = True,
60 norm_layer_type: str = "none",
61 normalize_obs: bool = False,
62):
63 assert isinstance(action_space, Box), "Only continue action space is supported."
64
65 action_size = action_space.shape[0]
66
67 policy_network = make_policy_network(
68 action_size=action_size,
69 hidden_layer_sizes=actor_hidden_layer_sizes,
70 use_bias=use_bias,
71 activation_final=None,
72 norm_layer_type=norm_layer_type,
73 )
74
75 if normalize_obs:
76 obs_preprocessor = running_statistics.normalize
77 else:
78 obs_preprocessor = None
79
80 return DeterministicECAgent(
81 policy_network=policy_network,
82 obs_preprocessor=obs_preprocessor,
83 )
84
85
[docs]
86class ARSWorkflow(ESWorkflowTemplate):
[docs]
87 @classmethod
88 def name(cls):
89 return "ARS"
90
91 @classmethod
92 def _rescale_config(cls, config: DictConfig) -> None:
93 super()._rescale_config(config)
94
95 num_devices = jax.device_count()
96 if config.random_timesteps % num_devices != 0:
97 logger.warning(
98 f"When enable_multi_devices=True, pop_size ({config.random_timesteps}) should be divisible by num_devices ({num_devices}),"
99 )
100
101 config.random_timesteps = (config.random_timesteps // num_devices) * num_devices
102
103 @classmethod
104 def _build_from_config(cls, config: DictConfig) -> Self:
105 env = create_env(
106 config.env,
107 episode_length=config.env.max_episode_steps,
108 parallel=config.num_envs,
109 autoreset_mode=AutoresetMode.DISABLED,
110 )
111
112 agent = make_deterministic_ec_agent(
113 action_space=env.action_space,
114 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes,
115 use_bias=config.agent_network.use_bias,
116 normalize_obs=config.normalize_obs,
117 norm_layer_type=config.agent_network.norm_layer_type,
118 )
119
120 ec_optimizer = ARS(
121 pop_size=config.pop_size,
122 num_elites=config.num_elites,
123 lr=config.lr,
124 noise_std=config.noise_std,
125 optimizer_name=config.optimizer_name,
126 )
127
128 if config.explore:
129 action_fn = agent.compute_actions
130 else:
131 action_fn = agent.evaluate_actions
132
133 assert config.normalize_obs_mode in ["VBN", "RS", "Global"]
134 if config.normalize_obs_mode == "VBN":
135 ec_evaluator = Evaluator(
136 env=env,
137 action_fn=action_fn,
138 max_episode_steps=config.env.max_episode_steps,
139 discount=config.discount,
140 )
141 else:
142 ec_evaluator = EpisodeObsCollector(
143 env=env,
144 action_fn=action_fn,
145 max_episode_steps=config.env.max_episode_steps,
146 discount=config.discount,
147 )
148
149 # to evaluate the pop-mean actor
150 eval_env = create_env(
151 config.env,
152 episode_length=config.env.max_episode_steps,
153 parallel=config.num_eval_envs,
154 autoreset_mode=AutoresetMode.DISABLED,
155 )
156
157 evaluator = Evaluator(
158 env=eval_env,
159 action_fn=agent.evaluate_actions,
160 max_episode_steps=config.env.max_episode_steps,
161 )
162
163 agent_state_vmap_axes = AgentState(
164 params=0,
165 obs_preprocessor_state=None,
166 )
167
168 return cls(
169 config=config,
170 env=env,
171 agent=agent,
172 ec_optimizer=ec_optimizer,
173 ec_evaluator=ec_evaluator,
174 evaluator=evaluator,
175 agent_state_vmap_axes=agent_state_vmap_axes,
176 )
177
178 def _setup_agent_and_optimizer(self, key: jax.Array) -> tuple[AgentState, ECState]:
179 agent_key, ec_key = jax.random.split(key)
180 agent_state = self.agent.init(
181 self.env.obs_space, self.env.action_space, agent_key
182 )
183
184 init_actor_params = agent_state.params.policy_params
185 ec_opt_state = self.ec_optimizer.init(init_actor_params, ec_key)
186
187 # remove params
188 agent_state = self._replace_actor_params(agent_state, params=None)
189
190 return agent_state, ec_opt_state
191
192 def _postsetup(self, state: State) -> State:
193 # setup obs_preprocessor_state
194 if self.config.normalize_obs:
195 key, obs_key = jax.random.split(state.key, 2)
196 agent_state = init_obs_preprocessor(
197 agent_state=state.agent_state,
198 config=self.config,
199 key=obs_key,
200 dp_axis_name=self.dp_axis_name,
201 )
202
203 # Note: we don't count these random timesteps in state.metrics
204 return state.replace(
205 agent_state=agent_state,
206 key=key,
207 )
208 else:
209 return state
210
211 def _replace_actor_params(
212 self, agent_state: AgentState, params: Params
213 ) -> AgentState:
214 return agent_state.replace(
215 params=agent_state.params.replace(policy_params=params)
216 )
217
218 def _get_pop_center(self, state: State) -> AgentState:
219 pop_center = state.ec_opt_state.mean
220
221 return self._replace_actor_params(state.agent_state, pop_center)
222
223 def _update_obs_preprocessor(
224 self, agent_state: AgentState, trajectory: SampleBatch
225 ) -> AgentState:
226 if self.config.normalize_obs_mode == "Global":
227 obs_preprocessor_state = running_statistics.update(
228 agent_state.obs_preprocessor_state,
229 trajectory.obs,
230 weights=1 - trajectory.dones,
231 dp_axis_name=self.dp_axis_name,
232 )
233
234 elif self.config.normalize_obs_mode == "RS":
235 dummy_obs = jtu.tree_map(lambda x: x[0, 0], trajectory.obs)
236 obs_preprocessor_state = running_statistics.init_state(dummy_obs)
237 obs_preprocessor_state = running_statistics.update(
238 obs_preprocessor_state,
239 trajectory.obs,
240 weights=1 - trajectory.dones,
241 dp_axis_name=self.dp_axis_name,
242 )
243 else:
244 obs_preprocessor_state = agent_state.obs_preprocessor_state
245
246 return agent_state.replace(obs_preprocessor_state=obs_preprocessor_state)