1import logging
2import math
3from collections.abc import Sequence
4
5import chex
6import jax
7import jax.numpy as jnp
8import jax.tree_util as jtu
9
10from evorl.agent import AgentActionFn
11from evorl.envs import Env
12from evorl.metrics import EvaluateMetric
13from evorl.rollout import rollout, RolloutFn
14from evorl.types import PyTreeNode, pytree_field
15from evorl.sample_batch import SampleBatch
16from evorl.utils.jax_utils import rng_split
17from evorl.utils.rl_toolkits import compute_discount_return, compute_episode_length
18
19logger = logging.getLogger(__name__)
20
21
[docs]
22class EpisodeCollector(PyTreeNode):
23 """Evaluate and return eval metrics and episodic trajectory.
24
25 EpisodeCollector is similar as `Evaluator`, but it also returns the trajectories. When evaluating the agent, call `rollout()`.
26
27 Attributes:
28 env: Vectorized environment w/o autoreset.
29 action_fn: The agent action function.
30 max_episode_steps: The maximum number of steps in an episode.
31 env_extra_fields: The extra fields to collect from the environment.
32 discount: The discount factor.
33 """
34
35 env: Env
36 action_fn: AgentActionFn
37 max_episode_steps: int = pytree_field(static=True)
38 env_extra_fields: Sequence[str] = ()
39 discount: float = 1.0
40
41 rollout_fn: RolloutFn = pytree_field(default=rollout, static=True)
42
43 def __post_init__(self):
44 assert hasattr(self.env, "num_envs"), "only parrallel envs are supported"
45
[docs]
46 def rollout(
47 self,
48 agent_state,
49 key: chex.PRNGKey,
50 num_episodes: int,
51 ) -> tuple[EvaluateMetric, SampleBatch]:
52 num_envs = self.env.num_envs
53 num_iters = math.ceil(num_episodes / num_envs)
54 if num_episodes % num_envs != 0:
55 logger.warning(
56 f"num_episode ({num_episodes}) cannot be divided by parallel_envs ({num_envs}),"
57 f"set new num_episodes={num_iters * num_envs}"
58 )
59
60 action_fn = self.action_fn
61 env_reset_fn = self.env.reset
62 env_step_fn = self.env.step
63
64 def _evaluate_fn(key, unused_t):
65 next_key, init_env_key = rng_split(key, 2)
66 env_state = env_reset_fn(init_env_key)
67
68 # Note: be careful when self.max_episode_steps < env.max_episode_steps,
69 # where dones could all be zeros.
70 episode_trajectory, env_state = self.rollout_fn(
71 env_step_fn,
72 action_fn,
73 env_state,
74 agent_state,
75 key,
76 self.max_episode_steps,
77 self.env_extra_fields,
78 )
79
80 return next_key, episode_trajectory
81
82 # [#iters, T, #envs]
83 _, episode_trajectory = jax.lax.scan(_evaluate_fn, key, (), length=num_iters)
84
85 # [#iters, T, #envs] -> [T, num_episodes]
86 episode_trajectory = jtu.tree_map(
87 lambda x: jax.lax.collapse(jnp.swapaxes(x, 0, 1), 1, 3), episode_trajectory
88 )
89
90 # [num_episodes]
91 discount_returns = compute_discount_return(
92 episode_trajectory.rewards, episode_trajectory.dones, self.discount
93 )
94
95 episode_lengths = compute_episode_length(episode_trajectory.dones)
96
97 eval_metrics = EvaluateMetric(
98 episode_returns=discount_returns,
99 episode_lengths=episode_lengths,
100 )
101
102 return eval_metrics, episode_trajectory