Source code for evorl.replay_buffers.lap_replay_buffer

 1"""LAP (Loss-Adjusted Prioritization) Replay Buffer."""
 2
 3import jax
 4import chex
 5import jax.numpy as jnp
 6
 7from evorl.utils.jax_utils import tree_get
 8from .prioritized_replay_buffer import PrioritizedReplayBuffer, PrioritizedReplayBufferState
 9
10
11
[docs] 12class LAPReplayBuffer(PrioritizedReplayBuffer): 13 """ReplayBuffer with Loss-Adjusted Prioritization from TD7 paper. 14 15 LAP is a variation of Prioritized Experience Replay (PER) that uses 16 proportional sampling but explicitly drops Importance Sampling (IS) weights, 17 setting them uniformly to 1.0. 18 It leverages the parent `PrioritizedReplayBuffer` class for alpha-scaled 19 priority sampling. 20 21 Attributes: 22 capacity: the maximum capacity of the replay buffer. 23 sample_batch_size: the batch size for `sample()`. 24 min_sample_timesteps: the minimum number of timesteps before sampling. 25 alpha: the exponent determining how much prioritization is used (0 = uniform, 1 = full). 26 """ 27
[docs] 28 def sample( 29 self, 30 buffer_state: PrioritizedReplayBufferState, 31 key: chex.PRNGKey, 32 beta: float | chex.Array = 0.0, # beta is ignored for LAP 33 ) -> tuple[chex.ArrayTree, chex.Array, PrioritizedReplayBufferState]: 34 """Sample a batch proportional to alpha-scaled priorities without IS weights. 35 36 Args: 37 buffer_state: Current buffer state. 38 key: PRNG key. 39 beta: Ignored mapping variable to match parent interface. 40 41 Returns: 42 A tuple of (batch, weights, updated_buffer_state). 43 The IS weights are uniformly set to 1.0. 44 """ 45 # We can just call the parent sample with beta=0.0 which evaluates IS weights to 1.0. 46 # However, to save math operations, we explicitly build the batch. 47 48 # Mask out invalid priorities beyond buffer_size 49 mask = jnp.arange(self.capacity) < buffer_state.buffer_size 50 raw_priority = jnp.where(mask, buffer_state.priority, 0.0) 51 52 # Apply alpha exponent 53 priority_alpha = raw_priority ** self.alpha 54 priority_alpha = jnp.where(mask, priority_alpha, 0.0) 55 56 csum = jnp.cumsum(priority_alpha) 57 val = jax.random.uniform(key, (self.sample_batch_size,)) * csum[-1] 58 indices = jnp.searchsorted(csum, val) 59 # Clamp indices to valid range 60 indices = jnp.clip(indices, 0, buffer_state.buffer_size - 1) 61 62 batch = tree_get(buffer_state.data, indices) 63 64 # LAP explicitly ignores IS weights. Set to 1.0 uniformly. 65 weights = jnp.ones(self.sample_batch_size, jnp.float32) 66 67 # Store indices for later priority update 68 new_state = PrioritizedReplayBufferState( 69 data=buffer_state.data, 70 current_index=buffer_state.current_index, 71 buffer_size=buffer_state.buffer_size, 72 priority=buffer_state.priority, 73 max_priority=buffer_state.max_priority, 74 sample_indices=indices, 75 ) 76 77 return batch, weights, new_state