1import logging
2from omegaconf import DictConfig
3from typing_extensions import Self # pytype: disable=not-supported-yet]
4
5import jax
6
7from evorl.types import State, Params
8from evorl.envs import AutoresetMode, create_env
9from evorl.evaluators import Evaluator
10from evorl.agent import AgentState
11from evorl.ec.optimizers import OpenESNoiseTable, ExponentialScheduleSpec, ECState
12
13
14from evorl.algorithms.ec.so.es_workflow import ESWorkflowTemplate
15from evorl.algorithms.ec.obs_utils import init_obs_preprocessor
16from evorl.algorithms.ec.ec_agent import make_deterministic_ec_agent
17
18
19logger = logging.getLogger(__name__)
20
21
[docs]
22class OpenESWorkflow(ESWorkflowTemplate):
[docs]
23 @classmethod
24 def name(cls):
25 return "OpenES-NoiseTable"
26
27 @classmethod
28 def _rescale_config(cls, config: DictConfig) -> None:
29 super()._rescale_config(config)
30
31 num_devices = jax.device_count()
32 if config.random_timesteps % num_devices != 0:
33 logger.warning(
34 f"When enable_multi_devices=True, pop_size ({config.random_timesteps}) should be divisible by num_devices ({num_devices}),"
35 )
36
37 config.random_timesteps = (config.random_timesteps // num_devices) * num_devices
38
39 @classmethod
40 def _build_from_config(cls, config: DictConfig) -> Self:
41 env = create_env(
42 config.env,
43 episode_length=config.env.max_episode_steps,
44 parallel=config.num_envs,
45 autoreset_mode=AutoresetMode.DISABLED,
46 )
47
48 agent = make_deterministic_ec_agent(
49 action_space=env.action_space,
50 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes,
51 use_bias=config.agent_network.use_bias,
52 normalize_obs=config.normalize_obs,
53 norm_layer_type=config.agent_network.norm_layer_type,
54 )
55
56 ec_optimizer = OpenESNoiseTable(
57 pop_size=config.pop_size,
58 noise_table_size=config.noise_table_size,
59 lr_schedule=ExponentialScheduleSpec(**config.ec_lr),
60 noise_std_schedule=ExponentialScheduleSpec(**config.ec_noise_std),
61 mirror_sampling=config.mirror_sampling,
62 weight_decay=config.weight_decay,
63 optimizer_name=config.optimizer_name,
64 )
65
66 if config.explore:
67 action_fn = agent.compute_actions
68 else:
69 action_fn = agent.evaluate_actions
70
71 ec_evaluator = Evaluator(
72 env=env,
73 action_fn=action_fn,
74 max_episode_steps=config.env.max_episode_steps,
75 discount=config.discount,
76 )
77
78 # to evaluate the pop-mean actor
79 eval_env = create_env(
80 config.env,
81 episode_length=config.env.max_episode_steps,
82 parallel=config.num_eval_envs,
83 autoreset_mode=AutoresetMode.DISABLED,
84 )
85
86 evaluator = Evaluator(
87 env=eval_env,
88 action_fn=agent.evaluate_actions,
89 max_episode_steps=config.env.max_episode_steps,
90 )
91
92 agent_state_vmap_axes = AgentState(
93 params=0,
94 obs_preprocessor_state=None,
95 )
96
97 return cls(
98 config=config,
99 env=env,
100 agent=agent,
101 ec_optimizer=ec_optimizer,
102 ec_evaluator=ec_evaluator,
103 evaluator=evaluator,
104 agent_state_vmap_axes=agent_state_vmap_axes,
105 )
106
107 def _setup_agent_and_optimizer(self, key: jax.Array) -> tuple[AgentState, ECState]:
108 agent_key, ec_key = jax.random.split(key)
109 agent_state = self.agent.init(
110 self.env.obs_space, self.env.action_space, agent_key
111 )
112
113 init_actor_params = agent_state.params.policy_params
114 ec_opt_state = self.ec_optimizer.init(init_actor_params, ec_key)
115
116 # remove params
117 agent_state = self._replace_actor_params(agent_state, params=None)
118
119 # add shared noise table
120
121 return agent_state, ec_opt_state
122
123 def _postsetup(self, state: State) -> State:
124 # setup obs_preprocessor_state
125 if self.config.normalize_obs:
126 key, obs_key = jax.random.split(state.key, 2)
127 agent_state = init_obs_preprocessor(
128 agent_state=state.agent_state,
129 config=self.config,
130 key=obs_key,
131 dp_axis_name=self.dp_axis_name,
132 )
133
134 # Note: we don't count these random timesteps in state.metrics
135 return state.replace(
136 agent_state=agent_state,
137 key=key,
138 )
139 else:
140 return state
141
142 def _replace_actor_params(
143 self, agent_state: AgentState, params: Params
144 ) -> AgentState:
145 return agent_state.replace(
146 params=agent_state.params.replace(policy_params=params)
147 )
148
149 def _get_pop_center(self, state: State) -> AgentState:
150 pop_center = state.ec_opt_state.mean
151
152 return self._replace_actor_params(state.agent_state, pop_center)