Source code for ding.utils.default_helper

import copy
import logging
import random
from typing import Union, Mapping, List, NamedTuple, Tuple, Callable, Optional, Any
from functools import lru_cache  # in python3.9, we can change to cache

import numpy as np
import torch


def lists_to_dicts(
        data: Union[List[Union[dict, NamedTuple]], Tuple[Union[dict, NamedTuple]]],
        recursive: bool = False,
) -> Union[Mapping[object, object], NamedTuple]:
    r"""
    Overview:
        Transform a list of dicts to a dict of lists.
    Arguments:
        - data (:obj:`Union[List[Union[dict, NamedTuple]], Tuple[Union[dict, NamedTuple]]]`):
            A dict of lists need to be transformed
        - recursive (:obj:`bool`): whether recursively deals with dict element
    Returns:
        - newdata (:obj:`Union[Mapping[object, object], NamedTuple]`): A list of dicts as a result
    Example:
        >>> from ding.utils import *
        >>> lists_to_dicts([{1: 1, 10: 3}, {1: 2, 10: 4}])
        {1: [1, 2], 10: [3, 4]}
    """
    if len(data) == 0:
        raise ValueError("empty data")
    if isinstance(data[0], dict):
        if recursive:
            new_data = {}
            for k in data[0].keys():
                if isinstance(data[0][k], dict):
                    tmp = [data[b][k] for b in range(len(data))]
                    new_data[k] = lists_to_dicts(tmp)
                else:
                    new_data[k] = [data[b][k] for b in range(len(data))]
        else:
            new_data = {k: [data[b][k] for b in range(len(data))] for k in data[0].keys()}
    elif isinstance(data[0], tuple) and hasattr(data[0], '_fields'):  # namedtuple
        new_data = type(data[0])(*list(zip(*data)))
    else:
        raise TypeError("not support element type: {}".format(type(data[0])))
    return new_data


def dicts_to_lists(data: Mapping[object, List[object]]) -> List[Mapping[object, object]]:
    r"""
    Overview:
        Transform a dict of lists to a list of dicts.

    Arguments:
        - data (:obj:`Mapping[object, list]`): A list of dicts need to be transformed

    Returns:
        - newdata (:obj:`List[Mapping[object, object]]`): A dict of lists as a result

    Example:
        >>> from ding.utils import *
        >>> dicts_to_lists({1: [1, 2], 10: [3, 4]})
        [{1: 1, 10: 3}, {1: 2, 10: 4}]
    """
    new_data = [v for v in data.values()]
    new_data = [{k: v for k, v in zip(data.keys(), t)} for t in list(zip(*new_data))]
    return new_data


def override(cls: type) -> Callable[[
        Callable,
], Callable]:
    """
    Overview:
        Annotation for documenting method overrides.

    Arguments:
        - cls (:obj:`type`): The superclass that provides the overridden method. If this
            cls does not actually have the method, an error is raised.
    """

    def check_override(method: Callable) -> Callable:
        if method.__name__ not in dir(cls):
            raise NameError("{} does not override any method of {}".format(method, cls))
        return method

    return check_override


def squeeze(data: object) -> object:
    """
    Overview:
        Squeeze data from tuple, list or dict to single object
    Example:
        >>> a = (4, )
        >>> a = squeeze(a)
        >>> print(a)
        >>> 4
    """
    if isinstance(data, tuple) or isinstance(data, list):
        if len(data) == 1:
            return data[0]
        else:
            return tuple(data)
    elif isinstance(data, dict):
        if len(data) == 1:
            return list(data.values())[0]
    return data


default_get_set = set()


