Source code for evorl.algorithms.contrib.a2c_v2

 1import logging
 2import math
 3from collections.abc import Sequence
 4
 5import jax.tree_util as jtu
 6import numpy as np
 7
 8from evorl.types import MISSING_REWARD, State
 9from evorl.utils.rl_toolkits import fold_multi_steps
10from evorl.recorders import add_prefix
11
12from evorl.algorithms.a2c import A2CWorkflow as _A2CWorkflow
13
14logger = logging.getLogger(__name__)
15
16
[docs] 17class A2CWorkflow(_A2CWorkflow):
[docs] 18 @classmethod 19 def name(cls): 20 return "A2C-V2"
21
[docs] 22 def learn(self, state: State) -> State: 23 one_step_timesteps = self.config.rollout_length * self.config.num_envs 24 num_iters = math.ceil(self.config.total_timesteps / one_step_timesteps) 25 26 steps_interval = self.config.eval_interval 27 28 _multi_steps = fold_multi_steps(self.step, steps_interval) 29 30 num_fold_iters = math.ceil(num_iters / steps_interval) 31 32 for i in range(num_fold_iters): 33 train_metrics_arr, state = _multi_steps(state) 34 35 train_metrics = jtu.tree_map(lambda x: x[-1], train_metrics_arr) 36 37 workflow_metrics = state.metrics 38 iterations = workflow_metrics.iterations.tolist() 39 40 self.recorder.write(workflow_metrics.to_local_dict(), iterations) 41 train_metric_data = train_metrics.to_local_dict() 42 train_metric_data["train_episode_return"] = get_train_episode_return( 43 train_metric_data["train_episode_return"] 44 ) 45 self.recorder.write(train_metric_data, iterations) 46 47 eval_metrics, state = self.evaluate(state) 48 self.recorder.write( 49 add_prefix(eval_metrics.to_local_dict(), "eval"), iterations 50 ) 51 52 self.checkpoint_manager.save( 53 iterations, 54 state, 55 force=i == num_fold_iters - 1, 56 ) 57 58 return state
59 60 61def _default_episode_return_reduce_fn(x): 62 return x[-1] 63 64
[docs] 65def get_train_episode_return( 66 episode_return_arr: Sequence[float], reduce_fn=_default_episode_return_reduce_fn 67): 68 """Handle episode return array with MISSING_REWARD, i.e., returned from multiple call of average_episode_discount_return.""" 69 episode_return_arr = np.array(episode_return_arr) 70 mask = episode_return_arr == MISSING_REWARD 71 if mask.all(): 72 return None 73 else: 74 return reduce_fn(episode_return_arr[~mask]).tolist()