from abc import ABC
from types import MethodType
from typing import Type, Union, Any, List, Callable, Iterable, Dict, Optional
from functools import partial, wraps
from easydict import EasyDict
import copy
import platform
from collections import namedtuple
import numbers
import logging
import torch
import enum
import time
import traceback
import signal
from ding.torch_utils import to_tensor, to_ndarray, to_list
from ding.utils import ENV_MANAGER_REGISTRY, import_module, deep_merge_dicts, one_time_warning
from ding.envs.env.base_env import BaseEnvTimestep
from ding.utils.time_helper import WatchDog
class EnvState(enum.IntEnum):
VOID = 0
INIT = 1
RUN = 2
RESET = 3
DONE = 4
ERROR = 5
def retry_wrapper(func: Callable = None, max_retry: int = 10, waiting_time: float = 0.1) -> Callable:
"""
Overview:
Retry the function until exceeding the maximum retry times.
"""
if func is None:
return partial(retry_wrapper, max_retry=max_retry)
if max_retry == 1:
return func
@wraps(func)
def wrapper(*args, **kwargs):
exceptions = []
for _ in range(max_retry):
try:
ret = func(*args, **kwargs)
return ret
except BaseException as e:
exceptions.append(e)
time.sleep(waiting_time)
e_info = ''.join(
[
'Retry {} failed from:\n {}\n'.format(i, ''.join(traceback.format_tb(e.__traceback__)) + str(e))
for i, e in enumerate(exceptions)
]
)
func_exception = Exception("Function {} runtime error:\n{}".format(func, e_info))
raise RuntimeError("Function {} has exceeded max retries({})".format(func, max_retry)) from func_exception
return wrapper
def timeout_wrapper(func: Callable = None, timeout: int = 10) -> Callable:
"""
Overview:
Watch the function that must be finihsed within a period of time. If timeout, raise the captured error.
"""
if func is None:
return partial(timeout_wrapper, timeout=timeout)
windows_flag = platform.system().lower() == 'windows'
if windows_flag:
one_time_warning("Timeout wrapper is not implemented in windows platform, so ignore it default")
return func
@wraps(func)
def wrapper(*args, **kwargs):
watchdog = WatchDog(timeout)
try:
watchdog.start()
except ValueError as e:
# watchdog invalid case
return func(*args, **kwargs)
try:
return func(*args, **kwargs)
except BaseException as e:
raise e
finally:
watchdog.stop()
return wrapper
[docs]@ENV_MANAGER_REGISTRY.register('base')
class BaseEnvManager(object):
"""
Overview:
Create a BaseEnvManager to manage multiple environments.
Interfaces:
reset, step, seed, close, enable_save_replay, launch, env_info, default_config
Properties:
env_num, ready_obs, done, method_name_list,active_env
"""
@classmethod
def default_config(cls: type) -> EasyDict:
cfg = EasyDict(copy.deepcopy(cls.config))
cfg.cfg_type = cls.__name__ + 'Dict'
return cfg
config = dict(
episode_num=float("inf"),
max_retry=1,
step_timeout=60,
auto_reset=True,
reset_timeout=60,
retry_waiting_time=0.1,
)
def __init__(
self,
env_fn: List[Callable],
cfg: EasyDict = EasyDict({}),
) -> None:
"""
Overview:
Initialize the BaseEnvManager.
Arguments:
- env_fn (:obj:`List[Callable]`): The function to create environment
- cfg (:obj:`EasyDict`): Config
"""
self._cfg = cfg
self._env_fn = env_fn
self._env_num = len(self._env_fn)
self._closed = True
self._env_replay_path = None
# env_ref is used to acquire some common attributes of env, like obs_shape and act_shape
self._env_ref = self._env_fn[0]()
self._env_states = {i: EnvState.VOID for i in range(self._env_num)}
self._env_seed = {i: None for i in range(self._env_num)}
self._episode_num = self._cfg.episode_num
self._max_retry = self._cfg.max_retry
self._step_timeout = self._cfg.step_timeout
self._auto_reset = self._cfg.auto_reset
self._reset_timeout = self._cfg.reset_timeout
self._retry_waiting_time = self._cfg.retry_waiting_time
@property
def env_num(self) -> int:
return self._env_num
@property
def ready_obs(self) -> Dict[int, Any]:
"""
Overview:
Get the next observations(in ``torch.Tensor`` type) and corresponding env id.
Return:
A dictionary with observations and their environment IDs.
Example:
>>> obs_dict = env_manager.ready_obs
>>> actions_dict = {env_id: model.forward(obs) for env_id, obs in obs_dict.items())}
"""
return {i: self._ready_obs[i] for i in range(self.env_num) if self._env_episode_count[i] < self._episode_num}
@property
def done(self) -> bool:
return all([s == EnvState.DONE for s in self._env_states.values()])
@property
def method_name_list(self) -> list:
return ['reset', 'step', 'seed', 'close', 'enable_save_replay']
@property
def active_env(self) -> List[int]:
return [i for i, s in self._env_states.items() if s == EnvState.RUN]
def __getattr__(self, key: str) -> Any:
"""
Note:
If a python object doesn't have the attribute whose name is `key`, it will call this method.
We suppose that all envs have the same attributes.
If you need different envs, please implement other env managers.
"""
if not hasattr(self._env_ref, key):
raise AttributeError("env `{}` doesn't have the attribute `{}`".format(type(self._env_ref), key))
if isinstance(getattr(self._env_ref, key), MethodType) and key not in self.method_name_list:
raise RuntimeError("env getattr doesn't support method({}), please override method_name_list".format(key))
self._check_closed()
return [getattr(env, key) if hasattr(env, key) else None for env in self._envs]
def _check_closed(self):
"""
Overview:
Check whether the env manager is closed. Will be called in ``__getattr__`` and ``step``.
"""
assert not self._closed, "env manager is closed, please use the alive env manager"
[docs] def launch(self, reset_param: Optional[Dict] = None) -> None:
"""
Overview:
Set up the environments and their parameters.
Arguments:
- reset_param (:obj:`Optional[Dict]`): Dict of reset parameters for each environment, key is the env_id, \
value is the cooresponding reset parameters.
"""
assert self._closed, "Please first close the env manager"
if reset_param is not None:
assert len(reset_param) == len(self._env_fn)
self._create_state()
self.reset(reset_param)
def _create_state(self) -> None:
self._env_episode_count = {i: 0 for i in range(self.env_num)}
self._ready_obs = {i: None for i in range(self.env_num)}
self._envs = [e() for e in self._env_fn]
# env_ref is used to acquire some common attributes of env, like obs_shape and act_shape
self._env_ref = self._envs[0]
assert len(self._envs) == self._env_num
self._reset_param = {i: {} for i in range(self.env_num)}
self._env_states = {i: EnvState.INIT for i in range(self.env_num)}
if self._env_replay_path is not None:
for e, s in zip(self._envs, self._env_replay_path):
e.enable_save_replay(s)
self._closed = False
[docs] def reset(self, reset_param: Optional[Dict] = None) -> None:
"""
Overview:
Reset the environments their parameters.
Arguments:
- reset_param (:obj:`List`): Dict of reset parameters for each environment, key is the env_id, \
value is the cooresponding reset parameters.
"""
self._check_closed()
# set seed if necessary
env_ids = list(range(self._env_num)) if reset_param is None else list(reset_param.keys())
for i, env_id in enumerate(env_ids): # loop-type is necessary
if self._env_seed[env_id] is not None:
if self._env_dynamic_seed is not None:
self._envs[env_id].seed(self._env_seed[env_id], self._env_dynamic_seed)
else:
self._envs[env_id].seed(self._env_seed[env_id])
self._env_seed[env_id] = None # seed only use once
# reset env
if reset_param is None:
env_range = range(self.env_num)
else:
for env_id in reset_param:
self._reset_param[env_id] = reset_param[env_id]
env_range = reset_param.keys()
for env_id in env_range:
if self._env_replay_path is not None and self._env_states[env_id] == EnvState.RUN:
logging.warning("please don't reset a unfinished env when you enable save replay, we just skip it")
continue
self._env_states[env_id] = EnvState.RESET
self._reset(env_id)
def _reset(self, env_id: int) -> None:
@retry_wrapper(max_retry=self._max_retry, waiting_time=self._retry_waiting_time)
@timeout_wrapper(timeout=self._reset_timeout)
def reset_fn():
# if self._reset_param[env_id] is None, just reset specific env, not pass reset param
if self._reset_param[env_id] is not None:
assert isinstance(self._reset_param[env_id], dict), type(self._reset_param[env_id])
return self._envs[env_id].reset(**self._reset_param[env_id])
else:
return self._envs[env_id].reset()
try:
obs = reset_fn()
except Exception as e:
self._env_states[env_id] = EnvState.ERROR
self.close()
raise e
self._ready_obs[env_id] = obs
self._env_states[env_id] = EnvState.RUN
[docs] def step(self, actions: Dict[int, Any]) -> Dict[int, namedtuple]:
"""
Overview:
Step all environments. Reset an env if done.
Arguments:
- actions (:obj:`Dict[int, Any]`): {env_id: action}
Returns:
- timesteps (:obj:`Dict[int, namedtuple]`): {env_id: timestep}. Timestep is a \
``BaseEnvTimestep`` tuple with observation, reward, done, env_info.
Example:
>>> actions_dict = {env_id: model.forward(obs) for env_id, obs in obs_dict.items())}
>>> timesteps = env_manager.step(actions_dict):
>>> for env_id, timestep in timesteps.items():
>>> pass
.. note:
- The env_id that appears in ``actions`` will also be returned in ``timesteps``.
- Once an environment is done, it is reset immediately.
"""
self._check_closed()
timesteps = {}
for env_id, act in actions.items():
timesteps[env_id] = self._step(env_id, act)
if timesteps[env_id].done:
self._env_episode_count[env_id] += 1
if self._env_episode_count[env_id] < self._episode_num and self._auto_reset:
self._env_states[env_id] = EnvState.RESET
self._reset(env_id)
else:
self._env_states[env_id] = EnvState.DONE
else:
self._ready_obs[env_id] = timesteps[env_id].obs
return timesteps
def _step(self, env_id: int, act: Any) -> namedtuple:
@retry_wrapper(max_retry=self._max_retry, waiting_time=self._retry_waiting_time)
@timeout_wrapper(timeout=self._step_timeout)
def step_fn():
return self._envs[env_id].step(act)
try:
ret = step_fn()
return ret
except Exception as e:
self._env_states[env_id] = EnvState.ERROR
raise e
[docs] def seed(self, seed: Union[Dict[int, int], List[int], int], dynamic_seed: bool = None) -> None:
"""
Overview:
Set the seed for each environment.
Arguments:
- seed (:obj:`Union[Dict[int, int], List[int], int]`): List of seeds for each environment; \
Or one seed for the first environment and other seeds are generated automatically.
"""
if isinstance(seed, numbers.Integral):
seed = [seed + i for i in range(self.env_num)]
self._env_seed = seed
elif isinstance(seed, list):
assert len(seed) == self._env_num, "len(seed) {:d} != env_num {:d}".format(len(seed), self._env_num)
self._env_seed = seed
elif isinstance(seed, dict):
if not hasattr(self, '_env_seed'):
raise RuntimeError("please indicate all the seed of each env in the beginning")
for env_id, s in seed.items():
self._env_seed[env_id] = s
else:
raise TypeError("invalid seed arguments type: {}".format(type(seed)))
self._env_dynamic_seed = dynamic_seed
[docs] def enable_save_replay(self, replay_path: Union[List[str], str]) -> None:
"""
Overview:
Set each env's replay save path.
Arguments:
- replay_path (:obj:`Union[List[str], str]`): List of paths for each environment; \
Or one path for all environments.
"""
if isinstance(replay_path, str):
replay_path = [replay_path] * self.env_num
self._env_replay_path = replay_path
[docs] def close(self) -> None:
"""
Overview:
Release the environment resources.
"""
if self._closed:
return
self._env_ref.close()
for env in self._envs:
env.close()
for i in range(self._env_num):
self._env_states[i] = EnvState.VOID
self._closed = True
[docs] def env_info(self) -> namedtuple:
"""
Overview:
Get one env's info, for example, action space, observation space, reward space, etc.
Returnns:
- info (:obj:`namedtuple`): Usually a namedtuple ``BaseEnvInfo``, each element is ``EnvElementInfo``.
"""
return self._env_ref.info()
[docs]def create_env_manager(manager_cfg: dict, env_fn: List[Callable]) -> BaseEnvManager:
r"""
Overview:
Create an env manager according to manager cfg and env function.
Arguments:
- manager_cfg (:obj:`EasyDict`): Env manager config.
- env_fn (:obj:` List[Callable]`): A list of envs' functions.
ArgumentsKeys:
- `manager_cfg`'s necessary: `type`
"""
manager_cfg = copy.deepcopy(manager_cfg)
if 'import_names' in manager_cfg:
import_module(manager_cfg.pop('import_names'))
manager_type = manager_cfg.pop('type')
return ENV_MANAGER_REGISTRY.build(manager_type, env_fn=env_fn, cfg=manager_cfg)
[docs]def get_env_manager_cls(cfg: EasyDict) -> type:
r"""
Overview:
Get an env manager class according to cfg.
Arguments:
- cfg (:obj:`EasyDict`): Env manager config.
ArgumentsKeys:
- necessary: `type`
"""
import_module(cfg.get('import_names', []))
return ENV_MANAGER_REGISTRY.get(cfg.type)