Source code for evorl.algorithms.contrib.a2c_v2
1import logging
2import math
3from collections.abc import Sequence
4
5import jax.tree_util as jtu
6import numpy as np
7
8from evorl.types import MISSING_REWARD, State
9from evorl.utils.rl_toolkits import fold_multi_steps
10from evorl.recorders import add_prefix
11
12from evorl.algorithms.a2c import A2CWorkflow as _A2CWorkflow
13
14logger = logging.getLogger(__name__)
15
16
[docs]
17class A2CWorkflow(_A2CWorkflow):
[docs]
18 @classmethod
19 def name(cls):
20 return "A2C-V2"
21
[docs]
22 def learn(self, state: State) -> State:
23 one_step_timesteps = self.config.rollout_length * self.config.num_envs
24 num_iters = math.ceil(self.config.total_timesteps / one_step_timesteps)
25
26 steps_interval = self.config.eval_interval
27
28 _multi_steps = fold_multi_steps(self.step, steps_interval)
29
30 num_fold_iters = math.ceil(num_iters / steps_interval)
31
32 for i in range(num_fold_iters):
33 train_metrics_arr, state = _multi_steps(state)
34
35 train_metrics = jtu.tree_map(lambda x: x[-1], train_metrics_arr)
36
37 workflow_metrics = state.metrics
38 iterations = workflow_metrics.iterations.tolist()
39
40 self.recorder.write(workflow_metrics.to_local_dict(), iterations)
41 train_metric_data = train_metrics.to_local_dict()
42 train_metric_data["train_episode_return"] = get_train_episode_return(
43 train_metric_data["train_episode_return"]
44 )
45 self.recorder.write(train_metric_data, iterations)
46
47 eval_metrics, state = self.evaluate(state)
48 self.recorder.write(
49 add_prefix(eval_metrics.to_local_dict(), "eval"), iterations
50 )
51
52 self.checkpoint_manager.save(
53 iterations,
54 state,
55 force=i == num_fold_iters - 1,
56 )
57
58 return state
59
60
61def _default_episode_return_reduce_fn(x):
62 return x[-1]
63
64
[docs]
65def get_train_episode_return(
66 episode_return_arr: Sequence[float], reduce_fn=_default_episode_return_reduce_fn
67):
68 """Handle episode return array with MISSING_REWARD, i.e., returned from multiple call of average_episode_discount_return."""
69 episode_return_arr = np.array(episode_return_arr)
70 mask = episode_return_arr == MISSING_REWARD
71 if mask.all():
72 return None
73 else:
74 return reduce_fn(episode_return_arr[~mask]).tolist()