Source code for evorl.sample_batch
1import chex
2
3from .types import PyTreeData, PyTreeArrayMixin, ExtraInfo, Reward, RewardDict
4from .utils.jax_utils import right_shift_with_padding
5
6
[docs]
7class SampleBatch(PyTreeData, PyTreeArrayMixin):
8 """Data container for trajectory data."""
9
10 obs: chex.ArrayTree | None = None
11 actions: chex.ArrayTree | None = None
12 rewards: Reward | RewardDict | None = None
13 next_obs: chex.Array | None = None
14 dones: chex.Array | None = None
15 extras: ExtraInfo | None = None
16
17
[docs]
18class Episode(PyTreeData):
19 """The container for an episode trajectory."""
20
21 trajectory: SampleBatch
22
23 @property
24 def valid_mask(self) -> chex.Array:
25 return 1 - right_shift_with_padding(self.trajectory.dones, 1)