Source code for evorl.evaluators.mo_brax_evaluator

  1import copy
  2import logging
  3from collections.abc import Callable
  4from functools import partial
  5import math
  6
  7import chex
  8import jax
  9import jax.numpy as jnp
 10import jax.tree_util as jtu
 11from evorl.agent import AgentState, AgentActionFn
 12from evorl.envs import EnvState, EnvStepFn
 13from evorl.envs.brax import BraxAdapter
 14from evorl.rollout import SampleBatch
 15from evorl.types import Action, PolicyExtraInfo, PyTreeDict, pytree_field
 16from evorl.utils.jax_utils import rng_split
 17from evorl.utils.rl_toolkits import compute_discount_return, compute_episode_length
 18
 19from .evaluator import Evaluator
 20
 21logger = logging.getLogger(__name__)
 22
 23
[docs] 24class BraxEvaluator(Evaluator): 25 """Mutli-objective evaluator for Brax environments. 26 27 Attributes: 28 metric_names: The names of the metrics to evaluate, default is ("reward", "episode_lengths") 29 30 """ 31 32 metric_names: tuple[str] = pytree_field( 33 default=("reward", "episode_lengths"), static=True 34 ) 35 36 def __post_init__(self): 37 super().__post_init__() 38 assert isinstance(self.env.unwrapped, BraxAdapter), ( 39 "only support Brax environments" 40 ) 41
[docs] 42 def evaluate( 43 self, agent_state: chex.ArrayTree, key: chex.PRNGKey, num_episodes: int 44 ) -> chex.ArrayTree: 45 num_envs = self.env.num_envs 46 num_iters = math.ceil(num_episodes / num_envs) 47 if num_episodes % num_envs != 0: 48 logger.warning( 49 f"num_episode ({num_episodes}) cannot be divided by parallel_envs ({num_envs})," 50 f"set new num_episodes={num_iters * num_envs}" 51 ) 52 53 action_fn = self.action_fn 54 env_reset_fn = self.env.reset 55 env_step_fn = self.env.step 56 if key.ndim > 1: 57 for _ in range(key.ndim - 1): 58 action_fn = jax.vmap(action_fn) 59 env_reset_fn = jax.vmap(env_reset_fn) 60 env_step_fn = jax.vmap(env_step_fn) 61 62 metric_names = copy.deepcopy(tuple(self.metric_names)) 63 # we also need episode_length to calculate the sampled_timesteps 64 if "episode_lengths" not in metric_names: 65 metric_names = metric_names + ("episode_lengths",) 66 67 def _evaluate_fn(key, unused_t): 68 next_key, init_env_key, rollout_key = rng_split(key, 3) 69 env_state = env_reset_fn(init_env_key) 70 71 if self.discount == 1.0: 72 # use fast undiscount evaluation 73 metrics, env_state = fast_eval_metrics( 74 env_step_fn, 75 action_fn, 76 env_state, 77 agent_state, 78 rollout_key, 79 self.max_episode_steps, 80 metric_names=metric_names, 81 ) 82 83 else: 84 metrics, env_state = eval_metrics( 85 env_step_fn, 86 action_fn, 87 env_state, 88 agent_state, 89 rollout_key, 90 self.max_episode_steps, 91 discount=self.discount, 92 metric_names=self.metric_names, 93 ) 94 95 return next_key, metrics # [..., #envs] 96 97 # [#iters, #pop, #envs] 98 _, objectives = jax.lax.scan(_evaluate_fn, key, (), length=num_iters) 99 100 objectives = jtu.tree_map(_flatten_metric, objectives) # [#pop, num_episodes] 101 102 return objectives
103 104 105def _flatten_metric(x): 106 """Flatten the last two dims. 107 108 Args: 109 x: jax tensor with shape (#iters, ..., #envs) 110 111 Returns: 112 flatten x with shape (..., #iters * #envs) 113 """ 114 return jax.lax.collapse(jnp.moveaxis(x, 0, -2), -2) 115 116
[docs] 117def eval_env_step( 118 env_fn: EnvStepFn, 119 action_fn: AgentActionFn, 120 env_state: EnvState, 121 agent_state: AgentState, # readonly 122 key: chex.PRNGKey, 123 metric_names: tuple[str] = (), 124) -> tuple[SampleBatch, EnvState]: 125 # sample_batch: [#envs, ...] 126 sample_batch = SampleBatch(obs=env_state.obs) 127 128 actions, policy_extras = action_fn(agent_state, sample_batch, key) 129 env_nstate = env_fn(env_state, actions) 130 131 # info = env_nstate.info 132 # env_extras = {x: info[x] for x in env_extra_fields if x in info} 133 134 rewards = PyTreeDict( 135 { 136 name: val 137 for name, val in env_nstate.info.metrics.items() 138 if name in metric_names 139 } 140 ) 141 rewards.reward = env_nstate.reward 142 143 transition = SampleBatch( 144 rewards=rewards, 145 dones=env_nstate.done, 146 ) 147 148 return transition, env_nstate
149 150
[docs] 151def eval_rollout_episode( 152 env_fn: Callable[[EnvState, Action], EnvState], 153 action_fn: Callable[ 154 [AgentState, SampleBatch, chex.PRNGKey], tuple[Action, PolicyExtraInfo] 155 ], 156 env_state: EnvState, 157 agent_state: AgentState, 158 key: chex.PRNGKey, 159 rollout_length: int, 160 metric_names: tuple[str] = (), 161) -> tuple[SampleBatch, EnvState]: 162 """Evaulate a batch of episodic trajectories. 163 164 The retruned metrics are defined by `metric_names`. 165 """ 166 _eval_env_step = partial( 167 eval_env_step, env_fn, action_fn, metric_names=metric_names 168 ) 169 170 def _one_step_rollout(carry, unused_t): 171 env_state, current_key, prev_transition = carry 172 # next_key, current_key = jax.random.split(current_key, 2) 173 next_key, current_key = rng_split(current_key, 2) 174 175 transition, env_nstate = jax.lax.cond( 176 env_state.done.all(), 177 lambda *x: (env_state.replace(), prev_transition.replace()), 178 _eval_env_step, 179 env_state, 180 agent_state, 181 current_key, 182 ) 183 184 return (env_nstate, next_key, transition), transition 185 186 # run one-step rollout first to get bootstrap transition 187 # it will not include in the trajectory when env_state is from env.reset() 188 # this is manually controlled by user. 189 bootstrap_transition, _ = _eval_env_step(env_state, agent_state, key) 190 191 (env_state, _, _), trajectory = jax.lax.scan( 192 _one_step_rollout, 193 (env_state, key, bootstrap_transition), 194 (), 195 length=rollout_length, 196 ) 197 198 return trajectory, env_state
199 200
[docs] 201def eval_metrics( 202 env_fn: Callable[[EnvState, Action], EnvState], 203 action_fn: Callable[ 204 [AgentState, SampleBatch, chex.PRNGKey], tuple[Action, PolicyExtraInfo] 205 ], 206 env_state: EnvState, 207 agent_state: AgentState, 208 key: chex.PRNGKey, 209 rollout_length: int, 210 discount: float, 211 metric_names: tuple[str] = (), 212) -> tuple[PyTreeDict, EnvState]: 213 episode_trajectory, env_state = eval_rollout_episode( 214 env_fn, 215 action_fn, 216 env_state, 217 agent_state, 218 key, 219 rollout_length, 220 metric_names=metric_names, 221 ) 222 223 metrics = PyTreeDict() 224 for name in metric_names: 225 if "reward" in name: 226 # For metrics like 'reward_forward' and 'reward_ctrl' 227 metrics[name] = compute_discount_return( 228 episode_trajectory.rewards[name], 229 episode_trajectory.dones, 230 discount, 231 ) 232 elif "episode_lengths" == name: 233 metrics[name] = compute_episode_length(episode_trajectory.dones) 234 else: 235 # For other metrics like 'x_position', we use the last value as the objective. 236 # Note: It is ok to use [-1], since wrapper ensures that the last value 237 # repeats the terminal step value. 238 metrics[name] = episode_trajectory.rewards[name][-1] 239 240 return metrics, env_state
241 242
[docs] 243def fast_eval_metrics( 244 env_fn: Callable[[EnvState, Action], EnvState], 245 action_fn: Callable[ 246 [AgentState, SampleBatch, chex.PRNGKey], tuple[Action, PolicyExtraInfo] 247 ], 248 env_state: EnvState, 249 agent_state: AgentState, 250 key: chex.PRNGKey, 251 rollout_length: int, 252 metric_names: tuple[str] = (), 253) -> tuple[PyTreeDict, EnvState]: 254 """Fast evaulate a batch of episodic trajectories. 255 256 The retruned metrics are defined by `metric_names`. 257 """ 258 _eval_env_step = partial( 259 eval_env_step, env_fn, action_fn, metric_names=metric_names 260 ) 261 262 def _terminate_cond(carry): 263 env_state, current_key, prev_metrics = carry 264 return (prev_metrics.episode_lengths < rollout_length).all() & ( 265 ~env_state.done.all() 266 ) 267 268 def _one_step_rollout(carry): 269 env_state, current_key, prev_metrics = carry 270 next_key, current_key = rng_split(current_key, 2) 271 272 transition, env_nstate = _eval_env_step(env_state, agent_state, current_key) 273 274 prev_dones = env_state.done 275 276 metrics = PyTreeDict() 277 for name in metric_names: 278 if "reward" in name: 279 metrics[name] = ( 280 prev_metrics[name] + (1 - prev_dones) * transition.rewards[name] 281 ) 282 elif "episode_lengths" == name: 283 metrics[name] = prev_metrics[name] + (1 - prev_dones) 284 elif name in metrics: 285 metrics[name] = transition.rewards[name] 286 287 return env_nstate, next_key, metrics 288 289 batch_shape = env_state.reward.shape 290 291 env_state, _, metrics = jax.lax.while_loop( 292 _terminate_cond, 293 _one_step_rollout, 294 ( 295 env_state, 296 key, 297 PyTreeDict({name: jnp.zeros(batch_shape) for name in metric_names}), 298 ), 299 ) 300 301 return metrics, env_state