Source code for ding.worker.replay_buffer.naive_buffer

import copy
from typing import Union, Any, Optional, List
import numpy as np

from ding.worker.replay_buffer import IBuffer
from ding.utils import LockContext, LockContextType, BUFFER_REGISTRY
from .utils import UsedDataRemover, generate_id


[docs]@BUFFER_REGISTRY.register('naive') class NaiveReplayBuffer(IBuffer): r""" Overview: Naive replay buffer, can store and sample data. An naive implementation of replay buffer with no priority or any other advanced features. This buffer refers to multi-thread/multi-process and guarantees thread-safe, which means that methods like ``sample``, ``push``, ``clear`` are all mutual to each other. Interface: start, close, push, update, sample, clear, count, state_dict, load_state_dict, default_config Property: replay_buffer_size, push_count """ config = dict( type='naive', name='default', replay_buffer_size=10000, deepcopy=False, # default `False` for serial pipeline enable_track_used_data=False, ) def __init__( self, cfg: 'EasyDict', # noqa name: str = 'default', ) -> None: """ Overview: Initialize the buffer Arguments: - cfg (:obj:`dict`): Config dict. - name (:obj:`Optional[str]`): Buffer name, used to generate unique data id and logger name. """ self.name = name self._cfg = cfg self._replay_buffer_size = self._cfg.replay_buffer_size self._deepcopy = self._cfg.deepcopy # ``_data`` is a circular queue to store data (full data or meta data) self._data = [None for _ in range(self._replay_buffer_size)] # Current valid data count, indicating how many elements in ``self._data`` is valid. self._valid_count = 0 # How many pieces of data have been pushed into this buffer, should be no less than ``_valid_count``. self._push_count = 0 # Point to the tail position where next data can be inserted, i.e. latest inserted data's next position. self._tail = 0 # Lock to guarantee thread safe self._lock = LockContext(type_=LockContextType.THREAD_LOCK) self._end_flag = False self._enable_track_used_data = self._cfg.enable_track_used_data if self._enable_track_used_data: self._used_data_remover = UsedDataRemover()
[docs] def start(self) -> None: """ Overview: Start the buffer's used_data_remover thread if enables track_used_data. """ if self._enable_track_used_data: self._used_data_remover.start()
[docs] def close(self) -> None: """ Overview: Clear the buffer; Join the buffer's used_data_remover thread if enables track_used_data. """ self.clear() if self._enable_track_used_data: self._used_data_remover.close()
[docs] def push(self, data: Union[List[Any], Any], cur_collector_envstep: int) -> None: r""" Overview: Push a data into buffer. Arguments: - data (:obj:`Union[List[Any], Any]`): The data which will be pushed into buffer. Can be one \ (in `Any` type), or many(int `List[Any]` type). - cur_collector_envstep (:obj:`int`): Collector's current env step. \ Not used in naive buffer, but preserved for compatibility. """ if isinstance(data, list): self._extend(data, cur_collector_envstep) else: self._append(data, cur_collector_envstep)
[docs] def sample(self, size: int, cur_learner_iter: int) -> Optional[list]: r""" Overview: Sample data with length ``size``. Arguments: - size (:obj:`int`): The number of the data that will be sampled. - cur_learner_iter (:obj:`int`): Learner's current iteration. \ Not used in naive buffer, but preserved for compatibility. Returns: - sample_data (:obj:`list`): A list of data with length ``size``. """ if size == 0: return [] can_sample = self._sample_check(size) if not can_sample: return None with self._lock: indices = self._get_indices(size) result = self._sample_with_indices(indices, cur_learner_iter) return result
def _append(self, ori_data: Any, cur_collector_envstep: int = -1) -> None: r""" Overview: Append a data item into ``self._data``. Arguments: - ori_data (:obj:`Any`): The data which will be inserted. - cur_collector_envstep (:obj:`int`): Not used in this method, but preserved for compatibility. """ with self._lock: if self._deepcopy: data = copy.deepcopy(ori_data) else: data = ori_data self._push_count += 1 if self._data[self._tail] is None: self._valid_count += 1 elif self._enable_track_used_data: self._used_data_remover.add_used_data(self._data[self._tail]) self._data[self._tail] = data self._tail = (self._tail + 1) % self._replay_buffer_size def _extend(self, ori_data: List[Any], cur_collector_envstep: int = -1) -> None: r""" Overview: Extend a data list into queue. Add two keys in each data item, you can refer to ``_append`` for details. Arguments: - ori_data (:obj:`List[Any]`): The data list. - cur_collector_envstep (:obj:`int`): Not used in this method, but preserved for compatibility. """ with self._lock: if self._deepcopy: data = copy.deepcopy(ori_data) else: data = ori_data length = len(data) # When updating ``_data`` and ``_use_count``, should consider two cases regarding # the relationship between "tail + data length" and "replay buffer size" to check whether # data will exceed beyond buffer's max length limitation. if self._tail + length <= self._replay_buffer_size: if self._valid_count != self._replay_buffer_size: self._valid_count += length elif self._enable_track_used_data: for i in range(length): self._used_data_remover.add_used_data(self._data[self._tail + i]) self._push_count += length self._data[self._tail:self._tail + length] = data else: new_tail = self._tail data_start = 0 residual_num = len(data) while True: space = self._replay_buffer_size - new_tail L = min(space, residual_num) if self._valid_count != self._replay_buffer_size: self._valid_count += L elif self._enable_track_used_data: for i in range(L): self._used_data_remover.add_used_data(self._data[new_tail + i]) self._push_count += L self._data[new_tail:new_tail + L] = data[data_start:data_start + L] residual_num -= L assert residual_num >= 0 if residual_num == 0: break else: new_tail = 0 data_start += L # Update ``tail`` and ``next_unique_id`` after the whole list is pushed into buffer. self._tail = (self._tail + length) % self._replay_buffer_size def _sample_check(self, size: int) -> bool: r""" Overview: Check whether this buffer has more than `size` datas to sample. Arguments: - size (:obj:`int`): Number of data that will be sampled. Returns: - can_sample (:obj:`bool`): Whether this buffer can sample enough data. """ if self._valid_count < size: print("No enough elements for sampling (expect: {} / current: {})".format(size, self._valid_count)) return False else: return True
[docs] def update(self, info: dict) -> None: r""" Overview: Naive Buffer does not need to update any info, but this method is preserved for compatibility. """ print( '[BUFFER WARNING] Naive Buffer does not need to update any info, \ but `update` method is preserved for compatibility.' )
[docs] def clear(self) -> None: """ Overview: Clear all the data and reset the related variables. """ with self._lock: for i in range(len(self._data)): if self._data[i] is not None: if self._enable_track_used_data: self._used_data_remover.add_used_data(self._data[i]) self._data[i] = None self._valid_count = 0 self._push_count = 0 self._tail = 0
def __del__(self) -> None: """ Overview: Call ``close`` to delete the object. """ self.close() def _get_indices(self, size: int) -> list: r""" Overview: Get the sample index list. Arguments: - size (:obj:`int`): The number of the data that will be sampled Returns: - index_list (:obj:`list`): A list including all the sample indices, whose length should equal to ``size``. """ assert self._valid_count <= self._replay_buffer_size if self._valid_count == self._replay_buffer_size: tail = self._replay_buffer_size else: tail = self._tail indices = list(np.random.choice(a=tail, size=size, replace=False)) return indices def _sample_with_indices(self, indices: List[int], cur_learner_iter: int) -> list: r""" Overview: Sample data with ``indices``. Arguments: - indices (:obj:`List[int]`): A list including all the sample indices. - cur_learner_iter (:obj:`int`): Not used in this method, but preserved for compatibility. Returns: - data (:obj:`list`) Sampled data. """ data = [] for idx in indices: assert self._data[idx] is not None, idx if self._deepcopy: copy_data = copy.deepcopy(self._data[idx]) else: copy_data = self._data[idx] data.append(copy_data) return data
[docs] def count(self) -> int: """ Overview: Count how many valid datas there are in the buffer. Returns: - count (:obj:`int`): Number of valid data. """ return self._valid_count
[docs] def state_dict(self) -> dict: """ Overview: Provide a state dict to keep a record of current buffer. Returns: - state_dict (:obj:`Dict[str, Any]`): A dict containing all important values in the buffer. \ With the dict, one can easily reproduce the buffer. """ return { 'data': self._data, 'tail': self._tail, 'valid_count': self._valid_count, 'push_count': self._push_count, }
[docs] def load_state_dict(self, _state_dict: dict) -> None: """ Overview: Load state dict to reproduce the buffer. Returns: - state_dict (:obj:`Dict[str, Any]`): A dict containing all important values in the buffer. """ assert 'data' in _state_dict if set(_state_dict.keys()) == set(['data']): self._extend(_state_dict['data']) else: for k, v in _state_dict.items(): setattr(self, '_{}'.format(k), v)
@property def replay_buffer_size(self) -> int: return self._replay_buffer_size @property def push_count(self) -> int: return self._push_count