import time
import copy
import logging
from typing import Any, Union, Callable, List, Dict, Optional, Tuple
from functools import partial
from easydict import EasyDict
from collections import namedtuple
from ding.torch_utils import build_checkpoint_helper, CountVar, auto_checkpoint, build_log_buffer
from ding.utils import build_logger, EasyTimer, pretty_print, deep_merge_dicts, import_module, LEARNER_REGISTRY, \
get_rank, get_world_size
from ding.utils.autolog import LoggedValue, LoggedModel, NaturalTime, TickTime, TimeMode
from ding.utils.data import AsyncDataLoader, default_collate
from .learner_hook import build_learner_hook_by_cfg, add_learner_hook, merge_hooks, LearnerHook
logging.info('') # necessary
[docs]@LEARNER_REGISTRY.register('base')
class BaseLearner(object):
r"""
Overview:
Base class for model learning.
Interface:
train, call_hook, register_hook, save_checkpoint, start, setup_dataloader, close
Property:
learn_info, priority_info, last_iter, name, rank, world_size, policy
monitor, log_buffer, logger, tb_logger
"""
@classmethod
def default_config(cls: type) -> EasyDict:
cfg = EasyDict(copy.deepcopy(cls.config))
cfg.cfg_type = cls.__name__ + 'Dict'
return cfg
config = dict(
train_iterations=int(1e9),
dataloader=dict(num_workers=0, ),
# --- Hooks ---
hook=dict(
load_ckpt_before_run='',
log_show_after_iter=100,
save_ckpt_after_iter=10000,
save_ckpt_after_run=True,
),
)
_name = "BaseLearner" # override this variable for sub-class learner
[docs] def __init__(
self,
cfg: EasyDict,
policy: namedtuple = None,
tb_logger: Optional['SummaryWriter'] = None, # noqa
dist_info: Tuple[int, int] = None,
) -> None:
"""
Overview:
Init method. Load config and use ``self._cfg`` to build common learner components,
e.g. logger, hooks.
Policy is not initialized here, but set afterwards through policy setter.
Arguments:
- cfg (:obj:`EasyDict`): Learner config, you can view `cfg <../../../configuration/index.html>`_ for ref.
- rank (:obj:`int`): Process number in multi-gpu training
Notes:
If you want to debug in sync CUDA mode, please add the following code at the beginning of ``__init__``.
.. code:: python
os.environ['CUDA_LAUNCH_BLOCKING'] = "1" # for debug async CUDA
"""
self._cfg = cfg
self._instance_name = self._name + '_' + time.ctime().replace(' ', '_').replace(':', '_')
self._ckpt_name = None
self._timer = EasyTimer()
# These 2 attributes are only used in parallel mode.
self._end_flag = False
self._learner_done = False
if dist_info is None:
self._rank = get_rank()
self._world_size = get_world_size()
else:
# Learner rank. Used to discriminate which GPU it uses.
self._rank, self._world_size = dist_info
if self._world_size > 1:
self._cfg.hook.log_reduce_after_iter = True
# Logger (Monitor will be initialized in policy setter)
# Only rank == 0 learner needs monitor and tb_logger, others only need text_logger to display terminal output.
if self._rank == 0:
if tb_logger is not None:
self._logger, _ = build_logger('./log/learner', 'learner', need_tb=False)
self._tb_logger = tb_logger
else:
self._logger, self._tb_logger = build_logger('./log/learner', 'learner')
else:
self._logger, _ = build_logger('./log/learner', 'learner', need_tb=False)
self._tb_logger = None
self._log_buffer = {
'scalar': build_log_buffer(),
'scalars': build_log_buffer(),
'histogram': build_log_buffer(),
}
# Setup policy
if policy is not None:
self.policy = policy
# Learner hooks. Used to do specific things at specific time point. Will be set in ``_setup_hook``
self._hooks = {'before_run': [], 'before_iter': [], 'after_iter': [], 'after_run': []}
# Last iteration. Used to record current iter.
self._last_iter = CountVar(init_val=0)
# Setup time wrapper and hook.
self._setup_wrapper()
self._setup_hook()
[docs] def _setup_hook(self) -> None:
"""
Overview:
Setup hook for base_learner. Hook is the way to implement some functions at specific time point
in base_learner. You can refer to ``learner_hook.py``.
"""
if hasattr(self, '_hooks'):
self._hooks = merge_hooks(self._hooks, build_learner_hook_by_cfg(self._cfg.hook))
else:
self._hooks = build_learner_hook_by_cfg(self._cfg.hook)
[docs] def _setup_wrapper(self) -> None:
"""
Overview:
Use ``_time_wrapper`` to get ``train_time``.
Note:
``data_time`` is wrapped in ``setup_dataloader``.
"""
self._wrapper_timer = EasyTimer()
self.train = self._time_wrapper(self.train, 'scalar', 'train_time')
def _time_wrapper(self, fn: Callable, var_type: str, var_name: str) -> Callable:
"""
Overview:
Wrap a function and record the time it used in ``_log_buffer``.
Arguments:
- fn (:obj:`Callable`): Function to be time_wrapped.
- var_type (:obj:`str`): Variable type, e.g. ['scalar', 'scalars', 'histogram'].
- var_name (:obj:`str`): Variable name, e.g. ['cur_lr', 'total_loss'].
Returns:
- wrapper (:obj:`Callable`): The wrapper to acquire a function's time.
"""
def wrapper(*args, **kwargs) -> Any:
with self._wrapper_timer:
ret = fn(*args, **kwargs)
self._log_buffer[var_type][var_name] = self._wrapper_timer.value
return ret
return wrapper
[docs] def register_hook(self, hook: LearnerHook) -> None:
"""
Overview:
Add a new learner hook.
Arguments:
- hook (:obj:`LearnerHook`): The hook to be addedr.
"""
add_learner_hook(self._hooks, hook)
[docs] def train(self, data: dict, envstep: int = -1) -> None:
"""
Overview:
Given training data, implement network update for one iteration and update related variables.
Learner's API for serial entry.
Also called in ``start`` for each iteration's training.
Arguments:
- data (:obj:`dict`): Training data which is retrieved from repaly buffer.
.. note::
``_policy`` must be set before calling this method.
``_policy.forward`` method contains: forward, backward, grad sync(if in multi-gpu mode) and
parameter update.
``before_iter`` and ``after_iter`` hooks are called at the beginning and ending.
"""
assert hasattr(self, '_policy'), "please set learner policy"
self.call_hook('before_iter')
# Forward
log_vars = self._policy.forward(data)
# Update replay buffer's priority info
priority = log_vars.pop('priority', None)
if priority is not None:
replay_buffer_idx = [d.get('replay_buffer_idx', None) for d in data]
replay_unique_id = [d.get('replay_unique_id', None) for d in data]
self.priority_info = {
'priority': priority,
'replay_buffer_idx': replay_buffer_idx,
'replay_unique_id': replay_unique_id,
}
# Discriminate vars in scalar, scalars and histogram type
# Regard a var as scalar type by default. For scalars and histogram type, must annotate by prefix "[xxx]"
scalars_vars, histogram_vars = {}, {}
for k in list(log_vars.keys()):
if "[scalars]" in k:
new_k = k.split(']')[-1]
scalars_vars[new_k] = log_vars.pop(k)
elif "[histogram]" in k:
new_k = k.split(']')[-1]
histogram_vars[new_k] = log_vars.pop(k)
# Update log_buffer
self._log_buffer['scalar'].update(log_vars)
self._log_buffer['scalars'].update(scalars_vars)
self._log_buffer['histogram'].update(histogram_vars)
self._collector_envstep = envstep
self.call_hook('after_iter')
self._last_iter.add(1)
@auto_checkpoint
def start(self) -> None:
"""
Overview:
[Only Used In Parallel Mode] Learner's API for parallel entry.
For each iteration, learner will get data through ``_next_data`` and call ``train`` to train.
.. note::
``before_run`` and ``after_run`` hooks are called at the beginning and ending.
"""
self._end_flag = False
self._learner_done = False
# before run hook
self.call_hook('before_run')
for i in range(self._cfg.train_iterations):
data = self._next_data()
if self._end_flag:
break
self.train(data)
self._learner_done = True
# after run hook
self.call_hook('after_run')
[docs] def setup_dataloader(self) -> None:
"""
Overview:
[Only Used In Parallel Mode] Setup learner's dataloader.
.. note::
Only in parallel mode will we use attributes ``get_data`` and ``_dataloader`` to get data from file system;
Instead, in serial version, we can fetch data from memory directly.
In parallel mode, ``get_data`` is set by ``LearnerCommHelper``, and should be callable.
Users don't need to know the related details if not necessary.
"""
cfg = self._cfg.dataloader
batch_size = self._policy.get_attribute('batch_size')
device = self._policy.get_attribute('device')
chunk_size = cfg.chunk_size if 'chunk_size' in cfg else batch_size
self._dataloader = AsyncDataLoader(
self.get_data, batch_size, device, chunk_size, collate_fn=lambda x: x, num_workers=cfg.num_workers
)
self._next_data = self._time_wrapper(self._next_data, 'scalar', 'data_time')
def _next_data(self) -> Any:
"""
Overview:
[Only Used In Parallel Mode] Call ``_dataloader``'s ``__next__`` method to return next training data.
Returns:
- data (:obj:`Any`): Next training data from dataloader.
"""
return next(self._dataloader)
[docs] def close(self) -> None:
"""
Overview:
[Only Used In Parallel Mode] Close the related resources, e.g. dataloader, tensorboard logger, etc.
"""
if self._end_flag:
return
self._end_flag = True
if hasattr(self, '_dataloader'):
self._dataloader.close()
if self._tb_logger:
self._tb_logger.flush()
self._tb_logger.close()
def __del__(self) -> None:
self.close()
[docs] def call_hook(self, name: str) -> None:
"""
Overview:
Call the corresponding hook plugins according to position name.
Arguments:
- name (:obj:`str`): Hooks in which position to call, \
should be in ['before_run', 'after_run', 'before_iter', 'after_iter'].
"""
for hook in self._hooks[name]:
hook(self)
def info(self, s: str) -> None:
"""
Overview:
Log string info by ``self._logger.info``.
Arguments:
- s (:obj:`str`): The message to add into the logger.
"""
self._logger.info('[RANK{}]: {}'.format(self._rank, s))
def debug(self, s: str) -> None:
self._logger.debug('[RANK{}]: {}'.format(self._rank, s))
[docs] def save_checkpoint(self, ckpt_name: str = None) -> None:
"""
Overview:
Directly call ``save_ckpt_after_run`` hook to save checkpoint.
Note:
Must guarantee that "save_ckpt_after_run" is registered in "after_run" hook.
This method is called in:
- ``auto_checkpoint`` (``torch_utils/checkpoint_helper.py``), which is designed for \
saving checkpoint whenever an exception raises.
- ``serial_pipeline`` (``entry/serial_entry.py``). Used to save checkpoint when reaching \
new highest evaluation reward.
"""
if ckpt_name is not None:
self.ckpt_name = ckpt_name
names = [h.name for h in self._hooks['after_run']]
assert 'save_ckpt_after_run' in names
idx = names.index('save_ckpt_after_run')
self._hooks['after_run'][idx](self)
self.ckpt_name = None
@property
def learn_info(self) -> dict:
"""
Overview:
Get current info dict, which will be sent to commander, e.g. replay buffer priority update,
current iteration, hyper-parameter adjustment, whether task is finished, etc.
Returns:
- info (:obj:`dict`): Current learner info dict.
"""
ret = {
'learner_step': self._last_iter.val,
'priority_info': self.priority_info,
'learner_done': self._learner_done
}
return ret
@property
def last_iter(self) -> CountVar:
return self._last_iter
@property
def train_iter(self) -> int:
return self._last_iter.val
@property
def monitor(self) -> 'TickMonitor': # noqa
return self._monitor
@property
def log_buffer(self) -> dict: # LogDict
return self._log_buffer
@log_buffer.setter
def log_buffer(self, _log_buffer: Dict[str, Dict[str, Any]]) -> None:
self._log_buffer = _log_buffer
@property
def logger(self) -> logging.Logger:
return self._logger
@property
def tb_logger(self) -> 'TensorBoradLogger': # noqa
return self._tb_logger
@property
def name(self) -> str:
return self._instance_name
@property
def rank(self) -> int:
return self._rank
@property
def world_size(self) -> int:
return self._world_size
@property
def policy(self) -> 'Policy': # noqa
return self._policy
@policy.setter
def policy(self, _policy: 'Policy') -> None: # noqa
"""
Note:
Policy variable monitor is set alongside with policy, because variables are determined by specific policy.
"""
self._policy = _policy
if self._rank == 0:
self._monitor = get_simple_monitor_type(self._policy.monitor_vars())(TickTime(), expire=10)
self.info(self._policy.info())
@property
def priority_info(self) -> dict:
if not hasattr(self, '_priority_info'):
self._priority_info = {}
return self._priority_info
@priority_info.setter
def priority_info(self, _priority_info: dict) -> None:
self._priority_info = _priority_info
@property
def ckpt_name(self) -> str:
return self._ckpt_name
@ckpt_name.setter
def ckpt_name(self, _ckpt_name: str) -> None:
self._ckpt_name = _ckpt_name
def create_learner(cfg: EasyDict, **kwargs) -> BaseLearner:
"""
Overview:
Given the key(learner_name), create a new learner instance if in learner_mapping's values,
or raise an KeyError. In other words, a derived learner must first register, then can call ``create_learner``
to get the instance.
Arguments:
- cfg (:obj:`EasyDict`): Learner config. Necessary keys: [learner.import_module, learner.learner_type].
Returns:
- learner (:obj:`BaseLearner`): The created new learner, should be an instance of one of \
learner_mapping's values.
"""
import_module(cfg.get('import_names', []))
return LEARNER_REGISTRY.build(cfg.type, cfg=cfg, **kwargs)
class TickMonitor(LoggedModel):
"""
Overview:
TickMonitor is to monitor related info during training.
Info includes: cur_lr, time(data, train, forward, backward), loss(total,...)
These info variables are firstly recorded in ``log_buffer``, then in ``LearnerHook`` will vars in
in this monitor be updated by``log_buffer``, finally printed to text logger and tensorboard logger.
Interface:
__init__, fixed_time, current_time, freeze, unfreeze, register_attribute_value, __getattr__
Property:
time, expire
"""
data_time = LoggedValue(float)
train_time = LoggedValue(float)
total_collect_step = LoggedValue(float)
total_step = LoggedValue(float)
total_episode = LoggedValue(float)
total_sample = LoggedValue(float)
total_duration = LoggedValue(float)
def __init__(self, time_: 'BaseTime', expire: Union[int, float]): # noqa
LoggedModel.__init__(self, time_, expire)
self.__register()
def __register(self):
def __avg_func(prop_name: str) -> float:
records = self.range_values[prop_name]()
_list = [_value for (_begin_time, _end_time), _value in records]
return sum(_list) / len(_list) if len(_list) != 0 else 0
def __val_func(prop_name: str) -> float:
records = self.range_values[prop_name]()
return records[-1][1]
for k in getattr(self, '_LoggedModel__properties'):
self.register_attribute_value('avg', k, partial(__avg_func, prop_name=k))
self.register_attribute_value('val', k, partial(__val_func, prop_name=k))
def get_simple_monitor_type(properties: List[str] = []) -> TickMonitor:
"""
Overview:
Besides basic training variables provided in ``TickMonitor``, many policies have their own customized
ones to record and monitor. This function can return a customized tick monitor.
Compared with ``TickMonitor``, ``SimpleTickMonitor`` can record extra ``properties`` passed in by a policy.
Argumenst:
- properties (:obj:`List[str]`): Customized properties to monitor.
Returns:
- simple_tick_monitor (:obj:`SimpleTickMonitor`): A simple customized tick monitor.
"""
if len(properties) == 0:
return TickMonitor
else:
attrs = {}
properties = [
'data_time', 'train_time', 'sample_count', 'total_collect_step', 'total_step', 'total_sample',
'total_episode', 'total_duration'
] + properties
for p_name in properties:
attrs[p_name] = LoggedValue(float)
return type('SimpleTickMonitor', (TickMonitor, ), attrs)