Source code for evorl.algorithms.contrib.td3_ep

 1from evorl.types import State
 2from evorl.recorders import add_prefix
 3
 4from evorl.algorithms.offpolicy_utils import skip_replay_buffer_state
 5from evorl.algorithms.td3 import TD3Workflow
 6
 7
[docs] 8class TD3WorkflowMod(TD3Workflow): 9 """TD3Workflow with total_episode termination condition.""" 10
[docs] 11 def learn(self, state: State) -> State: 12 sampled_episodes = state.metrics.sampled_episodes.tolist() 13 14 while sampled_episodes < self.config.total_episodes: 15 train_metrics, state = self._multi_steps(state) 16 workflow_metrics = state.metrics 17 18 # current iteration 19 iterations = state.metrics.iterations.tolist() 20 self.recorder.write(train_metrics.to_local_dict(), iterations) 21 self.recorder.write(workflow_metrics.to_local_dict(), iterations) 22 23 sampled_episodes = state.metrics.sampled_episodes.tolist() 24 25 if ( 26 iterations % self.config.eval_interval == 0 27 or sampled_episodes >= self.config.total_episodes 28 ): 29 eval_metrics, state = self.evaluate(state) 30 self.recorder.write( 31 add_prefix(eval_metrics.to_local_dict(), "eval"), iterations 32 ) 33 34 saved_state = state 35 if not self.config.save_replay_buffer: 36 saved_state = skip_replay_buffer_state(saved_state) 37 self.checkpoint_manager.save(iterations, saved_state) 38 self.checkpoint_manager.save(iterations, saved_state, force=True) 39 40 return state