from typing import Optional
import torch
import torch.nn as nn
from ding.torch_utils import fc_block, ResFCBlock, ResBlock
from ding.utils import SequenceType
[docs]class ConvEncoder(nn.Module):
r"""
Overview:
The ``Convolution Encoder`` used in models. Used to encoder raw 2-dim observation.
Interfaces:
``__init__``, ``forward``
"""
[docs] def __init__(
self,
obs_shape: SequenceType,
hidden_size_list: SequenceType = [32, 64, 64, 128],
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None
) -> None:
r"""
Overview:
Init the Convolution Encoder according to arguments.
Arguments:
- obs_shape (:obj:`SequenceType`): Sequence of ``in_channel``, some ``output size``
- hidden_size_list (:obj:`SequenceType`): The collection of ``hidden_size``
- activation (:obj:`nn.Module`):
The type of activation to use in the conv ``layers`` and ``ResBlock``,
if ``None`` then default set to ``nn.ReLU()``
- norm_type (:obj:`str`):
The type of normalization to use, see ``ding.torch_utils.ResBlock`` for more details
"""
super(ConvEncoder, self).__init__()
self.obs_shape = obs_shape
self.act = activation
self.hidden_size_list = hidden_size_list
layers = []
kernel_size = [8, 4, 3]
stride = [4, 2, 1]
input_size = obs_shape[0] # in_channel
for i in range(len(kernel_size)):
layers.append(nn.Conv2d(input_size, hidden_size_list[i], kernel_size[i], stride[i]))
layers.append(self.act)
input_size = hidden_size_list[i]
assert len(set(hidden_size_list[3:-1])) <= 1, "Please indicate the same hidden size for res block parts"
for i in range(3, len(self.hidden_size_list) - 1):
layers.append(ResBlock(self.hidden_size_list[i], activation=self.act, norm_type=norm_type))
layers.append(nn.Flatten())
self.main = nn.Sequential(*layers)
flatten_size = self._get_flatten_size()
self.mid = nn.Linear(flatten_size, hidden_size_list[-1])
[docs] def _get_flatten_size(self) -> int:
r"""
Overview:
Get the encoding size after ``self.main`` to get the number of ``in-features`` to feed to ``nn.Linear``.
Arguments:
- x (:obj:`torch.Tensor`): Encoded Tensor after ``self.main``
Returns:
- outputs (:obj:`torch.Tensor`): Size int, also number of in-feature
"""
test_data = torch.randn(1, *self.obs_shape)
with torch.no_grad():
output = self.main(test_data)
return output.shape[1]
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""
Overview:
Return embedding tensor of the env observation
Arguments:
- x (:obj:`torch.Tensor`): Env raw observation
Returns:
- outputs (:obj:`torch.Tensor`): Embedding tensor
"""
x = self.main(x)
x = self.mid(x)
return x
[docs]class FCEncoder(nn.Module):
r"""
Overview:
The ``FCEncoder`` used in models. Used to encoder raw 1-dim observation.
Interfaces:
``__init__``, ``forward``
"""
[docs] def __init__(
self,
obs_shape: int,
hidden_size_list: SequenceType,
res_block: bool = False,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None
) -> None:
r"""
Overview:
Init the FC Encoder according to arguments.
Arguments:
- obs_shape (:obj:`int`): Observation shape
- hidden_size_list (:obj:`SequenceType`): The collection of ``hidden_size``
- res_block (:obj:`bool`): Whether use ``res_block``.
- activation (:obj:`nn.Module`):
The type of activation to use in the ``ResFCBlock``,
if ``None`` then default set to ``nn.ReLU()``
- norm_type (:obj:`str`):
The type of normalization to use, see ``ding.torch_utils.ResFCBlock`` for more details
"""
super(FCEncoder, self).__init__()
self.obs_shape = obs_shape
self.act = activation
self.init = nn.Linear(obs_shape, hidden_size_list[0])
if res_block:
assert len(set(hidden_size_list)) == 1, "Please indicate the same hidden size for res block parts"
if len(hidden_size_list) == 1:
self.main = ResFCBlock(hidden_size_list[0], activation=self.act, norm_type=norm_type)
else:
layers = []
for i in range(len(hidden_size_list)):
layers.append(ResFCBlock(hidden_size_list[0], activation=self.act, norm_type=norm_type))
self.main = nn.Sequential(*layers)
else:
layers = []
for i in range(len(hidden_size_list) - 1):
layers.append(nn.Linear(hidden_size_list[i], hidden_size_list[i + 1]))
layers.append(self.act)
self.main = nn.Sequential(*layers)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""
Overview:
Return embedding tensor of the env observation
Arguments:
- x (:obj:`torch.Tensor`): Env raw observation
Returns:
- outputs (:obj:`torch.Tensor`): Embedding tensor
"""
x = self.act(self.init(x))
x = self.main(x)
return x
[docs]class StructEncoder(nn.Module):
# TODO(nyz)
pass