# {py:mod}`evorl.utils.jax_utils` ```{py:module} evorl.utils.jax_utils ``` ```{autodoc2-docstring} evorl.utils.jax_utils :parser: autodoc2_docstrings_parser :allowtitles: ``` ## Module Contents ### Functions ````{list-table} :class: autosummary longtable :align: left * - {py:obj}`disable_gpu_preallocation ` - ```{autodoc2-docstring} evorl.utils.jax_utils.disable_gpu_preallocation :parser: autodoc2_docstrings_parser :summary: ``` * - {py:obj}`enable_deterministic_mode ` - ```{autodoc2-docstring} evorl.utils.jax_utils.enable_deterministic_mode :parser: autodoc2_docstrings_parser :summary: ``` * - {py:obj}`has_nan ` - ```{autodoc2-docstring} evorl.utils.jax_utils.has_nan :parser: autodoc2_docstrings_parser :summary: ``` * - {py:obj}`invert_permutation ` - ```{autodoc2-docstring} evorl.utils.jax_utils.invert_permutation :parser: autodoc2_docstrings_parser :summary: ``` * - {py:obj}`is_jitted ` - ```{autodoc2-docstring} evorl.utils.jax_utils.is_jitted :parser: autodoc2_docstrings_parser :summary: ``` * - {py:obj}`jit_method ` - ```{autodoc2-docstring} evorl.utils.jax_utils.jit_method :parser: autodoc2_docstrings_parser :summary: ``` * - {py:obj}`optimize_gpu_utilization ` - ```{autodoc2-docstring} evorl.utils.jax_utils.optimize_gpu_utilization :parser: autodoc2_docstrings_parser :summary: ``` * - {py:obj}`pmap_method ` - ```{autodoc2-docstring} evorl.utils.jax_utils.pmap_method :parser: autodoc2_docstrings_parser :summary: ``` * - {py:obj}`right_shift_with_padding ` - ```{autodoc2-docstring} evorl.utils.jax_utils.right_shift_with_padding :parser: autodoc2_docstrings_parser :summary: ``` * - {py:obj}`rng_split ` - ```{autodoc2-docstring} evorl.utils.jax_utils.rng_split :parser: autodoc2_docstrings_parser :summary: ``` * - {py:obj}`rng_split_by_shape ` - ```{autodoc2-docstring} evorl.utils.jax_utils.rng_split_by_shape :parser: autodoc2_docstrings_parser :summary: ``` * - {py:obj}`rng_split_like_tree ` - ```{autodoc2-docstring} evorl.utils.jax_utils.rng_split_like_tree :parser: autodoc2_docstrings_parser :summary: ``` * - {py:obj}`scan_and_last ` - ```{autodoc2-docstring} evorl.utils.jax_utils.scan_and_last :parser: autodoc2_docstrings_parser :summary: ``` * - {py:obj}`scan_and_mean ` - ```{autodoc2-docstring} evorl.utils.jax_utils.scan_and_mean :parser: autodoc2_docstrings_parser :summary: ``` * - {py:obj}`sliding_window ` - ```{autodoc2-docstring} evorl.utils.jax_utils.sliding_window :parser: autodoc2_docstrings_parser :summary: ``` * - {py:obj}`tree_astype ` - ```{autodoc2-docstring} evorl.utils.jax_utils.tree_astype :parser: autodoc2_docstrings_parser :summary: ``` * - {py:obj}`tree_concat ` - ```{autodoc2-docstring} evorl.utils.jax_utils.tree_concat :parser: autodoc2_docstrings_parser :summary: ``` * - {py:obj}`tree_deepcopy ` - ```{autodoc2-docstring} evorl.utils.jax_utils.tree_deepcopy :parser: autodoc2_docstrings_parser :summary: ``` * - {py:obj}`tree_get ` - ```{autodoc2-docstring} evorl.utils.jax_utils.tree_get :parser: autodoc2_docstrings_parser :summary: ``` * - {py:obj}`tree_has_nan ` - ```{autodoc2-docstring} evorl.utils.jax_utils.tree_has_nan :parser: autodoc2_docstrings_parser :summary: ``` * - {py:obj}`tree_last ` - ```{autodoc2-docstring} evorl.utils.jax_utils.tree_last :parser: autodoc2_docstrings_parser :summary: ``` * - {py:obj}`tree_ones_like ` - ```{autodoc2-docstring} evorl.utils.jax_utils.tree_ones_like :parser: autodoc2_docstrings_parser :summary: ``` * - {py:obj}`tree_set ` - ```{autodoc2-docstring} evorl.utils.jax_utils.tree_set :parser: autodoc2_docstrings_parser :summary: ``` * - {py:obj}`tree_stop_gradient ` - ```{autodoc2-docstring} evorl.utils.jax_utils.tree_stop_gradient :parser: autodoc2_docstrings_parser :summary: ``` * - {py:obj}`tree_zeros_like ` - ```{autodoc2-docstring} evorl.utils.jax_utils.tree_zeros_like :parser: autodoc2_docstrings_parser :summary: ``` ```` ### API ````{py:function} disable_gpu_preallocation() :canonical: evorl.utils.jax_utils.disable_gpu_preallocation ```{autodoc2-docstring} evorl.utils.jax_utils.disable_gpu_preallocation :parser: autodoc2_docstrings_parser ``` ```` ````{py:function} enable_deterministic_mode() :canonical: evorl.utils.jax_utils.enable_deterministic_mode ```{autodoc2-docstring} evorl.utils.jax_utils.enable_deterministic_mode :parser: autodoc2_docstrings_parser ``` ```` ````{py:function} has_nan(x: jax.Array) -> bool :canonical: evorl.utils.jax_utils.has_nan ```{autodoc2-docstring} evorl.utils.jax_utils.has_nan :parser: autodoc2_docstrings_parser ``` ```` ````{py:function} invert_permutation(i: jax.Array) -> jax.Array :canonical: evorl.utils.jax_utils.invert_permutation ```{autodoc2-docstring} evorl.utils.jax_utils.invert_permutation :parser: autodoc2_docstrings_parser ``` ```` ````{py:function} is_jitted(func: collections.abc.Callable) :canonical: evorl.utils.jax_utils.is_jitted ```{autodoc2-docstring} evorl.utils.jax_utils.is_jitted :parser: autodoc2_docstrings_parser ``` ```` ````{py:function} jit_method(*, static_argnums: int | collections.abc.Sequence[int] | None = None, static_argnames: str | collections.abc.Iterable[str] | None = None, donate_argnums: int | collections.abc.Sequence[int] | None = None, donate_argnames: str | collections.abc.Iterable[str] | None = None, **kwargs) :canonical: evorl.utils.jax_utils.jit_method ```{autodoc2-docstring} evorl.utils.jax_utils.jit_method :parser: autodoc2_docstrings_parser ``` ```` ````{py:function} optimize_gpu_utilization() :canonical: evorl.utils.jax_utils.optimize_gpu_utilization ```{autodoc2-docstring} evorl.utils.jax_utils.optimize_gpu_utilization :parser: autodoc2_docstrings_parser ``` ```` ````{py:function} pmap_method(axis_name, *, static_broadcasted_argnums=(), donate_argnums=(), **kwargs) :canonical: evorl.utils.jax_utils.pmap_method ```{autodoc2-docstring} evorl.utils.jax_utils.pmap_method :parser: autodoc2_docstrings_parser ``` ```` ````{py:function} right_shift_with_padding(x: chex.Array, shift: int, fill_value: None | chex.Scalar = None) :canonical: evorl.utils.jax_utils.right_shift_with_padding ```{autodoc2-docstring} evorl.utils.jax_utils.right_shift_with_padding :parser: autodoc2_docstrings_parser ``` ```` ````{py:function} rng_split(key: chex.PRNGKey, num: int = 2) -> chex.PRNGKey :canonical: evorl.utils.jax_utils.rng_split ```{autodoc2-docstring} evorl.utils.jax_utils.rng_split :parser: autodoc2_docstrings_parser ``` ```` ````{py:function} rng_split_by_shape(key: chex.PRNGKey, shape: tuple[int]) -> chex.PRNGKey :canonical: evorl.utils.jax_utils.rng_split_by_shape ```{autodoc2-docstring} evorl.utils.jax_utils.rng_split_by_shape :parser: autodoc2_docstrings_parser ``` ```` ````{py:function} rng_split_like_tree(key: chex.PRNGKey, target: chex.ArrayTree, is_leaf=None) -> chex.ArrayTree :canonical: evorl.utils.jax_utils.rng_split_like_tree ```{autodoc2-docstring} evorl.utils.jax_utils.rng_split_like_tree :parser: autodoc2_docstrings_parser ``` ```` ````{py:function} scan_and_last(*args, **kwargs) :canonical: evorl.utils.jax_utils.scan_and_last ```{autodoc2-docstring} evorl.utils.jax_utils.scan_and_last :parser: autodoc2_docstrings_parser ``` ```` ````{py:function} scan_and_mean(*args, **kwargs) :canonical: evorl.utils.jax_utils.scan_and_mean ```{autodoc2-docstring} evorl.utils.jax_utils.scan_and_mean :parser: autodoc2_docstrings_parser ``` ```` ````{py:function} sliding_window(arr, length, stride) :canonical: evorl.utils.jax_utils.sliding_window ```{autodoc2-docstring} evorl.utils.jax_utils.sliding_window :parser: autodoc2_docstrings_parser ``` ```` ````{py:function} tree_astype(tree: chex.ArrayTree, dtype) :canonical: evorl.utils.jax_utils.tree_astype ```{autodoc2-docstring} evorl.utils.jax_utils.tree_astype :parser: autodoc2_docstrings_parser ``` ```` ````{py:function} tree_concat(nest1: chex.ArrayTree, nest2: chex.ArrayTree, axis: int = 0) :canonical: evorl.utils.jax_utils.tree_concat ```{autodoc2-docstring} evorl.utils.jax_utils.tree_concat :parser: autodoc2_docstrings_parser ``` ```` ````{py:function} tree_deepcopy(tree: chex.ArrayTree) -> chex.ArrayTree :canonical: evorl.utils.jax_utils.tree_deepcopy ```{autodoc2-docstring} evorl.utils.jax_utils.tree_deepcopy :parser: autodoc2_docstrings_parser ``` ```` ````{py:function} tree_get(tree: chex.ArrayTree, idx_or_slice) :canonical: evorl.utils.jax_utils.tree_get ```{autodoc2-docstring} evorl.utils.jax_utils.tree_get :parser: autodoc2_docstrings_parser ``` ```` ````{py:function} tree_has_nan(tree: chex.ArrayTree) -> chex.ArrayTree :canonical: evorl.utils.jax_utils.tree_has_nan ```{autodoc2-docstring} evorl.utils.jax_utils.tree_has_nan :parser: autodoc2_docstrings_parser ``` ```` ````{py:function} tree_last(tree: chex.ArrayTree) :canonical: evorl.utils.jax_utils.tree_last ```{autodoc2-docstring} evorl.utils.jax_utils.tree_last :parser: autodoc2_docstrings_parser ``` ```` ````{py:function} tree_ones_like(nest: chex.ArrayTree, dtype=None) -> chex.ArrayTree :canonical: evorl.utils.jax_utils.tree_ones_like ```{autodoc2-docstring} evorl.utils.jax_utils.tree_ones_like :parser: autodoc2_docstrings_parser ``` ```` ````{py:function} tree_set(src: chex.ArrayTree, target: chex.ArrayTree, idx_or_slice, indices_are_sorted: bool = False, unique_indices: bool = False, mode: str | None = None) :canonical: evorl.utils.jax_utils.tree_set ```{autodoc2-docstring} evorl.utils.jax_utils.tree_set :parser: autodoc2_docstrings_parser ``` ```` ````{py:function} tree_stop_gradient(nest: chex.ArrayTree) -> chex.ArrayTree :canonical: evorl.utils.jax_utils.tree_stop_gradient ```{autodoc2-docstring} evorl.utils.jax_utils.tree_stop_gradient :parser: autodoc2_docstrings_parser ``` ```` ````{py:function} tree_zeros_like(nest: chex.ArrayTree, dtype=None) -> chex.ArrayTree :canonical: evorl.utils.jax_utils.tree_zeros_like ```{autodoc2-docstring} evorl.utils.jax_utils.tree_zeros_like :parser: autodoc2_docstrings_parser ``` ````