1from omegaconf import DictConfig
2from typing_extensions import Self # pytype: disable=not-supported-yet]
3
4import jax
5
6from evorl.types import State, Params
7from evorl.envs import AutoresetMode, create_env
8from evorl.evaluators import Evaluator
9from evorl.agent import AgentState
10from evorl.ec.optimizers import SepCEM, ExponentialScheduleSpec, ECState
11
12from .es_workflow import ESWorkflowTemplate
13from ..obs_utils import init_obs_preprocessor
14from ..ec_agent import make_deterministic_ec_agent
15
16
[docs]
17class SepCEMWorkflow(ESWorkflowTemplate):
[docs]
18 @classmethod
19 def name(cls):
20 return "SepCEM"
21
22 @classmethod
23 def _build_from_config(cls, config: DictConfig) -> Self:
24 env = create_env(
25 config.env,
26 episode_length=config.env.max_episode_steps,
27 parallel=config.num_envs,
28 autoreset_mode=AutoresetMode.DISABLED,
29 )
30
31 agent = make_deterministic_ec_agent(
32 action_space=env.action_space,
33 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes,
34 use_bias=config.agent_network.use_bias,
35 normalize_obs=config.normalize_obs,
36 norm_layer_type=config.agent_network.norm_layer_type,
37 policy_obs_key=config.agent_network.policy_obs_key,
38 )
39
40 ec_optimizer = SepCEM(
41 pop_size=config.pop_size,
42 num_elites=config.num_elites,
43 cov_eps_schedule=ExponentialScheduleSpec(**config.cov_eps),
44 weighted_update=config.weighted_update,
45 rank_weight_shift=config.rank_weight_shift,
46 mirror_sampling=config.mirror_sampling,
47 )
48
49 if config.explore:
50 action_fn = agent.compute_actions
51 else:
52 action_fn = agent.evaluate_actions
53
54 ec_evaluator = Evaluator(
55 env=env,
56 action_fn=action_fn,
57 max_episode_steps=config.env.max_episode_steps,
58 discount=config.discount,
59 )
60
61 # to evaluate the pop-mean actor
62 eval_env = create_env(
63 config.env,
64 episode_length=config.env.max_episode_steps,
65 parallel=config.num_eval_envs,
66 autoreset_mode=AutoresetMode.DISABLED,
67 )
68
69 evaluator = Evaluator(
70 env=eval_env,
71 action_fn=agent.evaluate_actions,
72 max_episode_steps=config.env.max_episode_steps,
73 )
74
75 agent_state_vmap_axes = AgentState(
76 params=0,
77 obs_preprocessor_state=None,
78 )
79
80 return cls(
81 config=config,
82 env=env,
83 agent=agent,
84 ec_optimizer=ec_optimizer,
85 ec_evaluator=ec_evaluator,
86 evaluator=evaluator,
87 agent_state_vmap_axes=agent_state_vmap_axes,
88 )
89
90 def _setup_agent_and_optimizer(self, key: jax.Array) -> tuple[AgentState, ECState]:
91 agent_key, ec_key = jax.random.split(key)
92 agent_state = self.agent.init(
93 self.env.obs_space, self.env.action_space, agent_key
94 )
95
96 init_actor_params = agent_state.params.policy_params
97 ec_opt_state = self.ec_optimizer.init(init_actor_params, ec_key)
98
99 # remove params
100 agent_state = self._replace_actor_params(agent_state, params=None)
101
102 return agent_state, ec_opt_state
103
104 def _postsetup(self, state: State) -> State:
105 # setup obs_preprocessor_state
106 if self.config.normalize_obs:
107 key, obs_key = jax.random.split(state.key, 2)
108 agent_state = init_obs_preprocessor(
109 agent_state=state.agent_state,
110 config=self.config,
111 key=obs_key,
112 dp_axis_name=self.dp_axis_name,
113 )
114
115 # Note: we don't count these random timesteps in state.metrics
116 return state.replace(
117 agent_state=agent_state,
118 key=key,
119 )
120 else:
121 return state
122
123 def _replace_actor_params(
124 self, agent_state: AgentState, params: Params
125 ) -> AgentState:
126 return agent_state.replace(
127 params=agent_state.params.replace(policy_params=params)
128 )
129
130 def _get_pop_center(self, state: State) -> AgentState:
131 pop_center = state.ec_opt_state.mean
132
133 return self._replace_actor_params(state.agent_state, pop_center)