Source code for evorl.algorithms.contrib.pop_ppo

  1import logging
  2from functools import partial
  3import math
  4from typing_extensions import Self  # pytype: disable=not-supported-yet]
  5from omegaconf import DictConfig
  6
  7import chex
  8import jax
  9import jax.tree_util as jtu
 10
 11from evorl.types import State, MISSING_REWARD
 12from evorl.metrics import MetricBase
 13from evorl.utils.jax_utils import scan_and_last
 14from evorl.recorders import add_prefix, get_1d_array_statistics, get_1d_array
 15
 16from evorl.algorithms.ppo import PPOWorkflow
 17
 18logger = logging.getLogger(__name__)
 19
 20
[docs] 21class PopPPOWorkflow(PPOWorkflow):
[docs] 22 @classmethod 23 def name(cls): 24 return "PopPPO"
25
[docs] 26 @classmethod 27 def build_from_config( 28 cls, 29 config: DictConfig, 30 enable_multi_devices: bool = False, 31 enable_jit: bool = True, 32 ) -> Self: 33 devices = jax.local_devices() 34 35 if enable_multi_devices or len(devices) > 1: 36 raise NotImplementedError("Multi-devices is not supported yet.") 37 38 return super().build_from_config(config, enable_multi_devices, enable_jit)
39
[docs] 40 def setup(self, key: chex.PRNGKey) -> State: 41 state = jax.vmap(super().setup)( 42 jax.random.split(key, self.config.pop_size), 43 ) 44 45 return state
46
[docs] 47 def evaluate(self, state): 48 return jax.vmap(super().evaluate)(state)
49
[docs] 50 def step(self, state: State) -> tuple[MetricBase, State]: 51 return jax.vmap(super().step)(state)
52 53 def _multi_steps(self, state): 54 def _step(state, _): 55 train_metrics, state = self.step(state) 56 return state, train_metrics 57 58 state, train_metrics = scan_and_last( 59 _step, state, (), length=self.config.fold_iters 60 ) 61 62 return train_metrics, state 63
[docs] 64 def learn(self, state: State) -> State: 65 one_step_timesteps = ( 66 self.config.rollout_length * self.config.num_envs * self.config.fold_iters 67 ) 68 sampled_timesteps = state.metrics.sampled_timesteps.tolist()[0] 69 num_iters = math.ceil( 70 (self.config.total_timesteps - sampled_timesteps) / one_step_timesteps 71 ) 72 73 for i in range(num_iters): 74 train_metrics, state = self._multi_steps(state) 75 workflow_metrics = state.metrics 76 77 iters = state.metrics.iterations.tolist()[0] 78 79 workflow_metrics_data = jtu.tree_map( 80 lambda x: x[0], 81 workflow_metrics.to_local_dict(), 82 ) 83 84 self.recorder.write(workflow_metrics_data, iters) 85 86 train_metric_data = train_metrics.to_local_dict() 87 train_episode_return = train_metric_data["train_episode_return"] 88 train_episode_return = train_episode_return[ 89 train_episode_return != MISSING_REWARD 90 ] 91 if len(train_episode_return) > 0: 92 train_metric_data["train_episode_return"] = train_episode_return 93 else: 94 train_metric_data["train_episode_return"] = None 95 96 train_metric_data = jtu.tree_map( 97 partial(get_1d_array_statistics, histogram=True), 98 train_metric_data, 99 ) 100 self.recorder.write(train_metric_data, iters) 101 102 if iters % self.config.eval_interval == 0 or iters == num_iters: 103 eval_metrics, state = self.evaluate(state) 104 105 eval_metrics_dict = jtu.tree_map( 106 get_1d_array, 107 eval_metrics.to_local_dict(), 108 ) 109 110 self.recorder.write(add_prefix(eval_metrics_dict, "eval"), iters) 111 112 self.checkpoint_manager.save( 113 iters, 114 state, 115 force=iters == num_iters, 116 ) 117 118 return state