Key Concepts

Object-oriented functional programming model

EvoRL uses an object-oriented functional programming model, where classes define the static execution logic and their running states are stored externally. This is different from the concepts for the commonly used object-oriented programming model, where the class’s states are stored inside the class as its properties.

This object-oriented functional programming model is to support JAX’s functional programming style while taking advantage of object-oriented programming’s modularity and composability. Below is a toy example to demonstrate what the codes look like.

import jax
import jax.numpy as jnp


class Foo:
    def __init__(self, n, i):
        # n and i should be treated as static variables,
        # we should not change them after initialization.
        self.n = n
        self.i = i

    def init(self):
        return dict(
            a=jnp.ones((self.n,)),
            b=jnp.zeros((self.n,)),
        )

    @partial(jax.jit, static_argnums=0)
    def increment(self, state):
        res = state["a"] * state["b"]

        new_state = dict(
            a=state["a"] + self.i,
            b=state["b"] + self.i,
        )

        return res, new_state

foo = Foo(n=3, i=1)
state = foo.init()
for _ in range(10):
    res, state = foo.increment(state)
    print(res)

Functional programming requires that the functions are Pure function, which have no side effects, i.e., no mutation of external variables out of the function. In this example, after creating the foo object, we should not change foo.n & foo.i and should view them as read-only variables. The init() function defines the initial values of foo. These initial values represent the object’s state and are stored outside the object. Then, the state is utilized to execute the static logic defined in foo.increment().

Basic PyTree Data Containers

We provide some basic data containers to support the object-oriented functional programming model. They simplify the procedures of writing corresponding codes and increase the flexibility.

In package evorl.types, there are three basic data containers, as listed in the table. They are registered as JAX PyTree. With JAX PyTree API support, we can define which part of the data in the container is static.

Note

The term of static cames from jax.jit. Jitted functions only allow PyTree or jax.Array types as inputs. The static part of an input pytree object will be deemed constants during the compilation and determine the computation graph. When the static part of the input is changed in the following calls, the jitted function will be re-compiled.

Conversely, the jax.Array objects are viewed as pure data. When the data in these objects are changed (dtype and shape are still the same), the jitted function will not be compiled again.

Type

Description

Usage

PyTreeDict

An easydict with pytree support

Store pure data

PyTreeData

A pytree dataclass for Data

Store data

PyTreeNode

A pytree dataclass for Node

Build logic class

  • PyTreeDict provides an easydict-like API for general storage of pure data (jax.Array).

    from evorl.types import PyTreeDict
    
    d = PyTreeDict({"a": jnp.ones((3,)), "b": jnp.zeros((3,))})
    print(d.a, d["b"])
    d.c = jnp.zeros((5,))
    
  • PyTreeData provides python dataclasses API. New data classes can inherit this class and explicitly define each field. Compared to PyTreeDict, it allows defining static data via pytree_field, and it ensures that all fields cannot be modified after the creation.

    from evorl.types import PyTreeData
    
    class SampleBatch(PyTreeData):
        obs: jax.Array | None = None
        actions: jax.Array | None = None
        rewards: jax.Array | None = None
        next_obs: jax.Array | None = None
        dones: jax.Array | None = None
    
    sample_batch = SampleBatch(obs=jnp.ones((3, 4)))
    
    class Bar(PyTreeData):
        a: jax.Array
        b: int = pytree_field(static=True, default=1) # b is static
    
    bar = Bar(a=jnp.ones((3, 4)), b=5)
    
    # The PyTreeData object is immutable
    # bar.a = jnp.zeros((3, 4)) # raise FrozenInstanceError
    
    # To change the field, use the `replace` method,
    # which will return a new data instance.
    new_bar = bar.replace(a=jnp.zeros((3, 4)))
    
  • PyTreeNode is similar to PyTreeData. However, it allows setting or changing some fields in __post_init__ after creation. This feature makes it suitable for general classes. For example, Agent, Evaluator, EvoOptimizer, etc., are all from PyTreeNode.

    class OpenES(EvoOptimizer):
        """OpenAI ES."""
    
        pop_size: int
        lr_schedule: ExponentialScheduleSpec
        noise_std_schedule: ExponentialScheduleSpec
        mirror_sampling: bool = True
        optimizer_name: str = "adam"
        weight_decay: float | None = None
    
        fitness_shaping_fn: Callable[[chex.Array], chex.Array] = pytree_field(
            static=True, default=compute_centered_ranks
        )
        optimizer: optax.GradientTransformation = pytree_field(static=True, init=False)
    
        def __post_init__(self):
            self.optimizer = optax.adam(
                learning_rate=self.lr_schedule.init
            )
    

Agent

Agent encapsulates the learning agent and defines its actions for both training and evaluation. It manages the networks for the learning agent, including the policy network, which determines the agent’s decisions for actions, and an optional value network used for estimating state or state-action values. This class also specifies optional loss functions required for gradient-based updates.

In summary, it has two public methods:

Besides, most RL-based Agents also include one or multiple loss functions, which will be called in the corresponding workflow.step().

RL Environments

We provide a unified environment API in Env to adapt multiple env libraries.

An example about how to interact with the environment:

from evorl.envs.brax import BraxAdapter
from brax.envs import get_environment

brax_env = get_environment("hopper")
env = BraxAdapter(brax_env)

# reset the environment
env_key = jax.random.PRNGKey(42)
env_state = env.reset(env_key)

