Source code for evorl.algorithms.ec.so.es_workflow

  1import logging
  2from functools import partial
  3from omegaconf import DictConfig
  4
  5import jax
  6import jax.tree_util as jtu
  7
  8from evorl.agent import Agent, AgentState, AgentStateAxis
  9from evorl.evaluators import Evaluator
 10from evorl.metrics import EvaluateMetric, MetricBase
 11from evorl.types import State
 12from evorl.envs import Env
 13from evorl.ec.optimizers import EvoOptimizer
 14from evorl.recorders import get_1d_array_statistics
 15from evorl.workflows import ECWorkflowTemplate
 16
 17logger = logging.getLogger(__name__)
 18
 19
[docs] 20class ESWorkflowTemplate(ECWorkflowTemplate): 21 def __init__( 22 self, 23 *, 24 env: Env, 25 agent: Agent, 26 ec_optimizer: EvoOptimizer, 27 ec_evaluator: Evaluator, 28 evaluator: Evaluator, 29 agent_state_vmap_axes: AgentStateAxis = 0, 30 config: DictConfig, 31 ): 32 super().__init__( 33 env=env, 34 agent=agent, 35 ec_optimizer=ec_optimizer, 36 ec_evaluator=ec_evaluator, 37 agent_state_vmap_axes=agent_state_vmap_axes, 38 config=config, 39 ) 40 41 self.evaluator = evaluator # independent evaluator for pop_center 42 43 def _get_pop_center(self, state: State) -> AgentState: 44 raise NotImplementedError 45 46 def _record_callback(self, state: State, iters: int) -> None: 47 pass 48
[docs] 49 def evaluate(self, state: State) -> tuple[MetricBase, State]: 50 """Evaluate the policy with the mean of ES.""" 51 key, eval_key = jax.random.split(state.key, num=2) 52 53 agent_state = self._get_pop_center(state) 54 55 # [#episodes] 56 raw_eval_metrics = self.evaluator.evaluate( 57 agent_state, eval_key, num_episodes=self.config.eval_episodes 58 ) 59 60 eval_metrics = EvaluateMetric( 61 episode_returns=raw_eval_metrics.episode_returns.mean(), 62 episode_lengths=raw_eval_metrics.episode_lengths.mean(), 63 ).all_reduce(dp_axis_name=self.dp_axis_name) 64 65 return eval_metrics, state.replace(key=key)
66
[docs] 67 def learn(self, state: State) -> State: 68 start_iteration = state.metrics.iterations 69 70 for i in range(start_iteration, self.config.num_iters): 71 iters = i + 1 72 train_metrics, state = self.step(state) 73 workflow_metrics = state.metrics 74 75 self.recorder.write(workflow_metrics.to_local_dict(), iters) 76 77 train_metrics_dict = train_metrics.to_local_dict() 78 train_metrics_dict = jtu.tree_map( 79 partial(get_1d_array_statistics, histogram=True), 80 train_metrics.to_local_dict(), 81 ) 82 self.recorder.write(train_metrics_dict, iters) 83 84 if iters % self.config.eval_interval == 0 or iters == self.config.num_iters: 85 eval_metrics, state = self.evaluate(state) 86 self.recorder.write( 87 {"eval/pop_center": eval_metrics.to_local_dict()}, iters 88 ) 89 90 self._record_callback(state, iters) 91 92 self.checkpoint_manager.save( 93 iters, 94 state, 95 force=i == self.config.num_iters, 96 )
97
[docs] 98 @classmethod 99 def enable_jit(cls) -> None: 100 super().enable_jit() 101 cls.evaluate = jax.jit(cls.evaluate, static_argnums=(0,))
102
[docs] 103 @classmethod 104 def enable_shmap(cls, axis_name) -> None: 105 super().enable_shmap(axis_name) 106 # evaluate is handled via the base class pattern 107 pass