Source code for evorl.envs.env
1from abc import ABC, abstractmethod
2from collections.abc import Callable
3from typing import Any
4
5import chex
6
7from evorl.types import (
8 Action,
9 Done,
10 EnvInternalState,
11 Observation,
12 PyTreeData,
13 PyTreeDict,
14 Reward,
15 pytree_field,
16)
17
18from .space import Space
19
20
[docs]
21class EnvState(PyTreeData):
22 """State of the environment.
23
24 Include all the data needed to represent the state of the environment.
25
26 Attributes:
27 env_state: The internal state of the environment.
28 obs: The observation of the environment.
29 reward: The reward of the environment.
30 done: Whether the environment is done.
31 info: Extra info from the environment.
32 _internal: Extra internal data for the environment.
33 """
34
35 env_state: EnvInternalState
36 obs: Observation
37 reward: Reward
38 done: Done
39 info: PyTreeDict = pytree_field(default_factory=PyTreeDict) # info from env
40 _internal: PyTreeDict = pytree_field(
41 default_factory=PyTreeDict
42 ) # extra data for interal use
43
44
[docs]
45class Env(ABC):
46 """Unified EvoRL Env API."""
47
[docs]
48 @abstractmethod
49 def reset(self, key: chex.PRNGKey) -> EnvState:
50 """Reset the environment to initial state."""
51 raise NotImplementedError
52
[docs]
53 @abstractmethod
54 def step(self, state: EnvState, action: Action) -> EnvState:
55 """Take a step in the environment."""
56 raise NotImplementedError
57
58 @property
59 @abstractmethod
60 def action_space(self) -> Space:
61 """Get the action space of the environment."""
62 raise NotImplementedError
63
64 @property
65 @abstractmethod
66 def obs_space(self) -> Space:
67 """Get the observation space of the environment."""
68 raise NotImplementedError
69
70
[docs]
71class EnvAdapter(Env):
72 """Base class for an environment adapter.
73
74 Convert envs from other packages to EvoRL's Env API.
75 """
76
77 def __init__(self, env: Any):
78 self.env = env
79
80 @property
81 def unwrapped(self) -> Any:
82 return self.env
83
84
85EnvStepFn = Callable[[EnvState, Action], EnvState]
86EnvResetFn = Callable[[chex.PRNGKey], EnvState]