Source code for evorl.algorithms.random_agent

  1import jax
  2import chex
  3from omegaconf import DictConfig
  4from typing_extensions import Self  # pytype: disable=not-supported-yet
  5
  6from evorl.distributed import split_key_to_devices
  7from evorl.workflows import RLWorkflow
  8from evorl.agent import RandomAgent, Agent
  9from evorl.metrics import MetricBase, EvaluateMetric
 10from evorl.types import State
 11from evorl.envs import create_env, AutoresetMode
 12from evorl.evaluators import Evaluator
 13from evorl.recorders import add_prefix
 14
 15
[docs] 16class RandomTrainMetric(MetricBase): 17 pass
18 19
[docs] 20class RandomAgentWorkflow(RLWorkflow): 21 def __init__( 22 self, 23 agent: Agent, 24 evaluator: Evaluator, 25 config: DictConfig, 26 ): 27 super().__init__(config) 28 29 self.agent = agent 30 self.evaluator = evaluator 31
[docs] 32 @classmethod 33 def name(cls): 34 return "Random"
35 36 @classmethod 37 def _build_from_config(cls, config: DictConfig) -> Self: 38 max_episode_steps = config.env.max_episode_steps 39 40 agent = RandomAgent() 41 42 eval_env = create_env( 43 config.env, 44 episode_length=max_episode_steps, 45 parallel=config.num_eval_envs, 46 autoreset_mode=AutoresetMode.DISABLED, 47 ) 48 49 evaluator = Evaluator( 50 env=eval_env, 51 action_fn=agent.evaluate_actions, 52 max_episode_steps=max_episode_steps, 53 ) 54 55 return cls(agent, evaluator, config) 56
[docs] 57 def setup(self, key: chex.PRNGKey) -> State: 58 key, agent_key = jax.random.split(key) 59 60 env = self.evaluator.env 61 62 agent_state = self.agent.init(env.obs_space, env.action_space, agent_key) 63 64 if self.enable_multi_devices: 65 sharding = jax.sharding.NamedSharding( 66 jax.sharding.Mesh(self.devices, (self.dp_axis_name,)), 67 jax.sharding.PartitionSpec(), 68 ) 69 agent_state = jax.device_put(agent_state, sharding) 70 71 # key and env_state should be different over devices 72 key = split_key_to_devices(key, self.devices) 73 74 return State( 75 key=key, 76 agent_state=agent_state, 77 )
78
[docs] 79 def step(self, state: State) -> tuple[MetricBase, State]: 80 """Dummy step function for random agent.""" 81 return RandomTrainMetric(), state.replace()
82
[docs] 83 def learn(self, state: State) -> State: 84 """Dummy learn function for random agent.""" 85 eval_metrics, state = self.evaluate(state) 86 self.recorder.write(add_prefix(eval_metrics.to_local_dict(), "eval"), 0) 87 return state.replace()
88
[docs] 89 def evaluate(self, state: State) -> tuple[MetricBase, State]: 90 key, eval_key = jax.random.split(state.key, num=2) 91 92 # [#episodes] 93 raw_eval_metrics = self.evaluator.evaluate( 94 state.agent_state, eval_key, num_episodes=self.config.eval_episodes 95 ) 96 97 eval_metrics = EvaluateMetric( 98 episode_returns=raw_eval_metrics.episode_returns.mean(), 99 episode_lengths=raw_eval_metrics.episode_lengths.mean(), 100 ).all_reduce(dp_axis_name=self.dp_axis_name) 101 102 state = state.replace(key=key) 103 return eval_metrics, state