Source code for evorl.distributed.comm

 1from collections.abc import Sequence
 2
 3import chex
 4import jax
 5import jax.numpy as jnp
 6
 7from jax._src.distributed import global_state
 8
 9
[docs] 10def pmean(x, axis_name: str | None = None): 11 if axis_name is None: 12 return x 13 else: 14 return jax.lax.pmean(x, axis_name)
15 16
[docs] 17def psum(x, axis_name: str | None = None): 18 if axis_name is None: 19 return x 20 else: 21 return jax.lax.psum(x, axis_name)
22 23
[docs] 24def pmin(x, axis_name: str | None = None): 25 if axis_name is None: 26 return x 27 else: 28 return jax.lax.pmin(x, axis_name)
29 30
[docs] 31def pmax(x, axis_name: str | None = None): 32 if axis_name is None: 33 return x 34 else: 35 return jax.lax.pmax(x, axis_name)
36 37
[docs] 38def all_gather(x, axis_name: str | None = None, **kwargs): 39 """All-gather the data across all devices.""" 40 if axis_name is None: 41 return x 42 else: 43 return jax.lax.all_gather(x, axis_name, **kwargs)
44 45
[docs] 46def split_key_to_devices(key: chex.PRNGKey, devices: Sequence[jax.Device]): 47 """Split the key to each device.""" 48 keys = jax.random.split(key, len(devices)) 49 sharding = jax.sharding.NamedSharding( 50 jax.sharding.Mesh(devices, ("devices",)), jax.sharding.PartitionSpec("devices") 51 ) 52 return jax.device_put(keys, sharding)
53 54
[docs] 55def is_dist_initialized(): 56 """Whether the JAX's distributed setting is initialized.""" 57 # Note: global_state is a JAX internal API, which is not stable. 58 return global_state.coordinator_address is not None
59 60
[docs] 61def get_process_id(): 62 """Return the node id in multi-node distributed env.""" 63 if is_dist_initialized(): 64 return global_state.process_id 65 else: 66 return 0
67 68
[docs] 69def get_global_ranks(): 70 """Return the global rank for each device. 71 72 Returns: 73 The sharded ranks across devices. Each device has a unique rank. 74 """ 75 num_local_devices = jax.local_device_count() 76 77 process_id = get_process_id() 78 ranks = process_id * num_local_devices + jnp.arange( 79 num_local_devices, dtype=jnp.int32 80 ) 81 sharding = jax.sharding.NamedSharding( 82 jax.sharding.Mesh(jax.local_devices(), ("devices",)), 83 jax.sharding.PartitionSpec("devices"), 84 ) 85 ranks = jax.device_put(ranks, sharding) 86 87 return ranks