evorl.workflows.rl_workflow¶
Module Contents¶
Classes¶
Workflow template for Off-Policy RL algorithms. |
|
Workflow template for On-Policy RL algorithms. |
|
Base Workflow for RL algorithms. |
API¶
- class evorl.workflows.rl_workflow.OffPolicyWorkflow(env: evorl.envs.Env, agent: evorl.agent.Agent, optimizer: optax.GradientTransformation, evaluator: evorl.evaluators.Evaluator, replay_buffer: evorl.replay_buffers.AbstractReplayBuffer, config: omegaconf.DictConfig)[source]¶
Bases:
evorl.workflows.rl_workflow.RLWorkflowWorkflow template for Off-Policy RL algorithms.
This class constructs the template for Off-Policy RL algorithms, providing the general
setup()andevaluate()methods.- evaluate(state: evorl.types.State) tuple[evorl.metrics.MetricBase, evorl.types.State][source]¶
- setup(key: chex.PRNGKey) evorl.types.State[source]¶
- class evorl.workflows.rl_workflow.OnPolicyWorkflow(env: evorl.envs.Env, agent: evorl.agent.Agent, optimizer: optax.GradientTransformation, evaluator: evorl.evaluators.Evaluator, config: omegaconf.DictConfig)[source]¶
Bases:
evorl.workflows.rl_workflow.RLWorkflowWorkflow template for On-Policy RL algorithms.
This class constructs the template for On-Policy RL algorithms, providing the general
setup()andevaluate()methods.- evaluate(state: evorl.types.State) tuple[evorl.metrics.MetricBase, evorl.types.State][source]¶
- setup(key: chex.PRNGKey) evorl.types.State[source]¶
- class evorl.workflows.rl_workflow.RLWorkflow(config: omegaconf.DictConfig)[source]¶
Bases:
evorl.workflows.workflow.WorkflowBase Workflow for RL algorithms.
- classmethod build_from_config(config: omegaconf.DictConfig, enable_multi_devices: bool = False, enable_jit: bool = True) typing_extensions.Self[source]¶
Build the rl workflow instance from the config.
- Parameters:
config – Config of the workflow.
enable_multi_devices – Whether multi-devices training is enabled.
enable_jit – Whether jit is enabled.
- classmethod enable_jit() None[source]¶
Define which methods should be jitted.
By default, the workflow’s
step()andevaluate()methods are jitted.
- property enable_multi_devices: bool¶
Whether multi-devices training is enabled.
- classmethod enable_shmap(axis_name: str) None[source]¶
Define which methods should be shmaped.
This method defines the multi-device behavior. By default, the workflow’s
step()andevaluate()methods are shmaped.- Parameters:
axis_name – The axis_name for shmap.
- abstract evaluate(state: evorl.types.State) tuple[evorl.metrics.MetricBase, evorl.types.State][source]¶
Customize the evaluation logic for the workflow.
- Parameters:
state – State of the workflow.
- abstract step(state: evorl.types.State) tuple[evorl.metrics.MetricBase, evorl.types.State][source]¶
Customize the training logic of one iteration.
- Parameters:
state – State of the workflow.
- Returns:
Tuple of (metrics, state).