Source code for evorl.distributed.sharding

 1from collections.abc import Callable
 2
 3import chex
 4import jax
 5import jax.tree_util as jtu
 6from jax import shard_map
 7
 8
[docs] 9def tree_device_put(tree: chex.ArrayTree, device_or_sharding): 10 """Pytree version of `jax.device_put`.""" 11 return jtu.tree_map(lambda x: jax.device_put(x, device_or_sharding), tree)
12 13
[docs] 14def shmap_vmap(fn: Callable, mesh, in_specs, out_specs, **kwargs): 15 """Vmap on different gpu.""" 16 17 def shmap_f(*args): 18 return jax.vmap(fn)(*args) 19 20 return shard_map( 21 shmap_f, mesh=mesh, in_specs=in_specs, out_specs=out_specs, **kwargs 22 )
23 24
[docs] 25def shmap_map(fn: Callable, mesh, in_specs, out_specs, **kwargs): 26 """Sequential execution on different gpu. 27 28 Args: 29 fn: function to be executed, only positional arguments are supported 30 sharding: JAX sharding object 31 Returns: 32 A wrapped function. 33 """ 34 35 def g(carry): 36 return fn(*carry) 37 38 def shmap_f(*args): 39 return jax.lax.map(g, args) 40 41 return shard_map( 42 shmap_f, mesh=mesh, in_specs=in_specs, out_specs=out_specs, **kwargs 43 )