Source code for evorl.evaluators.episode_collector

  1import logging
  2import math
  3from collections.abc import Sequence
  4
  5import chex
  6import jax
  7import jax.numpy as jnp
  8import jax.tree_util as jtu
  9
 10from evorl.agent import AgentActionFn
 11from evorl.envs import Env
 12from evorl.metrics import EvaluateMetric
 13from evorl.rollout import rollout, RolloutFn
 14from evorl.types import PyTreeNode, pytree_field
 15from evorl.sample_batch import SampleBatch
 16from evorl.utils.jax_utils import rng_split
 17from evorl.utils.rl_toolkits import compute_discount_return, compute_episode_length
 18
 19logger = logging.getLogger(__name__)
 20
 21
[docs] 22class EpisodeCollector(PyTreeNode): 23 """Evaluate and return eval metrics and episodic trajectory. 24 25 EpisodeCollector is similar as `Evaluator`, but it also returns the trajectories. When evaluating the agent, call `rollout()`. 26 27 Attributes: 28 env: Vectorized environment w/o autoreset. 29 action_fn: The agent action function. 30 max_episode_steps: The maximum number of steps in an episode. 31 env_extra_fields: The extra fields to collect from the environment. 32 discount: The discount factor. 33 """ 34 35 env: Env 36 action_fn: AgentActionFn 37 max_episode_steps: int = pytree_field(static=True) 38 env_extra_fields: Sequence[str] = () 39 discount: float = 1.0 40 41 rollout_fn: RolloutFn = pytree_field(default=rollout, static=True) 42 43 def __post_init__(self): 44 assert hasattr(self.env, "num_envs"), "only parrallel envs are supported" 45
[docs] 46 def rollout( 47 self, 48 agent_state, 49 key: chex.PRNGKey, 50 num_episodes: int, 51 ) -> tuple[EvaluateMetric, SampleBatch]: 52 num_envs = self.env.num_envs 53 num_iters = math.ceil(num_episodes / num_envs) 54 if num_episodes % num_envs != 0: 55 logger.warning( 56 f"num_episode ({num_episodes}) cannot be divided by parallel_envs ({num_envs})," 57 f"set new num_episodes={num_iters * num_envs}" 58 ) 59 60 action_fn = self.action_fn 61 env_reset_fn = self.env.reset 62 env_step_fn = self.env.step 63 64 def _evaluate_fn(key, unused_t): 65 next_key, init_env_key = rng_split(key, 2) 66 env_state = env_reset_fn(init_env_key) 67 68 # Note: be careful when self.max_episode_steps < env.max_episode_steps, 69 # where dones could all be zeros. 70 episode_trajectory, env_state = self.rollout_fn( 71 env_step_fn, 72 action_fn, 73 env_state, 74 agent_state, 75 key, 76 self.max_episode_steps, 77 self.env_extra_fields, 78 ) 79 80 return next_key, episode_trajectory 81 82 # [#iters, T, #envs] 83 _, episode_trajectory = jax.lax.scan(_evaluate_fn, key, (), length=num_iters) 84 85 # [#iters, T, #envs] -> [T, num_episodes] 86 episode_trajectory = jtu.tree_map( 87 lambda x: jax.lax.collapse(jnp.swapaxes(x, 0, 1), 1, 3), episode_trajectory 88 ) 89 90 # [num_episodes] 91 discount_returns = compute_discount_return( 92 episode_trajectory.rewards, episode_trajectory.dones, self.discount 93 ) 94 95 episode_lengths = compute_episode_length(episode_trajectory.dones) 96 97 eval_metrics = EvaluateMetric( 98 episode_returns=discount_returns, 99 episode_lengths=episode_lengths, 100 ) 101 102 return eval_metrics, episode_trajectory