evorl.distributed.gradients

Module Contents

Functions

agent_gradient_update

gradient_update

Wrapper of the loss function that apply gradient updates.

loss_and_pgrad

API

evorl.distributed.gradients.agent_gradient_update(loss_fn: collections.abc.Callable[..., float], optimizer: optax.GradientTransformation, dp_axis_name: str | None = None, has_aux: bool = False, attach_fn: collections.abc.Callable[[chex.ArrayTree, chex.ArrayTree], chex.ArrayTree] = _attach_params_to_agent_state, detach_fn: collections.abc.Callable[[chex.ArrayTree, chex.ArrayTree], chex.ArrayTree] = _detach_params_to_agent_state)[source]
evorl.distributed.gradients.gradient_update(loss_fn: collections.abc.Callable[..., float], optimizer: optax.GradientTransformation, dp_axis_name: str | None, has_aux: bool = False)[source]

Wrapper of the loss function that apply gradient updates.

Parameters:
  • loss_fn – The loss function. (params, …) -> loss

  • optimizer – The optimizer to apply gradients.

  • dp_axis_name – If relevant, the name of the pmap axis to synchronize gradients.

  • has_aux – Whether the loss_fn has auxiliary data.

Returns:

A function that takes the same argument as the loss function plus the optimizer state. The output of this function is the loss, the new parameter, and the new optimizer state.

evorl.distributed.gradients.loss_and_pgrad(loss_fn: collections.abc.Callable[..., float], dp_axis_name: str | None, has_aux: bool = False)[source]