Source code for evorl.algorithms.ec.ec_agent

  1import logging
  2from typing import Any
  3
  4import chex
  5import flax.linen as nn
  6import jax.numpy as jnp
  7import jax.tree_util as jtu
  8
  9from evorl.distribution import get_categorical_dist, get_tanh_norm_dist
 10from evorl.networks import make_policy_network
 11from evorl.sample_batch import SampleBatch
 12from evorl.types import (
 13    Action,
 14    Params,
 15    PolicyExtraInfo,
 16    PyTreeDict,
 17    pytree_field,
 18    PyTreeData,
 19)
 20from evorl.utils import running_statistics
 21from evorl.utils.jax_utils import tree_get
 22from evorl.envs import Space, Box, Discrete
 23
 24from evorl.agent import Agent, AgentState
 25
 26logger = logging.getLogger(__name__)
 27
 28
[docs] 29class ECNetworkParams(PyTreeData): 30 """Contains training state for the learner.""" 31 32 policy_params: Params
33 34
[docs] 35class StochasticECAgent(Agent): 36 """Stochastic Agent. 37 38 Support continuous action space in [-1, 1] via TanhNormal distribution or discrete action space via Softmax distribution. 39 """ 40 41 continuous_action: bool 42 policy_network: nn.Module 43 obs_preprocessor: Any = pytree_field(default=None, static=True) 44 45 @property 46 def normalize_obs(self): 47 return self.obs_preprocessor is not None 48
[docs] 49 def init( 50 self, obs_space: Space, action_space: Space, key: chex.PRNGKey 51 ) -> AgentState: 52 dummy_obs = jtu.tree_map(lambda x: x[None, ...], obs_space.sample(key)) 53 policy_params = self.policy_network.init(key, dummy_obs) 54 55 params_state = ECNetworkParams( 56 policy_params=policy_params, 57 ) 58 59 if self.normalize_obs: 60 # Note: statistics are broadcasted to [T*B] 61 obs_preprocessor_state = running_statistics.init_state( 62 tree_get(dummy_obs, 0) 63 ) 64 else: 65 obs_preprocessor_state = None 66 67 return AgentState( 68 params=params_state, obs_preprocessor_state=obs_preprocessor_state 69 )
70
[docs] 71 def compute_actions( 72 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 73 ) -> tuple[Action, PolicyExtraInfo]: 74 obs = sample_batch.obs 75 if self.normalize_obs: 76 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 77 78 raw_actions = self.policy_network.apply(agent_state.params.policy_params, obs) 79 80 if self.continuous_action: 81 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1)) 82 else: 83 actions_dist = get_categorical_dist(raw_actions) 84 85 actions = actions_dist.sample(seed=key) 86 87 policy_extras = PyTreeDict( 88 # raw_action=raw_actions, 89 # logp=actions_dist.log_prob(actions) 90 ) 91 92 return actions, policy_extras
93
[docs] 94 def evaluate_actions( 95 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 96 ) -> tuple[Action, PolicyExtraInfo]: 97 obs = sample_batch.obs 98 if self.normalize_obs: 99 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 100 101 raw_actions = self.policy_network.apply(agent_state.params.policy_params, obs) 102 103 if self.continuous_action: 104 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1)) 105 else: 106 actions_dist = get_categorical_dist(raw_actions) 107 108 actions = actions_dist.mode() 109 110 return actions, PyTreeDict()
111 112
[docs] 113class DeterministicECAgent(Agent): 114 """Deterministic Agent for continuous action space in [-1, 1].""" 115 116 policy_network: nn.Module 117 obs_preprocessor: Any = pytree_field(default=None, static=True) 118 119 @property 120 def normalize_obs(self): 121 return self.obs_preprocessor is not None 122
[docs] 123 def init( 124 self, obs_space: Space, action_space: Space, key: chex.PRNGKey 125 ) -> AgentState: 126 dummy_obs = jtu.tree_map(lambda x: x[None, ...], obs_space.sample(key)) 127 policy_params = self.policy_network.init(key, dummy_obs) 128 129 params_state = ECNetworkParams( 130 policy_params=policy_params, 131 ) 132 133 if self.normalize_obs: 134 # Note: statistics are broadcasted to [T*B] 135 obs_preprocessor_state = running_statistics.init_state( 136 tree_get(dummy_obs, 0) 137 ) 138 else: 139 obs_preprocessor_state = None 140 141 return AgentState( 142 params=params_state, obs_preprocessor_state=obs_preprocessor_state 143 )
144
[docs] 145 def compute_actions( 146 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 147 ) -> tuple[Action, PolicyExtraInfo]: 148 obs = sample_batch.obs 149 if self.normalize_obs: 150 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 151 152 actions = self.policy_network.apply(agent_state.params.policy_params, obs) 153 154 return actions, PyTreeDict()
155
[docs] 156 def evaluate_actions( 157 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 158 ) -> tuple[Action, PolicyExtraInfo]: 159 return self.compute_actions(agent_state, sample_batch, key)
160 161
[docs] 162def make_stochastic_ec_agent( 163 action_space: Space, 164 actor_hidden_layer_sizes: tuple[int] = (256, 256), 165 use_bias: bool = True, 166 norm_layer_type: str = "none", 167 normalize_obs: bool = False, 168 policy_obs_key: str = "", 169): 170 if isinstance(action_space, Box): 171 action_size = action_space.shape[0] * 2 172 continuous_action = True 173 elif isinstance(action_space, Discrete): 174 action_size = action_space.n 175 continuous_action = False 176 else: 177 raise NotImplementedError(f"Unsupported action space: {action_space}") 178 179 policy_network = make_policy_network( 180 action_size=action_size, 181 hidden_layer_sizes=actor_hidden_layer_sizes, 182 norm_layer_type=norm_layer_type, 183 use_bias=use_bias, 184 obs_key=policy_obs_key, 185 ) 186 187 if normalize_obs: 188 obs_preprocessor = running_statistics.normalize 189 else: 190 obs_preprocessor = None 191 192 return StochasticECAgent( 193 continuous_action=continuous_action, 194 policy_network=policy_network, 195 obs_preprocessor=obs_preprocessor, 196 )
197 198
[docs] 199def make_deterministic_ec_agent( 200 action_space: Space, 201 actor_hidden_layer_sizes: tuple[int] = (256, 256), 202 use_bias: bool = True, 203 norm_layer_type: str = "none", 204 normalize_obs: bool = False, 205 policy_obs_key: str = "", 206): 207 assert isinstance(action_space, Box), "Only continue action space is supported." 208 209 action_size = action_space.shape[0] 210 211 policy_network = make_policy_network( 212 action_size=action_size, 213 hidden_layer_sizes=actor_hidden_layer_sizes, 214 use_bias=use_bias, 215 activation_final=nn.tanh, 216 norm_layer_type=norm_layer_type, 217 obs_key=policy_obs_key, 218 ) 219 220 if normalize_obs: 221 obs_preprocessor = running_statistics.normalize 222 else: 223 obs_preprocessor = None 224 225 return DeterministicECAgent( 226 policy_network=policy_network, 227 obs_preprocessor=obs_preprocessor, 228 )