1import logging
2from functools import partial
3from omegaconf import DictConfig
4
5import jax
6import jax.tree_util as jtu
7
8from evorl.agent import Agent, AgentState, AgentStateAxis
9from evorl.evaluators import Evaluator
10from evorl.metrics import EvaluateMetric, MetricBase
11from evorl.types import State
12from evorl.envs import Env
13from evorl.ec.optimizers import EvoOptimizer
14from evorl.recorders import get_1d_array_statistics
15from evorl.workflows import ECWorkflowTemplate
16
17logger = logging.getLogger(__name__)
18
19
[docs]
20class ESWorkflowTemplate(ECWorkflowTemplate):
21 def __init__(
22 self,
23 *,
24 env: Env,
25 agent: Agent,
26 ec_optimizer: EvoOptimizer,
27 ec_evaluator: Evaluator,
28 evaluator: Evaluator,
29 agent_state_vmap_axes: AgentStateAxis = 0,
30 config: DictConfig,
31 ):
32 super().__init__(
33 env=env,
34 agent=agent,
35 ec_optimizer=ec_optimizer,
36 ec_evaluator=ec_evaluator,
37 agent_state_vmap_axes=agent_state_vmap_axes,
38 config=config,
39 )
40
41 self.evaluator = evaluator # independent evaluator for pop_center
42
43 def _get_pop_center(self, state: State) -> AgentState:
44 raise NotImplementedError
45
46 def _record_callback(self, state: State, iters: int) -> None:
47 pass
48
[docs]
49 def evaluate(self, state: State) -> tuple[MetricBase, State]:
50 """Evaluate the policy with the mean of ES."""
51 key, eval_key = jax.random.split(state.key, num=2)
52
53 agent_state = self._get_pop_center(state)
54
55 # [#episodes]
56 raw_eval_metrics = self.evaluator.evaluate(
57 agent_state, eval_key, num_episodes=self.config.eval_episodes
58 )
59
60 eval_metrics = EvaluateMetric(
61 episode_returns=raw_eval_metrics.episode_returns.mean(),
62 episode_lengths=raw_eval_metrics.episode_lengths.mean(),
63 ).all_reduce(dp_axis_name=self.dp_axis_name)
64
65 return eval_metrics, state.replace(key=key)
66
[docs]
67 def learn(self, state: State) -> State:
68 start_iteration = state.metrics.iterations
69
70 for i in range(start_iteration, self.config.num_iters):
71 iters = i + 1
72 train_metrics, state = self.step(state)
73 workflow_metrics = state.metrics
74
75 self.recorder.write(workflow_metrics.to_local_dict(), iters)
76
77 train_metrics_dict = train_metrics.to_local_dict()
78 train_metrics_dict = jtu.tree_map(
79 partial(get_1d_array_statistics, histogram=True),
80 train_metrics.to_local_dict(),
81 )
82 self.recorder.write(train_metrics_dict, iters)
83
84 if iters % self.config.eval_interval == 0 or iters == self.config.num_iters:
85 eval_metrics, state = self.evaluate(state)
86 self.recorder.write(
87 {"eval/pop_center": eval_metrics.to_local_dict()}, iters
88 )
89
90 self._record_callback(state, iters)
91
92 self.checkpoint_manager.save(
93 iters,
94 state,
95 force=i == self.config.num_iters,
96 )
97
[docs]
98 @classmethod
99 def enable_jit(cls) -> None:
100 super().enable_jit()
101 cls.evaluate = jax.jit(cls.evaluate, static_argnums=(0,))
102
[docs]
103 @classmethod
104 def enable_shmap(cls, axis_name) -> None:
105 super().enable_shmap(axis_name)
106 # evaluate is handled via the base class pattern
107 pass