1import copy
2import logging
3from collections.abc import Callable
4from functools import partial
5import math
6
7import chex
8import jax
9import jax.numpy as jnp
10import jax.tree_util as jtu
11from evorl.agent import AgentState, AgentActionFn
12from evorl.envs import EnvState, EnvStepFn
13from evorl.envs.brax import BraxAdapter
14from evorl.rollout import SampleBatch
15from evorl.types import Action, PolicyExtraInfo, PyTreeDict, pytree_field
16from evorl.utils.jax_utils import rng_split
17from evorl.utils.rl_toolkits import compute_discount_return, compute_episode_length
18
19from .evaluator import Evaluator
20
21logger = logging.getLogger(__name__)
22
23
[docs]
24class BraxEvaluator(Evaluator):
25 """Mutli-objective evaluator for Brax environments.
26
27 Attributes:
28 metric_names: The names of the metrics to evaluate, default is ("reward", "episode_lengths")
29
30 """
31
32 metric_names: tuple[str] = pytree_field(
33 default=("reward", "episode_lengths"), static=True
34 )
35
36 def __post_init__(self):
37 super().__post_init__()
38 assert isinstance(self.env.unwrapped, BraxAdapter), (
39 "only support Brax environments"
40 )
41
[docs]
42 def evaluate(
43 self, agent_state: chex.ArrayTree, key: chex.PRNGKey, num_episodes: int
44 ) -> chex.ArrayTree:
45 num_envs = self.env.num_envs
46 num_iters = math.ceil(num_episodes / num_envs)
47 if num_episodes % num_envs != 0:
48 logger.warning(
49 f"num_episode ({num_episodes}) cannot be divided by parallel_envs ({num_envs}),"
50 f"set new num_episodes={num_iters * num_envs}"
51 )
52
53 action_fn = self.action_fn
54 env_reset_fn = self.env.reset
55 env_step_fn = self.env.step
56 if key.ndim > 1:
57 for _ in range(key.ndim - 1):
58 action_fn = jax.vmap(action_fn)
59 env_reset_fn = jax.vmap(env_reset_fn)
60 env_step_fn = jax.vmap(env_step_fn)
61
62 metric_names = copy.deepcopy(tuple(self.metric_names))
63 # we also need episode_length to calculate the sampled_timesteps
64 if "episode_lengths" not in metric_names:
65 metric_names = metric_names + ("episode_lengths",)
66
67 def _evaluate_fn(key, unused_t):
68 next_key, init_env_key, rollout_key = rng_split(key, 3)
69 env_state = env_reset_fn(init_env_key)
70
71 if self.discount == 1.0:
72 # use fast undiscount evaluation
73 metrics, env_state = fast_eval_metrics(
74 env_step_fn,
75 action_fn,
76 env_state,
77 agent_state,
78 rollout_key,
79 self.max_episode_steps,
80 metric_names=metric_names,
81 )
82
83 else:
84 metrics, env_state = eval_metrics(
85 env_step_fn,
86 action_fn,
87 env_state,
88 agent_state,
89 rollout_key,
90 self.max_episode_steps,
91 discount=self.discount,
92 metric_names=self.metric_names,
93 )
94
95 return next_key, metrics # [..., #envs]
96
97 # [#iters, #pop, #envs]
98 _, objectives = jax.lax.scan(_evaluate_fn, key, (), length=num_iters)
99
100 objectives = jtu.tree_map(_flatten_metric, objectives) # [#pop, num_episodes]
101
102 return objectives
103
104
105def _flatten_metric(x):
106 """Flatten the last two dims.
107
108 Args:
109 x: jax tensor with shape (#iters, ..., #envs)
110
111 Returns:
112 flatten x with shape (..., #iters * #envs)
113 """
114 return jax.lax.collapse(jnp.moveaxis(x, 0, -2), -2)
115
116
[docs]
117def eval_env_step(
118 env_fn: EnvStepFn,
119 action_fn: AgentActionFn,
120 env_state: EnvState,
121 agent_state: AgentState, # readonly
122 key: chex.PRNGKey,
123 metric_names: tuple[str] = (),
124) -> tuple[SampleBatch, EnvState]:
125 # sample_batch: [#envs, ...]
126 sample_batch = SampleBatch(obs=env_state.obs)
127
128 actions, policy_extras = action_fn(agent_state, sample_batch, key)
129 env_nstate = env_fn(env_state, actions)
130
131 # info = env_nstate.info
132 # env_extras = {x: info[x] for x in env_extra_fields if x in info}
133
134 rewards = PyTreeDict(
135 {
136 name: val
137 for name, val in env_nstate.info.metrics.items()
138 if name in metric_names
139 }
140 )
141 rewards.reward = env_nstate.reward
142
143 transition = SampleBatch(
144 rewards=rewards,
145 dones=env_nstate.done,
146 )
147
148 return transition, env_nstate
149
150
[docs]
151def eval_rollout_episode(
152 env_fn: Callable[[EnvState, Action], EnvState],
153 action_fn: Callable[
154 [AgentState, SampleBatch, chex.PRNGKey], tuple[Action, PolicyExtraInfo]
155 ],
156 env_state: EnvState,
157 agent_state: AgentState,
158 key: chex.PRNGKey,
159 rollout_length: int,
160 metric_names: tuple[str] = (),
161) -> tuple[SampleBatch, EnvState]:
162 """Evaulate a batch of episodic trajectories.
163
164 The retruned metrics are defined by `metric_names`.
165 """
166 _eval_env_step = partial(
167 eval_env_step, env_fn, action_fn, metric_names=metric_names
168 )
169
170 def _one_step_rollout(carry, unused_t):
171 env_state, current_key, prev_transition = carry
172 # next_key, current_key = jax.random.split(current_key, 2)
173 next_key, current_key = rng_split(current_key, 2)
174
175 transition, env_nstate = jax.lax.cond(
176 env_state.done.all(),
177 lambda *x: (env_state.replace(), prev_transition.replace()),
178 _eval_env_step,
179 env_state,
180 agent_state,
181 current_key,
182 )
183
184 return (env_nstate, next_key, transition), transition
185
186 # run one-step rollout first to get bootstrap transition
187 # it will not include in the trajectory when env_state is from env.reset()
188 # this is manually controlled by user.
189 bootstrap_transition, _ = _eval_env_step(env_state, agent_state, key)
190
191 (env_state, _, _), trajectory = jax.lax.scan(
192 _one_step_rollout,
193 (env_state, key, bootstrap_transition),
194 (),
195 length=rollout_length,
196 )
197
198 return trajectory, env_state
199
200
[docs]
201def eval_metrics(
202 env_fn: Callable[[EnvState, Action], EnvState],
203 action_fn: Callable[
204 [AgentState, SampleBatch, chex.PRNGKey], tuple[Action, PolicyExtraInfo]
205 ],
206 env_state: EnvState,
207 agent_state: AgentState,
208 key: chex.PRNGKey,
209 rollout_length: int,
210 discount: float,
211 metric_names: tuple[str] = (),
212) -> tuple[PyTreeDict, EnvState]:
213 episode_trajectory, env_state = eval_rollout_episode(
214 env_fn,
215 action_fn,
216 env_state,
217 agent_state,
218 key,
219 rollout_length,
220 metric_names=metric_names,
221 )
222
223 metrics = PyTreeDict()
224 for name in metric_names:
225 if "reward" in name:
226 # For metrics like 'reward_forward' and 'reward_ctrl'
227 metrics[name] = compute_discount_return(
228 episode_trajectory.rewards[name],
229 episode_trajectory.dones,
230 discount,
231 )
232 elif "episode_lengths" == name:
233 metrics[name] = compute_episode_length(episode_trajectory.dones)
234 else:
235 # For other metrics like 'x_position', we use the last value as the objective.
236 # Note: It is ok to use [-1], since wrapper ensures that the last value
237 # repeats the terminal step value.
238 metrics[name] = episode_trajectory.rewards[name][-1]
239
240 return metrics, env_state
241
242
[docs]
243def fast_eval_metrics(
244 env_fn: Callable[[EnvState, Action], EnvState],
245 action_fn: Callable[
246 [AgentState, SampleBatch, chex.PRNGKey], tuple[Action, PolicyExtraInfo]
247 ],
248 env_state: EnvState,
249 agent_state: AgentState,
250 key: chex.PRNGKey,
251 rollout_length: int,
252 metric_names: tuple[str] = (),
253) -> tuple[PyTreeDict, EnvState]:
254 """Fast evaulate a batch of episodic trajectories.
255
256 The retruned metrics are defined by `metric_names`.
257 """
258 _eval_env_step = partial(
259 eval_env_step, env_fn, action_fn, metric_names=metric_names
260 )
261
262 def _terminate_cond(carry):
263 env_state, current_key, prev_metrics = carry
264 return (prev_metrics.episode_lengths < rollout_length).all() & (
265 ~env_state.done.all()
266 )
267
268 def _one_step_rollout(carry):
269 env_state, current_key, prev_metrics = carry
270 next_key, current_key = rng_split(current_key, 2)
271
272 transition, env_nstate = _eval_env_step(env_state, agent_state, current_key)
273
274 prev_dones = env_state.done
275
276 metrics = PyTreeDict()
277 for name in metric_names:
278 if "reward" in name:
279 metrics[name] = (
280 prev_metrics[name] + (1 - prev_dones) * transition.rewards[name]
281 )
282 elif "episode_lengths" == name:
283 metrics[name] = prev_metrics[name] + (1 - prev_dones)
284 elif name in metrics:
285 metrics[name] = transition.rewards[name]
286
287 return env_nstate, next_key, metrics
288
289 batch_shape = env_state.reward.shape
290
291 env_state, _, metrics = jax.lax.while_loop(
292 _terminate_cond,
293 _one_step_rollout,
294 (
295 env_state,
296 key,
297 PyTreeDict({name: jnp.zeros(batch_shape) for name in metric_names}),
298 ),
299 )
300
301 return metrics, env_state