1import os
2from collections.abc import Iterable, Sequence, Callable
3from functools import partial
4import math
5import copy
6
7import chex
8import jax
9import jax.numpy as jnp
10import jax.tree_util as jtu
11
12
[docs]
13def disable_gpu_preallocation():
14 """Disable GPU memory preallocation for XLA.
15
16 Call this method at the beginning of your script.
17 """
18 os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
19
20
[docs]
21def optimize_gpu_utilization():
22 """Possible Optimizations for Nvidia GPU.
23
24 This function is not tested.
25 """
26 # xla_flags = os.getenv("XLA_FLAGS", "")
27 # print(f"current XLA_FLAGS: {xla_flags}")
28 # if len(xla_flags) > 0:
29 # xla_flags = xla_flags + " "
30 # os.environ['XLA_FLAGS'] = xla_flags + (
31 # '--xla_gpu_enable_triton_softmax_fusion=true '
32 # '--xla_gpu_triton_gemm_any=True '
33 # # '--xla_gpu_enable_async_collectives=true '
34 # # '--xla_gpu_enable_latency_hiding_scheduler=true '
35 # # '--xla_gpu_enable_highest_priority_async_stream=true '
36 # )
37
38 # used for single-host multi-device computations on Nvidia GPUs
39 os.environ.update(
40 {
41 "NCCL_LL128_BUFFSIZE": "-2",
42 "NCCL_LL_BUFFSIZE": "-2",
43 "NCCL_PROTO": "SIMPLE,LL,LL128",
44 }
45 )
46
47
[docs]
48def enable_deterministic_mode():
49 """Enable deterministic mode for JAX.
50
51 Call this method at the beginning of your script.
52 """
53 xla_flags = os.getenv("XLA_FLAGS", "")
54 # print(f"current XLA_FLAGS: {xla_flags}")
55 if len(xla_flags) > 0:
56 xla_flags = xla_flags + " "
57 os.environ["XLA_FLAGS"] = xla_flags + "--xla_gpu_deterministic_ops=true"
58
59
60# use chex.set_n_cpu_devices(n) instead
61# def set_host_device_count(n):
62# """
63# By default, XLA considers all CPU cores as one device. This utility tells XLA
64# that there are `n` host (CPU) devices available to use. As a consequence, this
65# allows parallel mapping in JAX :func:`jax.pmap` to work in CPU platform.
66
67# .. note:: This utility only takes effect at the beginning of your program.
68# Under the hood, this sets the environment variable
69# `XLA_FLAGS=--xla_force_host_platform_device_count=[num_devices]`, where
70# `[num_device]` is the desired number of CPU devices `n`.
71
72# .. warning:: Our understanding of the side effects of using the
73# `xla_force_host_platform_device_count` flag in XLA is incomplete. If you
74# observe some strange phenomenon when using this utility, please let us
75# know through our issue or forum page. More information is available in this
76# `JAX issue <https://github.com/google/jax/issues/1408>`_.
77
78# :param int n: number of devices to use.
79# """
80# xla_flags = os.getenv("XLA_FLAGS", "")
81# xla_flags = re.sub(
82# r"--xla_force_host_platform_device_count=\S+", "", xla_flags).split()
83# os.environ["XLA_FLAGS"] = " ".join(
84# ["--xla_force_host_platform_device_count={}".format(n)] + xla_flags)
85
86
[docs]
87def tree_zeros_like(nest: chex.ArrayTree, dtype=None) -> chex.ArrayTree:
88 """Pytree version of `jnp.zeros_like`."""
89 return jtu.tree_map(lambda x: jnp.zeros(x.shape, dtype or x.dtype), nest)
90
91
[docs]
92def tree_ones_like(nest: chex.ArrayTree, dtype=None) -> chex.ArrayTree:
93 """Pytree version of `jnp.ones_like`."""
94 return jtu.tree_map(lambda x: jnp.ones(x.shape, dtype or x.dtype), nest)
95
96
[docs]
97def tree_concat(nest1: chex.ArrayTree, nest2: chex.ArrayTree, axis: int = 0):
98 """Pytree version of `jnp.concatenate`."""
99 return jtu.tree_map(lambda x, y: jnp.concatenate([x, y], axis=axis), nest1, nest2)
100
101
[docs]
102def tree_stop_gradient(nest: chex.ArrayTree) -> chex.ArrayTree:
103 """Pytree version of `jax.lax.stop_gradient`."""
104 return jtu.tree_map(jax.lax.stop_gradient, nest)
105
106
[docs]
107def tree_astype(tree: chex.ArrayTree, dtype):
108 """Pytree version of `jnp.astype`."""
109 return jtu.tree_map(lambda x: x.astype(dtype), tree)
110
111
[docs]
112def tree_last(tree: chex.ArrayTree):
113 """Get the last element of each array in the pytree."""
114 return jtu.tree_map(lambda x: x[-1], tree)
115
116
[docs]
117def tree_get(tree: chex.ArrayTree, idx_or_slice):
118 """Get the elements of each array in the pytree."""
119 return jtu.tree_map(lambda x: x[idx_or_slice], tree)
120
121
[docs]
122def tree_set(
123 src: chex.ArrayTree,
124 target: chex.ArrayTree,
125 idx_or_slice,
126 indices_are_sorted: bool = False,
127 unique_indices: bool = False,
128 mode: str | None = None,
129):
130 """Set part of each array in the pytree.
131
132 A Pytree version of `src[idx_or_slice]=target`.
133
134 Args:
135 src: The source pytree.
136 target: The target pytree.
137 idx_or_slice: The indices or slices to be set.
138 indices_are_sorted: Whether the indices are sorted.
139 unique_indices: Whether the indices are unique.
140 mode: The mode to set the values.
141
142 Returns:
143 The updated source pytree.
144 """
145 return jtu.tree_map(
146 lambda x, y: x.at[idx_or_slice].set(
147 y,
148 indices_are_sorted=indices_are_sorted,
149 unique_indices=unique_indices,
150 mode=mode,
151 ),
152 src,
153 target,
154 )
155
156
[docs]
157def scan_and_mean(*args, **kwargs):
158 """Scan with mean aggregation.
159
160 Usage: same like `jax.lax.scan`, but the scan results will be averaged.
161 """
162 last_carry, ys = jax.lax.scan(*args, **kwargs)
163 return last_carry, jtu.tree_map(lambda x: x.mean(axis=0), ys)
164
165
[docs]
166def scan_and_last(*args, **kwargs):
167 """Scan and return last iteration results.
168
169 Usage: same like `jax.lax.scan`, but return the last scan iteration results.
170 """
171 last_carry, ys = jax.lax.scan(*args, **kwargs)
172 return last_carry, jtu.tree_map(lambda x: x[-1] if x.shape[0] > 0 else x, ys)
173
174
[docs]
175def jit_method(
176 *,
177 static_argnums: int | Sequence[int] | None = None,
178 static_argnames: str | Iterable[str] | None = None,
179 donate_argnums: int | Sequence[int] | None = None,
180 donate_argnames: str | Iterable[str] | None = None,
181 **kwargs,
182):
183 """A decorator for `jax.jit` with arguments.
184
185 Args:
186 static_argnums: The positional argument indices that are constant across
187 different calls to the function.
188
189 Returns:
190 A decorator for `jax.jit` with arguments.
191 """
192 return partial(
193 jax.jit,
194 static_argnums=static_argnums,
195 static_argnames=static_argnames,
196 donate_argnums=donate_argnums,
197 donate_argnames=donate_argnames,
198 **kwargs,
199 )
200
201
[docs]
202def pmap_method(
203 axis_name,
204 *,
205 static_broadcasted_argnums=(),
206 donate_argnums=(),
207 **kwargs,
208):
209 """A decorator for `jax.pmap` with arguments."""
210 return partial(
211 jax.pmap,
212 axis_name,
213 static_broadcasted_argnums=static_broadcasted_argnums,
214 donate_argnums=donate_argnums,
215 **kwargs,
216 )
217
218
219def _vmap_rng_split(key: chex.PRNGKey, num: int = 2) -> chex.PRNGKey:
220 """Enhanced version of `jax.random.split` that allows batched keys.
221
222 Args:
223 key: Key or batched keys with shape (B, 2)
224 num: Number of keys to split.
225
226 Returns:
227 Batched keys with shape (num, B, 2)
228 """
229 chex.assert_shape(key, (..., 2))
230
231 rng_split_fn = jax.random.split
232
233 for _ in range(key.ndim - 1):
234 rng_split_fn = jax.vmap(rng_split_fn, in_axes=(0, None), out_axes=1)
235
236 return rng_split_fn(key, num)
237
238
[docs]
239def rng_split(key: chex.PRNGKey, num: int = 2) -> chex.PRNGKey:
240 """Unified Version of `jax.random.split` for both single key and batched keys."""
241 if key.ndim == 1:
242 chex.assert_shape(key, (2,))
243 return jax.random.split(key, num)
244 else:
245 return _vmap_rng_split(key, num)
246
247
[docs]
248def rng_split_by_shape(key: chex.PRNGKey, shape: tuple[int]) -> chex.PRNGKey:
249 """Split the key into multiple keys according to the shape."""
250 chex.assert_shape(key, (2,))
251 keys = jax.random.split(key, math.prod(shape))
252 return jnp.reshape(keys, shape + (2,))
253
254
[docs]
255def rng_split_like_tree(
256 key: chex.PRNGKey, target: chex.ArrayTree, is_leaf=None
257) -> chex.ArrayTree:
258 """Split the key according to the structure of the target pytree."""
259 treedef = jtu.tree_structure(target, is_leaf=is_leaf)
260 keys = jax.random.split(key, treedef.num_leaves)
261 return jtu.tree_unflatten(treedef, keys)
262
263
[docs]
264def is_jitted(func: Callable):
265 """Detect if a function is wrapped by jit or pmap."""
266 return hasattr(func, "lower")
267
268
[docs]
269def has_nan(x: jax.Array) -> bool:
270 """Check if the array has NaN values."""
271 return jnp.isnan(x).any()
272
273
[docs]
274def tree_has_nan(tree: chex.ArrayTree) -> chex.ArrayTree:
275 """Check if the pytree has NaN values."""
276 return jtu.tree_map(has_nan, tree)
277
278
[docs]
279def invert_permutation(i: jax.Array) -> jax.Array:
280 """Helper function that inverts a permutation array."""
281 return jnp.empty_like(i).at[i].set(jnp.arange(i.size, dtype=i.dtype))
282
283
284def _deepcopy(x):
285 if isinstance(x, jax.Array):
286 # we don't copy jax arrays, since they are immutable
287 return x
288 else:
289 return copy.deepcopy(x)
290
291
[docs]
292def tree_deepcopy(tree: chex.ArrayTree) -> chex.ArrayTree:
293 """Deep copy the pytree.
294
295 Useful for mutable pytree structure like dict. The return also includes a deepcopy of these mutable structures.
296 """
297 return jtu.tree_map(_deepcopy, tree)
298
299
[docs]
300def right_shift_with_padding(
301 x: chex.Array, shift: int, fill_value: None | chex.Scalar = None
302):
303 """Shift the array to the right with padding."""
304 shifted_matrix = jnp.roll(x, shift=shift, axis=0)
305
306 if fill_value is not None:
307 padding = jnp.full_like(shifted_matrix[:shift], fill_value)
308 else:
309 padding = jnp.zeros_like(shifted_matrix[:shift])
310
311 shifted_matrix = shifted_matrix.at[:shift].set(padding)
312
313 return shifted_matrix
314
315
[docs]
316def sliding_window(arr, length, stride):
317 """Slide a window over the fist axis of the array.
318
319 Change shape from [T, ...] to [L, W, ...], where W = (T - L) // S + 1 is the number of windows.
320 """
321 starts = jnp.arange(0, arr.shape[0] - length + 1, stride)
322 windows = jax.vmap(
323 lambda start: jax.lax.dynamic_slice_in_dim(
324 arr, start_index=start, slice_size=length, axis=0
325 ),
326 )(starts)
327
328 return windows