Source code for evorl.algorithms.contrib.pop_ppo
1import logging
2from functools import partial
3import math
4from typing_extensions import Self # pytype: disable=not-supported-yet]
5from omegaconf import DictConfig
6
7import chex
8import jax
9import jax.tree_util as jtu
10
11from evorl.types import State, MISSING_REWARD
12from evorl.metrics import MetricBase
13from evorl.utils.jax_utils import scan_and_last
14from evorl.recorders import add_prefix, get_1d_array_statistics, get_1d_array
15
16from evorl.algorithms.ppo import PPOWorkflow
17
18logger = logging.getLogger(__name__)
19
20
[docs]
21class PopPPOWorkflow(PPOWorkflow):
[docs]
22 @classmethod
23 def name(cls):
24 return "PopPPO"
25
[docs]
26 @classmethod
27 def build_from_config(
28 cls,
29 config: DictConfig,
30 enable_multi_devices: bool = False,
31 enable_jit: bool = True,
32 ) -> Self:
33 devices = jax.local_devices()
34
35 if enable_multi_devices or len(devices) > 1:
36 raise NotImplementedError("Multi-devices is not supported yet.")
37
38 return super().build_from_config(config, enable_multi_devices, enable_jit)
39
[docs]
40 def setup(self, key: chex.PRNGKey) -> State:
41 state = jax.vmap(super().setup)(
42 jax.random.split(key, self.config.pop_size),
43 )
44
45 return state
46
[docs]
47 def evaluate(self, state):
48 return jax.vmap(super().evaluate)(state)
49
[docs]
50 def step(self, state: State) -> tuple[MetricBase, State]:
51 return jax.vmap(super().step)(state)
52
53 def _multi_steps(self, state):
54 def _step(state, _):
55 train_metrics, state = self.step(state)
56 return state, train_metrics
57
58 state, train_metrics = scan_and_last(
59 _step, state, (), length=self.config.fold_iters
60 )
61
62 return train_metrics, state
63
[docs]
64 def learn(self, state: State) -> State:
65 one_step_timesteps = (
66 self.config.rollout_length * self.config.num_envs * self.config.fold_iters
67 )
68 sampled_timesteps = state.metrics.sampled_timesteps.tolist()[0]
69 num_iters = math.ceil(
70 (self.config.total_timesteps - sampled_timesteps) / one_step_timesteps
71 )
72
73 for i in range(num_iters):
74 train_metrics, state = self._multi_steps(state)
75 workflow_metrics = state.metrics
76
77 iters = state.metrics.iterations.tolist()[0]
78
79 workflow_metrics_data = jtu.tree_map(
80 lambda x: x[0],
81 workflow_metrics.to_local_dict(),
82 )
83
84 self.recorder.write(workflow_metrics_data, iters)
85
86 train_metric_data = train_metrics.to_local_dict()
87 train_episode_return = train_metric_data["train_episode_return"]
88 train_episode_return = train_episode_return[
89 train_episode_return != MISSING_REWARD
90 ]
91 if len(train_episode_return) > 0:
92 train_metric_data["train_episode_return"] = train_episode_return
93 else:
94 train_metric_data["train_episode_return"] = None
95
96 train_metric_data = jtu.tree_map(
97 partial(get_1d_array_statistics, histogram=True),
98 train_metric_data,
99 )
100 self.recorder.write(train_metric_data, iters)
101
102 if iters % self.config.eval_interval == 0 or iters == num_iters:
103 eval_metrics, state = self.evaluate(state)
104
105 eval_metrics_dict = jtu.tree_map(
106 get_1d_array,
107 eval_metrics.to_local_dict(),
108 )
109
110 self.recorder.write(add_prefix(eval_metrics_dict, "eval"), iters)
111
112 self.checkpoint_manager.save(
113 iters,
114 state,
115 force=iters == num_iters,
116 )
117
118 return state