Source code for evorl.algorithms.ec.obs_utils

  1import logging
  2import math
  3from typing import Any
  4
  5import chex
  6import jax
  7import jax.tree_util as jtu
  8
  9from evorl.agent import AgentState, AgentActionFn, RandomAgent
 10from evorl.envs import Env, EnvState, EnvStepFn, create_env, AutoresetMode
 11from evorl.sample_batch import SampleBatch
 12from evorl.types import PyTreeNode
 13from evorl.utils.jax_utils import rng_split
 14from evorl.utils import running_statistics
 15
 16
 17logger = logging.getLogger(__name__)
 18
 19
[docs] 20class ObsPreprocessor(PyTreeNode): 21 random_timesteps: int = 0 22 static: bool = False # set True means using VBN (eg: OpenES) 23 24 def __post_init__(self): 25 if self.static: 26 assert self.random_timesteps > 0, ( 27 "init_timesteps should be greater than 0 if static is True" 28 )
29 30
[docs] 31def init_obs_preprocessor(agent_state, config, key, dp_axis_name=None): 32 assert config.random_timesteps > 0, "random_timesteps should be greater than 0" 33 34 env = create_env( 35 config.env, 36 episode_length=config.env.max_episode_steps, 37 parallel=config.num_envs, 38 autoreset_mode=AutoresetMode.NORMAL, 39 ) 40 41 obs_preprocessor_state = init_obs_preprocessor_with_random_timesteps( 42 agent_state.obs_preprocessor_state, 43 config.random_timesteps, 44 env, 45 key, 46 dp_axis_name=dp_axis_name, 47 ) 48 49 agent_state = agent_state.replace(obs_preprocessor_state=obs_preprocessor_state) 50 51 return agent_state
52 53
[docs] 54def init_obs_preprocessor_with_random_timesteps( 55 obs_preprocessor_state: Any, 56 timesteps: int, 57 env: Env, 58 key: chex.PRNGKey, 59 dp_axis_name: str | None = None, 60) -> Any: 61 env_key, agent_key, rollout_key = jax.random.split(key, num=3) 62 env_state = env.reset(env_key) 63 64 agent = RandomAgent() 65 66 agent_state = agent.init(env.obs_space, env.action_space, agent_key) 67 68 rollout_length = math.ceil(timesteps / env.num_envs) 69 70 if rollout_length > 0: 71 # obs (rollout_length, num_envs, ...) 72 obs, env_state = rollout_obs( 73 env.step, 74 agent.compute_actions, 75 env_state, 76 agent_state, 77 rollout_key, 78 rollout_length=rollout_length, 79 ) 80 81 obs = jtu.tree_map(lambda x: jax.lax.collapse(x, 0, 2), obs) 82 83 obs_preprocessor_state = running_statistics.update( 84 obs_preprocessor_state, obs, dp_axis_name=dp_axis_name 85 ) 86 87 return obs_preprocessor_state
88 89
[docs] 90def rollout_obs( 91 env_fn: EnvStepFn, 92 action_fn: AgentActionFn, 93 env_state: EnvState, 94 agent_state: AgentState, 95 key: chex.PRNGKey, 96 rollout_length: int, 97) -> tuple[chex.ArrayTree, EnvState]: 98 def _one_step_rollout(carry, unused_t): 99 env_state, current_key = carry 100 next_key, current_key = rng_split(current_key, 2) 101 sample_batch = SampleBatch(obs=env_state.obs) 102 actions, policy_extras = action_fn(agent_state, sample_batch, current_key) 103 env_nstate = env_fn(env_state, actions) 104 105 return (env_nstate, next_key), env_state.obs # obs_t 106 107 # trajectory: [T, #envs, ...] 108 (env_state, _), obs = jax.lax.scan( 109 _one_step_rollout, (env_state, key), (), length=rollout_length 110 ) 111 112 return obs, env_state