1import logging
2from omegaconf import DictConfig
3from typing_extensions import Self # pytype: disable=not-supported-yet]
4
5import jax
6import jax.tree_util as jtu
7
8from evorl.types import State, Params
9from evorl.envs import AutoresetMode, create_env
10from evorl.evaluators import Evaluator, EpisodeObsCollector
11from evorl.sample_batch import SampleBatch
12from evorl.agent import AgentState
13from evorl.ec.optimizers import ARS, ECState
14from evorl.utils import running_statistics
15
16from .es_workflow import ESWorkflowTemplate
17from ..obs_utils import init_obs_preprocessor
18from ..ec_agent import make_deterministic_ec_agent
19
20
21logger = logging.getLogger(__name__)
22
23
[docs]
24class ARSWorkflow(ESWorkflowTemplate):
[docs]
25 @classmethod
26 def name(cls):
27 return "ARS"
28
29 @classmethod
30 def _rescale_config(cls, config: DictConfig) -> None:
31 super()._rescale_config(config)
32
33 num_devices = jax.device_count()
34 if config.random_timesteps % num_devices != 0:
35 logging.warning(
36 f"When enable_multi_devices=True, pop_size ({config.random_timesteps}) should be divisible by num_devices ({num_devices}),"
37 )
38
39 config.random_timesteps = (config.random_timesteps // num_devices) * num_devices
40
41 @classmethod
42 def _build_from_config(cls, config: DictConfig) -> Self:
43 env = create_env(
44 config.env,
45 episode_length=config.env.max_episode_steps,
46 parallel=config.num_envs,
47 autoreset_mode=AutoresetMode.DISABLED,
48 )
49
50 agent = make_deterministic_ec_agent(
51 action_space=env.action_space,
52 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes,
53 use_bias=config.agent_network.use_bias,
54 normalize_obs=config.normalize_obs,
55 norm_layer_type=config.agent_network.norm_layer_type,
56 policy_obs_key=config.agent_network.policy_obs_key,
57 )
58
59 ec_optimizer = ARS(
60 pop_size=config.pop_size,
61 num_elites=config.num_elites,
62 lr=config.lr,
63 noise_std=config.noise_std,
64 optimizer_name=config.optimizer_name,
65 )
66
67 if config.explore:
68 action_fn = agent.compute_actions
69 else:
70 action_fn = agent.evaluate_actions
71
72 assert config.normalize_obs_mode in ["VBN", "RS", "Global"]
73 if config.normalize_obs_mode == "VBN":
74 ec_evaluator = Evaluator(
75 env=env,
76 action_fn=action_fn,
77 max_episode_steps=config.env.max_episode_steps,
78 discount=config.discount,
79 )
80 else:
81 ec_evaluator = EpisodeObsCollector(
82 env=env,
83 action_fn=action_fn,
84 max_episode_steps=config.env.max_episode_steps,
85 discount=config.discount,
86 )
87
88 # to evaluate the pop-mean actor
89 eval_env = create_env(
90 config.env,
91 episode_length=config.env.max_episode_steps,
92 parallel=config.num_eval_envs,
93 autoreset_mode=AutoresetMode.DISABLED,
94 )
95
96 evaluator = Evaluator(
97 env=eval_env,
98 action_fn=agent.evaluate_actions,
99 max_episode_steps=config.env.max_episode_steps,
100 )
101
102 agent_state_vmap_axes = AgentState(
103 params=0,
104 obs_preprocessor_state=None,
105 )
106
107 return cls(
108 config=config,
109 env=env,
110 agent=agent,
111 ec_optimizer=ec_optimizer,
112 ec_evaluator=ec_evaluator,
113 evaluator=evaluator,
114 agent_state_vmap_axes=agent_state_vmap_axes,
115 )
116
117 def _setup_agent_and_optimizer(self, key: jax.Array) -> tuple[AgentState, ECState]:
118 agent_key, ec_key = jax.random.split(key)
119 agent_state = self.agent.init(
120 self.env.obs_space, self.env.action_space, agent_key
121 )
122
123 init_actor_params = agent_state.params.policy_params
124 ec_opt_state = self.ec_optimizer.init(init_actor_params, ec_key)
125
126 # remove params
127 agent_state = self._replace_actor_params(agent_state, params=None)
128
129 return agent_state, ec_opt_state
130
131 def _postsetup(self, state: State) -> State:
132 # setup obs_preprocessor_state
133 if self.config.normalize_obs and self.config.normalize_obs_mode != "RS":
134 key, obs_key = jax.random.split(state.key, 2)
135 agent_state = init_obs_preprocessor(
136 agent_state=state.agent_state,
137 config=self.config,
138 key=obs_key,
139 dp_axis_name=self.dp_axis_name,
140 )
141
142 # Note: we don't count these random timesteps in state.metrics
143 return state.replace(
144 agent_state=agent_state,
145 key=key,
146 )
147 else:
148 return state
149
150 def _replace_actor_params(
151 self, agent_state: AgentState, params: Params
152 ) -> AgentState:
153 return agent_state.replace(
154 params=agent_state.params.replace(policy_params=params)
155 )
156
157 def _get_pop_center(self, state: State) -> AgentState:
158 pop_center = state.ec_opt_state.mean
159
160 return self._replace_actor_params(state.agent_state, pop_center)
161
162 def _update_obs_preprocessor(
163 self, agent_state: AgentState, trajectory: SampleBatch
164 ) -> AgentState:
165 if self.config.normalize_obs_mode == "Global":
166 obs_preprocessor_state = running_statistics.update(
167 agent_state.obs_preprocessor_state,
168 trajectory.obs,
169 weights=1 - trajectory.dones,
170 dp_axis_name=self.dp_axis_name,
171 )
172
173 elif self.config.normalize_obs_mode == "RS":
174 dummy_obs = jtu.tree_map(lambda x: x[0, 0], trajectory.obs)
175 obs_preprocessor_state = running_statistics.init_state(dummy_obs)
176 obs_preprocessor_state = running_statistics.update(
177 obs_preprocessor_state,
178 trajectory.obs,
179 weights=1 - trajectory.dones,
180 dp_axis_name=self.dp_axis_name,
181 )
182 else:
183 obs_preprocessor_state = agent_state.obs_preprocessor_state
184
185 return agent_state.replace(obs_preprocessor_state=obs_preprocessor_state)