Source code for evorl.replay_buffers.lap_replay_buffer
1"""LAP (Loss-Adjusted Prioritization) Replay Buffer."""
2
3import jax
4import chex
5import jax.numpy as jnp
6
7from evorl.utils.jax_utils import tree_get
8from .prioritized_replay_buffer import PrioritizedReplayBuffer, PrioritizedReplayBufferState
9
10
11
[docs]
12class LAPReplayBuffer(PrioritizedReplayBuffer):
13 """ReplayBuffer with Loss-Adjusted Prioritization from TD7 paper.
14
15 LAP is a variation of Prioritized Experience Replay (PER) that uses
16 proportional sampling but explicitly drops Importance Sampling (IS) weights,
17 setting them uniformly to 1.0.
18 It leverages the parent `PrioritizedReplayBuffer` class for alpha-scaled
19 priority sampling.
20
21 Attributes:
22 capacity: the maximum capacity of the replay buffer.
23 sample_batch_size: the batch size for `sample()`.
24 min_sample_timesteps: the minimum number of timesteps before sampling.
25 alpha: the exponent determining how much prioritization is used (0 = uniform, 1 = full).
26 """
27
[docs]
28 def sample(
29 self,
30 buffer_state: PrioritizedReplayBufferState,
31 key: chex.PRNGKey,
32 beta: float | chex.Array = 0.0, # beta is ignored for LAP
33 ) -> tuple[chex.ArrayTree, chex.Array, PrioritizedReplayBufferState]:
34 """Sample a batch proportional to alpha-scaled priorities without IS weights.
35
36 Args:
37 buffer_state: Current buffer state.
38 key: PRNG key.
39 beta: Ignored mapping variable to match parent interface.
40
41 Returns:
42 A tuple of (batch, weights, updated_buffer_state).
43 The IS weights are uniformly set to 1.0.
44 """
45 # We can just call the parent sample with beta=0.0 which evaluates IS weights to 1.0.
46 # However, to save math operations, we explicitly build the batch.
47
48 # Mask out invalid priorities beyond buffer_size
49 mask = jnp.arange(self.capacity) < buffer_state.buffer_size
50 raw_priority = jnp.where(mask, buffer_state.priority, 0.0)
51
52 # Apply alpha exponent
53 priority_alpha = raw_priority ** self.alpha
54 priority_alpha = jnp.where(mask, priority_alpha, 0.0)
55
56 csum = jnp.cumsum(priority_alpha)
57 val = jax.random.uniform(key, (self.sample_batch_size,)) * csum[-1]
58 indices = jnp.searchsorted(csum, val)
59 # Clamp indices to valid range
60 indices = jnp.clip(indices, 0, buffer_state.buffer_size - 1)
61
62 batch = tree_get(buffer_state.data, indices)
63
64 # LAP explicitly ignores IS weights. Set to 1.0 uniformly.
65 weights = jnp.ones(self.sample_batch_size, jnp.float32)
66
67 # Store indices for later priority update
68 new_state = PrioritizedReplayBufferState(
69 data=buffer_state.data,
70 current_index=buffer_state.current_index,
71 buffer_size=buffer_state.buffer_size,
72 priority=buffer_state.priority,
73 max_priority=buffer_state.max_priority,
74 sample_indices=indices,
75 )
76
77 return batch, weights, new_state