Source code for evorl.utils.jax_utils

  1import os
  2from collections.abc import Iterable, Sequence, Callable
  3from functools import partial
  4import math
  5import copy
  6
  7import chex
  8import jax
  9import jax.numpy as jnp
 10import jax.tree_util as jtu
 11
 12
[docs] 13def disable_gpu_preallocation(): 14 """Disable GPU memory preallocation for XLA. 15 16 Call this method at the beginning of your script. 17 """ 18 os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
19 20
[docs] 21def optimize_gpu_utilization(): 22 """Possible Optimizations for Nvidia GPU. 23 24 This function is not tested. 25 """ 26 # xla_flags = os.getenv("XLA_FLAGS", "") 27 # print(f"current XLA_FLAGS: {xla_flags}") 28 # if len(xla_flags) > 0: 29 # xla_flags = xla_flags + " " 30 # os.environ['XLA_FLAGS'] = xla_flags + ( 31 # '--xla_gpu_enable_triton_softmax_fusion=true ' 32 # '--xla_gpu_triton_gemm_any=True ' 33 # # '--xla_gpu_enable_async_collectives=true ' 34 # # '--xla_gpu_enable_latency_hiding_scheduler=true ' 35 # # '--xla_gpu_enable_highest_priority_async_stream=true ' 36 # ) 37 38 # used for single-host multi-device computations on Nvidia GPUs 39 os.environ.update( 40 { 41 "NCCL_LL128_BUFFSIZE": "-2", 42 "NCCL_LL_BUFFSIZE": "-2", 43 "NCCL_PROTO": "SIMPLE,LL,LL128", 44 } 45 )
46 47
[docs] 48def enable_deterministic_mode(): 49 """Enable deterministic mode for JAX. 50 51 Call this method at the beginning of your script. 52 """ 53 xla_flags = os.getenv("XLA_FLAGS", "") 54 # print(f"current XLA_FLAGS: {xla_flags}") 55 if len(xla_flags) > 0: 56 xla_flags = xla_flags + " " 57 os.environ["XLA_FLAGS"] = xla_flags + "--xla_gpu_deterministic_ops=true"
58 59 60# use chex.set_n_cpu_devices(n) instead 61# def set_host_device_count(n): 62# """ 63# By default, XLA considers all CPU cores as one device. This utility tells XLA 64# that there are `n` host (CPU) devices available to use. As a consequence, this 65# allows parallel mapping in JAX :func:`jax.pmap` to work in CPU platform. 66 67# .. note:: This utility only takes effect at the beginning of your program. 68# Under the hood, this sets the environment variable 69# `XLA_FLAGS=--xla_force_host_platform_device_count=[num_devices]`, where 70# `[num_device]` is the desired number of CPU devices `n`. 71 72# .. warning:: Our understanding of the side effects of using the 73# `xla_force_host_platform_device_count` flag in XLA is incomplete. If you 74# observe some strange phenomenon when using this utility, please let us 75# know through our issue or forum page. More information is available in this 76# `JAX issue <https://github.com/google/jax/issues/1408>`_. 77 78# :param int n: number of devices to use. 79# """ 80# xla_flags = os.getenv("XLA_FLAGS", "") 81# xla_flags = re.sub( 82# r"--xla_force_host_platform_device_count=\S+", "", xla_flags).split() 83# os.environ["XLA_FLAGS"] = " ".join( 84# ["--xla_force_host_platform_device_count={}".format(n)] + xla_flags) 85 86
[docs] 87def tree_zeros_like(nest: chex.ArrayTree, dtype=None) -> chex.ArrayTree: 88 """Pytree version of `jnp.zeros_like`.""" 89 return jtu.tree_map(lambda x: jnp.zeros(x.shape, dtype or x.dtype), nest)
90 91
[docs] 92def tree_ones_like(nest: chex.ArrayTree, dtype=None) -> chex.ArrayTree: 93 """Pytree version of `jnp.ones_like`.""" 94 return jtu.tree_map(lambda x: jnp.ones(x.shape, dtype or x.dtype), nest)
95 96
[docs] 97def tree_concat(nest1: chex.ArrayTree, nest2: chex.ArrayTree, axis: int = 0): 98 """Pytree version of `jnp.concatenate`.""" 99 return jtu.tree_map(lambda x, y: jnp.concatenate([x, y], axis=axis), nest1, nest2)
100 101
[docs] 102def tree_stop_gradient(nest: chex.ArrayTree) -> chex.ArrayTree: 103 """Pytree version of `jax.lax.stop_gradient`.""" 104 return jtu.tree_map(jax.lax.stop_gradient, nest)
105 106
[docs] 107def tree_astype(tree: chex.ArrayTree, dtype): 108 """Pytree version of `jnp.astype`.""" 109 return jtu.tree_map(lambda x: x.astype(dtype), tree)
110 111
[docs] 112def tree_last(tree: chex.ArrayTree): 113 """Get the last element of each array in the pytree.""" 114 return jtu.tree_map(lambda x: x[-1], tree)
115 116
[docs] 117def tree_get(tree: chex.ArrayTree, idx_or_slice): 118 """Get the elements of each array in the pytree.""" 119 return jtu.tree_map(lambda x: x[idx_or_slice], tree)
120 121
[docs] 122def tree_set( 123 src: chex.ArrayTree, 124 target: chex.ArrayTree, 125 idx_or_slice, 126 indices_are_sorted: bool = False, 127 unique_indices: bool = False, 128 mode: str | None = None, 129): 130 """Set part of each array in the pytree. 131 132 A Pytree version of `src[idx_or_slice]=target`. 133 134 Args: 135 src: The source pytree. 136 target: The target pytree. 137 idx_or_slice: The indices or slices to be set. 138 indices_are_sorted: Whether the indices are sorted. 139 unique_indices: Whether the indices are unique. 140 mode: The mode to set the values. 141 142 Returns: 143 The updated source pytree. 144 """ 145 return jtu.tree_map( 146 lambda x, y: x.at[idx_or_slice].set( 147 y, 148 indices_are_sorted=indices_are_sorted, 149 unique_indices=unique_indices, 150 mode=mode, 151 ), 152 src, 153 target, 154 )
155 156
[docs] 157def scan_and_mean(*args, **kwargs): 158 """Scan with mean aggregation. 159 160 Usage: same like `jax.lax.scan`, but the scan results will be averaged. 161 """ 162 last_carry, ys = jax.lax.scan(*args, **kwargs) 163 return last_carry, jtu.tree_map(lambda x: x.mean(axis=0), ys)
164 165
[docs] 166def scan_and_last(*args, **kwargs): 167 """Scan and return last iteration results. 168 169 Usage: same like `jax.lax.scan`, but return the last scan iteration results. 170 """ 171 last_carry, ys = jax.lax.scan(*args, **kwargs) 172 return last_carry, jtu.tree_map(lambda x: x[-1] if x.shape[0] > 0 else x, ys)
173 174
[docs] 175def jit_method( 176 *, 177 static_argnums: int | Sequence[int] | None = None, 178 static_argnames: str | Iterable[str] | None = None, 179 donate_argnums: int | Sequence[int] | None = None, 180 donate_argnames: str | Iterable[str] | None = None, 181 **kwargs, 182): 183 """A decorator for `jax.jit` with arguments. 184 185 Args: 186 static_argnums: The positional argument indices that are constant across 187 different calls to the function. 188 189 Returns: 190 A decorator for `jax.jit` with arguments. 191 """ 192 return partial( 193 jax.jit, 194 static_argnums=static_argnums, 195 static_argnames=static_argnames, 196 donate_argnums=donate_argnums, 197 donate_argnames=donate_argnames, 198 **kwargs, 199 )
200 201
[docs] 202def pmap_method( 203 axis_name, 204 *, 205 static_broadcasted_argnums=(), 206 donate_argnums=(), 207 **kwargs, 208): 209 """A decorator for `jax.pmap` with arguments.""" 210 return partial( 211 jax.pmap, 212 axis_name, 213 static_broadcasted_argnums=static_broadcasted_argnums, 214 donate_argnums=donate_argnums, 215 **kwargs, 216 )
217 218 219def _vmap_rng_split(key: chex.PRNGKey, num: int = 2) -> chex.PRNGKey: 220 """Enhanced version of `jax.random.split` that allows batched keys. 221 222 Args: 223 key: Key or batched keys with shape (B, 2) 224 num: Number of keys to split. 225 226 Returns: 227 Batched keys with shape (num, B, 2) 228 """ 229 chex.assert_shape(key, (..., 2)) 230 231 rng_split_fn = jax.random.split 232 233 for _ in range(key.ndim - 1): 234 rng_split_fn = jax.vmap(rng_split_fn, in_axes=(0, None), out_axes=1) 235 236 return rng_split_fn(key, num) 237 238
[docs] 239def rng_split(key: chex.PRNGKey, num: int = 2) -> chex.PRNGKey: 240 """Unified Version of `jax.random.split` for both single key and batched keys.""" 241 if key.ndim == 1: 242 chex.assert_shape(key, (2,)) 243 return jax.random.split(key, num) 244 else: 245 return _vmap_rng_split(key, num)
246 247
[docs] 248def rng_split_by_shape(key: chex.PRNGKey, shape: tuple[int]) -> chex.PRNGKey: 249 """Split the key into multiple keys according to the shape.""" 250 chex.assert_shape(key, (2,)) 251 keys = jax.random.split(key, math.prod(shape)) 252 return jnp.reshape(keys, shape + (2,))
253 254
[docs] 255def rng_split_like_tree( 256 key: chex.PRNGKey, target: chex.ArrayTree, is_leaf=None 257) -> chex.ArrayTree: 258 """Split the key according to the structure of the target pytree.""" 259 treedef = jtu.tree_structure(target, is_leaf=is_leaf) 260 keys = jax.random.split(key, treedef.num_leaves) 261 return jtu.tree_unflatten(treedef, keys)
262 263
[docs] 264def is_jitted(func: Callable): 265 """Detect if a function is wrapped by jit or pmap.""" 266 return hasattr(func, "lower")
267 268
[docs] 269def has_nan(x: jax.Array) -> bool: 270 """Check if the array has NaN values.""" 271 return jnp.isnan(x).any()
272 273
[docs] 274def tree_has_nan(tree: chex.ArrayTree) -> chex.ArrayTree: 275 """Check if the pytree has NaN values.""" 276 return jtu.tree_map(has_nan, tree)
277 278
[docs] 279def invert_permutation(i: jax.Array) -> jax.Array: 280 """Helper function that inverts a permutation array.""" 281 return jnp.empty_like(i).at[i].set(jnp.arange(i.size, dtype=i.dtype))
282 283 284def _deepcopy(x): 285 if isinstance(x, jax.Array): 286 # we don't copy jax arrays, since they are immutable 287 return x 288 else: 289 return copy.deepcopy(x) 290 291
[docs] 292def tree_deepcopy(tree: chex.ArrayTree) -> chex.ArrayTree: 293 """Deep copy the pytree. 294 295 Useful for mutable pytree structure like dict. The return also includes a deepcopy of these mutable structures. 296 """ 297 return jtu.tree_map(_deepcopy, tree)
298 299
[docs] 300def right_shift_with_padding( 301 x: chex.Array, shift: int, fill_value: None | chex.Scalar = None 302): 303 """Shift the array to the right with padding.""" 304 shifted_matrix = jnp.roll(x, shift=shift, axis=0) 305 306 if fill_value is not None: 307 padding = jnp.full_like(shifted_matrix[:shift], fill_value) 308 else: 309 padding = jnp.zeros_like(shifted_matrix[:shift]) 310 311 shifted_matrix = shifted_matrix.at[:shift].set(padding) 312 313 return shifted_matrix
314 315
[docs] 316def sliding_window(arr, length, stride): 317 """Slide a window over the fist axis of the array. 318 319 Change shape from [T, ...] to [L, W, ...], where W = (T - L) // S + 1 is the number of windows. 320 """ 321 starts = jnp.arange(0, arr.shape[0] - length + 1, stride) 322 windows = jax.vmap( 323 lambda start: jax.lax.dynamic_slice_in_dim( 324 arr, start_index=start, slice_size=length, axis=0 325 ), 326 )(starts) 327 328 return windows