evorl.replay_buffers

Package Contents

Classes

AbstractReplayBuffer

A ReplyBuffer Interface.

LAPReplayBuffer

ReplayBuffer with Loss-Adjusted Prioritization from TD7 paper.

PrioritizedReplayBuffer

ReplayBuffer with proportional prioritized sampling (LAP).

PrioritizedReplayBufferState

State for the prioritized replay buffer.

ReplayBuffer

ReplayBuffer with uniform sampling.

ReplayBufferState

Contains data related to a replay buffer.

API

class evorl.replay_buffers.AbstractReplayBuffer[source]

Bases: evorl.types.PyTreeNode

A ReplyBuffer Interface.

abstract add(buffer_state: evorl.replay_buffers.replay_buffer.ReplayBufferState, xs: chex.ArrayTree) evorl.replay_buffers.replay_buffer.ReplayBufferState[source]

Add data to the replay buffer.

Parameters:
  • buffer_state – The current state of the replay buffer.

  • xs – The data to add to the replay buffer.

Returns:

Updated state of the replay buffer.

abstract can_sample(buffer_state: evorl.replay_buffers.replay_buffer.ReplayBufferState) bool[source]

Check if the current replay buffer state can be used to sample.

Parameters:

buffer_state – The current state of the replay buffer.

Returns:

Whether the replay buffer is ready to call sample().

abstract init(sample_spec: chex.ArrayTree) evorl.replay_buffers.replay_buffer.ReplayBufferState[source]

Initialize the state of the replay buffer.

Parameters:

sample_spec – A single sample or sample spec that contains the pytree structure and their dtype and shape.

Returns:

The initial state of the replay buffer.

abstract is_full(buffer_state: evorl.replay_buffers.replay_buffer.ReplayBufferState) bool[source]
abstract sample(buffer_state: evorl.replay_buffers.replay_buffer.ReplayBufferState, key: chex.PRNGKey) chex.ArrayTree[source]

Sample a batch of data from the replay buffer.

Parameters:
  • buffer_state – The current state of the replay buffer.

  • key – JAX PRNGKey.

Returns:

A batch of data sampled from the replay buffer.

class evorl.replay_buffers.LAPReplayBuffer[source]

Bases: evorl.replay_buffers.prioritized_replay_buffer.PrioritizedReplayBuffer

ReplayBuffer with Loss-Adjusted Prioritization from TD7 paper.

LAP is a variation of Prioritized Experience Replay (PER) that uses proportional sampling but explicitly drops Importance Sampling (IS) weights, setting them uniformly to 1.0. It leverages the parent PrioritizedReplayBuffer class for alpha-scaled priority sampling.

Variables:
  • capacity – the maximum capacity of the replay buffer.

  • sample_batch_size – the batch size for sample().

  • min_sample_timesteps – the minimum number of timesteps before sampling.

  • alpha – the exponent determining how much prioritization is used (0 = uniform, 1 = full).

sample(buffer_state: evorl.replay_buffers.prioritized_replay_buffer.PrioritizedReplayBufferState, key: chex.PRNGKey, beta: float | chex.Array = 0.0) tuple[chex.ArrayTree, chex.Array, evorl.replay_buffers.prioritized_replay_buffer.PrioritizedReplayBufferState][source]

Sample a batch proportional to alpha-scaled priorities without IS weights.

Parameters:
  • buffer_state – Current buffer state.

  • key – PRNG key.

  • beta – Ignored mapping variable to match parent interface.

Returns:

A tuple of (batch, weights, updated_buffer_state). The IS weights are uniformly set to 1.0.

class evorl.replay_buffers.PrioritizedReplayBuffer[source]

Bases: evorl.replay_buffers.replay_buffer.ReplayBuffer

ReplayBuffer with proportional prioritized sampling (LAP).

Uses Loss-Adjusted Prioritization from TD7 paper. New samples are assigned max_priority. Priorities are updated based on TD-error after critic training.

Variables:
  • capacity – the maximum capacity of the replay buffer.

  • sample_batch_size – the batch size for sample().

  • min_sample_timesteps – the minimum number of timesteps before the replay buffer can sample.

  • alpha – the exponent determining how much prioritization is used (0 = uniform, 1 = full).

