1import logging
2import numpy as np
3from omegaconf import DictConfig
4from typing_extensions import Self # pytype: disable=not-supported-yet]
5
6import jax
7import jax.numpy as jnp
8import jax.tree_util as jtu
9
10from evorl.types import State, Params
11from evorl.envs import AutoresetMode, create_env
12from evorl.evaluators import Evaluator
13from evorl.agent import AgentState
14from evorl.ec.optimizers.ec_optimizer import ECState
15from evorl.utils.ec_utils import ParamVectorSpec
16
17from .es_workflow import ESWorkflowTemplate
18from ..obs_utils import init_obs_preprocessor
19from ..ec_agent import make_deterministic_ec_agent
20
21logger = logging.getLogger(__name__)
22
23
[docs]
24class CMAESWorkflow(ESWorkflowTemplate):
[docs]
25 @classmethod
26 def name(cls):
27 return "CMAES"
28
29 @classmethod
30 def _rescale_config(cls, config: DictConfig) -> None:
31 super()._rescale_config(config)
32
33 num_devices = jax.device_count()
34 if config.random_timesteps % num_devices != 0:
35 logging.warning(
36 f"When enable_multi_devices=True, pop_size ({config.random_timesteps}) should be divisible by num_devices ({num_devices}),"
37 )
38
39 config.random_timesteps = (config.random_timesteps // num_devices) * num_devices
40
41 @classmethod
42 def _build_from_config(cls, config: DictConfig) -> Self:
43 env = create_env(
44 config.env,
45 episode_length=config.env.max_episode_steps,
46 parallel=config.num_envs,
47 autoreset_mode=AutoresetMode.DISABLED,
48 )
49
50 agent = make_deterministic_ec_agent(
51 action_space=env.action_space,
52 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes,
53 use_bias=config.agent_network.use_bias,
54 normalize_obs=config.normalize_obs,
55 norm_layer_type=config.agent_network.norm_layer_type,
56 policy_obs_key=config.agent_network.policy_obs_key,
57 )
58
59 # dummy agent_state
60 agent_key = jax.random.PRNGKey(config.seed)
61 agent_state = agent.init(env.obs_space, env.action_space, agent_key)
62 param_vec_spec = ParamVectorSpec(agent_state.params.policy_params)
63
64 from evorl.ec.optimizers.evox_wrapper import EvoXAlgorithmAdapter
65 from evorl.ec.evox_algorithm import CMAES
66
67 ec_optimizer = EvoXAlgorithmAdapter(
68 algorithm=CMAES(
69 center_init=param_vec_spec.to_vector(agent_state.params.policy_params),
70 init_stdev=config.init_stdev,
71 pop_size=config.pop_size,
72 mu=config.num_elites,
73 ),
74 param_vec_spec=param_vec_spec,
75 )
76
77 if config.explore:
78 action_fn = agent.compute_actions
79 else:
80 action_fn = agent.evaluate_actions
81
82 ec_evaluator = Evaluator(
83 env=env,
84 action_fn=action_fn,
85 max_episode_steps=config.env.max_episode_steps,
86 discount=config.discount,
87 )
88
89 # to evaluate the pop-mean actor
90 eval_env = create_env(
91 config.env,
92 episode_length=config.env.max_episode_steps,
93 parallel=config.num_eval_envs,
94 autoreset_mode=AutoresetMode.DISABLED,
95 )
96
97 evaluator = Evaluator(
98 env=eval_env,
99 action_fn=agent.evaluate_actions,
100 max_episode_steps=config.env.max_episode_steps,
101 )
102
103 agent_state_vmap_axes = AgentState(
104 params=0,
105 obs_preprocessor_state=None,
106 )
107
108 return cls(
109 config=config,
110 env=env,
111 agent=agent,
112 ec_optimizer=ec_optimizer,
113 ec_evaluator=ec_evaluator,
114 evaluator=evaluator,
115 agent_state_vmap_axes=agent_state_vmap_axes,
116 )
117
118 def _setup_agent_and_optimizer(self, key: jax.Array) -> tuple[AgentState, ECState]:
119 agent_key, ec_key = jax.random.split(key)
120 agent_state = self.agent.init(
121 self.env.obs_space, self.env.action_space, agent_key
122 )
123
124 ec_opt_state = self.ec_optimizer.init(ec_key)
125
126 # remove params
127 agent_state = self._replace_actor_params(agent_state, params=None)
128
129 return agent_state, ec_opt_state
130
131 def _postsetup(self, state: State) -> State:
132 # setup obs_preprocessor_state
133 if self.config.normalize_obs:
134 key, obs_key = jax.random.split(state.key, 2)
135 agent_state = init_obs_preprocessor(
136 agent_state=state.agent_state,
137 config=self.config,
138 key=obs_key,
139 dp_axis_name=self.dp_axis_name,
140 )
141
142 # Note: we don't count these random timesteps in state.metrics
143 return state.replace(
144 agent_state=agent_state,
145 key=key,
146 )
147 else:
148 return state
149
150 def _replace_actor_params(
151 self, agent_state: AgentState, params: Params
152 ) -> AgentState:
153 return agent_state.replace(
154 params=agent_state.params.replace(policy_params=params)
155 )
156
157 def _get_pop_center(self, state: State) -> AgentState:
158 flat_pop_center = state.ec_opt_state.algo_state.mean
159
160 pop_center = self.ec_optimizer.param_vec_spec.to_tree(flat_pop_center)
161
162 return self._replace_actor_params(state.agent_state, pop_center)
163
164 def _record_callback(self, state: State, iters: int) -> None:
165 algo_state = state.ec_opt_state.algo_state
166 C = algo_state.C
167 std = jnp.sqrt(jnp.diagonal(C)) * algo_state.sigma
168
169 std = self.ec_optimizer.param_vec_spec.to_tree(std)
170 std_statistics = _get_std_statistics(std)
171 self.recorder.write({"ec/std": std_statistics}, iters)
172 self.recorder.write({"ec/sigma": algo_state.sigma.tolist()}, iters)
173
174
[docs]
175class SepCMAESWorkflow(CMAESWorkflow):
[docs]
176 @classmethod
177 def name(cls):
178 return "SepCMAES"
179
180 @classmethod
181 def _build_from_config(cls, config: DictConfig) -> Self:
182 env = create_env(
183 config.env,
184 episode_length=config.env.max_episode_steps,
185 parallel=config.num_envs,
186 autoreset_mode=AutoresetMode.DISABLED,
187 )
188
189 agent = make_deterministic_ec_agent(
190 action_space=env.action_space,
191 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes,
192 use_bias=config.agent_network.use_bias,
193 normalize_obs=config.normalize_obs,
194 norm_layer_type=config.agent_network.norm_layer_type,
195 policy_obs_key=config.agent_network.policy_obs_key,
196 )
197
198 # dummy agent_state
199 agent_key = jax.random.PRNGKey(config.seed)
200 agent_state = agent.init(env.obs_space, env.action_space, agent_key)
201 param_vec_spec = ParamVectorSpec(agent_state.params.policy_params)
202
203 from evorl.ec.optimizers.evox_wrapper import EvoXAlgorithmAdapter
204 from evorl.ec.evox_algorithm import SepCMAES
205
206 ec_optimizer = EvoXAlgorithmAdapter(
207 algorithm=SepCMAES(
208 center_init=param_vec_spec.to_vector(agent_state.params.policy_params),
209 init_stdev=config.init_stdev,
210 pop_size=config.pop_size,
211 mu=config.num_elites,
212 ),
213 param_vec_spec=param_vec_spec,
214 )
215
216 if config.explore:
217 action_fn = agent.compute_actions
218 else:
219 action_fn = agent.evaluate_actions
220
221 ec_evaluator = Evaluator(
222 env=env,
223 action_fn=action_fn,
224 max_episode_steps=config.env.max_episode_steps,
225 discount=config.discount,
226 )
227
228 # to evaluate the pop-mean actor
229 eval_env = create_env(
230 config.env,
231 episode_length=config.env.max_episode_steps,
232 parallel=config.num_eval_envs,
233 autoreset_mode=AutoresetMode.DISABLED,
234 )
235
236 evaluator = Evaluator(
237 env=eval_env,
238 action_fn=agent.evaluate_actions,
239 max_episode_steps=config.env.max_episode_steps,
240 )
241
242 agent_state_vmap_axes = AgentState(
243 params=0,
244 obs_preprocessor_state=None,
245 )
246
247 return cls(
248 config=config,
249 env=env,
250 agent=agent,
251 ec_optimizer=ec_optimizer,
252 ec_evaluator=ec_evaluator,
253 evaluator=evaluator,
254 agent_state_vmap_axes=agent_state_vmap_axes,
255 )
256
257 def _record_callback(self, state: State, iters: int) -> None:
258 algo_state = state.ec_opt_state.algo_state
259 C = algo_state.C
260 std = jnp.sqrt(C) * algo_state.sigma
261
262 std = self.ec_optimizer.param_vec_spec.to_tree(std)
263 std_statistics = _get_std_statistics(std)
264 self.recorder.write({"ec/std": std_statistics}, iters)
265 self.recorder.write({"ec/sigma": algo_state.sigma.tolist()}, iters)
266
267
268def _get_std_statistics(variance):
269 def _get_stats(x):
270 x = np.asarray(x)
271 return dict(
272 min=np.min(x).tolist(),
273 max=np.max(x).tolist(),
274 mean=np.mean(x).tolist(),
275 )
276
277 return jtu.tree_map(_get_stats, variance)