Source code for evorl.envs.env

 1from abc import ABC, abstractmethod
 2from collections.abc import Callable
 3from typing import Any
 4
 5import chex
 6
 7from evorl.types import (
 8    Action,
 9    Done,
10    EnvInternalState,
11    Observation,
12    PyTreeData,
13    PyTreeDict,
14    Reward,
15    pytree_field,
16)
17
18from .space import Space
19
20
[docs] 21class EnvState(PyTreeData): 22 """State of the environment. 23 24 Include all the data needed to represent the state of the environment. 25 26 Attributes: 27 env_state: The internal state of the environment. 28 obs: The observation of the environment. 29 reward: The reward of the environment. 30 done: Whether the environment is done. 31 info: Extra info from the environment. 32 _internal: Extra internal data for the environment. 33 """ 34 35 env_state: EnvInternalState 36 obs: Observation 37 reward: Reward 38 done: Done 39 info: PyTreeDict = pytree_field(default_factory=PyTreeDict) # info from env 40 _internal: PyTreeDict = pytree_field( 41 default_factory=PyTreeDict 42 ) # extra data for interal use
43 44
[docs] 45class Env(ABC): 46 """Unified EvoRL Env API.""" 47
[docs] 48 @abstractmethod 49 def reset(self, key: chex.PRNGKey) -> EnvState: 50 """Reset the environment to initial state.""" 51 raise NotImplementedError
52
[docs] 53 @abstractmethod 54 def step(self, state: EnvState, action: Action) -> EnvState: 55 """Take a step in the environment.""" 56 raise NotImplementedError
57 58 @property 59 @abstractmethod 60 def action_space(self) -> Space: 61 """Get the action space of the environment.""" 62 raise NotImplementedError 63 64 @property 65 @abstractmethod 66 def obs_space(self) -> Space: 67 """Get the observation space of the environment.""" 68 raise NotImplementedError
69 70
[docs] 71class EnvAdapter(Env): 72 """Base class for an environment adapter. 73 74 Convert envs from other packages to EvoRL's Env API. 75 """ 76 77 def __init__(self, env: Any): 78 self.env = env 79 80 @property 81 def unwrapped(self) -> Any: 82 return self.env
83 84 85EnvStepFn = Callable[[EnvState, Action], EnvState] 86EnvResetFn = Callable[[chex.PRNGKey], EnvState]