Source code for evorl.algorithms.contrib.td3_ep
1from evorl.types import State
2from evorl.recorders import add_prefix
3
4from evorl.algorithms.offpolicy_utils import skip_replay_buffer_state
5from evorl.algorithms.td3 import TD3Workflow
6
7
[docs]
8class TD3WorkflowMod(TD3Workflow):
9 """TD3Workflow with total_episode termination condition."""
10
[docs]
11 def learn(self, state: State) -> State:
12 sampled_episodes = state.metrics.sampled_episodes.tolist()
13
14 while sampled_episodes < self.config.total_episodes:
15 train_metrics, state = self._multi_steps(state)
16 workflow_metrics = state.metrics
17
18 # current iteration
19 iterations = state.metrics.iterations.tolist()
20 self.recorder.write(train_metrics.to_local_dict(), iterations)
21 self.recorder.write(workflow_metrics.to_local_dict(), iterations)
22
23 sampled_episodes = state.metrics.sampled_episodes.tolist()
24
25 if (
26 iterations % self.config.eval_interval == 0
27 or sampled_episodes >= self.config.total_episodes
28 ):
29 eval_metrics, state = self.evaluate(state)
30 self.recorder.write(
31 add_prefix(eval_metrics.to_local_dict(), "eval"), iterations
32 )
33
34 saved_state = state
35 if not self.config.save_replay_buffer:
36 saved_state = skip_replay_buffer_state(saved_state)
37 self.checkpoint_manager.save(iterations, saved_state)
38 self.checkpoint_manager.save(iterations, saved_state, force=True)
39
40 return state