Source code for evorl.evaluators.evaluator

  1import logging
  2import math
  3
  4import chex
  5import jax
  6
  7from evorl.agent import AgentState
  8from evorl.envs import Env
  9from evorl.metrics import EvaluateMetric
 10from evorl.rollout import rollout, fast_eval_rollout_episode
 11from evorl.types import PyTreeNode, pytree_field
 12from evorl.agent import AgentActionFn
 13from evorl.utils.jax_utils import rng_split
 14from evorl.utils.rl_toolkits import compute_discount_return, compute_episode_length
 15
 16logger = logging.getLogger(__name__)
 17
 18
[docs] 19class Evaluator(PyTreeNode): 20 """Evaluate the agent in the environments. 21 22 Attributes: 23 env: Vectorized environment w/o autoreset. 24 action_fn: The agent action function. 25 max_episode_steps: The maximum number of steps in an episode. 26 discount: The discount factor. 27 """ 28 29 env: Env = pytree_field(static=True) 30 action_fn: AgentActionFn = pytree_field(static=True) 31 max_episode_steps: int = pytree_field(static=True) 32 discount: float = pytree_field(default=1.0, static=True) 33 34 def __post_init__(self): 35 assert hasattr(self.env, "num_envs"), "only vectorized envs are supported" 36 # assert self.max_episode_steps <= self.env.max_episode_steps, ( 37 # f"max_episode_steps {self.max_episode_steps} should be equal or less than env.max_episode_steps {self.env.max_episode_steps}" 38 # ) 39
[docs] 40 def evaluate( 41 self, 42 agent_state: AgentState, 43 key: chex.PRNGKey, 44 num_episodes: int, 45 ) -> EvaluateMetric: 46 """Evaluate the agent based on its state. 47 48 Args: 49 agent_state: The state of the agent. 50 key: The PRNG key. 51 num_episodes: The number of episodes to evaluate. 52 53 Returns: 54 EvaluateMetric(episode_returns, episode_lengths). 55 """ 56 num_envs = self.env.num_envs 57 num_iters = math.ceil(num_episodes / num_envs) 58 if num_episodes % num_envs != 0: 59 logger.warning( 60 f"num_episode ({num_episodes}) cannot be divided by parallel_envs ({num_envs})," 61 f"set new num_episodes={num_iters * num_envs}" 62 ) 63 64 action_fn = self.action_fn 65 env_reset_fn = self.env.reset 66 env_step_fn = self.env.step 67 68 def _evaluate_fn(key, unused_t): 69 next_key, init_env_key, eval_key = rng_split(key, 3) 70 env_state = env_reset_fn(init_env_key) 71 if self.discount == 1.0: 72 episode_metrics, env_state = fast_eval_rollout_episode( 73 env_step_fn, 74 action_fn, 75 env_state, 76 agent_state, 77 eval_key, 78 self.max_episode_steps, 79 ) 80 episode_returns = episode_metrics.episode_returns 81 episode_lengths = episode_metrics.episode_lengths 82 else: 83 episode_trajectory, env_state = rollout( 84 env_step_fn, 85 action_fn, 86 env_state, 87 agent_state, 88 eval_key, 89 self.max_episode_steps, 90 ) 91 92 # Note: be careful when self.max_episode_steps < env.max_episode_steps, 93 # where dones could all be zeros. 94 # compute_discount_return & compute_episode_length are fine! 95 episode_returns = compute_discount_return( 96 episode_trajectory.rewards, episode_trajectory.dones, self.discount 97 ) 98 episode_lengths = compute_episode_length(episode_trajectory.dones) 99 100 return next_key, (episode_returns, episode_lengths) # [..., #envs] 101 102 # [#iters, #envs] 103 _, (episode_returns, episode_lengths) = jax.lax.scan( 104 _evaluate_fn, key, (), length=num_iters 105 ) 106 107 # [#iters, #envs] -> [num_episodes] 108 eval_metrics = EvaluateMetric( 109 episode_returns=episode_returns.flatten(), 110 episode_lengths=episode_lengths.flatten(), 111 ) 112 113 return eval_metrics