1from abc import ABCMeta, abstractmethod
2
3import chex
4import jax
5import jax.numpy as jnp
6import jax.tree_util as jtu
7
8from evorl.types import PyTreeData, PyTreeNode
9from evorl.utils.jax_utils import tree_get, tree_set
10
11
[docs]
12class ReplayBufferState(PyTreeData):
13 """Contains data related to a replay buffer.
14
15 Attributes:
16 data: the stored replay buffer data.
17 current_index: the pointer used for adding data.
18 buffer_size: the current size of the replay buffer.
19 """
20
21 data: chex.ArrayTree
22 current_index: chex.Array = jnp.zeros((), jnp.int32)
23 buffer_size: chex.Array = jnp.zeros((), jnp.int32)
24
25
[docs]
26class AbstractReplayBuffer(PyTreeNode, metaclass=ABCMeta):
27 """A ReplyBuffer Interface."""
28
[docs]
29 @abstractmethod
30 def init(self, sample_spec: chex.ArrayTree) -> ReplayBufferState:
31 """Initialize the state of the replay buffer.
32
33 Args:
34 sample_spec: A single sample or sample spec that contains the pytree structure and their dtype and shape.
35
36 Returns:
37 The initial state of the replay buffer.
38 """
39 pass
40
[docs]
41 @abstractmethod
42 def add(
43 self, buffer_state: ReplayBufferState, xs: chex.ArrayTree
44 ) -> ReplayBufferState:
45 """Add data to the replay buffer.
46
47 Args:
48 buffer_state: The current state of the replay buffer.
49 xs: The data to add to the replay buffer.
50
51 Returns:
52 Updated state of the replay buffer.
53 """
54 pass
55
[docs]
56 @abstractmethod
57 def sample(
58 self, buffer_state: ReplayBufferState, key: chex.PRNGKey
59 ) -> chex.ArrayTree:
60 """Sample a batch of data from the replay buffer.
61
62 Args:
63 buffer_state: The current state of the replay buffer.
64 key: JAX PRNGKey.
65
66 Returns:
67 A batch of data sampled from the replay buffer.
68 """
69 pass
70
[docs]
71 @abstractmethod
72 def can_sample(self, buffer_state: ReplayBufferState) -> bool:
73 """Check if the current replay buffer state can be used to sample.
74
75 Args:
76 buffer_state: The current state of the replay buffer.
77
78 Returns:
79 Whether the replay buffer is ready to call `sample()`.
80 """
81 pass
82
[docs]
83 @abstractmethod
84 def is_full(self, buffer_state: ReplayBufferState) -> bool:
85 pass
86
87
[docs]
88class ReplayBuffer(AbstractReplayBuffer):
89 """ReplayBuffer with uniform sampling.
90
91 Data are added and sampled in 1d-like structure.
92
93 Attributes:
94 capacity: the maximum capacity of the replay buffer.
95 sample_batch_size: the batch size for `sample()`.
96 min_sample_timesteps: the minimum number of timesteps before the replay buffer can sample.
97 """
98
99 capacity: int
100 sample_batch_size: int
101 min_sample_timesteps: int = 0
102
[docs]
103 def init(self, spec: chex.ArrayTree) -> ReplayBufferState:
104 # Note: broadcast_to will not pre-allocate memory
105 data = jtu.tree_map(
106 lambda x: jnp.broadcast_to(jnp.empty_like(x), (self.capacity, *x.shape)),
107 spec,
108 )
109
110 return ReplayBufferState(
111 data=data,
112 current_index=jnp.zeros((), jnp.int32),
113 buffer_size=jnp.zeros((), jnp.int32),
114 )
115
[docs]
116 def is_full(self, buffer_state: ReplayBufferState) -> bool:
117 return buffer_state.buffer_size == self.capacity
118
[docs]
119 def can_sample(self, buffer_state: ReplayBufferState) -> bool:
120 return buffer_state.buffer_size >= self.min_sample_timesteps
121
[docs]
122 def add(
123 self,
124 buffer_state: ReplayBufferState,
125 xs: chex.ArrayTree,
126 mask: chex.Array | None = None,
127 ) -> ReplayBufferState:
128 # Tips: when jit this function, set mask to static
129
130 chex.assert_trees_all_equal_dtypes(xs, buffer_state.data)
131
132 if mask is not None:
133 assert mask.ndim == 1
134 chex.assert_tree_shape_prefix(xs, mask.shape)
135 batch_size = mask.sum()
136
137 # Note: here we utilize the feature of jax.Array with mode="promise_in_bounds",
138 # that indices on [self.capacity] will be ignore when call set()
139 # eg: mask = [1,0,1,1,0], capacity = n > 5
140 # Then, cumsum_mask = [1,1,2,3,3], cumsum_mask-1 = [0,0,1,2,2]
141 # assume current_index = 0, then indices = [0,n,1,2,n]
142 cumsum_mask = jnp.cumsum(mask, axis=0, dtype=jnp.int32)
143 indices = (buffer_state.current_index + cumsum_mask - 1) % self.capacity
144 indices = jnp.where(mask, indices, self.capacity)
145 else:
146 batch_size = jtu.tree_leaves(xs)[0].shape[0]
147
148 indices = (
149 buffer_state.current_index + jnp.arange(batch_size, dtype=jnp.int32)
150 ) % self.capacity
151
152 data = tree_set(buffer_state.data, xs, indices, unique_indices=False)
153
154 current_index = (buffer_state.current_index + batch_size) % self.capacity
155 buffer_size = jnp.minimum(buffer_state.buffer_size + batch_size, self.capacity)
156
157 return buffer_state.replace(
158 data=data, current_index=current_index, buffer_size=buffer_size
159 )
160
[docs]
161 def sample(
162 self, buffer_state: ReplayBufferState, key: chex.ArrayTree
163 ) -> chex.ArrayTree:
164 indices = jax.random.randint(
165 key, (self.sample_batch_size,), minval=0, maxval=buffer_state.buffer_size
166 )
167
168 batch = tree_get(buffer_state.data, indices)
169
170 return batch