evorl.distributed.comm

Module Contents

Functions

all_gather

All-gather the data across all devices.

get_global_ranks

Return the global rank for each device.

get_process_id

Return the node id in multi-node distributed env.

is_dist_initialized

Whether the JAX’s distributed setting is initialized.

pmax

pmean

pmin

psum

split_key_to_devices

Split the key to each device.

API

evorl.distributed.comm.all_gather(x, axis_name: str | None = None, **kwargs)[source]

All-gather the data across all devices.

evorl.distributed.comm.get_global_ranks()[source]

Return the global rank for each device.

Returns:

The sharded ranks across devices. Each device has a unique rank.

evorl.distributed.comm.get_process_id()[source]

Return the node id in multi-node distributed env.

evorl.distributed.comm.is_dist_initialized()[source]

Whether the JAX’s distributed setting is initialized.

evorl.distributed.comm.pmax(x, axis_name: str | None = None)[source]
evorl.distributed.comm.pmean(x, axis_name: str | None = None)[source]
evorl.distributed.comm.pmin(x, axis_name: str | None = None)[source]
evorl.distributed.comm.psum(x, axis_name: str | None = None)[source]
evorl.distributed.comm.split_key_to_devices(key: chex.PRNGKey, devices: collections.abc.Sequence[jax.Device])[source]

Split the key to each device.