evorl.algorithms.td3

Module Contents

Classes

TD3Agent

The Agnet for TD3.

TD3NetworkParams

Contains training state for the learner.

TD3TrainMetric

TD3Workflow

Functions

API

class evorl.algorithms.td3.TD3Agent[source]

Bases: evorl.agent.Agent

The Agnet for TD3.

actor_loss(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) evorl.types.LossDict[source]

Actor loss in TD3.

Parameters:

sample_barch – [B, …]

Return: LossDict[ actor_loss critic_loss actor_entropy_loss ]

actor_network: flax.linen.Module

None

clip_policy_noise: float

0.5

compute_actions(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) tuple[evorl.types.Action, evorl.types.PolicyExtraInfo][source]
critic_loss(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) evorl.types.LossDict[source]

Critic loss in TD3.

Parameters:

sample_barch – [B, …]

Return: LossDict[ actor_loss critic_loss actor_entropy_loss ]

critic_network: flax.linen.Module

None

critics_in_actor_loss: str

‘first’

discount: float

0.99

evaluate_actions(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) tuple[evorl.types.Action, evorl.types.PolicyExtraInfo][source]
exploration_epsilon: float

0.5

init(obs_space: evorl.envs.Space, action_space: evorl.envs.Space, key: chex.PRNGKey) evorl.agent.AgentState[source]
property normalize_obs
obs_preprocessor: Any

‘pytree_field(…)’

policy_noise: float

0.2

class evorl.algorithms.td3.TD3NetworkParams[source]

Bases: evorl.types.PyTreeData

Contains training state for the learner.

actor_params: evorl.types.Params

None

critic_params: evorl.types.Params

None

target_actor_params: evorl.types.Params

None

target_critic_params: evorl.types.Params

None

class evorl.algorithms.td3.TD3TrainMetric[source]

Bases: evorl.metrics.MetricBase

actor_loss: chex.Array

None

critic_loss: chex.Array

None

raw_loss_dict: evorl.types.LossDict

‘metric_field(…)’

class evorl.algorithms.td3.TD3Workflow(env: evorl.envs.Env, agent: evorl.agent.Agent, optimizer: optax.GradientTransformation, evaluator: evorl.evaluators.Evaluator, replay_buffer: evorl.replay_buffers.AbstractReplayBuffer, config: omegaconf.DictConfig)[source]

Bases: evorl.algorithms.offpolicy_utils.OffPolicyWorkflowTemplate

classmethod name()[source]
step(state: evorl.types.State) tuple[evorl.metrics.MetricBase, evorl.types.State][source]
evorl.algorithms.td3.make_mlp_td3_agent(action_space: evorl.envs.Space, norm_layer_type: str = 'none', num_critics: int = 2, critic_hidden_layer_sizes: tuple[int] = (256, 256), actor_hidden_layer_sizes: tuple[int] = (256, 256), discount: float = 0.99, exploration_epsilon: float = 0.5, policy_noise: float = 0.2, clip_policy_noise: float = 0.5, critics_in_actor_loss: str = 'first', normalize_obs: bool = False, policy_obs_key: str = '', value_obs_key: str = '')[source]