Source code for ding.model.template.q_learning

from typing import Union, Optional, Dict, Callable, List
import torch
import torch.nn as nn

from ding.torch_utils import get_lstm
from ding.utils import MODEL_REGISTRY, SequenceType, squeeze
from ..common import FCEncoder, ConvEncoder, DiscreteHead, DuelingHead, MultiHead, RainbowHead, \
    QuantileHead, QRDQNHead, DistributionHead


[docs]@MODEL_REGISTRY.register('dqn') class DQN(nn.Module):
[docs] def __init__( self, obs_shape: Union[int, SequenceType], action_shape: Union[int, SequenceType], encoder_hidden_size_list: SequenceType = [128, 128, 64], dueling: bool = True, head_hidden_size: Optional[int] = None, head_layer_num: int = 1, activation: Optional[nn.Module] = nn.ReLU(), norm_type: Optional[str] = None ) -> None: """ Overview: Init the DQN (encoder + head) Model according to input arguments. Arguments: - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84]. - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3]. - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \ the last element must match ``head_hidden_size``. - dueling (:obj:`dueling`): Whether choose ``DuelingHead`` or ``DiscreteHead(default)``. - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network. - head_layer_num (:obj:`int`): The number of layers used in the head network to compute Q value output - activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \ if ``None`` then default set it to ``nn.ReLU()`` - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \ ``ding.torch_utils.fc_block`` for more details. """ super(DQN, self).__init__() # For compatibility: 1, (1, ), [4, 32, 32] obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape) if head_hidden_size is None: head_hidden_size = encoder_hidden_size_list[-1] # FC Encoder if isinstance(obs_shape, int) or len(obs_shape) == 1: self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) # Conv Encoder elif len(obs_shape) == 3: self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) else: raise RuntimeError( "not support obs_shape for pre-defined encoder: {}, please customize your own DQN".format(obs_shape) ) # Head Type if dueling: head_cls = DuelingHead else: head_cls = DiscreteHead multi_head = not isinstance(action_shape, int) if multi_head: self.head = MultiHead( head_cls, head_hidden_size, action_shape, layer_num=head_layer_num, activation=activation, norm_type=norm_type ) else: self.head = head_cls( head_hidden_size, action_shape, head_layer_num, activation=activation, norm_type=norm_type )
[docs] def forward(self, x: torch.Tensor) -> Dict: r""" Overview: DQN forward computation graph, input observation tensor to predict q_value. Arguments: - x (:obj:`torch.Tensor`): Observation inputs Returns: - outputs (:obj:`Dict`): DQN forward outputs, such as q_value. ReturnsKeys: - logit (:obj:`torch.Tensor`): Discrete Q-value output of each action dimension. Shapes: - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape`` - logit (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch size and M is ``action_shape`` Examples: >>> model = DQN(32, 6) # arguments: 'obs_shape' and 'action_shape' >>> inputs = torch.randn(4, 32) >>> outputs = model(inputs) >>> assert isinstance(outputs, dict) and outputs['logit'].shape == torch.Size([4, 6]) """ x = self.encoder(x) x = self.head(x) return x
[docs]@MODEL_REGISTRY.register('c51dqn') class C51DQN(nn.Module):
[docs] def __init__( self, obs_shape: Union[int, SequenceType], action_shape: Union[int, SequenceType], encoder_hidden_size_list: SequenceType = [128, 128, 64], head_hidden_size: int = 64, head_layer_num: int = 1, activation: Optional[nn.Module] = nn.ReLU(), norm_type: Optional[str] = None, v_min: Optional[float] = -10, v_max: Optional[float] = 10, n_atom: Optional[int] = 51, ) -> None: r""" Overview: Init the C51 Model according to input arguments. Arguments: - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space. - action_shape (:obj:`Union[int, SequenceType]`): Action's space. - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder`` - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to ``Head``. - head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` the after ``layer_fn``, if ``None`` then default set to ``nn.ReLU()`` - norm_type (:obj:`Optional[str]`): The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details` - n_atom (:obj:`Optional[int]`): Number of atoms in the prediction distribution. """ super(C51DQN, self).__init__() # For compatibility: 1, (1, ), [4, 32, 32] obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape) # FC Encoder if isinstance(obs_shape, int) or len(obs_shape) == 1: self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) # Conv Encoder elif len(obs_shape) == 3: self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) else: raise RuntimeError( "not support obs_shape for pre-defined encoder: {}, please customize your own C51DQN".format(obs_shape) ) # Head Type multi_head = not isinstance(action_shape, int) if multi_head: self.head = MultiHead( DistributionHead, head_hidden_size, action_shape, layer_num=head_layer_num, activation=activation, norm_type=norm_type, n_atom=n_atom, v_min=v_min, v_max=v_max, ) else: self.head = DistributionHead( head_hidden_size, action_shape, head_layer_num, activation=activation, norm_type=norm_type, n_atom=n_atom, v_min=v_min, v_max=v_max, )
[docs] def forward(self, x: torch.Tensor) -> Dict: r""" Overview: Use observation tensor to predict C51DQN's output. Parameter updates with C51DQN's MLPs forward setup. Arguments: - x (:obj:`torch.Tensor`): The encoded embedding tensor w/ ``(B, N=head_hidden_size)``. Returns: - outputs (:obj:`Dict`): Run with encoder and head. Return the result prediction dictionary. ReturnsKeys: - logit (:obj:`torch.Tensor`): Logit tensor with same size as input ``x``. - distribution (:obj:`torch.Tensor`): Distribution tensor of size ``(B, N, n_atom)`` Shapes: - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is head_hidden_size. - logit (:obj:`torch.FloatTensor`): :math:`(B, M)`, where M is action_shape. - distribution(:obj:`torch.FloatTensor`): :math:`(B, M, P)`, where P is n_atom. Examples: >>> model = C51DQN(128, 64) # arguments: 'obs_shape' and 'action_shape' >>> inputs = torch.randn(4, 128) >>> outputs = model(inputs) >>> assert isinstance(outputs, dict) >>> # default head_hidden_size: int = 64, >>> assert outputs['logit'].shape == torch.Size([4, 64]) >>> # default n_atom: int = 51 >>> assert outputs['distribution'].shape == torch.Size([4, 64, 51]) """ x = self.encoder(x) x = self.head(x) return x
[docs]@MODEL_REGISTRY.register('qrdqn') class QRDQN(nn.Module):
[docs] def __init__( self, obs_shape: Union[int, SequenceType], action_shape: Union[int, SequenceType], encoder_hidden_size_list: SequenceType = [128, 128, 64], head_hidden_size: Optional[int] = None, head_layer_num: int = 1, num_quantiles: int = 32, activation: Optional[nn.Module] = nn.ReLU(), norm_type: Optional[str] = None, ) -> None: r""" Overview: Init the QRDQN Model according to input arguments. Arguments: - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space. - action_shape (:obj:`Union[int, SequenceType]`): Action's space. - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder`` - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to ``Head``. - head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output - num_quantiles (:obj:`int`): Number of quantiles in the prediction distribution. - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` the after ``layer_fn``, if ``None`` then default set to ``nn.ReLU()`` - norm_type (:obj:`Optional[str]`): The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details` """ super(QRDQN, self).__init__() # For compatibility: 1, (1, ), [4, 32, 32] obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape) if head_hidden_size is None: head_hidden_size = encoder_hidden_size_list[-1] # FC Encoder if isinstance(obs_shape, int) or len(obs_shape) == 1: self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) # Conv Encoder elif len(obs_shape) == 3: self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) else: raise RuntimeError( "not support obs_shape for pre-defined encoder: {}, please customize your own QRDQN".format(obs_shape) ) # Head Type multi_head = not isinstance(action_shape, int) if multi_head: self.head = MultiHead( QRDQNHead, head_hidden_size, action_shape, layer_num=head_layer_num, num_quantiles=num_quantiles, activation=activation, norm_type=norm_type, ) else: self.head = QRDQNHead( head_hidden_size, action_shape, head_layer_num, num_quantiles=num_quantiles, activation=activation, norm_type=norm_type, )
[docs] def forward(self, x: torch.Tensor) -> Dict: r""" Overview: Use observation tensor to predict QRDQN's output. Parameter updates with QRDQN's MLPs forward setup. Arguments: - x (:obj:`torch.Tensor`): The encoded embedding tensor with ``(B, N=hidden_size)``. Returns: - outputs (:obj:`Dict`): Run with encoder and head. Return the result prediction dictionary. ReturnsKeys: - logit (:obj:`torch.Tensor`): Logit tensor with same size as input ``x``. - q (:obj:`torch.Tensor`): Q valye tensor tensor of size ``(B, N, num_quantiles)`` - tau (:obj:`torch.Tensor`): tau tensor of size ``(B, N, 1)`` Shapes: - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is head_hidden_size. - logit (:obj:`torch.FloatTensor`): :math:`(B, M)`, where M is action_shape. - tau (:obj:`torch.Tensor`): :math:`(B, M, 1)` Examples: >>> model = QRDQN(64, 64) >>> inputs = torch.randn(4, 64) >>> outputs = model(inputs) >>> assert isinstance(outputs, dict) >>> assert outputs['logit'].shape == torch.Size([4, 64]) >>> # default num_quantiles : int = 32 >>> assert outputs['q'].shape == torch.Size([4, 64, 32]) >>> assert outputs['tau'].shape == torch.Size([4, 32, 1]) """ x = self.encoder(x) x = self.head(x) return x
[docs]@MODEL_REGISTRY.register('iqn') class IQN(nn.Module):
[docs] def __init__( self, obs_shape: Union[int, SequenceType], action_shape: Union[int, SequenceType], encoder_hidden_size_list: SequenceType = [128, 128, 64], head_hidden_size: Optional[int] = None, head_layer_num: int = 1, num_quantiles: int = 32, quantile_embedding_size: int = 128, activation: Optional[nn.Module] = nn.ReLU(), norm_type: Optional[str] = None ) -> None: r""" Overview: Init the IQN Model according to input arguments. Arguments: - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape. - action_shape (:obj:`Union[int, SequenceType]`): Action space shape. - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder`` - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to ``Head``. - head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output - num_quantiles (:obj:`int`): Number of quantiles in the prediction distribution. - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` the after ``layer_fn``, if ``None`` then default set to ``nn.ReLU()`` - norm_type (:obj:`Optional[str]`): The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details. """ super(IQN, self).__init__() # For compatibility: 1, (1, ), [4, 32, 32] obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape) if head_hidden_size is None: head_hidden_size = encoder_hidden_size_list[-1] # FC Encoder if isinstance(obs_shape, int) or len(obs_shape) == 1: self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) # Conv Encoder elif len(obs_shape) == 3: self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) else: raise RuntimeError( "not support obs_shape for pre-defined encoder: {}, please customize your own IQN".format(obs_shape) ) # Head Type head_cls = QuantileHead multi_head = not isinstance(action_shape, int) if multi_head: self.head = MultiHead( head_cls, head_hidden_size, action_shape, layer_num=head_layer_num, num_quantiles=num_quantiles, quantile_embedding_size=quantile_embedding_size, activation=activation, norm_type=norm_type ) else: self.head = head_cls( head_hidden_size, action_shape, head_layer_num, activation=activation, norm_type=norm_type, num_quantiles=num_quantiles, quantile_embedding_size=quantile_embedding_size, )
[docs] def forward(self, x: torch.Tensor) -> Dict: r""" Overview: Use encoded embedding tensor to predict IQN's output. Parameter updates with IQN's MLPs forward setup. Arguments: - x (:obj:`torch.Tensor`): The encoded embedding tensor with ``(B, N=hidden_size)``. Returns: - outputs (:obj:`Dict`): Run with encoder and head. Return the result prediction dictionary. ReturnsKeys: - logit (:obj:`torch.Tensor`): Logit tensor with same size as input ``x``. - q (:obj:`torch.Tensor`): Q valye tensor tensor of size ``(num_quantiles, N, B)`` - quantiles (:obj:`torch.Tensor`): quantiles tensor of size ``(quantile_embedding_size, 1)`` Shapes: - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is head_hidden_size. - logit (:obj:`torch.FloatTensor`): :math:`(B, M)`, where M is action_shape - quantiles (:obj:`torch.Tensor`): :math:`(P, 1)`, where P is quantile_embedding_size. Examples: >>> model = IQN(64, 64) # arguments: 'obs_shape' and 'action_shape' >>> inputs = torch.randn(4, 64) >>> outputs = model(inputs) >>> assert isinstance(outputs, dict) >>> assert outputs['logit'].shape == torch.Size([4, 64]) >>> # default num_quantiles: int = 32 >>> assert outputs['q'].shape == torch.Size([32, 4, 64] >>> # default quantile_embedding_size: int = 128 >>> assert outputs['quantiles'].shape == torch.Size([128, 1]) """ x = self.encoder(x) x = self.head(x) return x
[docs]@MODEL_REGISTRY.register('rainbowdqn') class RainbowDQN(nn.Module): """ Overview: RainbowDQN network (C51 + Dueling + Noisy Block) .. note:: RainbowDQN contains dueling architecture by default """
[docs] def __init__( self, obs_shape: Union[int, SequenceType], action_shape: Union[int, SequenceType], encoder_hidden_size_list: SequenceType = [128, 128, 64], head_hidden_size: Optional[int] = None, head_layer_num: int = 1, activation: Optional[nn.Module] = nn.ReLU(), norm_type: Optional[str] = None, v_min: Optional[float] = -10, v_max: Optional[float] = 10, n_atom: Optional[int] = 51, ) -> None: """ Overview: Init the Rainbow Model according to arguments. Arguments: - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape. - action_shape (:obj:`Union[int, SequenceType]`): Action space shape. - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder`` - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to ``Head``. - head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` the after \ ``layer_fn``, if ``None`` then default set to ``nn.ReLU()`` - norm_type (:obj:`Optional[str]`): The type of normalization to use, see ``ding.torch_utils.fc_block`` \ for more details` - n_atom (:obj:`Optional[int]`): Number of atoms in the prediction distribution. """ super(RainbowDQN, self).__init__() # For compatibility: 1, (1, ), [4, 32, 32] obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape) if head_hidden_size is None: head_hidden_size = encoder_hidden_size_list[-1] # FC Encoder if isinstance(obs_shape, int) or len(obs_shape) == 1: self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) # Conv Encoder elif len(obs_shape) == 3: self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) else: raise RuntimeError( "not support obs_shape for pre-defined encoder: {}, please customize your own RainbowDQN". format(obs_shape) ) # Head Type multi_head = not isinstance(action_shape, int) if multi_head: self.head = MultiHead( RainbowHead, head_hidden_size, action_shape, layer_num=head_layer_num, activation=activation, norm_type=norm_type, n_atom=n_atom, v_min=v_min, v_max=v_max, ) else: self.head = RainbowHead( head_hidden_size, action_shape, head_layer_num, activation=activation, norm_type=norm_type, n_atom=n_atom, v_min=v_min, v_max=v_max, )
[docs] def forward(self, x: torch.Tensor) -> Dict: r""" Overview: Use observation tensor to predict Rainbow output. Parameter updates with Rainbow's MLPs forward setup. Arguments: - x (:obj:`torch.Tensor`): The encoded embedding tensor with ``(B, N=hidden_size)``. Returns: - outputs (:obj:`Dict`): Run ``MLP`` with ``RainbowHead`` setups and return the result prediction dictionary. ReturnsKeys: - logit (:obj:`torch.Tensor`): Logit tensor with same size as input ``x``. - distribution (:obj:`torch.Tensor`): Distribution tensor of size ``(B, N, n_atom)`` Shapes: - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is head_hidden_size. - logit (:obj:`torch.FloatTensor`): :math:`(B, M)`, where M is action_shape. - distribution(:obj:`torch.FloatTensor`): :math:`(B, M, P)`, where P is n_atom. Examples: >>> model = RainbowDQN(64, 64) # arguments: 'obs_shape' and 'action_shape' >>> inputs = torch.randn(4, 64) >>> outputs = model(inputs) >>> assert isinstance(outputs, dict) >>> assert outputs['logit'].shape == torch.Size([4, 64]) >>> # default n_atom: int =51 >>> assert outputs['distribution'].shape == torch.Size([4, 64, 51]) """ x = self.encoder(x) x = self.head(x) return x
def parallel_wrapper(forward_fn: Callable) -> Callable: r""" Overview: Process timestep T and batch_size B at the same time, in other words, treat different timestep data as different trajectories in a batch. Arguments: - forward_fn (:obj:`Callable`): Normal ``nn.Module`` 's forward function. Returns: - wrapper (:obj:`Callable`): Wrapped function. """ def wrapper(x: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor]]: T, B = x.shape[:2] def reshape(d): if isinstance(d, list): d = [reshape(t) for t in d] elif isinstance(d, dict): d = {k: reshape(v) for k, v in d.items()} else: d = d.reshape(T, B, *d.shape[1:]) return d x = x.reshape(T * B, *x.shape[2:]) x = forward_fn(x) x = reshape(x) return x return wrapper @MODEL_REGISTRY.register('drqn') class DRQN(nn.Module): """ Overview: DQN + RNN = DRQN """ def __init__( self, obs_shape: Union[int, SequenceType], action_shape: Union[int, SequenceType], encoder_hidden_size_list: SequenceType = [128, 128, 64], dueling: bool = True, head_hidden_size: Optional[int] = None, head_layer_num: int = 1, lstm_type: Optional[str] = 'normal', activation: Optional[nn.Module] = nn.ReLU(), norm_type: Optional[str] = None ) -> None: r""" Overview: Init the DRQN Model according to arguments. Arguments: - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space. - action_shape (:obj:`Union[int, SequenceType]`): Action's space. - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder`` - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to ``Head``. - lstm_type (:obj:`Optional[str]`): Version of lstm cell, now support ``['normal', 'pytorch']`` - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` the after ``layer_fn``, if ``None`` then default set to ``nn.ReLU()`` - norm_type (:obj:`Optional[str]`): The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details` """ super(DRQN, self).__init__() # For compatibility: 1, (1, ), [4, 32, 32] obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape) if head_hidden_size is None: head_hidden_size = encoder_hidden_size_list[-1] # FC Encoder if isinstance(obs_shape, int) or len(obs_shape) == 1: self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) # Conv Encoder elif len(obs_shape) == 3: self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) else: raise RuntimeError( "not support obs_shape for pre-defined encoder: {}, please customize your own DRQN".format(obs_shape) ) # LSTM Type self.rnn = get_lstm(lstm_type, input_size=head_hidden_size, hidden_size=head_hidden_size) # Head Type if dueling: head_cls = DuelingHead else: head_cls = DiscreteHead multi_head = not isinstance(action_shape, int) if multi_head: self.head = MultiHead( head_cls, head_hidden_size, action_shape, layer_num=head_layer_num, activation=activation, norm_type=norm_type ) else: self.head = head_cls( head_hidden_size, action_shape, head_layer_num, activation=activation, norm_type=norm_type ) def forward(self, inputs: Dict, inference: bool = False) -> Dict: r""" Overview: Use observation tensor to predict DRQN output. Parameter updates with DRQN's MLPs forward setup. Arguments: - inputs (:obj:`Dict`): ArgumentsKeys: - obs (:obj:`torch.Tensor`): Encoded observation - prev_state (:obj:`list`): Previous state's tensor of size ``(B, N)`` Returns: - outputs (:obj:`Dict`): Run ``MLP`` with ``DRQN`` setups and return the result prediction dictionary. ReturnsKeys: - logit (:obj:`torch.Tensor`): Logit tensor with same size as input ``obs``. - next_state (:obj:`list`): Next state's tensor of size ``(B, N)`` Shapes: - obs (:obj:`torch.Tensor`): :math:`(B, N=obs_space)`, where B is batch size. - prev_state(:obj:`torch.FloatTensor list`): :math:`[(B, N)]` - logit (:obj:`torch.FloatTensor`): :math:`(B, N)` - next_state(:obj:`torch.FloatTensor list`): :math:`[(B, N)]` Examples: >>> # Init input's Keys: >>> prev_state = [[torch.randn(1, 1, 64) for __ in range(2)] for _ in range(4)] # B=4 >>> obs = torch.randn(4,64) >>> model = DRQN(64, 64) # arguments: 'obs_shape' and 'action_shape' >>> outputs = model({'obs': inputs, 'prev_state': prev_state}, inference=True) >>> # Check outputs's Keys >>> assert isinstance(outputs, dict) >>> assert outputs['logit'].shape == (4, 64) >>> assert len(outputs['next_state']) == 4 >>> assert all([len(t) == 2 for t in outputs['next_state']]) >>> assert all([t[0].shape == (1, 1, 64) for t in outputs['next_state']]) """ x, prev_state = inputs['obs'], inputs['prev_state'] if inference: x = self.encoder(x) x = x.unsqueeze(0) x, next_state = self.rnn(x, prev_state) x = x.squeeze(0) x = self.head(x) x['next_state'] = next_state return x else: assert len(x.shape) in [3, 5], x.shape x = parallel_wrapper(self.encoder)(x) lstm_embedding = [] for t in range(x.shape[0]): # T timesteps output, prev_state = self.rnn(x[t:t + 1], prev_state) lstm_embedding.append(output) x = torch.cat(lstm_embedding, 0) x = parallel_wrapper(self.head)(x) x['next_state'] = prev_state return x class GeneralQNetwork(nn.Module): pass