1"""Prioritized Replay Buffer with LAP (Loss-Adjusted Prioritization) support."""
2
3import chex
4import jax
5import jax.numpy as jnp
6import jax.tree_util as jtu
7
8from evorl.utils.jax_utils import tree_get
9
10from .replay_buffer import ReplayBuffer, ReplayBufferState
11
12
[docs]
13class PrioritizedReplayBufferState(ReplayBufferState):
14 """State for the prioritized replay buffer.
15
16 Attributes:
17 priority: Priority values for each entry in the buffer.
18 max_priority: Current maximum priority value.
19 sample_indices: Indices of the last sampled batch (for priority updates).
20 """
21
22 priority: chex.Array = jnp.zeros((), jnp.float32)
23 max_priority: chex.Array = jnp.ones((), jnp.float32)
24 sample_indices: chex.Array = jnp.zeros((), jnp.int32)
25
26
[docs]
27class PrioritizedReplayBuffer(ReplayBuffer):
28 """ReplayBuffer with proportional prioritized sampling (LAP).
29
30 Uses Loss-Adjusted Prioritization from TD7 paper.
31 New samples are assigned max_priority. Priorities are updated
32 based on TD-error after critic training.
33
34 Attributes:
35 capacity: the maximum capacity of the replay buffer.
36 sample_batch_size: the batch size for `sample()`.
37 min_sample_timesteps: the minimum number of timesteps before the
38 replay buffer can sample.
39 alpha: the exponent determining how much prioritization is used (0 = uniform, 1 = full).
40 """
41
42 alpha: float = 0.6
43
[docs]
44 def init(self, spec: chex.ArrayTree) -> PrioritizedReplayBufferState:
45 data = jtu.tree_map(
46 lambda x: jnp.broadcast_to(jnp.empty_like(x), (self.capacity, *x.shape)),
47 spec,
48 )
49
50 return PrioritizedReplayBufferState(
51 data=data,
52 current_index=jnp.zeros((), jnp.int32),
53 buffer_size=jnp.zeros((), jnp.int32),
54 priority=jnp.zeros(self.capacity, jnp.float32),
55 max_priority=jnp.ones((), jnp.float32),
56 sample_indices=jnp.zeros(self.sample_batch_size, jnp.int32),
57 )
58
[docs]
59 def add(
60 self,
61 buffer_state: PrioritizedReplayBufferState,
62 xs: chex.ArrayTree,
63 mask: chex.Array | None = None,
64 ) -> PrioritizedReplayBufferState:
65 # Get indices before calling parent add
66 if mask is not None:
67 batch_size = mask.sum()
68 cumsum_mask = jnp.cumsum(mask, axis=0, dtype=jnp.int32)
69 indices = (buffer_state.current_index + cumsum_mask - 1) % self.capacity
70 indices = jnp.where(mask, indices, self.capacity)
71 else:
72 batch_size = jtu.tree_leaves(xs)[0].shape[0]
73 indices = (
74 buffer_state.current_index + jnp.arange(batch_size, dtype=jnp.int32)
75 ) % self.capacity
76
77 # Call parent add for data
78 new_state = super().add(buffer_state, xs, mask)
79
80 # Set priorities for new entries to max_priority
81 priority = buffer_state.priority.at[indices].set(
82 buffer_state.max_priority,
83 mode="drop",
84 )
85
86 return PrioritizedReplayBufferState(
87 data=new_state.data,
88 current_index=new_state.current_index,
89 buffer_size=new_state.buffer_size,
90 priority=priority,
91 max_priority=buffer_state.max_priority,
92 sample_indices=buffer_state.sample_indices,
93 )
94
[docs]
95 def sample(
96 self,
97 buffer_state: PrioritizedReplayBufferState,
98 key: chex.PRNGKey,
99 beta: float | chex.Array = 0.4,
100 ) -> tuple[chex.ArrayTree, chex.Array, PrioritizedReplayBufferState]:
101 """Sample a batch proportional to priorities with IS weights.
102
103 Args:
104 buffer_state: Current buffer state.
105 key: PRNG key.
106 beta: Importance sampling exponent.
107
108 Returns:
109 A tuple of (batch, weights, updated_buffer_state).
110 The weights are the computed Importance Sampling (IS) weights.
111 """
112 # Mask out invalid priorities beyond buffer_size
113 mask = jnp.arange(self.capacity) < buffer_state.buffer_size
114 raw_priority = jnp.where(mask, buffer_state.priority, 0.0)
115
116 # Apply alpha exponent
117 priority_alpha = raw_priority ** self.alpha
118 priority_alpha = jnp.where(mask, priority_alpha, 0.0)
119
120 sum_priority = jnp.sum(priority_alpha)
121 # Compute probabilities P(i) = p_i^alpha / sum(p_i^alpha)
122 # We avoid division here and directly use priority_alpha for csum
123
124 csum = jnp.cumsum(priority_alpha)
125 val = jax.random.uniform(key, (self.sample_batch_size,)) * csum[-1]
126 indices = jnp.searchsorted(csum, val)
127 # Clamp indices to valid range
128 indices = jnp.clip(indices, 0, buffer_state.buffer_size - 1)
129
130 batch = tree_get(buffer_state.data, indices)
131
132 # Compute IS weights: w_i = (N * P(i)) ** -beta / max(w_i)
133 # P(i) = priority_alpha[indices] / sum_priority
134 N = jnp.maximum(1, buffer_state.buffer_size)
135 p_i = priority_alpha[indices] / sum_priority
136 weights = (N * p_i) ** (-beta)
137 # Normalize weights by max weight in batch
138 weights = weights / jnp.max(weights)
139
140 # Store indices for later priority update
141 new_state = PrioritizedReplayBufferState(
142 data=buffer_state.data,
143 current_index=buffer_state.current_index,
144 buffer_size=buffer_state.buffer_size,
145 priority=buffer_state.priority,
146 max_priority=buffer_state.max_priority,
147 sample_indices=indices,
148 )
149
150 return batch, weights, new_state
151
[docs]
152 def update_priority(
153 self,
154 buffer_state: PrioritizedReplayBufferState,
155 priority: chex.Array,
156 ) -> PrioritizedReplayBufferState:
157 """Update priorities at the last sampled indices.
158
159 Args:
160 buffer_state: Current buffer state (must have valid sample_indices).
161 priority: New priority values, shape (sample_batch_size,).
162
163 Returns:
164 Updated buffer state with new priorities.
165 """
166 new_priority = buffer_state.priority.at[buffer_state.sample_indices].set(
167 priority
168 )
169 new_max_priority = jnp.maximum(buffer_state.max_priority, priority.max())
170
171 return PrioritizedReplayBufferState(
172 data=buffer_state.data,
173 current_index=buffer_state.current_index,
174 buffer_size=buffer_state.buffer_size,
175 priority=new_priority,
176 max_priority=new_max_priority,
177 sample_indices=buffer_state.sample_indices,
178 )
179
[docs]
180 def reset_max_priority(
181 self, buffer_state: PrioritizedReplayBufferState
182 ) -> PrioritizedReplayBufferState:
183 """Recompute max_priority from current buffer entries."""
184 mask = jnp.arange(self.capacity) < buffer_state.buffer_size
185 valid_priority = jnp.where(mask, buffer_state.priority, -jnp.inf)
186 # Handle empty buffer case by defaulting to 1.0 (or current max)
187 new_max_priority = jnp.maximum(valid_priority.max(), 1e-5)
188
189 return PrioritizedReplayBufferState(
190 data=buffer_state.data,
191 current_index=buffer_state.current_index,
192 buffer_size=buffer_state.buffer_size,
193 priority=buffer_state.priority,
194 max_priority=new_max_priority,
195 sample_indices=buffer_state.sample_indices,
196 )