evorl.replay_buffers¶
Package Contents¶
Classes¶
A ReplyBuffer Interface. |
|
ReplayBuffer with Loss-Adjusted Prioritization from TD7 paper. |
|
ReplayBuffer with proportional prioritized sampling (LAP). |
|
State for the prioritized replay buffer. |
|
ReplayBuffer with uniform sampling. |
|
Contains data related to a replay buffer. |
API¶
- class evorl.replay_buffers.AbstractReplayBuffer[source]¶
Bases:
evorl.types.PyTreeNodeA ReplyBuffer Interface.
- abstract add(buffer_state: evorl.replay_buffers.replay_buffer.ReplayBufferState, xs: chex.ArrayTree) evorl.replay_buffers.replay_buffer.ReplayBufferState[source]¶
Add data to the replay buffer.
- Parameters:
buffer_state – The current state of the replay buffer.
xs – The data to add to the replay buffer.
- Returns:
Updated state of the replay buffer.
- abstract can_sample(buffer_state: evorl.replay_buffers.replay_buffer.ReplayBufferState) bool[source]¶
Check if the current replay buffer state can be used to sample.
- Parameters:
buffer_state – The current state of the replay buffer.
- Returns:
Whether the replay buffer is ready to call
sample().
- abstract init(sample_spec: chex.ArrayTree) evorl.replay_buffers.replay_buffer.ReplayBufferState[source]¶
Initialize the state of the replay buffer.
- Parameters:
sample_spec – A single sample or sample spec that contains the pytree structure and their dtype and shape.
- Returns:
The initial state of the replay buffer.
- abstract is_full(buffer_state: evorl.replay_buffers.replay_buffer.ReplayBufferState) bool[source]¶
- abstract sample(buffer_state: evorl.replay_buffers.replay_buffer.ReplayBufferState, key: chex.PRNGKey) chex.ArrayTree[source]¶
Sample a batch of data from the replay buffer.
- Parameters:
buffer_state – The current state of the replay buffer.
key – JAX PRNGKey.
- Returns:
A batch of data sampled from the replay buffer.
- class evorl.replay_buffers.LAPReplayBuffer[source]¶
Bases:
evorl.replay_buffers.prioritized_replay_buffer.PrioritizedReplayBufferReplayBuffer with Loss-Adjusted Prioritization from TD7 paper.
LAP is a variation of Prioritized Experience Replay (PER) that uses proportional sampling but explicitly drops Importance Sampling (IS) weights, setting them uniformly to 1.0. It leverages the parent
PrioritizedReplayBufferclass for alpha-scaled priority sampling.- Variables:
capacity – the maximum capacity of the replay buffer.
sample_batch_size – the batch size for
sample().min_sample_timesteps – the minimum number of timesteps before sampling.
alpha – the exponent determining how much prioritization is used (0 = uniform, 1 = full).
- sample(buffer_state: evorl.replay_buffers.prioritized_replay_buffer.PrioritizedReplayBufferState, key: chex.PRNGKey, beta: float | chex.Array = 0.0) tuple[chex.ArrayTree, chex.Array, evorl.replay_buffers.prioritized_replay_buffer.PrioritizedReplayBufferState][source]¶
Sample a batch proportional to alpha-scaled priorities without IS weights.
- Parameters:
buffer_state – Current buffer state.
key – PRNG key.
beta – Ignored mapping variable to match parent interface.
- Returns:
A tuple of (batch, weights, updated_buffer_state). The IS weights are uniformly set to 1.0.
- class evorl.replay_buffers.PrioritizedReplayBuffer[source]¶
Bases:
evorl.replay_buffers.replay_buffer.ReplayBufferReplayBuffer with proportional prioritized sampling (LAP).
Uses Loss-Adjusted Prioritization from TD7 paper. New samples are assigned max_priority. Priorities are updated based on TD-error after critic training.
- Variables:
capacity – the maximum capacity of the replay buffer.
sample_batch_size – the batch size for
sample().min_sample_timesteps – the minimum number of timesteps before the replay buffer can sample.
alpha – the exponent determining how much prioritization is used (0 = uniform, 1 = full).
- add(buffer_state: evorl.replay_buffers.prioritized_replay_buffer.PrioritizedReplayBufferState, xs: chex.ArrayTree, mask: chex.Array | None = None) evorl.replay_buffers.prioritized_replay_buffer.PrioritizedReplayBufferState[source]¶
- alpha: float¶
0.6
- init(spec: chex.ArrayTree) evorl.replay_buffers.prioritized_replay_buffer.PrioritizedReplayBufferState[source]¶
- reset_max_priority(buffer_state: evorl.replay_buffers.prioritized_replay_buffer.PrioritizedReplayBufferState) evorl.replay_buffers.prioritized_replay_buffer.PrioritizedReplayBufferState[source]¶
Recompute max_priority from current buffer entries.
- sample(buffer_state: evorl.replay_buffers.prioritized_replay_buffer.PrioritizedReplayBufferState, key: chex.PRNGKey, beta: float | chex.Array = 0.4) tuple[chex.ArrayTree, chex.Array, evorl.replay_buffers.prioritized_replay_buffer.PrioritizedReplayBufferState][source]¶
Sample a batch proportional to priorities with IS weights.
- Parameters:
buffer_state – Current buffer state.
key – PRNG key.
beta – Importance sampling exponent.
- Returns:
A tuple of (batch, weights, updated_buffer_state). The weights are the computed Importance Sampling (IS) weights.
- update_priority(buffer_state: evorl.replay_buffers.prioritized_replay_buffer.PrioritizedReplayBufferState, priority: chex.Array) evorl.replay_buffers.prioritized_replay_buffer.PrioritizedReplayBufferState[source]¶
Update priorities at the last sampled indices.
- Parameters:
buffer_state – Current buffer state (must have valid sample_indices).
priority – New priority values, shape (sample_batch_size,).
- Returns:
Updated buffer state with new priorities.
- class evorl.replay_buffers.PrioritizedReplayBufferState[source]¶
Bases:
evorl.replay_buffers.replay_buffer.ReplayBufferStateState for the prioritized replay buffer.
- Variables:
priority – Priority values for each entry in the buffer.
max_priority – Current maximum priority value.
sample_indices – Indices of the last sampled batch (for priority updates).
- max_priority: chex.Array¶
‘ones(…)’
- priority: chex.Array¶
‘zeros(…)’
- sample_indices: chex.Array¶
‘zeros(…)’
- class evorl.replay_buffers.ReplayBuffer[source]¶
Bases:
evorl.replay_buffers.replay_buffer.AbstractReplayBufferReplayBuffer with uniform sampling.
Data are added and sampled in 1d-like structure.
- Variables:
capacity – the maximum capacity of the replay buffer.
sample_batch_size – the batch size for
sample().min_sample_timesteps – the minimum number of timesteps before the replay buffer can sample.
- add(buffer_state: evorl.replay_buffers.replay_buffer.ReplayBufferState, xs: chex.ArrayTree, mask: chex.Array | None = None) evorl.replay_buffers.replay_buffer.ReplayBufferState[source]¶
- can_sample(buffer_state: evorl.replay_buffers.replay_buffer.ReplayBufferState) bool[source]¶
- capacity: int¶
None
- init(spec: chex.ArrayTree) evorl.replay_buffers.replay_buffer.ReplayBufferState[source]¶
- is_full(buffer_state: evorl.replay_buffers.replay_buffer.ReplayBufferState) bool[source]¶
- min_sample_timesteps: int¶
0
- sample(buffer_state: evorl.replay_buffers.replay_buffer.ReplayBufferState, key: chex.ArrayTree) chex.ArrayTree[source]¶
- sample_batch_size: int¶
None
- class evorl.replay_buffers.ReplayBufferState[source]¶
Bases:
evorl.types.PyTreeDataContains data related to a replay buffer.
- Variables:
data – the stored replay buffer data.
current_index – the pointer used for adding data.
buffer_size – the current size of the replay buffer.
- buffer_size: chex.Array¶
‘zeros(…)’
- current_index: chex.Array¶
‘zeros(…)’
- data: chex.ArrayTree¶
None