Source code for evorl.ec.evox_algorithm.sort_utils
1import jax.numpy as jnp
2
3
[docs]
4def sort_by_key(keys, *vals):
5 assert len(keys.shape) == 1, (
6 f"Expect keys to be a 1d-vector, got shape {keys.shape}."
7 )
8 order = jnp.argsort(keys)
9 vals = map(lambda v: v[order], vals)
10 return keys[order], *vals