evorl.algorithms.sac¶
Module Contents¶
Classes¶
Functions¶
API¶
- class evorl.algorithms.sac.SACAgent[source]¶
Bases:
evorl.agent.Agent- actor_loss(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) evorl.types.LossDict[source]¶
- actor_network: flax.linen.Module¶
None
- alpha_loss(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) evorl.types.LossDict[source]¶
- compute_actions(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) tuple[evorl.types.Action, evorl.types.PolicyExtraInfo][source]¶
- critic_loss(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) evorl.types.LossDict[source]¶
- critic_network: flax.linen.Module¶
None
- discount: float¶
0.99
- evaluate_actions(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) tuple[evorl.types.Action, evorl.types.PolicyExtraInfo][source]¶
- init(obs_space: evorl.envs.Space, action_space: evorl.envs.Space, key: chex.PRNGKey) evorl.agent.AgentState[source]¶
- init_alpha: float¶
1.0
- property normalize_obs¶
- obs_preprocessor: Any¶
‘pytree_field(…)’
- class evorl.algorithms.sac.SACDiscreteAgent[source]¶
Bases:
evorl.agent.Agent- actor_loss(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) evorl.types.LossDict[source]¶
- actor_network: flax.linen.Module¶
None
- alpha_loss(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) evorl.types.LossDict[source]¶
- compute_actions(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) tuple[evorl.types.Action, evorl.types.PolicyExtraInfo][source]¶
- critic_loss(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) evorl.types.LossDict[source]¶
- critic_network: flax.linen.Module¶
None
- discount: float¶
0.99
- evaluate_actions(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) tuple[evorl.types.Action, evorl.types.PolicyExtraInfo][source]¶
- init(obs_space: evorl.envs.Space, action_space: evorl.envs.Space, key: chex.PRNGKey) evorl.agent.AgentState[source]¶
- init_alpha: float¶
1.0
- property normalize_obs¶
- obs_preprocessor: Any¶
‘pytree_field(…)’
- target_entropy_ratio: float¶
0.98
- class evorl.algorithms.sac.SACNetworkParams[source]¶
Bases:
evorl.types.PyTreeData- actor_params: evorl.types.Params¶
None
- critic_params: evorl.types.Params¶
None
- log_alpha: evorl.types.Params¶
None
- target_critic_params: evorl.types.Params¶
None
- class evorl.algorithms.sac.SACTrainMetric[source]¶
Bases:
evorl.metrics.MetricBase- actor_loss: chex.Array¶
None
- alpha_loss: chex.Array | None¶
None
- critic_loss: chex.Array¶
None
- raw_loss_dict: evorl.types.LossDict¶
‘metric_field(…)’
- class evorl.algorithms.sac.SACWorkflow(env: evorl.envs.Env, agent: evorl.agent.Agent, optimizer: optax.GradientTransformation, evaluator: evorl.evaluators.Evaluator, replay_buffer: evorl.replay_buffers.AbstractReplayBuffer, config: omegaconf.DictConfig)[source]¶
Bases:
evorl.algorithms.offpolicy_utils.OffPolicyWorkflowTemplate- step(state: evorl.types.State) tuple[evorl.metrics.MetricBase, evorl.types.State][source]¶
- evorl.algorithms.sac.make_mlp_sac_agent(action_space: evorl.envs.Space, num_critics: int = 2, critic_hidden_layer_sizes: tuple[int] = (256, 256), actor_hidden_layer_sizes: tuple[int] = (256, 256), init_alpha: float = 1.0, discount: float = 0.99, target_entropy_ratio: float = 0.98, normalize_obs: bool = False, policy_obs_key: str = '', value_obs_key: str = '')[source]¶