Source code for evorl.utils.ma_utils

 1import chex
 2import jax
 3import jax.numpy as jnp
 4
 5from evorl.types import AgentID, Done
 6
 7
[docs] 8def batchify(x: dict[AgentID, jax.Array], agent_list, padding=False) -> jax.Array: 9 """Batchify the data for multi-agent training. 10 11 Output batched data. 12 13 Args: 14 x: data from each agent, [batch_dims..., val] 15 Note: Currently, only the last dimension is viewed as value, and the rest are batch dimensions. 16 agent_list: list, list of agent names 17 num_actors: int, number of actors 18 padding: bool, whether to pad the data to the same length over the last dimension. 19 set to False if the data already has same length. 20 21 Returns: 22 Concatenated data from multiple agents with shape: [num_actors, batch_dims..., val] 23 """ 24 if padding: 25 26 def _pad(z, length): 27 return jnp.concatenate( 28 [z, jnp.zeros(z.shape[:-1] + [length - z.shape[-1]])], -1 29 ) 30 31 max_dim = max([x[a].shape[-1] for a in agent_list]) 32 x = jnp.stack( 33 [ 34 x[a] if x[a].shape[-1] == max_dim else _pad(x[a], max_dim) 35 for a in agent_list 36 ] 37 ) 38 else: 39 x = jnp.stack([x[a] for a in agent_list]) 40 41 return x # [num_actors, batch_dims..., val]
42 43
[docs] 44def unbatchify(x: jax.Array, agent_list) -> dict[AgentID, jax.Array]: 45 """Unbatchify the data for multi-agent training. 46 47 Here we assume data like actions has the same shape for each agent. (True for MaBrax) 48 49 Args: 50 x: batched data, [num_actors, batch_dims..., val] 51 Note: Currently, only the last dimension is viewed as value, and the rest are batch dimensions. 52 agent_list: list, list of agent names 53 54 Returns: 55 Dict {agent_name: data} 56 """ 57 return {a: x[i] for i, a in enumerate(agent_list)}
58 59
[docs] 60def multi_agent_episode_done(done: Done) -> chex.Array: 61 """Check whether the multi-agent episode is done.""" 62 return done["__all__"]