1import logging
2from omegaconf import DictConfig
3from typing_extensions import Self # pytype: disable=not-supported-yet]
4
5import jax
6
7from evorl.types import State, Params
8from evorl.envs import AutoresetMode, create_env
9from evorl.evaluators import Evaluator
10from evorl.agent import AgentState
11from evorl.ec.optimizers import VanillaES, ExponentialScheduleSpec, ECState
12
13
14from .es_workflow import ESWorkflowTemplate
15from ..obs_utils import init_obs_preprocessor
16from ..ec_agent import make_deterministic_ec_agent
17
18
19logger = logging.getLogger(__name__)
20
21
[docs]
22class VanillaESWorkflow(ESWorkflowTemplate):
[docs]
23 @classmethod
24 def name(cls):
25 return "VanillaES"
26
27 @classmethod
28 def _rescale_config(cls, config: DictConfig) -> None:
29 super()._rescale_config(config)
30
31 num_devices = jax.device_count()
32 if config.random_timesteps % num_devices != 0:
33 logger.warning(
34 f"When enable_multi_devices=True, pop_size ({config.random_timesteps}) should be divisible by num_devices ({num_devices}),"
35 )
36
37 config.random_timesteps = (config.random_timesteps // num_devices) * num_devices
38
39 @classmethod
40 def _build_from_config(cls, config: DictConfig) -> Self:
41 env = create_env(
42 config.env,
43 episode_length=config.env.max_episode_steps,
44 parallel=config.num_envs,
45 autoreset_mode=AutoresetMode.DISABLED,
46 )
47
48 agent = make_deterministic_ec_agent(
49 action_space=env.action_space,
50 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes,
51 use_bias=config.agent_network.use_bias,
52 normalize_obs=config.normalize_obs,
53 norm_layer_type=config.agent_network.norm_layer_type,
54 policy_obs_key=config.agent_network.policy_obs_key,
55 )
56
57 ec_optimizer = VanillaES(
58 pop_size=config.pop_size,
59 num_elites=config.num_elites,
60 noise_std_schedule=ExponentialScheduleSpec(**config.ec_noise_std),
61 )
62
63 if config.explore:
64 action_fn = agent.compute_actions
65 else:
66 action_fn = agent.evaluate_actions
67
68 ec_evaluator = Evaluator(
69 env=env,
70 action_fn=action_fn,
71 max_episode_steps=config.env.max_episode_steps,
72 discount=config.discount,
73 )
74
75 # to evaluate the pop-mean actor
76 eval_env = create_env(
77 config.env,
78 episode_length=config.env.max_episode_steps,
79 parallel=config.num_eval_envs,
80 autoreset_mode=AutoresetMode.DISABLED,
81 )
82
83 evaluator = Evaluator(
84 env=eval_env,
85 action_fn=agent.evaluate_actions,
86 max_episode_steps=config.env.max_episode_steps,
87 )
88
89 agent_state_vmap_axes = AgentState(
90 params=0,
91 obs_preprocessor_state=None,
92 )
93
94 return cls(
95 config=config,
96 env=env,
97 agent=agent,
98 ec_optimizer=ec_optimizer,
99 ec_evaluator=ec_evaluator,
100 evaluator=evaluator,
101 agent_state_vmap_axes=agent_state_vmap_axes,
102 )
103
104 def _setup_agent_and_optimizer(self, key: jax.Array) -> tuple[AgentState, ECState]:
105 agent_key, ec_key = jax.random.split(key)
106 agent_state = self.agent.init(
107 self.env.obs_space, self.env.action_space, agent_key
108 )
109
110 init_actor_params = agent_state.params.policy_params
111 ec_opt_state = self.ec_optimizer.init(init_actor_params, ec_key)
112
113 # remove params
114 agent_state = self._replace_actor_params(agent_state, params=None)
115
116 return agent_state, ec_opt_state
117
118 def _postsetup(self, state: State) -> State:
119 # setup obs_preprocessor_state
120 if self.config.normalize_obs:
121 key, obs_key = jax.random.split(state.key, 2)
122 agent_state = init_obs_preprocessor(
123 agent_state=state.agent_state,
124 config=self.config,
125 key=obs_key,
126 dp_axis_name=self.dp_axis_name,
127 )
128
129 # Note: we don't count these random timesteps in state.metrics
130 return state.replace(
131 agent_state=agent_state,
132 key=key,
133 )
134 else:
135 return state
136
137 def _replace_actor_params(
138 self, agent_state: AgentState, params: Params
139 ) -> AgentState:
140 return agent_state.replace(
141 params=agent_state.params.replace(policy_params=params)
142 )
143
144 def _get_pop_center(self, state: State) -> AgentState:
145 pop_center = state.ec_opt_state.mean
146
147 return self._replace_actor_params(state.agent_state, pop_center)