evorl.utils.jax_utils

Module Contents

Functions

disable_gpu_preallocation

Disable GPU memory preallocation for XLA.

enable_deterministic_mode

Enable deterministic mode for JAX.

has_nan

Check if the array has NaN values.

invert_permutation

Helper function that inverts a permutation array.

is_jitted

Detect if a function is wrapped by jit or pmap.

jit_method

A decorator for jax.jit with arguments.

optimize_gpu_utilization

Possible Optimizations for Nvidia GPU.

pmap_method

A decorator for jax.pmap with arguments.

right_shift_with_padding

Shift the array to the right with padding.

rng_split

Unified Version of jax.random.split for both single key and batched keys.

rng_split_by_shape

Split the key into multiple keys according to the shape.

rng_split_like_tree

Split the key according to the structure of the target pytree.

scan_and_last

Scan and return last iteration results.

scan_and_mean

Scan with mean aggregation.

sliding_window

Slide a window over the fist axis of the array.

tree_astype

Pytree version of jnp.astype.

tree_concat

Pytree version of jnp.concatenate.

tree_deepcopy

Deep copy the pytree.

tree_get

Get the elements of each array in the pytree.

tree_has_nan

Check if the pytree has NaN values.

tree_last

Get the last element of each array in the pytree.

tree_ones_like

Pytree version of jnp.ones_like.

tree_set

Set part of each array in the pytree.

tree_stop_gradient

Pytree version of jax.lax.stop_gradient.

tree_zeros_like

Pytree version of jnp.zeros_like.

API

evorl.utils.jax_utils.disable_gpu_preallocation()[source]

Disable GPU memory preallocation for XLA.

Call this method at the beginning of your script.

evorl.utils.jax_utils.enable_deterministic_mode()[source]

Enable deterministic mode for JAX.

Call this method at the beginning of your script.

evorl.utils.jax_utils.has_nan(x: jax.Array) bool[source]

Check if the array has NaN values.

evorl.utils.jax_utils.invert_permutation(i: jax.Array) jax.Array[source]

Helper function that inverts a permutation array.

evorl.utils.jax_utils.is_jitted(func: collections.abc.Callable)[source]

Detect if a function is wrapped by jit or pmap.

evorl.utils.jax_utils.jit_method(*, static_argnums: int | collections.abc.Sequence[int] | None = None, static_argnames: str | collections.abc.Iterable[str] | None = None, donate_argnums: int | collections.abc.Sequence[int] | None = None, donate_argnames: str | collections.abc.Iterable[str] | None = None, **kwargs)[source]

A decorator for jax.jit with arguments.

Parameters:

static_argnums – The positional argument indices that are constant across different calls to the function.

Returns:

A decorator for jax.jit with arguments.

evorl.utils.jax_utils.optimize_gpu_utilization()[source]

Possible Optimizations for Nvidia GPU.

This function is not tested.

evorl.utils.jax_utils.pmap_method(axis_name, *, static_broadcasted_argnums=(), donate_argnums=(), **kwargs)[source]

A decorator for jax.pmap with arguments.

evorl.utils.jax_utils.right_shift_with_padding(x: chex.Array, shift: int, fill_value: None | chex.Scalar = None)[source]

Shift the array to the right with padding.

evorl.utils.jax_utils.rng_split(key: chex.PRNGKey, num: int = 2) chex.PRNGKey[source]

Unified Version of jax.random.split for both single key and batched keys.

evorl.utils.jax_utils.rng_split_by_shape(key: chex.PRNGKey, shape: tuple[int]) chex.PRNGKey[source]

Split the key into multiple keys according to the shape.

evorl.utils.jax_utils.rng_split_like_tree(key: chex.PRNGKey, target: chex.ArrayTree, is_leaf=None) chex.ArrayTree[source]

Split the key according to the structure of the target pytree.

evorl.utils.jax_utils.scan_and_last(*args, **kwargs)[source]

Scan and return last iteration results.

Usage: same like jax.lax.scan, but return the last scan iteration results.

evorl.utils.jax_utils.scan_and_mean(*args, **kwargs)[source]

Scan with mean aggregation.

Usage: same like jax.lax.scan, but the scan results will be averaged.

evorl.utils.jax_utils.sliding_window(arr, length, stride)[source]

Slide a window over the fist axis of the array.

Change shape from [T, …] to [L, W, …], where W = (T - L) // S + 1 is the number of windows.

evorl.utils.jax_utils.tree_astype(tree: chex.ArrayTree, dtype)[source]

Pytree version of jnp.astype.

evorl.utils.jax_utils.tree_concat(nest1: chex.ArrayTree, nest2: chex.ArrayTree, axis: int = 0)[source]

Pytree version of jnp.concatenate.

evorl.utils.jax_utils.tree_deepcopy(tree: chex.ArrayTree) chex.ArrayTree[source]

Deep copy the pytree.

Useful for mutable pytree structure like dict. The return also includes a deepcopy of these mutable structures.

evorl.utils.jax_utils.tree_get(tree: chex.ArrayTree, idx_or_slice)[source]

Get the elements of each array in the pytree.

evorl.utils.jax_utils.tree_has_nan(tree: chex.ArrayTree) chex.ArrayTree[source]

Check if the pytree has NaN values.

evorl.utils.jax_utils.tree_last(tree: chex.ArrayTree)[source]

Get the last element of each array in the pytree.

evorl.utils.jax_utils.tree_ones_like(nest: chex.ArrayTree, dtype=None) chex.ArrayTree[source]

Pytree version of jnp.ones_like.

evorl.utils.jax_utils.tree_set(src: chex.ArrayTree, target: chex.ArrayTree, idx_or_slice, indices_are_sorted: bool = False, unique_indices: bool = False, mode: str | None = None)[source]

Set part of each array in the pytree.

A Pytree version of src[idx_or_slice]=target.

Parameters:
  • src – The source pytree.

  • target – The target pytree.

  • idx_or_slice – The indices or slices to be set.

  • indices_are_sorted – Whether the indices are sorted.

  • unique_indices – Whether the indices are unique.

  • mode – The mode to set the values.

Returns:

The updated source pytree.

evorl.utils.jax_utils.tree_stop_gradient(nest: chex.ArrayTree) chex.ArrayTree[source]

Pytree version of jax.lax.stop_gradient.

evorl.utils.jax_utils.tree_zeros_like(nest: chex.ArrayTree, dtype=None) chex.ArrayTree[source]

Pytree version of jnp.zeros_like.