evorl.utils.rl_toolkits

Module Contents

Functions

approximate_kl

Approximate KL divergence by K3 estimator (no bias, low variance).

average_episode_discount_return

Estimate the average episode return from a segmented trajectory.

compute_discount_return

Compute discount return from the episodic trajectory.

compute_episode_length

Compute the length of the episode.

compute_gae

Calculates the Generalized Advantage Estimation (GAE).

compute_gae_with_horizon

flatten_pop_rollout_episode

Flatten the trajectory from [#pop, T, B, …] to [T, #pop*B, …].

flatten_rollout_trajectory

Flatten the trajectory from [T, B, …] to [T*B, …].

fold_multi_steps

Fold multiple steps into a single step function.

shuffle_sample_batch

Shuffle the sample batch.

soft_target_update

Perform soft update on target network.

API

evorl.utils.rl_toolkits.approximate_kl(logratio: jax.Array, mode='k3', axis=-1) jax.Array[source]

Approximate KL divergence by K3 estimator (no bias, low variance).

See http://joschu.net/blog/kl-approx.html

Parameters:

logratio – ratio of p(x)/q(x), where x are sampled from q(x)

Returns:

Approximated KL(q||p) (Forward KL)

evorl.utils.rl_toolkits.average_episode_discount_return(episode_discount_return: jax.Array, dones: jax.Array, dp_axis_name: str | None = None) jax.Array[source]

Estimate the average episode return from a segmented trajectory.

This method does not require the trajectory data from a complete episode.

evorl.utils.rl_toolkits.compute_discount_return(rewards: chex.Array, dones: chex.Array, discount: float = 1.0) chex.Array[source]

Compute discount return from the episodic trajectory.

Parameters:
  • rewards – Rewards collected from an episodic trajectory.

  • dones – Dones collected from an episodic trajectory.

  • discount – Discount factor.

Returns:

Discounted return.

evorl.utils.rl_toolkits.compute_episode_length(dones: chex.Array) chex.Array[source]

Compute the length of the episode.

Parameters:

dones – Dones collected from an episodic trajectory.

Returns:

Length of the episode.

evorl.utils.rl_toolkits.compute_gae(rewards: jax.Array, values: jax.Array, dones: jax.Array, terminations: jax.Array, gae_lambda: float = 1.0, discount: float = 0.99) tuple[jax.Array, jax.Array][source]

Calculates the Generalized Advantage Estimation (GAE).

Parameters:
  • rewards – A float32 tensor of shape [T, B] containing rewards generated by following the behaviour policy.

  • values – A float32 tensor of shape [T+1, B] with the value function estimates wrt. the target policy. values[T] is the bootstrap_value

  • dones – A float32 tensor of shape [T, B] with done signal.

  • terminations – A float32 tensor of shape [T, B] with termination signal.

  • gae_lambda – Mix between 1-step (gae_lambda=0) and n-step (gae_lambda=1).

  • discount – TD discount.

Returns:

  • Lambda returns with shape [T, B], can be used as targets to train a baseline (V(x_t) - vs_t)^2.

  • Advantages with shape [T, B].

Return type:

Tuple

evorl.utils.rl_toolkits.compute_gae_with_horizon(rewards: jax.Array, values: jax.Array, dones: jax.Array, terminations: jax.Array, gae_horizon: int = 0, gae_lambda: float = 1.0, discount: float = 0.99) tuple[jax.Array, jax.Array][source]
evorl.utils.rl_toolkits.flatten_pop_rollout_episode(trajectory: evorl.sample_batch.SampleBatch)[source]

Flatten the trajectory from [#pop, T, B, …] to [T, #pop*B, …].

evorl.utils.rl_toolkits.flatten_rollout_trajectory(trajectory: evorl.sample_batch.SampleBatch) evorl.sample_batch.SampleBatch[source]

Flatten the trajectory from [T, B, …] to [T*B, …].

evorl.utils.rl_toolkits.fold_multi_steps(step_fn, num_steps)[source]

Fold multiple steps into a single step function.

evorl.utils.rl_toolkits.shuffle_sample_batch(sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey)[source]

Shuffle the sample batch.

evorl.utils.rl_toolkits.soft_target_update(target_params, source_params, tau: float)[source]

Perform soft update on target network.

Parameters:
  • target_params – Target network parameters.

  • source_params – Source network parameters.

  • tau – Interpolation factor.

Returns:

Updated target network parameters.