Source code for evorl.distributed.comm
1from collections.abc import Sequence
2
3import chex
4import jax
5import jax.numpy as jnp
6
7from jax._src.distributed import global_state
8
9
[docs]
10def pmean(x, axis_name: str | None = None):
11 if axis_name is None:
12 return x
13 else:
14 return jax.lax.pmean(x, axis_name)
15
16
[docs]
17def psum(x, axis_name: str | None = None):
18 if axis_name is None:
19 return x
20 else:
21 return jax.lax.psum(x, axis_name)
22
23
[docs]
24def pmin(x, axis_name: str | None = None):
25 if axis_name is None:
26 return x
27 else:
28 return jax.lax.pmin(x, axis_name)
29
30
[docs]
31def pmax(x, axis_name: str | None = None):
32 if axis_name is None:
33 return x
34 else:
35 return jax.lax.pmax(x, axis_name)
36
37
[docs]
38def all_gather(x, axis_name: str | None = None, **kwargs):
39 """All-gather the data across all devices."""
40 if axis_name is None:
41 return x
42 else:
43 return jax.lax.all_gather(x, axis_name, **kwargs)
44
45
[docs]
46def split_key_to_devices(key: chex.PRNGKey, devices: Sequence[jax.Device]):
47 """Split the key to each device."""
48 keys = jax.random.split(key, len(devices))
49 sharding = jax.sharding.NamedSharding(
50 jax.sharding.Mesh(devices, ("devices",)), jax.sharding.PartitionSpec("devices")
51 )
52 return jax.device_put(keys, sharding)
53
54
[docs]
55def is_dist_initialized():
56 """Whether the JAX's distributed setting is initialized."""
57 # Note: global_state is a JAX internal API, which is not stable.
58 return global_state.coordinator_address is not None
59
60
[docs]
61def get_process_id():
62 """Return the node id in multi-node distributed env."""
63 if is_dist_initialized():
64 return global_state.process_id
65 else:
66 return 0
67
68
[docs]
69def get_global_ranks():
70 """Return the global rank for each device.
71
72 Returns:
73 The sharded ranks across devices. Each device has a unique rank.
74 """
75 num_local_devices = jax.local_device_count()
76
77 process_id = get_process_id()
78 ranks = process_id * num_local_devices + jnp.arange(
79 num_local_devices, dtype=jnp.int32
80 )
81 sharding = jax.sharding.NamedSharding(
82 jax.sharding.Mesh(jax.local_devices(), ("devices",)),
83 jax.sharding.PartitionSpec("devices"),
84 )
85 ranks = jax.device_put(ranks, sharding)
86
87 return ranks