Source code for evorl.sample_batch

 1import chex
 2
 3from .types import PyTreeData, PyTreeArrayMixin, ExtraInfo, Reward, RewardDict
 4from .utils.jax_utils import right_shift_with_padding
 5
 6
[docs] 7class SampleBatch(PyTreeData, PyTreeArrayMixin): 8 """Data container for trajectory data.""" 9 10 obs: chex.ArrayTree | None = None 11 actions: chex.ArrayTree | None = None 12 rewards: Reward | RewardDict | None = None 13 next_obs: chex.Array | None = None 14 dones: chex.Array | None = None 15 extras: ExtraInfo | None = None
16 17
[docs] 18class Episode(PyTreeData): 19 """The container for an episode trajectory.""" 20 21 trajectory: SampleBatch 22 23 @property 24 def valid_mask(self) -> chex.Array: 25 return 1 - right_shift_with_padding(self.trajectory.dones, 1)