evorl.distributed.comm¶
Module Contents¶
Functions¶
All-gather the data across all devices. |
|
Return the global rank for each device. |
|
Return the node id in multi-node distributed env. |
|
Whether the JAX’s distributed setting is initialized. |
|
Split the key to each device. |
API¶
- evorl.distributed.comm.all_gather(x, axis_name: str | None = None, **kwargs)[source]¶
All-gather the data across all devices.
- evorl.distributed.comm.get_global_ranks()[source]¶
Return the global rank for each device.
- Returns:
The sharded ranks across devices. Each device has a unique rank.