Source code for evorl.ec.operators.selection.tournament_selection

 1import chex
 2import jax
 3import jax.numpy as jnp
 4
 5from evorl.types import PyTreeNode
 6
 7
 8def tournament_selection(
 9    fitnesses: chex.Array,
10    num_offsprings: int,
11    key: chex.PRNGKey,
12    *,
13    tournament_size: int,
14):
15    """Tournament selection operator for single objective."""
16    chex.assert_shape(fitnesses, (fitnesses.shape[0],))
17    assert num_offsprings > 0, "num_offsprings must be positive"
18    assert tournament_size > 1, "tournament_size must be greater than 1"
19
20    ranks = jnp.argsort(fitnesses, descending=True)
21    pop_size = len(ranks)
22
23    selected_indices = jnp.min(
24        jax.random.randint(key, (num_offsprings, tournament_size), 0, pop_size), axis=-1
25    )
26    return ranks[selected_indices]
27
28
[docs] 29class TournamentSelection(PyTreeNode): 30 tournament_size: int = 2 31 32 def __call__(self, fitnesses, num_offsprings, key): 33 return tournament_selection( 34 fitnesses, num_offsprings, key, tournament_size=self.tournament_size 35 )