Source code for evorl.algorithms.random_agent
1import jax
2import chex
3from omegaconf import DictConfig
4from typing_extensions import Self # pytype: disable=not-supported-yet
5
6from evorl.distributed import split_key_to_devices
7from evorl.workflows import RLWorkflow
8from evorl.agent import RandomAgent, Agent
9from evorl.metrics import MetricBase, EvaluateMetric
10from evorl.types import State
11from evorl.envs import create_env, AutoresetMode
12from evorl.evaluators import Evaluator
13from evorl.recorders import add_prefix
14
15
[docs]
16class RandomTrainMetric(MetricBase):
17 pass
18
19
[docs]
20class RandomAgentWorkflow(RLWorkflow):
21 def __init__(
22 self,
23 agent: Agent,
24 evaluator: Evaluator,
25 config: DictConfig,
26 ):
27 super().__init__(config)
28
29 self.agent = agent
30 self.evaluator = evaluator
31
[docs]
32 @classmethod
33 def name(cls):
34 return "Random"
35
36 @classmethod
37 def _build_from_config(cls, config: DictConfig) -> Self:
38 max_episode_steps = config.env.max_episode_steps
39
40 agent = RandomAgent()
41
42 eval_env = create_env(
43 config.env,
44 episode_length=max_episode_steps,
45 parallel=config.num_eval_envs,
46 autoreset_mode=AutoresetMode.DISABLED,
47 )
48
49 evaluator = Evaluator(
50 env=eval_env,
51 action_fn=agent.evaluate_actions,
52 max_episode_steps=max_episode_steps,
53 )
54
55 return cls(agent, evaluator, config)
56
[docs]
57 def setup(self, key: chex.PRNGKey) -> State:
58 key, agent_key = jax.random.split(key)
59
60 env = self.evaluator.env
61
62 agent_state = self.agent.init(env.obs_space, env.action_space, agent_key)
63
64 if self.enable_multi_devices:
65 sharding = jax.sharding.NamedSharding(
66 jax.sharding.Mesh(self.devices, (self.dp_axis_name,)),
67 jax.sharding.PartitionSpec(),
68 )
69 agent_state = jax.device_put(agent_state, sharding)
70
71 # key and env_state should be different over devices
72 key = split_key_to_devices(key, self.devices)
73
74 return State(
75 key=key,
76 agent_state=agent_state,
77 )
78
[docs]
79 def step(self, state: State) -> tuple[MetricBase, State]:
80 """Dummy step function for random agent."""
81 return RandomTrainMetric(), state.replace()
82
[docs]
83 def learn(self, state: State) -> State:
84 """Dummy learn function for random agent."""
85 eval_metrics, state = self.evaluate(state)
86 self.recorder.write(add_prefix(eval_metrics.to_local_dict(), "eval"), 0)
87 return state.replace()
88
[docs]
89 def evaluate(self, state: State) -> tuple[MetricBase, State]:
90 key, eval_key = jax.random.split(state.key, num=2)
91
92 # [#episodes]
93 raw_eval_metrics = self.evaluator.evaluate(
94 state.agent_state, eval_key, num_episodes=self.config.eval_episodes
95 )
96
97 eval_metrics = EvaluateMetric(
98 episode_returns=raw_eval_metrics.episode_returns.mean(),
99 episode_lengths=raw_eval_metrics.episode_lengths.mean(),
100 ).all_reduce(dp_axis_name=self.dp_axis_name)
101
102 state = state.replace(key=key)
103 return eval_metrics, state