Source code for evorl.replay_buffers.replay_buffer

  1from abc import ABCMeta, abstractmethod
  2
  3import chex
  4import jax
  5import jax.numpy as jnp
  6import jax.tree_util as jtu
  7
  8from evorl.types import PyTreeData, PyTreeNode
  9from evorl.utils.jax_utils import tree_get, tree_set
 10
 11
[docs] 12class ReplayBufferState(PyTreeData): 13 """Contains data related to a replay buffer. 14 15 Attributes: 16 data: the stored replay buffer data. 17 current_index: the pointer used for adding data. 18 buffer_size: the current size of the replay buffer. 19 """ 20 21 data: chex.ArrayTree 22 current_index: chex.Array = jnp.zeros((), jnp.int32) 23 buffer_size: chex.Array = jnp.zeros((), jnp.int32)
24 25
[docs] 26class AbstractReplayBuffer(PyTreeNode, metaclass=ABCMeta): 27 """A ReplyBuffer Interface.""" 28
[docs] 29 @abstractmethod 30 def init(self, sample_spec: chex.ArrayTree) -> ReplayBufferState: 31 """Initialize the state of the replay buffer. 32 33 Args: 34 sample_spec: A single sample or sample spec that contains the pytree structure and their dtype and shape. 35 36 Returns: 37 The initial state of the replay buffer. 38 """ 39 pass
40
[docs] 41 @abstractmethod 42 def add( 43 self, buffer_state: ReplayBufferState, xs: chex.ArrayTree 44 ) -> ReplayBufferState: 45 """Add data to the replay buffer. 46 47 Args: 48 buffer_state: The current state of the replay buffer. 49 xs: The data to add to the replay buffer. 50 51 Returns: 52 Updated state of the replay buffer. 53 """ 54 pass
55
[docs] 56 @abstractmethod 57 def sample( 58 self, buffer_state: ReplayBufferState, key: chex.PRNGKey 59 ) -> chex.ArrayTree: 60 """Sample a batch of data from the replay buffer. 61 62 Args: 63 buffer_state: The current state of the replay buffer. 64 key: JAX PRNGKey. 65 66 Returns: 67 A batch of data sampled from the replay buffer. 68 """ 69 pass
70
[docs] 71 @abstractmethod 72 def can_sample(self, buffer_state: ReplayBufferState) -> bool: 73 """Check if the current replay buffer state can be used to sample. 74 75 Args: 76 buffer_state: The current state of the replay buffer. 77 78 Returns: 79 Whether the replay buffer is ready to call `sample()`. 80 """ 81 pass
82
[docs] 83 @abstractmethod 84 def is_full(self, buffer_state: ReplayBufferState) -> bool: 85 pass
86 87
[docs] 88class ReplayBuffer(AbstractReplayBuffer): 89 """ReplayBuffer with uniform sampling. 90 91 Data are added and sampled in 1d-like structure. 92 93 Attributes: 94 capacity: the maximum capacity of the replay buffer. 95 sample_batch_size: the batch size for `sample()`. 96 min_sample_timesteps: the minimum number of timesteps before the replay buffer can sample. 97 """ 98 99 capacity: int 100 sample_batch_size: int 101 min_sample_timesteps: int = 0 102
[docs] 103 def init(self, spec: chex.ArrayTree) -> ReplayBufferState: 104 # Note: broadcast_to will not pre-allocate memory 105 data = jtu.tree_map( 106 lambda x: jnp.broadcast_to(jnp.empty_like(x), (self.capacity, *x.shape)), 107 spec, 108 ) 109 110 return ReplayBufferState( 111 data=data, 112 current_index=jnp.zeros((), jnp.int32), 113 buffer_size=jnp.zeros((), jnp.int32), 114 )
115
[docs] 116 def is_full(self, buffer_state: ReplayBufferState) -> bool: 117 return buffer_state.buffer_size == self.capacity
118
[docs] 119 def can_sample(self, buffer_state: ReplayBufferState) -> bool: 120 return buffer_state.buffer_size >= self.min_sample_timesteps
121
[docs] 122 def add( 123 self, 124 buffer_state: ReplayBufferState, 125 xs: chex.ArrayTree, 126 mask: chex.Array | None = None, 127 ) -> ReplayBufferState: 128 # Tips: when jit this function, set mask to static 129 130 chex.assert_trees_all_equal_dtypes(xs, buffer_state.data) 131 132 if mask is not None: 133 assert mask.ndim == 1 134 chex.assert_tree_shape_prefix(xs, mask.shape) 135 batch_size = mask.sum() 136 137 # Note: here we utilize the feature of jax.Array with mode="promise_in_bounds", 138 # that indices on [self.capacity] will be ignore when call set() 139 # eg: mask = [1,0,1,1,0], capacity = n > 5 140 # Then, cumsum_mask = [1,1,2,3,3], cumsum_mask-1 = [0,0,1,2,2] 141 # assume current_index = 0, then indices = [0,n,1,2,n] 142 cumsum_mask = jnp.cumsum(mask, axis=0, dtype=jnp.int32) 143 indices = (buffer_state.current_index + cumsum_mask - 1) % self.capacity 144 indices = jnp.where(mask, indices, self.capacity) 145 else: 146 batch_size = jtu.tree_leaves(xs)[0].shape[0] 147 148 indices = ( 149 buffer_state.current_index + jnp.arange(batch_size, dtype=jnp.int32) 150 ) % self.capacity 151 152 data = tree_set(buffer_state.data, xs, indices, unique_indices=False) 153 154 current_index = (buffer_state.current_index + batch_size) % self.capacity 155 buffer_size = jnp.minimum(buffer_state.buffer_size + batch_size, self.capacity) 156 157 return buffer_state.replace( 158 data=data, current_index=current_index, buffer_size=buffer_size 159 )
160
[docs] 161 def sample( 162 self, buffer_state: ReplayBufferState, key: chex.ArrayTree 163 ) -> chex.ArrayTree: 164 indices = jax.random.randint( 165 key, (self.sample_batch_size,), minval=0, maxval=buffer_state.buffer_size 166 ) 167 168 batch = tree_get(buffer_state.data, indices) 169 170 return batch