# apply one step
actions = jnp.zeros((3,))
env_nstate = env.step(env_state, actions)

We provide multiple Wrapper classes for Env, they are defined in evorl.envs.wrappers. For instance, ActionSquashWrapper converts the action space from [-1,1] to [low, high], VmapAutoResetWrapper converts a single env to k parallel envs.

Based on the top of them, we provide environment creation functions for different libraries.

from evorl.envs import create_wrapped_brax_env, AutoresetMode

train_vec_env = create_wrapped_brax_env(
    "hopper", parallel=16, autoreset_mode=AutoresetMode.NORMAL
)
eval_vec_env = create_wrapped_brax_env(
    "hopper", parallel=16, autoreset_mode=AutoresetMode.DISABLED
)

Timestep subscript notation

The environment timestep subscript notation follows the convention of other common RL libraries. For the env_state at timestep t:

graph LR
  pre:::hidden --> A[$$s_t$$] -- "$$A_t,R_t,D_t,T_t$$" --> B["$$s_{t+1}$$"] --> post:::hidden
  classDef hidden display: none;
  • action: \(A_t\)

  • done & termination: \(D_t, T_t\)

  • reward: \(R_t\)

  • obs: \(O_{t+1}\)

Trajectory Data & Rollout

SampleBatch is a data container for trajectory data from the rollout between the agent and the environment. It is a subclass of PyTreeData and has 6 fields:

  • obs: chex.ArrayTree | None = None

  • actions: chex.ArrayTree | None = None

  • rewards: Reward | RewardDict | None = None

  • next_obs: chex.Array | None = None

  • dones: chex.Array | None = None

  • extras: ExtraInfo | None = None: Other trajectory information.

Some fields can be empty for various use cases. For example, SampleBatch can be used as a obs-only batch for computing actions, or used for storing trajectory data from the rollout.

evorl.rollout provides various function to execute the rollout for a given agent and environment object. The example below demonstrates how to collect training data from a vectorized environment:

from evorl import RandomAgent
from evorl.envs import create_wrapped_brax_env, AutoresetMode
from evorl.rollout import rollout

env = create_wrapped_brax_env(
    "hopper", parallel=16, autoreset_mode=AutoresetMode.NORMAL
)
agent = RandomAgent()

key = jax.random.PRNGKey(42)
rollout_key, env_key, agent_key = jax.random.split(key, 3)
env_state = env.reset(env_key)
agent_state = agent.init(env.obs_space, env.action_space, agent_key)

# trajectory data shape [128, 16, ...]
trajectory, env_nstate = rollout(
    env.step,
    agent.compute_actions,
    env_state,
    agent_state,
    rollout_key,
    rollout_length=128,
    env_extra_fields=("termination", "truncation", "steps"),
)

Besides collecting trajectory data, algorithms also need to evaluate the agent by complete episodes. We provide various evaluators in evorl.evaluators.

from evorl import RandomAgent
from evorl.envs import create_wrapped_brax_env, AutoresetMode
from evorl.evaluators import Evaluator

env = create_wrapped_brax_env(
    "hopper", parallel=16, autoreset_mode=AutoresetMode.NORMAL
)
agent = RandomAgent()

key = jax.random.PRNGKey(42)
eval_key, env_key, agent_key = jax.random.split(key, 3)
env_state = env.reset(env_key)
agent_state = agent.init(env.obs_space, env.action_space, agent_key)

evaluator = Evaluator(
    env=env,
    agent_fn=agent.evaluate_actions,
    max_episode_steps=1000,
    discount=1,
)

eval_metrics = evaluator.evaluate(agent_state, eval_key, num_episodes=10)
print("Avg Return:", eval_metrics.episode_returns.mean())
print("Avg Length:", eval_metrics.episode_lengths.mean())

Workflow

A complete algorithm consists of three components:

  1. a config file,

  2. a Workflow subclass

  3. a corresponding Agent subclass.

Workflow defines the entire training logic for a given algorithm. Its step() method encapsulates a single training iteration. Meanwhile, learn() method orchestrates the training loop. Besides calling step() repeatedly, it also manages other tasks such as termination condition checks, performance evaluation, periodic logging, and model checkpointing.

Algorithms are defined in evorl.algorithm. Most algorithms are defined in a single *.py file, containing their Agent class and Workflow class. A workflow receives a config object during the creation, which is linked to a *.yaml config file in path configs/agent.

An example of how to use Workflow:

from hydra import compose, initialize
import jax

from evorl.algorithms.ppo import PPOWorkflow


with initialize(version_base=None, config_path="configs"):
    # choose the config file:
    # env: configs/brax/hopper.yaml
    # algorithm: configs/agent/ppo.yaml
    config = compose(config_name="config", overrides=["env=brax/hopper", "agent=ppo"])

workflow = PPOWorkflow.build_from_config(config, enable_jit=config.enable_jit)
state = workflow.init(jax.random.PRNGKey(config.seed))

# Train in one line:
# state = workflow.learn(state)

# Or manually control the training loop:
for i in range(100):
    train_metrics, state = workflow.step(state)
    if i % 10 == 0:
        eval_metrics, state = workflow.evaluate(state)
        print(
            f"Step {i} Avg Return: {eval_metrics.episode_returns} Avg Length: {eval_metrics.episode_lengths}"
        )

# release resources like checkpoint manager, recorders.
workflow.close()