evorl.algorithms.dqn

Module Contents

Classes

Functions

API

class evorl.algorithms.dqn.DQNAgent[source]

Bases: evorl.agent.Agent

compute_actions(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) tuple[evorl.types.Action, evorl.types.PolicyExtraInfo][source]
discount: float

0.99

evaluate_actions(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) tuple[evorl.types.Action, evorl.types.PolicyExtraInfo][source]
init(obs_space: evorl.envs.Space, action_space: evorl.envs.Space, key: chex.PRNGKey) evorl.agent.AgentState[source]
loss(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) evorl.types.LossDict[source]
property normalize_obs
obs_preprocessor: Any

‘pytree_field(…)’

q_network: flax.linen.Module

None

target_type: str

‘DDQN’

class evorl.algorithms.dqn.DQNNetworkParams[source]

Bases: evorl.types.PyTreeData

exploration_epsilon: float

None

q_params: evorl.types.Params

None

target_q_params: evorl.types.Params

None

class evorl.algorithms.dqn.DQNTrainMetric[source]

Bases: evorl.metrics.MetricBase

loss: chex.Array

‘zeros(…)’

raw_loss_dict: evorl.types.LossDict

‘metric_field(…)’

class evorl.algorithms.dqn.DQNWorkflow(env: evorl.envs.Env, agent: evorl.agent.Agent, optimizer: optax.GradientTransformation, evaluator: evorl.evaluators.Evaluator, replay_buffer: evorl.replay_buffers.AbstractReplayBuffer, config: omegaconf.DictConfig)[source]

Bases: evorl.algorithms.offpolicy_utils.OffPolicyWorkflowTemplate

classmethod name()[source]
step(state: evorl.types.State) tuple[evorl.metrics.MetricBase, evorl.types.State][source]
class evorl.algorithms.dqn.DQNWorkflowMetric[source]

Bases: evorl.metrics.WorkflowMetric

training_updates: chex.Array

‘zeros(…)’

evorl.algorithms.dqn.make_mlp_discrete_dqn_agent(action_space: evorl.envs.Space, discount: float = 0.99, target_type: str = 'DDQN', q_hidden_layer_sizes: tuple[int] = (256, 256), normalize_obs: bool = False, value_obs_key: str = '')[source]