Source code for evorl.replay_buffers.prioritized_replay_buffer

  1"""Prioritized Replay Buffer with LAP (Loss-Adjusted Prioritization) support."""
  2
  3import chex
  4import jax
  5import jax.numpy as jnp
  6import jax.tree_util as jtu
  7
  8from evorl.utils.jax_utils import tree_get
  9
 10from .replay_buffer import ReplayBuffer, ReplayBufferState
 11
 12
[docs] 13class PrioritizedReplayBufferState(ReplayBufferState): 14 """State for the prioritized replay buffer. 15 16 Attributes: 17 priority: Priority values for each entry in the buffer. 18 max_priority: Current maximum priority value. 19 sample_indices: Indices of the last sampled batch (for priority updates). 20 """ 21 22 priority: chex.Array = jnp.zeros((), jnp.float32) 23 max_priority: chex.Array = jnp.ones((), jnp.float32) 24 sample_indices: chex.Array = jnp.zeros((), jnp.int32)
25 26
[docs] 27class PrioritizedReplayBuffer(ReplayBuffer): 28 """ReplayBuffer with proportional prioritized sampling (LAP). 29 30 Uses Loss-Adjusted Prioritization from TD7 paper. 31 New samples are assigned max_priority. Priorities are updated 32 based on TD-error after critic training. 33 34 Attributes: 35 capacity: the maximum capacity of the replay buffer. 36 sample_batch_size: the batch size for `sample()`. 37 min_sample_timesteps: the minimum number of timesteps before the 38 replay buffer can sample. 39 alpha: the exponent determining how much prioritization is used (0 = uniform, 1 = full). 40 """ 41 42 alpha: float = 0.6 43
[docs] 44 def init(self, spec: chex.ArrayTree) -> PrioritizedReplayBufferState: 45 data = jtu.tree_map( 46 lambda x: jnp.broadcast_to(jnp.empty_like(x), (self.capacity, *x.shape)), 47 spec, 48 ) 49 50 return PrioritizedReplayBufferState( 51 data=data, 52 current_index=jnp.zeros((), jnp.int32), 53 buffer_size=jnp.zeros((), jnp.int32), 54 priority=jnp.zeros(self.capacity, jnp.float32), 55 max_priority=jnp.ones((), jnp.float32), 56 sample_indices=jnp.zeros(self.sample_batch_size, jnp.int32), 57 )
58
[docs] 59 def add( 60 self, 61 buffer_state: PrioritizedReplayBufferState, 62 xs: chex.ArrayTree, 63 mask: chex.Array | None = None, 64 ) -> PrioritizedReplayBufferState: 65 # Get indices before calling parent add 66 if mask is not None: 67 batch_size = mask.sum() 68 cumsum_mask = jnp.cumsum(mask, axis=0, dtype=jnp.int32) 69 indices = (buffer_state.current_index + cumsum_mask - 1) % self.capacity 70 indices = jnp.where(mask, indices, self.capacity) 71 else: 72 batch_size = jtu.tree_leaves(xs)[0].shape[0] 73 indices = ( 74 buffer_state.current_index + jnp.arange(batch_size, dtype=jnp.int32) 75 ) % self.capacity 76 77 # Call parent add for data 78 new_state = super().add(buffer_state, xs, mask) 79 80 # Set priorities for new entries to max_priority 81 priority = buffer_state.priority.at[indices].set( 82 buffer_state.max_priority, 83 mode="drop", 84 ) 85 86 return PrioritizedReplayBufferState( 87 data=new_state.data, 88 current_index=new_state.current_index, 89 buffer_size=new_state.buffer_size, 90 priority=priority, 91 max_priority=buffer_state.max_priority, 92 sample_indices=buffer_state.sample_indices, 93 )
94
[docs] 95 def sample( 96 self, 97 buffer_state: PrioritizedReplayBufferState, 98 key: chex.PRNGKey, 99 beta: float | chex.Array = 0.4, 100 ) -> tuple[chex.ArrayTree, chex.Array, PrioritizedReplayBufferState]: 101 """Sample a batch proportional to priorities with IS weights. 102 103 Args: 104 buffer_state: Current buffer state. 105 key: PRNG key. 106 beta: Importance sampling exponent. 107 108 Returns: 109 A tuple of (batch, weights, updated_buffer_state). 110 The weights are the computed Importance Sampling (IS) weights. 111 """ 112 # Mask out invalid priorities beyond buffer_size 113 mask = jnp.arange(self.capacity) < buffer_state.buffer_size 114 raw_priority = jnp.where(mask, buffer_state.priority, 0.0) 115 116 # Apply alpha exponent 117 priority_alpha = raw_priority ** self.alpha 118 priority_alpha = jnp.where(mask, priority_alpha, 0.0) 119 120 sum_priority = jnp.sum(priority_alpha) 121 # Compute probabilities P(i) = p_i^alpha / sum(p_i^alpha) 122 # We avoid division here and directly use priority_alpha for csum 123 124 csum = jnp.cumsum(priority_alpha) 125 val = jax.random.uniform(key, (self.sample_batch_size,)) * csum[-1] 126 indices = jnp.searchsorted(csum, val) 127 # Clamp indices to valid range 128 indices = jnp.clip(indices, 0, buffer_state.buffer_size - 1) 129 130 batch = tree_get(buffer_state.data, indices) 131 132 # Compute IS weights: w_i = (N * P(i)) ** -beta / max(w_i) 133 # P(i) = priority_alpha[indices] / sum_priority 134 N = jnp.maximum(1, buffer_state.buffer_size) 135 p_i = priority_alpha[indices] / sum_priority 136 weights = (N * p_i) ** (-beta) 137 # Normalize weights by max weight in batch 138 weights = weights / jnp.max(weights) 139 140 # Store indices for later priority update 141 new_state = PrioritizedReplayBufferState( 142 data=buffer_state.data, 143 current_index=buffer_state.current_index, 144 buffer_size=buffer_state.buffer_size, 145 priority=buffer_state.priority, 146 max_priority=buffer_state.max_priority, 147 sample_indices=indices, 148 ) 149 150 return batch, weights, new_state
151
[docs] 152 def update_priority( 153 self, 154 buffer_state: PrioritizedReplayBufferState, 155 priority: chex.Array, 156 ) -> PrioritizedReplayBufferState: 157 """Update priorities at the last sampled indices. 158 159 Args: 160 buffer_state: Current buffer state (must have valid sample_indices). 161 priority: New priority values, shape (sample_batch_size,). 162 163 Returns: 164 Updated buffer state with new priorities. 165 """ 166 new_priority = buffer_state.priority.at[buffer_state.sample_indices].set( 167 priority 168 ) 169 new_max_priority = jnp.maximum(buffer_state.max_priority, priority.max()) 170 171 return PrioritizedReplayBufferState( 172 data=buffer_state.data, 173 current_index=buffer_state.current_index, 174 buffer_size=buffer_state.buffer_size, 175 priority=new_priority, 176 max_priority=new_max_priority, 177 sample_indices=buffer_state.sample_indices, 178 )
179
[docs] 180 def reset_max_priority( 181 self, buffer_state: PrioritizedReplayBufferState 182 ) -> PrioritizedReplayBufferState: 183 """Recompute max_priority from current buffer entries.""" 184 mask = jnp.arange(self.capacity) < buffer_state.buffer_size 185 valid_priority = jnp.where(mask, buffer_state.priority, -jnp.inf) 186 # Handle empty buffer case by defaulting to 1.0 (or current max) 187 new_max_priority = jnp.maximum(valid_priority.max(), 1e-5) 188 189 return PrioritizedReplayBufferState( 190 data=buffer_state.data, 191 current_index=buffer_state.current_index, 192 buffer_size=buffer_state.buffer_size, 193 priority=buffer_state.priority, 194 max_priority=new_max_priority, 195 sample_indices=buffer_state.sample_indices, 196 )