1import logging
2from omegaconf import DictConfig
3from typing_extensions import Self # pytype: disable=not-supported-yet]
4
5import jax
6
7from evorl.types import State, Params
8from evorl.envs import AutoresetMode, create_env
9from evorl.evaluators import Evaluator
10from evorl.agent import AgentState
11from evorl.ec.optimizers import OpenES, ExponentialScheduleSpec, ECState
12
13
14from .es_workflow import ESWorkflowTemplate
15from ..obs_utils import init_obs_preprocessor
16from ..ec_agent import make_deterministic_ec_agent
17
18
19logger = logging.getLogger(__name__)
20
21
[docs]
22class OpenESWorkflow(ESWorkflowTemplate):
[docs]
23 @classmethod
24 def name(cls):
25 return "OpenES"
26
27 @classmethod
28 def _rescale_config(cls, config: DictConfig) -> None:
29 super()._rescale_config(config)
30
31 num_devices = jax.device_count()
32 if config.random_timesteps % num_devices != 0:
33 logger.warning(
34 f"When enable_multi_devices=True, pop_size ({config.random_timesteps}) should be divisible by num_devices ({num_devices}),"
35 )
36
37 config.random_timesteps = (config.random_timesteps // num_devices) * num_devices
38
39 @classmethod
40 def _build_from_config(cls, config: DictConfig) -> Self:
41 env = create_env(
42 config.env,
43 episode_length=config.env.max_episode_steps,
44 parallel=config.num_envs,
45 autoreset_mode=AutoresetMode.DISABLED,
46 )
47
48 agent = make_deterministic_ec_agent(
49 action_space=env.action_space,
50 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes,
51 use_bias=config.agent_network.use_bias,
52 normalize_obs=config.normalize_obs,
53 norm_layer_type=config.agent_network.norm_layer_type,
54 policy_obs_key=config.agent_network.policy_obs_key,
55 )
56
57 ec_optimizer = OpenES(
58 pop_size=config.pop_size,
59 lr_schedule=ExponentialScheduleSpec(**config.ec_lr),
60 noise_std_schedule=ExponentialScheduleSpec(**config.ec_noise_std),
61 mirror_sampling=config.mirror_sampling,
62 weight_decay=config.weight_decay,
63 optimizer_name=config.optimizer_name,
64 )
65
66 if config.explore:
67 action_fn = agent.compute_actions
68 else:
69 action_fn = agent.evaluate_actions
70
71 ec_evaluator = Evaluator(
72 env=env,
73 action_fn=action_fn,
74 max_episode_steps=config.env.max_episode_steps,
75 discount=config.discount,
76 )
77
78 # to evaluate the pop-mean actor
79 eval_env = create_env(
80 config.env,
81 episode_length=config.env.max_episode_steps,
82 parallel=config.num_eval_envs,
83 autoreset_mode=AutoresetMode.DISABLED,
84 )
85
86 evaluator = Evaluator(
87 env=eval_env,
88 action_fn=agent.evaluate_actions,
89 max_episode_steps=config.env.max_episode_steps,
90 )
91
92 agent_state_vmap_axes = AgentState(
93 params=0,
94 obs_preprocessor_state=None,
95 )
96
97 return cls(
98 config=config,
99 env=env,
100 agent=agent,
101 ec_optimizer=ec_optimizer,
102 ec_evaluator=ec_evaluator,
103 evaluator=evaluator,
104 agent_state_vmap_axes=agent_state_vmap_axes,
105 )
106
107 def _setup_agent_and_optimizer(self, key: jax.Array) -> tuple[AgentState, ECState]:
108 agent_key, ec_key = jax.random.split(key)
109 agent_state = self.agent.init(
110 self.env.obs_space, self.env.action_space, agent_key
111 )
112
113 init_actor_params = agent_state.params.policy_params
114 ec_opt_state = self.ec_optimizer.init(init_actor_params, ec_key)
115
116 # remove params
117 agent_state = self._replace_actor_params(agent_state, params=None)
118
119 return agent_state, ec_opt_state
120
121 def _postsetup(self, state: State) -> State:
122 # setup obs_preprocessor_state
123 if self.config.normalize_obs:
124 key, obs_key = jax.random.split(state.key, 2)
125 agent_state = init_obs_preprocessor(
126 agent_state=state.agent_state,
127 config=self.config,
128 key=obs_key,
129 dp_axis_name=self.dp_axis_name,
130 )
131
132 # Note: we don't count these random timesteps in state.metrics
133 return state.replace(
134 agent_state=agent_state,
135 key=key,
136 )
137 else:
138 return state
139
140 def _replace_actor_params(
141 self, agent_state: AgentState, params: Params
142 ) -> AgentState:
143 return agent_state.replace(
144 params=agent_state.params.replace(policy_params=params)
145 )
146
147 def _get_pop_center(self, state: State) -> AgentState:
148 pop_center = state.ec_opt_state.mean
149
150 return self._replace_actor_params(state.agent_state, pop_center)