1from collections.abc import Callable
2
3import chex
4import jax
5import optax
6
7"""Training gradient utility functions.
8
9Modified from https://github.com/google/brax/blob/main/brax/training/gradients.py
10"""
11
12
[docs]
13def loss_and_pgrad(
14 loss_fn: Callable[..., float], dp_axis_name: str | None, has_aux: bool = False
15):
16 g = jax.value_and_grad(loss_fn, has_aux=has_aux)
17
18 def h(*args, **kwargs):
19 value, grads = g(*args, **kwargs)
20 return value, jax.lax.pmean(grads, axis_name=dp_axis_name)
21
22 return g if dp_axis_name is None else h
23
24
[docs]
25def gradient_update(
26 loss_fn: Callable[..., float],
27 optimizer: optax.GradientTransformation,
28 dp_axis_name: str | None,
29 has_aux: bool = False,
30):
31 """Wrapper of the loss function that apply gradient updates.
32
33 Args:
34 loss_fn: The loss function. (params, ...) -> loss
35 optimizer: The optimizer to apply gradients.
36 dp_axis_name: If relevant, the name of the pmap axis to synchronize gradients.
37 has_aux: Whether the loss_fn has auxiliary data.
38
39 Returns:
40 A function that takes the same argument as the loss function plus the
41 optimizer state. The output of this function is the loss, the new parameter,
42 and the new optimizer state.
43 """
44 loss_and_pgrad_fn = loss_and_pgrad(
45 loss_fn, dp_axis_name=dp_axis_name, has_aux=has_aux
46 )
47
48 def f(opt_state, params, *args, **kwargs):
49 value, grads = loss_and_pgrad_fn(params, *args, **kwargs)
50 params_update, opt_state = optimizer.update(grads, opt_state)
51 params = optax.apply_updates(params, params_update)
52 return (
53 value,
54 params,
55 opt_state,
56 )
57
58 return f
59
60
61def _attach_params_to_agent_state(agent_state, params):
62 return agent_state.replace(params=params)
63
64
65def _detach_params_to_agent_state(agent_state):
66 return agent_state.params
67
68
[docs]
69def agent_gradient_update(
70 loss_fn: Callable[..., float],
71 optimizer: optax.GradientTransformation,
72 dp_axis_name: str | None = None,
73 has_aux: bool = False,
74 attach_fn: Callable[
75 [chex.ArrayTree, chex.ArrayTree], chex.ArrayTree
76 ] = _attach_params_to_agent_state,
77 detach_fn: Callable[
78 [chex.ArrayTree, chex.ArrayTree], chex.ArrayTree
79 ] = _detach_params_to_agent_state,
80):
81 def _loss_fn(params, agent_state, sample_batch, key):
82 agent_state = attach_fn(agent_state, params)
83 return loss_fn(agent_state, sample_batch, key)
84
85 _gradient_update_fn = gradient_update(
86 _loss_fn, optimizer, dp_axis_name=dp_axis_name, has_aux=has_aux
87 )
88
89 def f(opt_state, agent_state, *args, **kwargs):
90 params = detach_fn(agent_state)
91 value, params, opt_state = _gradient_update_fn(
92 opt_state, params, agent_state, *args, **kwargs
93 )
94 agent_state = attach_fn(agent_state, params)
95
96 return value, agent_state, opt_state
97
98 return f