evorl.algorithms.impala¶
Module Contents¶
Classes¶
Contains training state for the learner. |
|
Syncrhonous version of IMPALA (A2C|PPO w/ V-Trace). |
Functions¶
API¶
- class evorl.algorithms.impala.IMPALAAgent[source]¶
Bases:
evorl.agent.Agent- adv_mode: str¶
‘pytree_field(…)’
- clip_c_threshold: float¶
1.0
- clip_pg_rho_threshold: float¶
1.0
- clip_rho_threshold: float¶
1.0
- compute_actions(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) tuple[evorl.types.Action, evorl.types.PolicyExtraInfo][source]¶
- continuous_action: bool¶
None
- discount: float¶
0.99
- evaluate_actions(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) tuple[evorl.types.Action, evorl.types.PolicyExtraInfo][source]¶
- init(obs_space: evorl.envs.Space, action_space: evorl.envs.Space, key: chex.PRNGKey) evorl.agent.AgentState[source]¶
- loss(agent_state: evorl.agent.AgentState, trajectory: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) evorl.types.LossDict[source]¶
IMPALA loss.
- Parameters:
trajectory – [T, B, …] a sequence of transitions, not shuffled timesteps
- property normalize_obs¶
- obs_preprocessor: Any¶
‘pytree_field(…)’
- policy_network: flax.linen.Module¶
None
- value_network: flax.linen.Module¶
None
- vtrace_lambda: float¶
1.0
- class evorl.algorithms.impala.IMPALANetworkParams[source]¶
Bases:
evorl.types.PyTreeDataContains training state for the learner.
- policy_params: evorl.types.Params¶
None
- value_params: evorl.types.Params¶
None
- class evorl.algorithms.impala.IMPALAWorkflow(env: evorl.envs.Env, agent: evorl.agent.Agent, optimizer: optax.GradientTransformation, evaluator: evorl.evaluators.Evaluator, config: omegaconf.DictConfig)[source]¶
Bases:
evorl.workflows.OnPolicyWorkflowSyncrhonous version of IMPALA (A2C|PPO w/ V-Trace).
- learn(state: evorl.types.State) evorl.types.State[source]¶
- step(state: evorl.types.State) tuple[evorl.metrics.MetricBase, evorl.types.State][source]¶
- evorl.algorithms.impala.compute_pg_advantage(vtrace, v_t, v_t_plus_1, rewards, terminations, discount=0.99, lambda_=1.0, mode='official')[source]¶
- evorl.algorithms.impala.compute_vtrace(rho_t, v_t, v_t_plus_1, rewards, dones, terminations, discount=0.99, lambda_=1.0, clip_rho_threshold=1.0, clip_c_threshold=1.0)[source]¶
- evorl.algorithms.impala.make_mlp_impala_agent(action_space: evorl.envs.Space, discount: float = 0.99, vtrace_lambda: float = 1.0, clip_rho_threshold: float = 1.0, clip_c_threshold: float = 1.0, clip_pg_rho_threshold: float = 1.0, adv_mode: str = 'official', actor_hidden_layer_sizes: tuple[int] = (256, 256), critic_hidden_layer_sizes: tuple[int] = (256, 256), normalize_obs: bool = False, policy_obs_key: str = '', value_obs_key: str = '')[source]¶