evorl.algorithms.impala

Module Contents

Classes

IMPALAAgent

IMPALANetworkParams

Contains training state for the learner.

IMPALAWorkflow

Syncrhonous version of IMPALA (A2C|PPO w/ V-Trace).

Functions

API

class evorl.algorithms.impala.IMPALAAgent[source]

Bases: evorl.agent.Agent

adv_mode: str

‘pytree_field(…)’

clip_c_threshold: float

1.0

clip_pg_rho_threshold: float

1.0

clip_rho_threshold: float

1.0

compute_actions(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) tuple[evorl.types.Action, evorl.types.PolicyExtraInfo][source]
continuous_action: bool

None

discount: float

0.99

evaluate_actions(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) tuple[evorl.types.Action, evorl.types.PolicyExtraInfo][source]
init(obs_space: evorl.envs.Space, action_space: evorl.envs.Space, key: chex.PRNGKey) evorl.agent.AgentState[source]
loss(agent_state: evorl.agent.AgentState, trajectory: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) evorl.types.LossDict[source]

IMPALA loss.

Parameters:

trajectory – [T, B, …] a sequence of transitions, not shuffled timesteps

property normalize_obs
obs_preprocessor: Any

‘pytree_field(…)’

policy_network: flax.linen.Module

None

value_network: flax.linen.Module

None

vtrace_lambda: float

1.0

class evorl.algorithms.impala.IMPALANetworkParams[source]

Bases: evorl.types.PyTreeData

Contains training state for the learner.

policy_params: evorl.types.Params

None

value_params: evorl.types.Params

None

class evorl.algorithms.impala.IMPALAWorkflow(env: evorl.envs.Env, agent: evorl.agent.Agent, optimizer: optax.GradientTransformation, evaluator: evorl.evaluators.Evaluator, config: omegaconf.DictConfig)[source]

Bases: evorl.workflows.OnPolicyWorkflow

Syncrhonous version of IMPALA (A2C|PPO w/ V-Trace).

learn(state: evorl.types.State) evorl.types.State[source]
classmethod name()[source]
step(state: evorl.types.State) tuple[evorl.metrics.MetricBase, evorl.types.State][source]
evorl.algorithms.impala.compute_pg_advantage(vtrace, v_t, v_t_plus_1, rewards, terminations, discount=0.99, lambda_=1.0, mode='official')[source]
evorl.algorithms.impala.compute_vtrace(rho_t, v_t, v_t_plus_1, rewards, dones, terminations, discount=0.99, lambda_=1.0, clip_rho_threshold=1.0, clip_c_threshold=1.0)[source]
evorl.algorithms.impala.make_mlp_impala_agent(action_space: evorl.envs.Space, discount: float = 0.99, vtrace_lambda: float = 1.0, clip_rho_threshold: float = 1.0, clip_c_threshold: float = 1.0, clip_pg_rho_threshold: float = 1.0, adv_mode: str = 'official', actor_hidden_layer_sizes: tuple[int] = (256, 256), critic_hidden_layer_sizes: tuple[int] = (256, 256), normalize_obs: bool = False, policy_obs_key: str = '', value_obs_key: str = '')[source]