Source code for evorl.algorithms.erl.erl_workflow

  1import copy
  2import logging
  3from typing_extensions import Self  # pytype: disable=not-supported-yet]
  4from omegaconf import DictConfig
  5
  6import chex
  7import jax
  8import jax.numpy as jnp
  9import optax
 10
 11from evorl.agent import AgentStateAxis
 12from evorl.metrics import MetricBase, metric_field
 13from evorl.types import PyTreeDict, State
 14from evorl.utils import running_statistics
 15from evorl.utils.jax_utils import tree_stop_gradient
 16from evorl.utils.rl_toolkits import flatten_rollout_trajectory
 17from evorl.evaluators import Evaluator, EpisodeCollector
 18from evorl.sample_batch import SampleBatch
 19from evorl.replay_buffers import AbstractReplayBuffer, ReplayBufferState
 20from evorl.agent import Agent, AgentState, RandomAgent
 21from evorl.envs import create_env, AutoresetMode, Env
 22from evorl.workflows import Workflow
 23from evorl.rollout import rollout
 24from evorl.ec.optimizers import EvoOptimizer, ECState
 25
 26from ..offpolicy_utils import clean_trajectory
 27
 28logger = logging.getLogger(__name__)
 29
 30
[docs] 31class ERLTrainMetric(MetricBase): 32 pop_episode_returns: chex.Array | None = None 33 pop_episode_lengths: chex.Array | None = None 34 rb_size: chex.Array | None = None 35 rl_episode_returns: chex.Array | None = None 36 rl_episode_lengths: chex.Array | None = None 37 rl_metrics: MetricBase | None = None 38 ec_info: PyTreeDict = metric_field(default_factory=PyTreeDict)
39 40
[docs] 41class WorkflowMetric(MetricBase): 42 sampled_timesteps: chex.Array = jnp.zeros((), dtype=jnp.uint32) 43 sampled_episodes: chex.Array = jnp.zeros((), dtype=jnp.uint32) 44 rl_sampled_timesteps: chex.Array = jnp.zeros((), dtype=jnp.uint32) 45 iterations: chex.Array = jnp.zeros((), dtype=jnp.uint32)
46 47
[docs] 48class ERLWorkflowBase(Workflow): 49 def __init__( 50 self, 51 *, 52 env: Env, 53 agent: Agent, 54 agent_state_vmap_axes: AgentStateAxis, 55 optimizer: optax.GradientTransformation, 56 ec_optimizer: EvoOptimizer, 57 ec_collector: EpisodeCollector, 58 rl_collector: EpisodeCollector, 59 evaluator: Evaluator, # to evaluate the pop-mean actor 60 replay_buffer: AbstractReplayBuffer, 61 config: DictConfig, 62 ): 63 super().__init__(config) 64 self.env = env 65 self.agent = agent 66 self.agent_state_vmap_axes = agent_state_vmap_axes 67 self.optimizer = optimizer 68 self.ec_optimizer = ec_optimizer 69 self.ec_collector = ec_collector 70 self.rl_collector = rl_collector 71 self.evaluator = evaluator 72 self.replay_buffer = replay_buffer 73 74 self.devices = jax.local_devices()[:1] 75
[docs] 76 @classmethod 77 def build_from_config( 78 cls, 79 config: DictConfig, 80 enable_multi_devices: bool = False, 81 enable_jit: bool = True, 82 ) -> Self: 83 config = copy.deepcopy(config) # avoid in-place modification 84 85 devices = jax.local_devices() 86 87 if enable_multi_devices or len(devices) > 1: 88 raise NotImplementedError("Multi-devices is not supported yet.") 89 90 if enable_jit: 91 cls.enable_jit() 92 93 workflow = cls._build_from_config(config) 94 95 return workflow
96 97 @classmethod 98 def _build_from_config(cls, config: DictConfig) -> Self: 99 raise NotImplementedError 100
[docs] 101 def setup(self, key: chex.PRNGKey) -> State: 102 key, agent_key, rb_key = jax.random.split(key, 3) 103 104 agent_state, opt_state, ec_opt_state = self._setup_agent_and_optimizer( 105 agent_key 106 ) 107 workflow_metrics = self._setup_workflow_metrics() 108 replay_buffer_state = self._setup_replaybuffer(rb_key) 109 110 # ======================= 111 112 state = State( 113 key=key, 114 metrics=workflow_metrics, 115 agent_state=agent_state, 116 opt_state=opt_state, 117 ec_opt_state=ec_opt_state, 118 replay_buffer_state=replay_buffer_state, 119 ) 120 121 if self.config.random_timesteps > 0: 122 logger.info("Start replay buffer post-setup") 123 state = self._postsetup_replaybuffer(state) 124 logger.info("Complete replay buffer post-setup") 125 126 return state
127 128 def _setup_workflow_metrics(self) -> MetricBase: 129 return WorkflowMetric() 130 131 def _setup_agent_and_optimizer( 132 self, key: chex.PRNGKey 133 ) -> tuple[AgentState, chex.ArrayTree, ECState]: 134 raise NotImplementedError 135 136 def _setup_replaybuffer(self, key: chex.PRNGKey) -> ReplayBufferState: 137 action_space = self.env.action_space 138 obs_space = self.env.obs_space 139 140 # create dummy data to initialize the replay buffer 141 dummy_action = jnp.zeros(action_space.shape) 142 dummy_obs = obs_space.sample(key) 143 144 dummy_reward = jnp.zeros(()) 145 dummy_done = jnp.zeros(()) 146 147 dummy_sample_batch = SampleBatch( 148 obs=dummy_obs, 149 actions=dummy_action, 150 rewards=dummy_reward, 151 # next_obs=dummy_obs, 152 # dones=dummy_done, 153 extras=PyTreeDict( 154 policy_extras=PyTreeDict(), 155 env_extras=PyTreeDict( 156 {"ori_obs": dummy_obs, "termination": dummy_done} 157 ), 158 ), 159 ) 160 replay_buffer_state = self.replay_buffer.init(dummy_sample_batch) 161 162 return replay_buffer_state 163 164 def _postsetup_replaybuffer(self, state: State) -> State: 165 action_space = self.env.action_space 166 obs_space = self.env.obs_space 167 config = self.config 168 169 replay_buffer_state = state.replay_buffer_state 170 agent_state = state.agent_state 171 172 # We need a separate autoreset env to fill the replay buffer 173 env = create_env( 174 config.env, 175 episode_length=config.env.max_episode_steps, 176 parallel=config.num_envs, 177 autoreset_mode=AutoresetMode.NORMAL, 178 record_ori_obs=True, 179 ) 180 181 # ==== fill random transitions ==== 182 183 key, env_key, rollout_key = jax.random.split(state.key, num=3) 184 random_agent = RandomAgent() 185 random_agent_state = random_agent.init( 186 obs_space, action_space, jax.random.PRNGKey(0) 187 ) 188 rollout_length = config.random_timesteps // config.num_envs 189 190 env_state = env.reset(env_key) 191 trajectory, env_state = rollout( 192 env_fn=env.step, 193 action_fn=random_agent.compute_actions, 194 env_state=env_state, 195 agent_state=random_agent_state, 196 key=rollout_key, 197 rollout_length=rollout_length, 198 env_extra_fields=("ori_obs", "termination"), 199 ) 200 201 # sampled_timesteps = jnp.uint32(rollout_length * config.num_envs) 202 # # Since we sample from autoreset env, this metric might not be accurate: 203 # sampled_episodes = trajectory.dones.sum() 204 205 # [T, B, ...] -> [T*B, ...] 206 trajectory = clean_trajectory(trajectory) 207 trajectory = flatten_rollout_trajectory(trajectory) 208 trajectory = tree_stop_gradient(trajectory) 209 210 replay_buffer_state = self.replay_buffer.add(replay_buffer_state, trajectory) 211 212 if agent_state.obs_preprocessor_state is not None and rollout_length > 0: 213 agent_state = agent_state.replace( 214 obs_preprocessor_state=running_statistics.update( 215 agent_state.obs_preprocessor_state, trajectory.obs 216 ) 217 ) 218 219 # Note: we don't count the random transitions in the metrics, to align the x-axis(sampled_episodes/iterations) over different runs 220 # workflow_metrics = state.metrics.replace( 221 # sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps, 222 # sampled_episodes=state.metrics.sampled_episodes + sampled_episodes, 223 # ) 224 225 return state.replace( 226 key=key, 227 # metrics=workflow_metrics, 228 agent_state=agent_state, 229 replay_buffer_state=replay_buffer_state, 230 ) 231
[docs] 232 def warmup_step(self, state: State) -> tuple[MetricBase, State]: 233 raise NotImplementedError
234 235 def _rl_injection(self, *args, **kwargs): 236 raise NotImplementedError 237
[docs] 238 def evaluate(self, state: State) -> tuple[MetricBase, State]: 239 raise NotImplementedError
240
[docs] 241 @classmethod 242 def enable_jit(cls) -> None: 243 cls.step = jax.jit(cls.step, static_argnums=(0,)) 244 245 cls.evaluate = jax.jit(cls.evaluate, static_argnums=(0,)) 246 cls._postsetup_replaybuffer = jax.jit( 247 cls._postsetup_replaybuffer, static_argnums=(0,) 248 ) 249 cls.warmup_step = jax.jit(cls.warmup_step, static_argnums=(0,))