def default_get(
        data: dict,
        name: str,
        default_value: Optional[Any] = None,
        default_fn: Optional[Callable] = None,
        judge_fn: Optional[Callable] = None
) -> Any:
    r"""
    Overview:
        Getting the value by input, checks generically on the inputs with \
        at least ``data`` and ``name``. If ``name`` exists in ``data``, \
        get the value at ``name``; else, add ``name`` to ``default_get_set``\
        with value generated by \
        ``default_fn`` (or directly as ``default_value``) that \
        is checked by `` judge_fn`` to be legal.
    Arguments:
        - data(:obj:`dict`): Data input dictionary
        - name(:obj:`str`): Key name
        - default_value(:obj:`Optional[Any]`) = None,
        - default_fn(:obj:`Optional[Callable]`) = Value
        - judge_fn(:obj:`Optional[Callable]`) = None
    Returns:
        - ret(:obj:`list`): Splitted data
        - residual(:obj:`list`): Residule list
    """
    if name in data:
        return data[name]
    else:
        assert default_value is not None or default_fn is not None
        value = default_fn() if default_fn is not None else default_value
        if judge_fn:
            assert judge_fn(value), "defalut value({}) is not accepted by judge_fn".format(type(value))
        if name not in default_get_set:
            logging.warning("{} use default value {}".format(name, value))
            default_get_set.add(name)
        return value


def list_split(data: list, step: int) -> List[list]:
    r"""
    Overview:
        Split list of data by step.
    Arguments:
        - data(:obj:`list`): List of data for spliting
        - step(:obj:`int`): Number of step for spliting
    Returns:
        - ret(:obj:`list`): List of splitted data.
        - residual(:obj:`list`): Residule list. This value is ``None`` when  ``data`` divides ``steps``.
    Example:
        >>> list_split([1,2,3,4],2)
        ([[1, 2], [3, 4]], None)
        >>> list_split([1,2,3,4],3)
        ([[1, 2, 3]], [4])
    """
    if len(data) < step:
        return [], data
    ret = []
    divide_num = len(data) // step
    for i in range(divide_num):
        start, end = i * step, (i + 1) * step
        ret.append(data[start:end])
    if divide_num * step < len(data):
        residual = data[divide_num * step:]
    else:
        residual = None
    return ret, residual


def error_wrapper(fn, default_ret, warning_msg=""):
    r"""
    Overview:
        wrap the function, so that any Exception in the function will be catched and return the default_ret
    Arguments:
        - fn (:obj:`Callable`): the function to be wraped
        - default_ret (:obj:`obj`): the default return when an Exception occurred in the function
    Returns:
        - wrapper (:obj:`Callable`): the wrapped function
    Examples:
        >>> # Used to checkfor Fakelink (Refer to utils.linklink_dist_helper.py)
        >>> def get_rank():  # Get the rank of linklink model, return 0 if use FakeLink.
        >>>    if is_fake_link:
        >>>        return 0
        >>>    return error_wrapper(link.get_rank, 0)()
    """

    def wrapper(*args, **kwargs):
        try:
            ret = fn(*args, **kwargs)
        except Exception as e:
            ret = default_ret
            if warning_msg != "":
                one_time_warning(warning_msg, "\ndefault_ret = {}\terror = {}".format(default_ret, e))
        return ret

    return wrapper


