evorl.algorithms.td7

Module Contents

Classes

Functions

avg_l1_norm

Average L1 Norm used in TD7.

make_td7_agent

API

class evorl.algorithms.td7.TD7Actor[source]

Bases: flax.linen.Module

action_size: int

None

hidden_layer_sizes: Sequence[int]

(256, 256)

state_emb_dim: int

256

z_s_dim: int

256

class evorl.algorithms.td7.TD7Agent[source]

Bases: evorl.agent.Agent

The Agent for TD7.

actor_loss(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) evorl.types.LossDict[source]
actor_network: flax.linen.Module

None

clip_policy_noise: float

0.5

compute_actions(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) tuple[evorl.types.Action, evorl.types.PolicyExtraInfo][source]
critic_loss(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) evorl.types.LossDict[source]
critic_network: flax.linen.Module

None

discount: float

0.99

encoder_loss(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) evorl.types.LossDict[source]
encoder_network: flax.linen.Module

None

evaluate_actions(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) tuple[evorl.types.Action, evorl.types.PolicyExtraInfo][source]
exploration_epsilon: float

0.1

init(obs_space: evorl.envs.Space, action_space: evorl.envs.Space, key: chex.PRNGKey) evorl.agent.AgentState[source]
min_priority: float

1.0

property normalize_obs
obs_preprocessor: Any

‘pytree_field(…)’

policy_noise: float

0.2

class evorl.algorithms.td7.TD7Critic[source]

Bases: flax.linen.Module

hidden_layer_sizes: Sequence[int]

(256, 256)

state_action_emb_dim: int

256

z_s_dim: int

256

z_sa_dim: int

256

class evorl.algorithms.td7.TD7Encoder[source]

Bases: flax.linen.Module

f_layer_sizes: Sequence[int]

(256, 256)

g_layer_sizes: Sequence[int]

(256, 256)

setup()[source]
z_s_dim: int

256

z_sa_dim: int

256

zs(obs: jax.Array) jax.Array[source]
zsa(z_s: jax.Array, action: jax.Array) jax.Array[source]
class evorl.algorithms.td7.TD7NetworkParams[source]

Bases: evorl.types.PyTreeData

actor_params: evorl.types.Params

None

checkpoint_actor_params: evorl.types.Params

None

checkpoint_encoder_params: evorl.types.Params

None

critic_params: evorl.types.Params

None

encoder_params: evorl.types.Params

None

fixed_encoder_params: evorl.types.Params

None

fixed_encoder_target_params: evorl.types.Params

None

target_actor_params: evorl.types.Params

None

target_critic_params: evorl.types.Params

None

class evorl.algorithms.td7.TD7TrainMetric[source]

Bases: evorl.metrics.MetricBase

actor_loss: chex.Array

None

critic_loss: chex.Array

None

encoder_loss: chex.Array

None

raw_loss_dict: evorl.types.LossDict

‘metric_field(…)’

class evorl.algorithms.td7.TD7Workflow(env: evorl.envs.Env, agent: evorl.agent.Agent, optimizer: optax.GradientTransformation, evaluator: evorl.evaluators.Evaluator, replay_buffer: evorl.replay_buffers.AbstractReplayBuffer, config: omegaconf.DictConfig)[source]

Bases: evorl.algorithms.offpolicy_utils.OffPolicyWorkflowTemplate

learn(state: evorl.types.State) evorl.types.State[source]
classmethod name()[source]
step(state: evorl.types.State) tuple[evorl.metrics.MetricBase, evorl.types.State][source]
evorl.algorithms.td7.avg_l1_norm(x: jax.Array, eps: float = 1e-08) jax.Array[source]

Average L1 Norm used in TD7.

evorl.algorithms.td7.make_td7_agent(action_space: evorl.envs.Space, z_s_dim: int = 256, z_sa_dim: int = 256, f_layer_sizes: Sequence[int] = (256, 256), g_layer_sizes: Sequence[int] = (256, 256), state_emb_dim: int = 256, state_action_emb_dim: int = 256, critic_hidden_layer_sizes: Sequence[int] = (256, 256), actor_hidden_layer_sizes: Sequence[int] = (256, 256), discount: float = 0.99, exploration_epsilon: float = 0.1, policy_noise: float = 0.2, clip_policy_noise: float = 0.5, min_priority: float = 1.0, normalize_obs: bool = False)[source]