Source code for evorl.distributed.sharding
1from collections.abc import Callable
2
3import chex
4import jax
5import jax.tree_util as jtu
6from jax import shard_map
7
8
[docs]
9def tree_device_put(tree: chex.ArrayTree, device_or_sharding):
10 """Pytree version of `jax.device_put`."""
11 return jtu.tree_map(lambda x: jax.device_put(x, device_or_sharding), tree)
12
13
[docs]
14def shmap_vmap(fn: Callable, mesh, in_specs, out_specs, **kwargs):
15 """Vmap on different gpu."""
16
17 def shmap_f(*args):
18 return jax.vmap(fn)(*args)
19
20 return shard_map(
21 shmap_f, mesh=mesh, in_specs=in_specs, out_specs=out_specs, **kwargs
22 )
23
24
[docs]
25def shmap_map(fn: Callable, mesh, in_specs, out_specs, **kwargs):
26 """Sequential execution on different gpu.
27
28 Args:
29 fn: function to be executed, only positional arguments are supported
30 sharding: JAX sharding object
31 Returns:
32 A wrapped function.
33 """
34
35 def g(carry):
36 return fn(*carry)
37
38 def shmap_f(*args):
39 return jax.lax.map(g, args)
40
41 return shard_map(
42 shmap_f, mesh=mesh, in_specs=in_specs, out_specs=out_specs, **kwargs
43 )