evorl.distributed

Package Contents

Functions

agent_gradient_update

all_gather

All-gather the data across all devices.

get_global_ranks

Return the global rank for each device.

get_process_id

Return the node id in multi-node distributed env.

gradient_update

Wrapper of the loss function that apply gradient updates.

is_dist_initialized

Whether the JAX’s distributed setting is initialized.

pmax

pmean

pmin

psum

shmap_map

Sequential execution on different gpu.

shmap_vmap

Vmap on different gpu.

split_key_to_devices

Split the key to each device.

tree_device_put

Pytree version of jax.device_put.

Data

API

evorl.distributed.DP_AXIS_NAME

‘DP’

evorl.distributed.POP_AXIS_NAME

‘POP’

evorl.distributed.agent_gradient_update(loss_fn: collections.abc.Callable[..., float], optimizer: optax.GradientTransformation, dp_axis_name: str | None = None, has_aux: bool = False, attach_fn: collections.abc.Callable[[chex.ArrayTree, chex.ArrayTree], chex.ArrayTree] = _attach_params_to_agent_state, detach_fn: collections.abc.Callable[[chex.ArrayTree, chex.ArrayTree], chex.ArrayTree] = _detach_params_to_agent_state)[source]
evorl.distributed.all_gather(x, axis_name: str | None = None, **kwargs)[source]

All-gather the data across all devices.

evorl.distributed.get_global_ranks()[source]

Return the global rank for each device.

Returns:

The sharded ranks across devices. Each device has a unique rank.

evorl.distributed.get_process_id()[source]

Return the node id in multi-node distributed env.

evorl.distributed.gradient_update(loss_fn: collections.abc.Callable[..., float], optimizer: optax.GradientTransformation, dp_axis_name: str | None, has_aux: bool = False)[source]

Wrapper of the loss function that apply gradient updates.

Parameters:
  • loss_fn – The loss function. (params, …) -> loss

  • optimizer – The optimizer to apply gradients.

  • dp_axis_name – If relevant, the name of the pmap axis to synchronize gradients.

  • has_aux – Whether the loss_fn has auxiliary data.

Returns:

A function that takes the same argument as the loss function plus the optimizer state. The output of this function is the loss, the new parameter, and the new optimizer state.

evorl.distributed.is_dist_initialized()[source]

Whether the JAX’s distributed setting is initialized.

evorl.distributed.pmax(x, axis_name: str | None = None)[source]
evorl.distributed.pmean(x, axis_name: str | None = None)[source]
evorl.distributed.pmin(x, axis_name: str | None = None)[source]
evorl.distributed.psum(x, axis_name: str | None = None)[source]
evorl.distributed.shmap_map(fn: collections.abc.Callable, mesh, in_specs, out_specs, **kwargs)[source]

Sequential execution on different gpu.

Parameters:
  • fn – function to be executed, only positional arguments are supported

  • sharding – JAX sharding object

Returns:

A wrapped function.

evorl.distributed.shmap_vmap(fn: collections.abc.Callable, mesh, in_specs, out_specs, **kwargs)[source]

Vmap on different gpu.

evorl.distributed.split_key_to_devices(key: chex.PRNGKey, devices: collections.abc.Sequence[jax.Device])[source]

Split the key to each device.

evorl.distributed.tree_device_put(tree: chex.ArrayTree, device_or_sharding)[source]

Pytree version of jax.device_put.