1import logging
2import numpy as np
3from omegaconf import DictConfig
4from typing_extensions import Self # pytype: disable=not-supported-yet]
5
6import jax
7import jax.numpy as jnp
8
9from evorl.types import State, Params
10from evorl.envs import AutoresetMode, create_env
11from evorl.evaluators import BraxEvaluator
12from evorl.agent import AgentState
13from evorl.ec.optimizers.ec_optimizer import ECState
14from evorl.utils.ec_utils import ParamVectorSpec
15from evorl.recorders import get_1d_array_statistics
16from evorl.workflows import MultiObjectiveECWorkflowTemplate
17
18from ..obs_utils import init_obs_preprocessor
19from ..ec_agent import make_deterministic_ec_agent
20
21logger = logging.getLogger(__name__)
22
23
[docs]
24class NSGA2Workflow(MultiObjectiveECWorkflowTemplate):
[docs]
25 @classmethod
26 def name(cls):
27 return "NSGA2"
28
29 @classmethod
30 def _rescale_config(cls, config: DictConfig) -> None:
31 super()._rescale_config(config)
32
33 num_devices = jax.device_count()
34 if config.random_timesteps % num_devices != 0:
35 logging.warning(
36 f"When enable_multi_devices=True, pop_size ({config.random_timesteps}) should be divisible by num_devices ({num_devices}),"
37 )
38
39 config.random_timesteps = (config.random_timesteps // num_devices) * num_devices
40
41 @classmethod
42 def _build_from_config(cls, config: DictConfig) -> Self:
43 env = create_env(
44 config.env,
45 episode_length=config.env.max_episode_steps,
46 parallel=config.num_envs,
47 autoreset_mode=AutoresetMode.DISABLED,
48 )
49
50 agent = make_deterministic_ec_agent(
51 action_space=env.action_space,
52 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes,
53 use_bias=config.agent_network.use_bias,
54 normalize_obs=config.normalize_obs,
55 norm_layer_type=config.agent_network.norm_layer_type,
56 )
57
58 # dummy agent_state
59 agent_key = jax.random.PRNGKey(config.seed)
60 agent_state = agent.init(env.obs_space, env.action_space, agent_key)
61 param_vec_spec = ParamVectorSpec(agent_state.params.policy_params)
62
63 from evorl.ec.optimizers.evox_wrapper import EvoXAlgorithmAdapter
64 from evox.algorithms import NSGA2
65
66 ec_optimizer = EvoXAlgorithmAdapter(
67 algorithm=NSGA2(
68 lb=jnp.full((param_vec_spec.vec_size,), config.agent_network.lb),
69 ub=jnp.full((param_vec_spec.vec_size,), config.agent_network.ub),
70 n_objs=len(config.metric_names),
71 pop_size=config.pop_size,
72 ),
73 param_vec_spec=param_vec_spec,
74 )
75
76 if config.explore:
77 action_fn = agent.compute_actions
78 else:
79 action_fn = agent.evaluate_actions
80
81 ec_evaluator = BraxEvaluator(
82 env=env,
83 action_fn=action_fn,
84 max_episode_steps=config.env.max_episode_steps,
85 discount=config.discount,
86 metric_names=tuple(config.metric_names),
87 )
88
89 agent_state_vmap_axes = AgentState(
90 params=0,
91 obs_preprocessor_state=None,
92 )
93
94 return cls(
95 config=config,
96 env=env,
97 agent=agent,
98 ec_optimizer=ec_optimizer,
99 ec_evaluator=ec_evaluator,
100 agent_state_vmap_axes=agent_state_vmap_axes,
101 )
102
103 def _setup_agent_and_optimizer(self, key: jax.Array) -> tuple[AgentState, ECState]:
104 agent_key, ec_key = jax.random.split(key)
105 agent_state = self.agent.init(
106 self.env.obs_space, self.env.action_space, agent_key
107 )
108
109 ec_opt_state = self.ec_optimizer.init(ec_key)
110
111 # remove params
112 agent_state = self._replace_actor_params(agent_state, params=None)
113
114 return agent_state, ec_opt_state
115
116 def _postsetup(self, state: State) -> State:
117 # setup obs_preprocessor_state
118 if self.config.normalize_obs:
119 key, obs_key = jax.random.split(state.key, 2)
120 agent_state = init_obs_preprocessor(
121 agent_state=state.agent_state,
122 config=self.config,
123 key=obs_key,
124 dp_axis_name=self.dp_axis_name,
125 )
126
127 # Note: we don't count these random timesteps in state.metrics
128 return state.replace(
129 agent_state=agent_state,
130 key=key,
131 )
132 else:
133 return state
134
135 def _replace_actor_params(
136 self, agent_state: AgentState, params: Params
137 ) -> AgentState:
138 return agent_state.replace(
139 params=agent_state.params.replace(policy_params=params)
140 )
141
[docs]
142 def learn(self, state: State) -> State:
143 start_iteration = state.metrics.iterations
144
145 for i in range(start_iteration, self.config.num_iters):
146 iters = i + 1
147 train_metrics, state = self.step(state)
148 workflow_metrics = state.metrics
149
150 self.recorder.write(workflow_metrics.to_local_dict(), iters)
151
152 cpu_device = jax.devices("cpu")[0]
153 with jax.default_device(cpu_device):
154 from evox.operators import non_dominated_sort
155
156 objectives = jax.device_put(train_metrics.objectives, cpu_device)
157 pf_rank = non_dominated_sort(-objectives, "scan")
158 pf_objectives = train_metrics.objectives[pf_rank == 0]
159
160 train_metrics_dict = {}
161 metric_names = self.config.metric_names
162 objectives = np.asarray(objectives)
163 pf_objectives = np.asarray(pf_objectives)
164 train_metrics_dict["objectives"] = {
165 metric_names[i]: get_1d_array_statistics(
166 objectives[:, i], histogram=True
167 )
168 for i in range(len(metric_names))
169 }
170
171 train_metrics_dict["pf_objectives"] = {
172 metric_names[i]: get_1d_array_statistics(
173 pf_objectives[:, i], histogram=True
174 )
175 for i in range(len(metric_names))
176 }
177 train_metrics_dict["num_pf"] = pf_objectives.shape[0]
178
179 self.recorder.write(train_metrics_dict, iters)
180
181 self.checkpoint_manager.save(
182 iters,
183 state,
184 force=i == self.config.num_iters,
185 )