evorl.utils.jax_utils¶
Module Contents¶
Functions¶
Disable GPU memory preallocation for XLA. |
|
Enable deterministic mode for JAX. |
|
Check if the array has NaN values. |
|
Helper function that inverts a permutation array. |
|
Detect if a function is wrapped by jit or pmap. |
|
A decorator for |
|
Possible Optimizations for Nvidia GPU. |
|
A decorator for |
|
Shift the array to the right with padding. |
|
Unified Version of |
|
Split the key into multiple keys according to the shape. |
|
Split the key according to the structure of the target pytree. |
|
Scan and return last iteration results. |
|
Scan with mean aggregation. |
|
Slide a window over the fist axis of the array. |
|
Pytree version of |
|
Pytree version of |
|
Deep copy the pytree. |
|
Get the elements of each array in the pytree. |
|
Check if the pytree has NaN values. |
|
Get the last element of each array in the pytree. |
|
Pytree version of |
|
Set part of each array in the pytree. |
|
Pytree version of |
|
Pytree version of |
API¶
- evorl.utils.jax_utils.disable_gpu_preallocation()[source]¶
Disable GPU memory preallocation for XLA.
Call this method at the beginning of your script.
- evorl.utils.jax_utils.enable_deterministic_mode()[source]¶
Enable deterministic mode for JAX.
Call this method at the beginning of your script.
- evorl.utils.jax_utils.invert_permutation(i: jax.Array) jax.Array[source]¶
Helper function that inverts a permutation array.
- evorl.utils.jax_utils.is_jitted(func: collections.abc.Callable)[source]¶
Detect if a function is wrapped by jit or pmap.
- evorl.utils.jax_utils.jit_method(*, static_argnums: int | collections.abc.Sequence[int] | None = None, static_argnames: str | collections.abc.Iterable[str] | None = None, donate_argnums: int | collections.abc.Sequence[int] | None = None, donate_argnames: str | collections.abc.Iterable[str] | None = None, **kwargs)[source]¶
A decorator for
jax.jitwith arguments.- Parameters:
static_argnums – The positional argument indices that are constant across different calls to the function.
- Returns:
A decorator for
jax.jitwith arguments.
- evorl.utils.jax_utils.optimize_gpu_utilization()[source]¶
Possible Optimizations for Nvidia GPU.
This function is not tested.
- evorl.utils.jax_utils.pmap_method(axis_name, *, static_broadcasted_argnums=(), donate_argnums=(), **kwargs)[source]¶
A decorator for
jax.pmapwith arguments.
- evorl.utils.jax_utils.right_shift_with_padding(x: chex.Array, shift: int, fill_value: None | chex.Scalar = None)[source]¶
Shift the array to the right with padding.
- evorl.utils.jax_utils.rng_split(key: chex.PRNGKey, num: int = 2) chex.PRNGKey[source]¶
Unified Version of
jax.random.splitfor both single key and batched keys.
- evorl.utils.jax_utils.rng_split_by_shape(key: chex.PRNGKey, shape: tuple[int]) chex.PRNGKey[source]¶
Split the key into multiple keys according to the shape.
- evorl.utils.jax_utils.rng_split_like_tree(key: chex.PRNGKey, target: chex.ArrayTree, is_leaf=None) chex.ArrayTree[source]¶
Split the key according to the structure of the target pytree.
- evorl.utils.jax_utils.scan_and_last(*args, **kwargs)[source]¶
Scan and return last iteration results.
Usage: same like
jax.lax.scan, but return the last scan iteration results.
- evorl.utils.jax_utils.scan_and_mean(*args, **kwargs)[source]¶
Scan with mean aggregation.
Usage: same like
jax.lax.scan, but the scan results will be averaged.
- evorl.utils.jax_utils.sliding_window(arr, length, stride)[source]¶
Slide a window over the fist axis of the array.
Change shape from [T, …] to [L, W, …], where W = (T - L) // S + 1 is the number of windows.
- evorl.utils.jax_utils.tree_astype(tree: chex.ArrayTree, dtype)[source]¶
Pytree version of
jnp.astype.
- evorl.utils.jax_utils.tree_concat(nest1: chex.ArrayTree, nest2: chex.ArrayTree, axis: int = 0)[source]¶
Pytree version of
jnp.concatenate.
- evorl.utils.jax_utils.tree_deepcopy(tree: chex.ArrayTree) chex.ArrayTree[source]¶
Deep copy the pytree.
Useful for mutable pytree structure like dict. The return also includes a deepcopy of these mutable structures.
- evorl.utils.jax_utils.tree_get(tree: chex.ArrayTree, idx_or_slice)[source]¶
Get the elements of each array in the pytree.
- evorl.utils.jax_utils.tree_has_nan(tree: chex.ArrayTree) chex.ArrayTree[source]¶
Check if the pytree has NaN values.
- evorl.utils.jax_utils.tree_last(tree: chex.ArrayTree)[source]¶
Get the last element of each array in the pytree.
- evorl.utils.jax_utils.tree_ones_like(nest: chex.ArrayTree, dtype=None) chex.ArrayTree[source]¶
Pytree version of
jnp.ones_like.
- evorl.utils.jax_utils.tree_set(src: chex.ArrayTree, target: chex.ArrayTree, idx_or_slice, indices_are_sorted: bool = False, unique_indices: bool = False, mode: str | None = None)[source]¶
Set part of each array in the pytree.
A Pytree version of
src[idx_or_slice]=target.- Parameters:
src – The source pytree.
target – The target pytree.
idx_or_slice – The indices or slices to be set.
indices_are_sorted – Whether the indices are sorted.
unique_indices – Whether the indices are unique.
mode – The mode to set the values.
- Returns:
The updated source pytree.