Source code for evorl.algorithms.erl.cemrl_workflow

  1import copy
  2import logging
  3from typing_extensions import Self  # pytype: disable=not-supported-yet]
  4from omegaconf import DictConfig
  5
  6import chex
  7import jax
  8import jax.numpy as jnp
  9import optax
 10
 11from evorl.agent import AgentStateAxis
 12from evorl.metrics import MetricBase, WorkflowMetric, metric_field
 13from evorl.types import PyTreeDict, State
 14from evorl.utils import running_statistics
 15from evorl.utils.jax_utils import tree_stop_gradient
 16from evorl.utils.rl_toolkits import flatten_rollout_trajectory
 17from evorl.evaluators import Evaluator, EpisodeCollector
 18from evorl.sample_batch import SampleBatch
 19from evorl.replay_buffers import AbstractReplayBuffer, ReplayBufferState
 20from evorl.agent import Agent, AgentState, RandomAgent
 21from evorl.envs import create_env, AutoresetMode, Env
 22from evorl.workflows import Workflow
 23from evorl.rollout import rollout
 24from evorl.ec.optimizers import EvoOptimizer, ECState
 25
 26from ..offpolicy_utils import clean_trajectory
 27
 28logger = logging.getLogger(__name__)
 29
 30
[docs] 31class CEMRLTrainMetric(MetricBase): 32 rb_size: chex.Array 33 pop_episode_returns: chex.Array 34 pop_episode_lengths: chex.Array 35 rl_metrics: MetricBase | None = None 36 ec_info: PyTreeDict = metric_field(default_factory=PyTreeDict)
37 38
[docs] 39class CEMRLWorkflowBase(Workflow): 40 """Base Class for CEMRL, equipped with many useful methods.""" 41 42 def __init__( 43 self, 44 *, 45 env: Env, 46 agent: Agent, 47 agent_state_vmap_axes: AgentStateAxis, 48 optimizer: optax.GradientTransformation, 49 ec_optimizer: EvoOptimizer, 50 collector: EpisodeCollector, 51 evaluator: Evaluator, # to evaluate the pop-mean actor 52 replay_buffer: AbstractReplayBuffer, 53 config: DictConfig, 54 ): 55 super().__init__(config) 56 self.env = env 57 self.agent = agent 58 self.agent_state_vmap_axes = agent_state_vmap_axes 59 self.optimizer = optimizer 60 self.ec_optimizer = ec_optimizer 61 self.collector = collector 62 self.evaluator = evaluator 63 self.replay_buffer = replay_buffer 64 65 self.devices = jax.local_devices()[:1] 66
[docs] 67 @classmethod 68 def build_from_config( 69 cls, 70 config: DictConfig, 71 enable_multi_devices: bool = False, 72 enable_jit: bool = True, 73 ) -> Self: 74 config = copy.deepcopy(config) # avoid in-place modification 75 76 devices = jax.local_devices() 77 78 if enable_multi_devices or len(devices) > 1: 79 raise NotImplementedError("Multi-devices is not supported yet.") 80 81 if enable_jit: 82 cls.enable_jit() 83 84 workflow = cls._build_from_config(config) 85 86 return workflow
87 88 @classmethod 89 def _build_from_config(cls, config: DictConfig) -> Self: 90 raise NotImplementedError 91
[docs] 92 def setup(self, key: chex.PRNGKey) -> State: 93 key, agent_key, rb_key = jax.random.split(key, 3) 94 95 # [#pop, ...] 96 agent_state, opt_state, ec_opt_state = self._setup_agent_and_optimizer( 97 agent_key 98 ) 99 100 workflow_metrics = self._setup_workflow_metrics() 101 102 replay_buffer_state = self._setup_replaybuffer(rb_key) 103 104 # ======================= 105 106 state = State( 107 key=key, 108 metrics=workflow_metrics, 109 agent_state=agent_state, 110 opt_state=opt_state, 111 ec_opt_state=ec_opt_state, 112 replay_buffer_state=replay_buffer_state, 113 ) 114 115 if self.config.random_timesteps > 0: 116 logger.info("Start replay buffer post-setup") 117 state = self._postsetup_replaybuffer(state) 118 logger.info("Complete replay buffer post-setup") 119 120 return state
121 122 def _setup_workflow_metrics(self) -> MetricBase: 123 return WorkflowMetric() 124 125 def _setup_agent_and_optimizer( 126 self, key: chex.PRNGKey 127 ) -> tuple[AgentState, chex.ArrayTree, ECState]: 128 raise NotImplementedError 129 130 def _setup_replaybuffer(self, key: chex.PRNGKey) -> ReplayBufferState: 131 action_space = self.env.action_space 132 obs_space = self.env.obs_space 133 134 # create dummy data to initialize the replay buffer 135 dummy_action = jnp.zeros(action_space.shape) 136 dummy_obs = obs_space.sample(key) 137 138 dummy_reward = jnp.zeros(()) 139 dummy_done = jnp.zeros(()) 140 141 dummy_sample_batch = SampleBatch( 142 obs=dummy_obs, 143 actions=dummy_action, 144 rewards=dummy_reward, 145 # next_obs=dummy_obs, 146 # dones=dummy_done, 147 extras=PyTreeDict( 148 policy_extras=PyTreeDict(), 149 env_extras=PyTreeDict( 150 {"ori_obs": dummy_obs, "termination": dummy_done} 151 ), 152 ), 153 ) 154 replay_buffer_state = self.replay_buffer.init(dummy_sample_batch) 155 156 return replay_buffer_state 157 158 def _postsetup_replaybuffer(self, state: State) -> State: 159 action_space = self.env.action_space 160 obs_space = self.env.obs_space 161 config = self.config 162 163 replay_buffer_state = state.replay_buffer_state 164 agent_state = state.agent_state 165 166 # We need a separate autoreset env to fill the replay buffer 167 env = create_env( 168 config.env, 169 episode_length=config.env.max_episode_steps, 170 parallel=config.num_envs, 171 autoreset_mode=AutoresetMode.NORMAL, 172 record_ori_obs=True, 173 ) 174 175 # ==== fill random transitions ==== 176 177 key, env_key, rollout_key = jax.random.split(state.key, num=3) 178 random_agent = RandomAgent() 179 random_agent_state = random_agent.init( 180 obs_space, action_space, jax.random.PRNGKey(0) 181 ) 182 rollout_length = config.random_timesteps // config.num_envs 183 184 env_state = env.reset(env_key) 185 trajectory, env_state = rollout( 186 env_fn=env.step, 187 action_fn=random_agent.compute_actions, 188 env_state=env_state, 189 agent_state=random_agent_state, 190 key=rollout_key, 191 rollout_length=rollout_length, 192 env_extra_fields=("ori_obs", "termination"), 193 ) 194 195 # sampled_timesteps = jnp.uint32(rollout_length * config.num_envs) 196 # # Since we sample from autoreset env, this metric might not be accurate: 197 # sampled_episodes = trajectory.dones.astype(jnp.uint32).sum() 198 199 # [T, B, ...] -> [T*B, ...] 200 trajectory = clean_trajectory(trajectory) 201 trajectory = flatten_rollout_trajectory(trajectory) 202 trajectory = tree_stop_gradient(trajectory) 203 204 replay_buffer_state = self.replay_buffer.add(replay_buffer_state, trajectory) 205 206 if agent_state.obs_preprocessor_state is not None and rollout_length > 0: 207 agent_state = agent_state.replace( 208 obs_preprocessor_state=running_statistics.update( 209 agent_state.obs_preprocessor_state, 210 trajectory.obs, 211 ) 212 ) 213 214 # Note: we don't count the random transitions in the metrics, to align the x-axis(sampled_episodes/iterations) over different runs 215 # workflow_metrics = state.metrics.replace( 216 # sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps, 217 # sampled_episodes=state.metrics.sampled_episodes + sampled_episodes, 218 # ) 219 220 return state.replace( 221 key=key, 222 # metrics=workflow_metrics, 223 agent_state=agent_state, 224 replay_buffer_state=replay_buffer_state, 225 ) 226 227 def _rl_injection(self, *args, **kwargs): 228 raise NotImplementedError 229
[docs] 230 def evaluate(self, state: State) -> tuple[MetricBase, State]: 231 raise NotImplementedError
232
[docs] 233 @classmethod 234 def enable_jit(cls) -> None: 235 cls.step = jax.jit(cls.step, static_argnums=(0,)) 236 cls.evaluate = jax.jit(cls.evaluate, static_argnums=(0,)) 237 cls._postsetup_replaybuffer = jax.jit( 238 cls._postsetup_replaybuffer, static_argnums=(0,) 239 )