[docs]class LimitedSpaceContainer: r""" Overview: A space simulator. Interface: ``__init__``, ``get_residual_space``, ``release_space`` """
[docs] def __init__(self, min_val: int, max_val: int) -> None: """ Overview: Set ``min_val`` and ``max_val`` of the container, also set ``cur`` to ``min_val`` for initialization. Arguments: - min_val (:obj:`int`): Min volume of the container, usually 0. - max_val (:obj:`int`): Max volume of the container. """ self.min_val = min_val self.max_val = max_val assert (max_val >= min_val) self.cur = self.min_val
[docs] def get_residual_space(self) -> int: """ Overview: Get all residual pieces of space. Set ``cur`` to ``max_val`` Arguments: - ret (:obj:`int`): Residual space, calculated by ``max_val`` - ``cur``. """ ret = self.max_val - self.cur self.cur = self.max_val return ret
[docs] def acquire_space(self) -> bool: """ Overview: Try to get one pice of space. If there is one, return True; Otherwise return False. Returns: - flag (:obj:`bool`): Whether there is any piece of residual space. """ if self.cur < self.max_val: self.cur += 1 return True else: return False
[docs] def release_space(self) -> None: """ Overview: Release only one piece of space. Decrement ``cur``, but ensure it won't be negative. """ self.cur = max(self.min_val, self.cur - 1)
[docs] def increase_space(self) -> None: """ Overview: Increase one piece in space. Increment ``max_val``. """ self.max_val += 1
[docs] def decrease_space(self) -> None: """ Overview: Decrease one piece in space. Decrement ``max_val``. """ self.max_val -= 1
def deep_merge_dicts(original: dict, new_dict: dict) -> dict: """ Overview: Merge two dicts by calling ``deep_update`` Arguments: - original (:obj:`dict`): Dict 1. - new_dict (:obj:`dict`): Dict 2. Returns: - merged_dict (:obj:`dict`): A new dict that is d1 and d2 deeply merged. """ original = original or {} new_dict = new_dict or {} merged = copy.deepcopy(original) if new_dict: # if new_dict is neither empty dict nor None deep_update(merged, new_dict, True, []) return merged def deep_update( original: dict, new_dict: dict, new_keys_allowed: bool = False, whitelist: Optional[List[str]] = None, override_all_if_type_changes: Optional[List[str]] = None ): """ Overview: Update original dict with values from new_dict recursively. Arguments: - original (:obj:`dict`): Dictionary with default values. - new_dict (:obj:`dict`): Dictionary with values to be updated - new_keys_allowed (:obj:`bool`): Whether new keys are allowed. - whitelist (:obj:`Optional[List[str]]`): List of keys that correspond to dict values where new subkeys can be introduced. This is only at the top level. - override_all_if_type_changes(:obj:`Optional[List[str]]`): List of top level keys with value=dict, for which we always simply override the entire value (:obj:`dict`), if the "type" key in that value dict changes. .. note:: If new key is introduced in new_dict, then if new_keys_allowed is not True, an error will be thrown. Further, for sub-dicts, if the key is in the whitelist, then new subkeys can be introduced. """ whitelist = whitelist or [] override_all_if_type_changes = override_all_if_type_changes or [] for k, value in new_dict.items(): if k not in original and not new_keys_allowed: raise RuntimeError("Unknown config parameter `{}`. Base config have: {}.".format(k, original.keys())) # Both original value and new one are dicts. if isinstance(original.get(k), dict) and isinstance(value, dict): # Check old type vs old one. If different, override entire value. if k in override_all_if_type_changes and \ "type" in value and "type" in original[k] and \ value["type"] != original[k]["type"]: original[k] = value # Whitelisted key -> ok to add new subkeys. elif k in whitelist: deep_update(original[k], value, True) # Non-whitelisted key. else: deep_update(original[k], value, new_keys_allowed) # Original value not a dict OR new value not a dict: # Override entire value. else: original[k] = value return original def flatten_dict(data: dict, delimiter: str = "/") -> dict: """ Overview: Flatten the dict, see example Arguments: - data (:obj:`dict`): Original nested dict - delimiter (str): Delimiter of the keys of the new dict Returns: - data (:obj:`dict`): Flattened nested dict Example: >>> a {'a': {'b': 100}} >>> flatten_dict(a) {'a/b': 100} """ data = copy.deepcopy(data) while any(isinstance(v, dict) for v in data.values()): remove = [] add = {} for key, value in data.items(): if isinstance(value, dict): for subkey, v in value.items(): add[delimiter.join([key, subkey])] = v remove.append(key) data.update(add) for k in remove: del data[k] return data def set_pkg_seed(seed: int, use_cuda: bool = True) -> None: """ Overview: Side effect function to set seed for ``random``, ``numpy random``, and ``torch's manual seed``.\ This is usaually used in entry scipt in the section of setting random seed for all package and instance Argument: - seed(:obj:`int`): Set seed - use_cuda(:obj:`bool`) Whether use cude Examples: >>> # ../entry/xxxenv_xxxpolicy_main.py >>> ... # Set random seed for all package and instance >>> collector_env.seed(seed) >>> evaluator_env.seed(seed, dynamic_seed=False) >>> set_pkg_seed(seed, use_cuda=cfg.policy.cuda) >>> ... # Set up RL Policy, etc. >>> ... """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if use_cuda and torch.cuda.is_available(): torch.cuda.manual_seed(seed) @lru_cache() def one_time_warning(warning_msg: str) -> None: logging.warning(warning_msg)