evorl.algorithms.meta.pbt_sac.param_sac

Module Contents

Classes

ParamSACAgent

SAC agent with parameterized hyperparameters.

ParamSACTrainMetric

ParamSACWorkflow

Workflow for ParamSAC.

Functions

API

class evorl.algorithms.meta.pbt_sac.param_sac.ParamSACAgent[source]

Bases: evorl.algorithms.sac.SACAgent

SAC agent with parameterized hyperparameters.

critic_loss(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) evorl.types.LossDict[source]
init(obs_space: evorl.envs.Space, action_space: evorl.envs.Space, key: chex.PRNGKey) evorl.agent.AgentState[source]
class evorl.algorithms.meta.pbt_sac.param_sac.ParamSACTrainMetric[source]

Bases: evorl.algorithms.sac.SACTrainMetric

trajectory: evorl.sample_batch.SampleBatch

None

class evorl.algorithms.meta.pbt_sac.param_sac.ParamSACWorkflow(env: evorl.envs.Env, agent: evorl.agent.Agent, optimizer: optax.GradientTransformation, evaluator: evorl.evaluators.Evaluator, replay_buffer: evorl.replay_buffers.AbstractReplayBuffer, config: omegaconf.DictConfig)[source]

Bases: evorl.algorithms.offpolicy_utils.OffPolicyWorkflowTemplate

Workflow for ParamSAC.

Note: This workflow can only work with PBTParamSACWorkflow, since the replay buffer is initialized and managed by PBT externally.

classmethod name()[source]
setup(key: chex.PRNGKey) evorl.types.State[source]
step(state: evorl.types.State) tuple[evorl.metrics.MetricBase, evorl.types.State][source]
evorl.algorithms.meta.pbt_sac.param_sac.make_mlp_sac_agent(action_space: evorl.envs.Space, critic_hidden_layer_sizes: tuple[int] = (256, 256), actor_hidden_layer_sizes: tuple[int] = (256, 256), init_alpha: float = 1.0, discount: float = 0.99, normalize_obs: bool = False)[source]