1import logging
2import math
3
4import chex
5import jax
6
7from evorl.agent import AgentState
8from evorl.envs import Env
9from evorl.metrics import EvaluateMetric
10from evorl.rollout import rollout, fast_eval_rollout_episode
11from evorl.types import PyTreeNode, pytree_field
12from evorl.agent import AgentActionFn
13from evorl.utils.jax_utils import rng_split
14from evorl.utils.rl_toolkits import compute_discount_return, compute_episode_length
15
16logger = logging.getLogger(__name__)
17
18
[docs]
19class Evaluator(PyTreeNode):
20 """Evaluate the agent in the environments.
21
22 Attributes:
23 env: Vectorized environment w/o autoreset.
24 action_fn: The agent action function.
25 max_episode_steps: The maximum number of steps in an episode.
26 discount: The discount factor.
27 """
28
29 env: Env = pytree_field(static=True)
30 action_fn: AgentActionFn = pytree_field(static=True)
31 max_episode_steps: int = pytree_field(static=True)
32 discount: float = pytree_field(default=1.0, static=True)
33
34 def __post_init__(self):
35 assert hasattr(self.env, "num_envs"), "only vectorized envs are supported"
36 # assert self.max_episode_steps <= self.env.max_episode_steps, (
37 # f"max_episode_steps {self.max_episode_steps} should be equal or less than env.max_episode_steps {self.env.max_episode_steps}"
38 # )
39
[docs]
40 def evaluate(
41 self,
42 agent_state: AgentState,
43 key: chex.PRNGKey,
44 num_episodes: int,
45 ) -> EvaluateMetric:
46 """Evaluate the agent based on its state.
47
48 Args:
49 agent_state: The state of the agent.
50 key: The PRNG key.
51 num_episodes: The number of episodes to evaluate.
52
53 Returns:
54 EvaluateMetric(episode_returns, episode_lengths).
55 """
56 num_envs = self.env.num_envs
57 num_iters = math.ceil(num_episodes / num_envs)
58 if num_episodes % num_envs != 0:
59 logger.warning(
60 f"num_episode ({num_episodes}) cannot be divided by parallel_envs ({num_envs}),"
61 f"set new num_episodes={num_iters * num_envs}"
62 )
63
64 action_fn = self.action_fn
65 env_reset_fn = self.env.reset
66 env_step_fn = self.env.step
67
68 def _evaluate_fn(key, unused_t):
69 next_key, init_env_key, eval_key = rng_split(key, 3)
70 env_state = env_reset_fn(init_env_key)
71 if self.discount == 1.0:
72 episode_metrics, env_state = fast_eval_rollout_episode(
73 env_step_fn,
74 action_fn,
75 env_state,
76 agent_state,
77 eval_key,
78 self.max_episode_steps,
79 )
80 episode_returns = episode_metrics.episode_returns
81 episode_lengths = episode_metrics.episode_lengths
82 else:
83 episode_trajectory, env_state = rollout(
84 env_step_fn,
85 action_fn,
86 env_state,
87 agent_state,
88 eval_key,
89 self.max_episode_steps,
90 )
91
92 # Note: be careful when self.max_episode_steps < env.max_episode_steps,
93 # where dones could all be zeros.
94 # compute_discount_return & compute_episode_length are fine!
95 episode_returns = compute_discount_return(
96 episode_trajectory.rewards, episode_trajectory.dones, self.discount
97 )
98 episode_lengths = compute_episode_length(episode_trajectory.dones)
99
100 return next_key, (episode_returns, episode_lengths) # [..., #envs]
101
102 # [#iters, #envs]
103 _, (episode_returns, episode_lengths) = jax.lax.scan(
104 _evaluate_fn, key, (), length=num_iters
105 )
106
107 # [#iters, #envs] -> [num_episodes]
108 eval_metrics = EvaluateMetric(
109 episode_returns=episode_returns.flatten(),
110 episode_lengths=episode_lengths.flatten(),
111 )
112
113 return eval_metrics