evorl.algorithms.a2c¶
Module Contents¶
Classes¶
Contains training state for the learner. |
|
Functions¶
API¶
- class evorl.algorithms.a2c.A2CAgent[source]¶
Bases:
evorl.agent.Agent- compute_actions(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) tuple[evorl.types.Action, evorl.types.PolicyExtraInfo][source]¶
- compute_values(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch) chex.Array[source]¶
- continuous_action: bool¶
None
- evaluate_actions(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) tuple[evorl.types.Action, evorl.types.PolicyExtraInfo][source]¶
- init(obs_space: evorl.envs.Space, action_space: evorl.envs.Space, key: chex.PRNGKey) evorl.agent.AgentState[source]¶
- loss(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) evorl.types.LossDict[source]¶
- property normalize_obs¶
- obs_preprocessor: Any¶
‘pytree_field(…)’
- policy_network: flax.linen.Module¶
None
- value_network: flax.linen.Module¶
None
- class evorl.algorithms.a2c.A2CNetworkParams[source]¶
Bases:
evorl.types.PyTreeDataContains training state for the learner.
- policy_params: evorl.types.Params¶
None
- value_params: evorl.types.Params¶
None
- class evorl.algorithms.a2c.A2CWorkflow(env: evorl.envs.Env, agent: evorl.agent.Agent, optimizer: optax.GradientTransformation, evaluator: evorl.evaluators.Evaluator, config: omegaconf.DictConfig)[source]¶
Bases:
evorl.workflows.OnPolicyWorkflow- learn(state: evorl.types.State) evorl.types.State[source]¶
- step(state: evorl.types.State) tuple[evorl.metrics.MetricBase, evorl.types.State][source]¶
- evorl.algorithms.a2c.make_mlp_a2c_agent(action_space: evorl.envs.Space, actor_hidden_layer_sizes: tuple[int] = (256, 256), critic_hidden_layer_sizes: tuple[int] = (256, 256), normalize_obs: bool = False, policy_obs_key: str = '', value_obs_key: str = '') evorl.algorithms.a2c.A2CAgent[source]¶