evorl.algorithms.td7¶
Module Contents¶
Classes¶
The Agent for TD7. |
|
Functions¶
Average L1 Norm used in TD7. |
|
API¶
- class evorl.algorithms.td7.TD7Actor[source]¶
Bases:
flax.linen.Module- action_size: int¶
None
(256, 256)
- state_emb_dim: int¶
256
- z_s_dim: int¶
256
- class evorl.algorithms.td7.TD7Agent[source]¶
Bases:
evorl.agent.AgentThe Agent for TD7.
- actor_loss(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) evorl.types.LossDict[source]¶
- actor_network: flax.linen.Module¶
None
- clip_policy_noise: float¶
0.5
- compute_actions(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) tuple[evorl.types.Action, evorl.types.PolicyExtraInfo][source]¶
- critic_loss(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) evorl.types.LossDict[source]¶
- critic_network: flax.linen.Module¶
None
- discount: float¶
0.99
- encoder_loss(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) evorl.types.LossDict[source]¶
- encoder_network: flax.linen.Module¶
None
- evaluate_actions(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) tuple[evorl.types.Action, evorl.types.PolicyExtraInfo][source]¶
- exploration_epsilon: float¶
0.1
- init(obs_space: evorl.envs.Space, action_space: evorl.envs.Space, key: chex.PRNGKey) evorl.agent.AgentState[source]¶
- min_priority: float¶
1.0
- property normalize_obs¶
- obs_preprocessor: Any¶
‘pytree_field(…)’
- policy_noise: float¶
0.2
- class evorl.algorithms.td7.TD7Critic[source]¶
Bases:
flax.linen.Module(256, 256)
- state_action_emb_dim: int¶
256
- z_s_dim: int¶
256
- z_sa_dim: int¶
256
- class evorl.algorithms.td7.TD7Encoder[source]¶
Bases:
flax.linen.Module- f_layer_sizes: Sequence[int]¶
(256, 256)
- g_layer_sizes: Sequence[int]¶
(256, 256)
- z_s_dim: int¶
256
- z_sa_dim: int¶
256
- class evorl.algorithms.td7.TD7NetworkParams[source]¶
Bases:
evorl.types.PyTreeData- actor_params: evorl.types.Params¶
None
- checkpoint_actor_params: evorl.types.Params¶
None
- checkpoint_encoder_params: evorl.types.Params¶
None
- critic_params: evorl.types.Params¶
None
- encoder_params: evorl.types.Params¶
None
- fixed_encoder_params: evorl.types.Params¶
None
- fixed_encoder_target_params: evorl.types.Params¶
None
- target_actor_params: evorl.types.Params¶
None
- target_critic_params: evorl.types.Params¶
None
- class evorl.algorithms.td7.TD7TrainMetric[source]¶
Bases:
evorl.metrics.MetricBase- actor_loss: chex.Array¶
None
- critic_loss: chex.Array¶
None
- encoder_loss: chex.Array¶
None
- raw_loss_dict: evorl.types.LossDict¶
‘metric_field(…)’
- class evorl.algorithms.td7.TD7Workflow(env: evorl.envs.Env, agent: evorl.agent.Agent, optimizer: optax.GradientTransformation, evaluator: evorl.evaluators.Evaluator, replay_buffer: evorl.replay_buffers.AbstractReplayBuffer, config: omegaconf.DictConfig)[source]¶
Bases:
evorl.algorithms.offpolicy_utils.OffPolicyWorkflowTemplate- learn(state: evorl.types.State) evorl.types.State[source]¶
- step(state: evorl.types.State) tuple[evorl.metrics.MetricBase, evorl.types.State][source]¶
- evorl.algorithms.td7.avg_l1_norm(x: jax.Array, eps: float = 1e-08) jax.Array[source]¶
Average L1 Norm used in TD7.
- evorl.algorithms.td7.make_td7_agent(action_space: evorl.envs.Space, z_s_dim: int = 256, z_sa_dim: int = 256, f_layer_sizes: Sequence[int] = (256, 256), g_layer_sizes: Sequence[int] = (256, 256), state_emb_dim: int = 256, state_action_emb_dim: int = 256, critic_hidden_layer_sizes: Sequence[int] = (256, 256), actor_hidden_layer_sizes: Sequence[int] = (256, 256), discount: float = 0.99, exploration_epsilon: float = 0.1, policy_noise: float = 0.2, clip_policy_noise: float = 0.5, min_priority: float = 1.0, normalize_obs: bool = False)[source]¶