Source code for evorl.envs
1import importlib
2
3from .space import Space, Box, Discrete, SpaceContainer, is_leaf_space
4from .env import Env, EnvState, EnvStepFn, EnvResetFn
5from .multi_agent_env import MultiAgentEnv
6from .wrappers.training_wrapper import AutoresetMode
7from .brax import create_brax_env, create_wrapped_brax_env
8from .gymnasium import create_gymnasium_env
9
10
[docs]
11def create_env(env_cfg, **kwargs) -> Env:
12 """Unified env creator.
13
14 Args:
15 env_cfg: The environment configuration.
16 **kwargs: Additional keyword arguments for the environment creator.
17
18 Returns:
19 The created environment.
20 """
21 env_type = env_cfg.env_type
22 env_name = env_cfg.env_name
23
24 match env_type:
25 case "brax":
26 env = create_wrapped_brax_env(env_name, **kwargs)
27 case "playground":
28 env = create_wrapped_mujoco_playground_env(env_name, **kwargs)
29 case "gymnax":
30 env = create_wrapped_gymnax_env(env_name, **kwargs)
31 case "jumanji":
32 env = create_jumanji_env(env_name, **kwargs)
33 case "jaxmarl":
34 env = create_mabrax_env(env_name, **kwargs)
35 case "envpool":
36 env = create_envpool_env(env_name, **kwargs)
37 case "gymnasium":
38 env = create_gymnasium_env(env_name, **kwargs)
39 case _:
40 raise ValueError(f"env_type {env_type} not supported")
41
42 return env
43
44
45__all__ = [
46 "Env",
47 "EnvState",
48 "MultiAgentEnv",
49 "Space",
50 "Box",
51 "Discrete",
52 "SpaceContainer",
53 "AutoresetMode",
54 "is_leaf_space",
55 "create_env",
56 "create_brax_env",
57 "create_wrapped_brax_env",
58 "create_gymnasium_env",
59]
60
61if importlib.util.find_spec("gymnax") is not None:
62 from .gymnax import create_gymnax_env, create_wrapped_gymnax_env
63
64 __all__.extend(["create_gymnax_env", "create_wrapped_gymnax_env"])
65
66if importlib.util.find_spec("jumanji") is not None:
67 from .jumanji import create_jumanji_env
68
69 __all__.extend(["create_jumanji_env"])
70
71if importlib.util.find_spec("jaxmarl") is not None:
72 from .jaxmarl import create_mabrax_env, create_wrapped_mabrax_env
73
74 __all__.extend(["create_mabrax_env", "create_wrapped_mabrax_env"])
75
76if importlib.util.find_spec("envpool") is not None:
77 from .envpool import create_envpool_env
78
79 __all__.extend(["create_envpool_env"])
80
81if importlib.util.find_spec("mujoco_playground") is not None:
82 from .mujoco_playground import (
83 create_mujoco_playground_env,
84 create_wrapped_mujoco_playground_env,
85 )
86
87 __all__.extend(
88 ["create_mujoco_playground_env", "create_wrapped_mujoco_playground_env"]
89 )