evorl.workflows.rl_workflow

Module Contents

Classes

OffPolicyWorkflow

Workflow template for Off-Policy RL algorithms.

OnPolicyWorkflow

Workflow template for On-Policy RL algorithms.

RLWorkflow

Base Workflow for RL algorithms.

API

class evorl.workflows.rl_workflow.OffPolicyWorkflow(env: evorl.envs.Env, agent: evorl.agent.Agent, optimizer: optax.GradientTransformation, evaluator: evorl.evaluators.Evaluator, replay_buffer: evorl.replay_buffers.AbstractReplayBuffer, config: omegaconf.DictConfig)[source]

Bases: evorl.workflows.rl_workflow.RLWorkflow

Workflow template for Off-Policy RL algorithms.

This class constructs the template for Off-Policy RL algorithms, providing the general setup() and evaluate() methods.

evaluate(state: evorl.types.State) tuple[evorl.metrics.MetricBase, evorl.types.State][source]
setup(key: chex.PRNGKey) evorl.types.State[source]
class evorl.workflows.rl_workflow.OnPolicyWorkflow(env: evorl.envs.Env, agent: evorl.agent.Agent, optimizer: optax.GradientTransformation, evaluator: evorl.evaluators.Evaluator, config: omegaconf.DictConfig)[source]

Bases: evorl.workflows.rl_workflow.RLWorkflow

Workflow template for On-Policy RL algorithms.

This class constructs the template for On-Policy RL algorithms, providing the general setup() and evaluate() methods.

evaluate(state: evorl.types.State) tuple[evorl.metrics.MetricBase, evorl.types.State][source]
setup(key: chex.PRNGKey) evorl.types.State[source]
class evorl.workflows.rl_workflow.RLWorkflow(config: omegaconf.DictConfig)[source]

Bases: evorl.workflows.workflow.Workflow

Base Workflow for RL algorithms.

classmethod build_from_config(config: omegaconf.DictConfig, enable_multi_devices: bool = False, enable_jit: bool = True) typing_extensions.Self[source]

Build the rl workflow instance from the config.

Parameters:
  • config – Config of the workflow.

  • enable_multi_devices – Whether multi-devices training is enabled.

  • enable_jit – Whether jit is enabled.

classmethod enable_jit() None[source]

Define which methods should be jitted.

By default, the workflow’s step() and evaluate() methods are jitted.

property enable_multi_devices: bool

Whether multi-devices training is enabled.

classmethod enable_shmap(axis_name: str) None[source]

Define which methods should be shmaped.

This method defines the multi-device behavior. By default, the workflow’s step() and evaluate() methods are shmaped.

Parameters:

axis_name – The axis_name for shmap.

abstract evaluate(state: evorl.types.State) tuple[evorl.metrics.MetricBase, evorl.types.State][source]

Customize the evaluation logic for the workflow.

Parameters:

state – State of the workflow.

abstract step(state: evorl.types.State) tuple[evorl.metrics.MetricBase, evorl.types.State][source]

Customize the training logic of one iteration.

Parameters:

state – State of the workflow.

Returns:

Tuple of (metrics, state).