Source code for evorl.algorithms.offpolicy_utils

  1import logging
  2import math
  3from omegaconf import DictConfig
  4
  5import jax
  6import jax.numpy as jnp
  7import chex
  8
  9from evorl.replay_buffers import ReplayBufferState
 10from evorl.envs import Discrete
 11from evorl.distributed.comm import psum
 12from evorl.workflows import OffPolicyWorkflow
 13from evorl.sample_batch import SampleBatch
 14from evorl.types import State, PyTreeDict
 15from evorl.rollout import rollout
 16from evorl.utils.rl_toolkits import flatten_rollout_trajectory
 17from evorl.utils import running_statistics
 18from evorl.utils.jax_utils import tree_stop_gradient, scan_and_last
 19from evorl.agent import RandomAgent
 20from evorl.recorders import add_prefix
 21
 22
 23logger = logging.getLogger(__name__)
 24
 25
[docs] 26class OffPolicyWorkflowTemplate(OffPolicyWorkflow): 27 """Wrapping some common template for off-policy RL with TD Learning.""" 28 29 @classmethod 30 def _rescale_config(cls, config: DictConfig) -> None: 31 num_devices = jax.device_count() 32 33 if config.num_envs % num_devices != 0: 34 logger.warning( 35 f"num_envs({config.num_envs}) cannot be divided by num_devices({num_devices}), " 36 f"rescale num_envs to {config.num_envs // num_devices}" 37 ) 38 if config.num_eval_envs % num_devices != 0: 39 logger.warning( 40 f"num_eval_envs({config.num_eval_envs}) cannot be divided by num_devices({num_devices}), " 41 f"rescale num_eval_envs to {config.num_eval_envs // num_devices}" 42 ) 43 if config.replay_buffer_capacity % num_devices != 0: 44 logger.warning( 45 f"replay_buffer_capacity({config.replay_buffer_capacity}) cannot be divided by num_devices({num_devices}), " 46 f"rescale replay_buffer_capacity to {config.replay_buffer_capacity // num_devices}" 47 ) 48 if config.random_timesteps % num_devices != 0: 49 logger.warning( 50 f"random_timesteps({config.random_timesteps}) cannot be divided by num_devices({num_devices}), " 51 f"rescale random_timesteps to {config.random_timesteps // num_devices}" 52 ) 53 if config.learning_start_timesteps % num_devices != 0: 54 logger.warning( 55 f"learning_start_timesteps({config.learning_start_timesteps}) cannot be divided by num_devices({num_devices}), " 56 f"rescale learning_start_timesteps to {config.learning_start_timesteps // num_devices}" 57 ) 58 59 config.num_envs = config.num_envs // num_devices 60 config.num_eval_envs = config.num_eval_envs // num_devices 61 config.replay_buffer_capacity = config.replay_buffer_capacity // num_devices 62 config.random_timesteps = config.random_timesteps // num_devices 63 config.learning_start_timesteps = config.learning_start_timesteps // num_devices 64 65 def _setup_replaybuffer(self, key: chex.PRNGKey) -> ReplayBufferState: 66 action_space = self.env.action_space 67 obs_space = self.env.obs_space 68 69 # create dummy data to initialize the replay buffer 70 if isinstance(action_space, Discrete): 71 dummy_action = jnp.zeros((), dtype=jnp.int32) 72 else: 73 dummy_action = jnp.zeros(action_space.shape) 74 dummy_obs = obs_space.sample(key) 75 dummy_reward = jnp.zeros(()) 76 dummy_done = jnp.zeros(()) 77 78 dummy_sample_batch = SampleBatch( 79 obs=dummy_obs, 80 actions=dummy_action, 81 rewards=dummy_reward, 82 # next_obs=dummy_obs, 83 # dones=dummy_done, 84 extras=PyTreeDict( 85 policy_extras=PyTreeDict(), 86 env_extras=PyTreeDict( 87 {"ori_obs": dummy_obs, "termination": dummy_done} 88 ), 89 ), 90 ) 91 replay_buffer_state = self.replay_buffer.init(dummy_sample_batch) 92 93 return replay_buffer_state 94 95 def _postsetup_replaybuffer(self, state: State) -> State: 96 action_space = self.env.action_space 97 obs_space = self.env.obs_space 98 config = self.config 99 replay_buffer_state = state.replay_buffer_state 100 agent_state = state.agent_state 101 102 def _rollout(agent, agent_state, key, rollout_length): 103 env_key, rollout_key = jax.random.split(key) 104 105 env_state = self.env.reset(env_key) 106 107 trajectory, env_state = rollout( 108 env_fn=self.env.step, 109 action_fn=agent.compute_actions, 110 env_state=env_state, 111 agent_state=agent_state, 112 key=rollout_key, 113 rollout_length=rollout_length, 114 env_extra_fields=("ori_obs", "termination"), 115 ) 116 117 # [T, B, ...] -> [T*B, ...] 118 trajectory = clean_trajectory(trajectory) 119 trajectory = flatten_rollout_trajectory(trajectory) 120 trajectory = tree_stop_gradient(trajectory) 121 122 return trajectory 123 124 def _update_obs_preprocessor(agent_state, trajectory): 125 if ( 126 agent_state.obs_preprocessor_state is not None 127 and len(trajectory.obs) > 0 128 ): 129 agent_state = agent_state.replace( 130 obs_preprocessor_state=running_statistics.update( 131 agent_state.obs_preprocessor_state, 132 trajectory.obs, 133 dp_axis_name=self.dp_axis_name, 134 ) 135 ) 136 return agent_state 137 138 # ==== fill random transitions ==== 139 140 key, random_rollout_key, rollout_key = jax.random.split(state.key, num=3) 141 random_agent = RandomAgent() 142 random_agent_state = random_agent.init( 143 obs_space, action_space, jax.random.PRNGKey(0) 144 ) 145 rollout_length = config.random_timesteps // config.num_envs 146 147 trajectory = _rollout( 148 random_agent, 149 random_agent_state, 150 key=random_rollout_key, 151 rollout_length=rollout_length, 152 ) 153 154 agent_state = _update_obs_preprocessor(agent_state, trajectory) 155 replay_buffer_state = self.replay_buffer.add(replay_buffer_state, trajectory) 156 157 rollout_timesteps = rollout_length * config.num_envs 158 sampled_timesteps = psum( 159 jnp.uint32(rollout_timesteps), axis_name=self.dp_axis_name 160 ) 161 162 # ==== fill tansition state from init agent ==== 163 rollout_length = math.ceil( 164 (config.learning_start_timesteps - rollout_timesteps) / config.num_envs 165 ) 166 167 trajectory = _rollout( 168 self.agent, 169 agent_state, 170 key=rollout_key, 171 rollout_length=rollout_length, 172 ) 173 174 agent_state = _update_obs_preprocessor(agent_state, trajectory) 175 replay_buffer_state = self.replay_buffer.add(replay_buffer_state, trajectory) 176 177 rollout_timesteps = rollout_length * config.num_envs 178 sampled_timesteps = sampled_timesteps + psum( 179 jnp.uint32(rollout_timesteps), axis_name=self.dp_axis_name 180 ) 181 182 workflow_metrics = state.metrics.replace( 183 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps, 184 ).all_reduce(dp_axis_name=self.dp_axis_name) 185 186 return state.replace( 187 key=key, 188 metrics=workflow_metrics, 189 agent_state=agent_state, 190 replay_buffer_state=replay_buffer_state, 191 ) 192 193 def _multi_steps(self, state): 194 def _step(state, _): 195 train_metrics, state = self.step(state) 196 return state, train_metrics 197 198 state, train_metrics = scan_and_last( 199 _step, state, (), length=self.config.fold_iters 200 ) 201 202 # jax.debug.print("train_metrics: {}", tree_has_nan(train_metrics)) 203 # jax.debug.print("state: {}", tree_has_nan(state)) 204 205 return train_metrics, state 206
[docs] 207 def learn(self, state: State) -> State: 208 num_devices = jax.device_count() 209 one_step_timesteps = self.config.rollout_length * self.config.num_envs 210 sampled_timesteps = state.metrics.sampled_timesteps.tolist() 211 num_iters = math.ceil( 212 (self.config.total_timesteps - sampled_timesteps) 213 / (one_step_timesteps * self.config.fold_iters * num_devices) 214 ) 215 start_iteration = state.metrics.iterations.tolist() 216 final_iteration = num_iters + start_iteration 217 218 for i in range(num_iters): 219 train_metrics, state = self._multi_steps(state) 220 workflow_metrics = state.metrics 221 222 # current iteration 223 iterations = state.metrics.iterations.tolist() 224 self.recorder.write(train_metrics.to_local_dict(), iterations) 225 self.recorder.write(workflow_metrics.to_local_dict(), iterations) 226 227 if ( 228 iterations % self.config.eval_interval == 0 229 or iterations == final_iteration 230 ): 231 eval_metrics, state = self.evaluate(state) 232 self.recorder.write( 233 add_prefix(eval_metrics.to_local_dict(), "eval"), iterations 234 ) 235 236 saved_state = state 237 if not self.config.save_replay_buffer: 238 saved_state = skip_replay_buffer_state(saved_state) 239 self.checkpoint_manager.save( 240 iterations, saved_state, force=iterations == final_iteration 241 ) 242 243 return state
244
[docs] 245 @classmethod 246 def enable_jit(cls) -> None: 247 super().enable_jit() 248 cls._postsetup_replaybuffer = jax.jit( 249 cls._postsetup_replaybuffer, static_argnums=(0,) 250 ) 251 cls._multi_steps = jax.jit(cls._multi_steps, static_argnums=(0,))
252 253
[docs] 254def skip_replay_buffer_state(state: State) -> State: 255 """Utility function to remove replay_buffer_state from state. 256 257 Usually used when saving the off-policy workflow state to disk. 258 """ 259 return state.replace(replay_buffer_state=None)
260 261
[docs] 262def clean_trajectory(trajectory: SampleBatch) -> SampleBatch: 263 """Clean the trajectory to make it suitable for the replay buffer.""" 264 return trajectory.replace( 265 next_obs=None, 266 dones=None, 267 )