Source code for evorl.ec.evox_algorithm.sort_utils

 1import jax.numpy as jnp
 2
 3
[docs] 4def sort_by_key(keys, *vals): 5 assert len(keys.shape) == 1, ( 6 f"Expect keys to be a 1d-vector, got shape {keys.shape}." 7 ) 8 order = jnp.argsort(keys) 9 vals = map(lambda v: v[order], vals) 10 return keys[order], *vals