Source code for evorl.algorithms.ec.mo.nsga2_brax

  1import logging
  2import numpy as np
  3from omegaconf import DictConfig
  4from typing_extensions import Self  # pytype: disable=not-supported-yet]
  5
  6import jax
  7import jax.numpy as jnp
  8
  9from evorl.types import State, Params
 10from evorl.envs import AutoresetMode, create_env
 11from evorl.evaluators import BraxEvaluator
 12from evorl.agent import AgentState
 13from evorl.ec.optimizers.ec_optimizer import ECState
 14from evorl.utils.ec_utils import ParamVectorSpec
 15from evorl.recorders import get_1d_array_statistics
 16from evorl.workflows import MultiObjectiveECWorkflowTemplate
 17
 18from ..obs_utils import init_obs_preprocessor
 19from ..ec_agent import make_deterministic_ec_agent
 20
 21logger = logging.getLogger(__name__)
 22
 23
[docs] 24class NSGA2Workflow(MultiObjectiveECWorkflowTemplate):
[docs] 25 @classmethod 26 def name(cls): 27 return "NSGA2"
28 29 @classmethod 30 def _rescale_config(cls, config: DictConfig) -> None: 31 super()._rescale_config(config) 32 33 num_devices = jax.device_count() 34 if config.random_timesteps % num_devices != 0: 35 logging.warning( 36 f"When enable_multi_devices=True, pop_size ({config.random_timesteps}) should be divisible by num_devices ({num_devices})," 37 ) 38 39 config.random_timesteps = (config.random_timesteps // num_devices) * num_devices 40 41 @classmethod 42 def _build_from_config(cls, config: DictConfig) -> Self: 43 env = create_env( 44 config.env, 45 episode_length=config.env.max_episode_steps, 46 parallel=config.num_envs, 47 autoreset_mode=AutoresetMode.DISABLED, 48 ) 49 50 agent = make_deterministic_ec_agent( 51 action_space=env.action_space, 52 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes, 53 use_bias=config.agent_network.use_bias, 54 normalize_obs=config.normalize_obs, 55 norm_layer_type=config.agent_network.norm_layer_type, 56 ) 57 58 # dummy agent_state 59 agent_key = jax.random.PRNGKey(config.seed) 60 agent_state = agent.init(env.obs_space, env.action_space, agent_key) 61 param_vec_spec = ParamVectorSpec(agent_state.params.policy_params) 62 63 from evorl.ec.optimizers.evox_wrapper import EvoXAlgorithmAdapter 64 from evox.algorithms import NSGA2 65 66 ec_optimizer = EvoXAlgorithmAdapter( 67 algorithm=NSGA2( 68 lb=jnp.full((param_vec_spec.vec_size,), config.agent_network.lb), 69 ub=jnp.full((param_vec_spec.vec_size,), config.agent_network.ub), 70 n_objs=len(config.metric_names), 71 pop_size=config.pop_size, 72 ), 73 param_vec_spec=param_vec_spec, 74 ) 75 76 if config.explore: 77 action_fn = agent.compute_actions 78 else: 79 action_fn = agent.evaluate_actions 80 81 ec_evaluator = BraxEvaluator( 82 env=env, 83 action_fn=action_fn, 84 max_episode_steps=config.env.max_episode_steps, 85 discount=config.discount, 86 metric_names=tuple(config.metric_names), 87 ) 88 89 agent_state_vmap_axes = AgentState( 90 params=0, 91 obs_preprocessor_state=None, 92 ) 93 94 return cls( 95 config=config, 96 env=env, 97 agent=agent, 98 ec_optimizer=ec_optimizer, 99 ec_evaluator=ec_evaluator, 100 agent_state_vmap_axes=agent_state_vmap_axes, 101 ) 102 103 def _setup_agent_and_optimizer(self, key: jax.Array) -> tuple[AgentState, ECState]: 104 agent_key, ec_key = jax.random.split(key) 105 agent_state = self.agent.init( 106 self.env.obs_space, self.env.action_space, agent_key 107 ) 108 109 ec_opt_state = self.ec_optimizer.init(ec_key) 110 111 # remove params 112 agent_state = self._replace_actor_params(agent_state, params=None) 113 114 return agent_state, ec_opt_state 115 116 def _postsetup(self, state: State) -> State: 117 # setup obs_preprocessor_state 118 if self.config.normalize_obs: 119 key, obs_key = jax.random.split(state.key, 2) 120 agent_state = init_obs_preprocessor( 121 agent_state=state.agent_state, 122 config=self.config, 123 key=obs_key, 124 dp_axis_name=self.dp_axis_name, 125 ) 126 127 # Note: we don't count these random timesteps in state.metrics 128 return state.replace( 129 agent_state=agent_state, 130 key=key, 131 ) 132 else: 133 return state 134 135 def _replace_actor_params( 136 self, agent_state: AgentState, params: Params 137 ) -> AgentState: 138 return agent_state.replace( 139 params=agent_state.params.replace(policy_params=params) 140 ) 141
[docs] 142 def learn(self, state: State) -> State: 143 start_iteration = state.metrics.iterations 144 145 for i in range(start_iteration, self.config.num_iters): 146 iters = i + 1 147 train_metrics, state = self.step(state) 148 workflow_metrics = state.metrics 149 150 self.recorder.write(workflow_metrics.to_local_dict(), iters) 151 152 cpu_device = jax.devices("cpu")[0] 153 with jax.default_device(cpu_device): 154 from evox.operators import non_dominated_sort 155 156 objectives = jax.device_put(train_metrics.objectives, cpu_device) 157 pf_rank = non_dominated_sort(-objectives, "scan") 158 pf_objectives = train_metrics.objectives[pf_rank == 0] 159 160 train_metrics_dict = {} 161 metric_names = self.config.metric_names 162 objectives = np.asarray(objectives) 163 pf_objectives = np.asarray(pf_objectives) 164 train_metrics_dict["objectives"] = { 165 metric_names[i]: get_1d_array_statistics( 166 objectives[:, i], histogram=True 167 ) 168 for i in range(len(metric_names)) 169 } 170 171 train_metrics_dict["pf_objectives"] = { 172 metric_names[i]: get_1d_array_statistics( 173 pf_objectives[:, i], histogram=True 174 ) 175 for i in range(len(metric_names)) 176 } 177 train_metrics_dict["num_pf"] = pf_objectives.shape[0] 178 179 self.recorder.write(train_metrics_dict, iters) 180 181 self.checkpoint_manager.save( 182 iters, 183 state, 184 force=i == self.config.num_iters, 185 )