add(buffer_state: evorl.replay_buffers.prioritized_replay_buffer.PrioritizedReplayBufferState, xs: chex.ArrayTree, mask: chex.Array | None = None) evorl.replay_buffers.prioritized_replay_buffer.PrioritizedReplayBufferState[source]
alpha: float

0.6

init(spec: chex.ArrayTree) evorl.replay_buffers.prioritized_replay_buffer.PrioritizedReplayBufferState[source]
reset_max_priority(buffer_state: evorl.replay_buffers.prioritized_replay_buffer.PrioritizedReplayBufferState) evorl.replay_buffers.prioritized_replay_buffer.PrioritizedReplayBufferState[source]

Recompute max_priority from current buffer entries.

sample(buffer_state: evorl.replay_buffers.prioritized_replay_buffer.PrioritizedReplayBufferState, key: chex.PRNGKey, beta: float | chex.Array = 0.4) tuple[chex.ArrayTree, chex.Array, evorl.replay_buffers.prioritized_replay_buffer.PrioritizedReplayBufferState][source]

Sample a batch proportional to priorities with IS weights.

Parameters:
  • buffer_state – Current buffer state.

  • key – PRNG key.

  • beta – Importance sampling exponent.

Returns:

A tuple of (batch, weights, updated_buffer_state). The weights are the computed Importance Sampling (IS) weights.

update_priority(buffer_state: evorl.replay_buffers.prioritized_replay_buffer.PrioritizedReplayBufferState, priority: chex.Array) evorl.replay_buffers.prioritized_replay_buffer.PrioritizedReplayBufferState[source]

Update priorities at the last sampled indices.

Parameters:
  • buffer_state – Current buffer state (must have valid sample_indices).

  • priority – New priority values, shape (sample_batch_size,).

Returns:

Updated buffer state with new priorities.

class evorl.replay_buffers.PrioritizedReplayBufferState[source]

Bases: evorl.replay_buffers.replay_buffer.ReplayBufferState

State for the prioritized replay buffer.

Variables:
  • priority – Priority values for each entry in the buffer.

  • max_priority – Current maximum priority value.

  • sample_indices – Indices of the last sampled batch (for priority updates).

max_priority: chex.Array

‘ones(…)’

priority: chex.Array

‘zeros(…)’

sample_indices: chex.Array

‘zeros(…)’

class evorl.replay_buffers.ReplayBuffer[source]

Bases: evorl.replay_buffers.replay_buffer.AbstractReplayBuffer

ReplayBuffer with uniform sampling.

Data are added and sampled in 1d-like structure.

Variables:
  • capacity – the maximum capacity of the replay buffer.

  • sample_batch_size – the batch size for sample().

  • min_sample_timesteps – the minimum number of timesteps before the replay buffer can sample.

add(buffer_state: evorl.replay_buffers.replay_buffer.ReplayBufferState, xs: chex.ArrayTree, mask: chex.Array | None = None) evorl.replay_buffers.replay_buffer.ReplayBufferState[source]
can_sample(buffer_state: evorl.replay_buffers.replay_buffer.ReplayBufferState) bool[source]
capacity: int

None

init(spec: chex.ArrayTree) evorl.replay_buffers.replay_buffer.ReplayBufferState[source]
is_full(buffer_state: evorl.replay_buffers.replay_buffer.ReplayBufferState) bool[source]
min_sample_timesteps: int

0

sample(buffer_state: evorl.replay_buffers.replay_buffer.ReplayBufferState, key: chex.ArrayTree) chex.ArrayTree[source]
sample_batch_size: int

None

class evorl.replay_buffers.ReplayBufferState[source]

Bases: evorl.types.PyTreeData

Contains data related to a replay buffer.

Variables:
  • data – the stored replay buffer data.

  • current_index – the pointer used for adding data.

  • buffer_size – the current size of the replay buffer.

buffer_size: chex.Array

‘zeros(…)’

current_index: chex.Array

‘zeros(…)’

data: chex.ArrayTree

None