1import logging
2import math
3from typing import Any
4
5import chex
6import jax
7import jax.tree_util as jtu
8
9from evorl.agent import AgentState, AgentActionFn, RandomAgent
10from evorl.envs import Env, EnvState, EnvStepFn, create_env, AutoresetMode
11from evorl.sample_batch import SampleBatch
12from evorl.types import PyTreeNode
13from evorl.utils.jax_utils import rng_split
14from evorl.utils import running_statistics
15
16
17logger = logging.getLogger(__name__)
18
19
[docs]
20class ObsPreprocessor(PyTreeNode):
21 random_timesteps: int = 0
22 static: bool = False # set True means using VBN (eg: OpenES)
23
24 def __post_init__(self):
25 if self.static:
26 assert self.random_timesteps > 0, (
27 "init_timesteps should be greater than 0 if static is True"
28 )
29
30
[docs]
31def init_obs_preprocessor(agent_state, config, key, dp_axis_name=None):
32 assert config.random_timesteps > 0, "random_timesteps should be greater than 0"
33
34 env = create_env(
35 config.env,
36 episode_length=config.env.max_episode_steps,
37 parallel=config.num_envs,
38 autoreset_mode=AutoresetMode.NORMAL,
39 )
40
41 obs_preprocessor_state = init_obs_preprocessor_with_random_timesteps(
42 agent_state.obs_preprocessor_state,
43 config.random_timesteps,
44 env,
45 key,
46 dp_axis_name=dp_axis_name,
47 )
48
49 agent_state = agent_state.replace(obs_preprocessor_state=obs_preprocessor_state)
50
51 return agent_state
52
53
[docs]
54def init_obs_preprocessor_with_random_timesteps(
55 obs_preprocessor_state: Any,
56 timesteps: int,
57 env: Env,
58 key: chex.PRNGKey,
59 dp_axis_name: str | None = None,
60) -> Any:
61 env_key, agent_key, rollout_key = jax.random.split(key, num=3)
62 env_state = env.reset(env_key)
63
64 agent = RandomAgent()
65
66 agent_state = agent.init(env.obs_space, env.action_space, agent_key)
67
68 rollout_length = math.ceil(timesteps / env.num_envs)
69
70 if rollout_length > 0:
71 # obs (rollout_length, num_envs, ...)
72 obs, env_state = rollout_obs(
73 env.step,
74 agent.compute_actions,
75 env_state,
76 agent_state,
77 rollout_key,
78 rollout_length=rollout_length,
79 )
80
81 obs = jtu.tree_map(lambda x: jax.lax.collapse(x, 0, 2), obs)
82
83 obs_preprocessor_state = running_statistics.update(
84 obs_preprocessor_state, obs, dp_axis_name=dp_axis_name
85 )
86
87 return obs_preprocessor_state
88
89
[docs]
90def rollout_obs(
91 env_fn: EnvStepFn,
92 action_fn: AgentActionFn,
93 env_state: EnvState,
94 agent_state: AgentState,
95 key: chex.PRNGKey,
96 rollout_length: int,
97) -> tuple[chex.ArrayTree, EnvState]:
98 def _one_step_rollout(carry, unused_t):
99 env_state, current_key = carry
100 next_key, current_key = rng_split(current_key, 2)
101 sample_batch = SampleBatch(obs=env_state.obs)
102 actions, policy_extras = action_fn(agent_state, sample_batch, current_key)
103 env_nstate = env_fn(env_state, actions)
104
105 return (env_nstate, next_key), env_state.obs # obs_t
106
107 # trajectory: [T, #envs, ...]
108 (env_state, _), obs = jax.lax.scan(
109 _one_step_rollout, (env_state, key), (), length=rollout_length
110 )
111
112 return obs, env_state