Source code for evorl.ec.operators.selection.tournament_selection
1import chex
2import jax
3import jax.numpy as jnp
4
5from evorl.types import PyTreeNode
6
7
8def tournament_selection(
9 fitnesses: chex.Array,
10 num_offsprings: int,
11 key: chex.PRNGKey,
12 *,
13 tournament_size: int,
14):
15 """Tournament selection operator for single objective."""
16 chex.assert_shape(fitnesses, (fitnesses.shape[0],))
17 assert num_offsprings > 0, "num_offsprings must be positive"
18 assert tournament_size > 1, "tournament_size must be greater than 1"
19
20 ranks = jnp.argsort(fitnesses, descending=True)
21 pop_size = len(ranks)
22
23 selected_indices = jnp.min(
24 jax.random.randint(key, (num_offsprings, tournament_size), 0, pop_size), axis=-1
25 )
26 return ranks[selected_indices]
27
28
[docs]
29class TournamentSelection(PyTreeNode):
30 tournament_size: int = 2
31
32 def __call__(self, fitnesses, num_offsprings, key):
33 return tournament_selection(
34 fitnesses, num_offsprings, key, tournament_size=self.tournament_size
35 )