Source code for evorl.utils.ma_utils
1import chex
2import jax
3import jax.numpy as jnp
4
5from evorl.types import AgentID, Done
6
7
[docs]
8def batchify(x: dict[AgentID, jax.Array], agent_list, padding=False) -> jax.Array:
9 """Batchify the data for multi-agent training.
10
11 Output batched data.
12
13 Args:
14 x: data from each agent, [batch_dims..., val]
15 Note: Currently, only the last dimension is viewed as value, and the rest are batch dimensions.
16 agent_list: list, list of agent names
17 num_actors: int, number of actors
18 padding: bool, whether to pad the data to the same length over the last dimension.
19 set to False if the data already has same length.
20
21 Returns:
22 Concatenated data from multiple agents with shape: [num_actors, batch_dims..., val]
23 """
24 if padding:
25
26 def _pad(z, length):
27 return jnp.concatenate(
28 [z, jnp.zeros(z.shape[:-1] + [length - z.shape[-1]])], -1
29 )
30
31 max_dim = max([x[a].shape[-1] for a in agent_list])
32 x = jnp.stack(
33 [
34 x[a] if x[a].shape[-1] == max_dim else _pad(x[a], max_dim)
35 for a in agent_list
36 ]
37 )
38 else:
39 x = jnp.stack([x[a] for a in agent_list])
40
41 return x # [num_actors, batch_dims..., val]
42
43
[docs]
44def unbatchify(x: jax.Array, agent_list) -> dict[AgentID, jax.Array]:
45 """Unbatchify the data for multi-agent training.
46
47 Here we assume data like actions has the same shape for each agent. (True for MaBrax)
48
49 Args:
50 x: batched data, [num_actors, batch_dims..., val]
51 Note: Currently, only the last dimension is viewed as value, and the rest are batch dimensions.
52 agent_list: list, list of agent names
53
54 Returns:
55 Dict {agent_name: data}
56 """
57 return {a: x[i] for i, a in enumerate(agent_list)}
58
59
[docs]
60def multi_agent_episode_done(done: Done) -> chex.Array:
61 """Check whether the multi-agent episode is done."""
62 return done["__all__"]