Source code for evorl.distributed.gradients

 1from collections.abc import Callable
 2
 3import chex
 4import jax
 5import optax
 6
 7"""Training gradient utility functions.
 8
 9Modified from https://github.com/google/brax/blob/main/brax/training/gradients.py
10"""
11
12
[docs] 13def loss_and_pgrad( 14 loss_fn: Callable[..., float], dp_axis_name: str | None, has_aux: bool = False 15): 16 g = jax.value_and_grad(loss_fn, has_aux=has_aux) 17 18 def h(*args, **kwargs): 19 value, grads = g(*args, **kwargs) 20 return value, jax.lax.pmean(grads, axis_name=dp_axis_name) 21 22 return g if dp_axis_name is None else h
23 24
[docs] 25def gradient_update( 26 loss_fn: Callable[..., float], 27 optimizer: optax.GradientTransformation, 28 dp_axis_name: str | None, 29 has_aux: bool = False, 30): 31 """Wrapper of the loss function that apply gradient updates. 32 33 Args: 34 loss_fn: The loss function. (params, ...) -> loss 35 optimizer: The optimizer to apply gradients. 36 dp_axis_name: If relevant, the name of the pmap axis to synchronize gradients. 37 has_aux: Whether the loss_fn has auxiliary data. 38 39 Returns: 40 A function that takes the same argument as the loss function plus the 41 optimizer state. The output of this function is the loss, the new parameter, 42 and the new optimizer state. 43 """ 44 loss_and_pgrad_fn = loss_and_pgrad( 45 loss_fn, dp_axis_name=dp_axis_name, has_aux=has_aux 46 ) 47 48 def f(opt_state, params, *args, **kwargs): 49 value, grads = loss_and_pgrad_fn(params, *args, **kwargs) 50 params_update, opt_state = optimizer.update(grads, opt_state) 51 params = optax.apply_updates(params, params_update) 52 return ( 53 value, 54 params, 55 opt_state, 56 ) 57 58 return f
59 60 61def _attach_params_to_agent_state(agent_state, params): 62 return agent_state.replace(params=params) 63 64 65def _detach_params_to_agent_state(agent_state): 66 return agent_state.params 67 68
[docs] 69def agent_gradient_update( 70 loss_fn: Callable[..., float], 71 optimizer: optax.GradientTransformation, 72 dp_axis_name: str | None = None, 73 has_aux: bool = False, 74 attach_fn: Callable[ 75 [chex.ArrayTree, chex.ArrayTree], chex.ArrayTree 76 ] = _attach_params_to_agent_state, 77 detach_fn: Callable[ 78 [chex.ArrayTree, chex.ArrayTree], chex.ArrayTree 79 ] = _detach_params_to_agent_state, 80): 81 def _loss_fn(params, agent_state, sample_batch, key): 82 agent_state = attach_fn(agent_state, params) 83 return loss_fn(agent_state, sample_batch, key) 84 85 _gradient_update_fn = gradient_update( 86 _loss_fn, optimizer, dp_axis_name=dp_axis_name, has_aux=has_aux 87 ) 88 89 def f(opt_state, agent_state, *args, **kwargs): 90 params = detach_fn(agent_state) 91 value, params, opt_state = _gradient_update_fn( 92 opt_state, params, agent_state, *args, **kwargs 93 ) 94 agent_state = attach_fn(agent_state, params) 95 96 return value, agent_state, opt_state 97 98 return f