evorl.distributed.gradients¶
Module Contents¶
Functions¶
Wrapper of the loss function that apply gradient updates. |
|
API¶
- evorl.distributed.gradients.agent_gradient_update(loss_fn: collections.abc.Callable[..., float], optimizer: optax.GradientTransformation, dp_axis_name: str | None = None, has_aux: bool = False, attach_fn: collections.abc.Callable[[chex.ArrayTree, chex.ArrayTree], chex.ArrayTree] = _attach_params_to_agent_state, detach_fn: collections.abc.Callable[[chex.ArrayTree, chex.ArrayTree], chex.ArrayTree] = _detach_params_to_agent_state)[source]¶
- evorl.distributed.gradients.gradient_update(loss_fn: collections.abc.Callable[..., float], optimizer: optax.GradientTransformation, dp_axis_name: str | None, has_aux: bool = False)[source]¶
Wrapper of the loss function that apply gradient updates.
- Parameters:
loss_fn – The loss function. (params, …) -> loss
optimizer – The optimizer to apply gradients.
dp_axis_name – If relevant, the name of the pmap axis to synchronize gradients.
has_aux – Whether the loss_fn has auxiliary data.
- Returns:
A function that takes the same argument as the loss function plus the optimizer state. The output of this function is the loss, the new parameter, and the new optimizer state.