evorl.distributed¶
Package Contents¶
Functions¶
All-gather the data across all devices. |
|
Return the global rank for each device. |
|
Return the node id in multi-node distributed env. |
|
Wrapper of the loss function that apply gradient updates. |
|
Whether the JAX’s distributed setting is initialized. |
|
Sequential execution on different gpu. |
|
Vmap on different gpu. |
|
Split the key to each device. |
|
Pytree version of |
Data¶
API¶
- evorl.distributed.DP_AXIS_NAME¶
‘DP’
- evorl.distributed.POP_AXIS_NAME¶
‘POP’
- evorl.distributed.agent_gradient_update(loss_fn: collections.abc.Callable[..., float], optimizer: optax.GradientTransformation, dp_axis_name: str | None = None, has_aux: bool = False, attach_fn: collections.abc.Callable[[chex.ArrayTree, chex.ArrayTree], chex.ArrayTree] = _attach_params_to_agent_state, detach_fn: collections.abc.Callable[[chex.ArrayTree, chex.ArrayTree], chex.ArrayTree] = _detach_params_to_agent_state)[source]¶
- evorl.distributed.all_gather(x, axis_name: str | None = None, **kwargs)[source]¶
All-gather the data across all devices.
- evorl.distributed.get_global_ranks()[source]¶
Return the global rank for each device.
- Returns:
The sharded ranks across devices. Each device has a unique rank.
- evorl.distributed.gradient_update(loss_fn: collections.abc.Callable[..., float], optimizer: optax.GradientTransformation, dp_axis_name: str | None, has_aux: bool = False)[source]¶
Wrapper of the loss function that apply gradient updates.
- Parameters:
loss_fn – The loss function. (params, …) -> loss
optimizer – The optimizer to apply gradients.
dp_axis_name – If relevant, the name of the pmap axis to synchronize gradients.
has_aux – Whether the loss_fn has auxiliary data.
- Returns:
A function that takes the same argument as the loss function plus the optimizer state. The output of this function is the loss, the new parameter, and the new optimizer state.
- evorl.distributed.is_dist_initialized()[source]¶
Whether the JAX’s distributed setting is initialized.
- evorl.distributed.shmap_map(fn: collections.abc.Callable, mesh, in_specs, out_specs, **kwargs)[source]¶
Sequential execution on different gpu.
- Parameters:
fn – function to be executed, only positional arguments are supported
sharding – JAX sharding object
- Returns:
A wrapped function.
- evorl.distributed.shmap_vmap(fn: collections.abc.Callable, mesh, in_specs, out_specs, **kwargs)[source]¶
Vmap on different gpu.