Source code for ding.torch_utils.network.transformer
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import List, Optional, Tuple
from .nn_module import fc_block, build_normalization
[docs]class Attention(nn.Module):
r"""
Overview:
For each entry embedding, compute individual attention across all entries, add them up to get output attention
Interfaces:
split, forward
"""
def __init__(self, input_dim: int, head_dim: int, output_dim: int, head_num: int, dropout: nn.Module) -> None:
r"""
Overview:
Init attention
Arguments:
- input_dim (:obj:`int`): dimension of input
- head_dim (:obj:`int`): dimension of each head
- output_dim (:obj:`int`): dimension of output
- head_num (:obj:`int`): head num for multihead attention
- dropout (:obj:`nn.Module`): dropout layer
"""
super(Attention, self).__init__()
self.head_num = head_num
self.head_dim = head_dim
self.dropout = dropout
self.attention_pre = fc_block(input_dim, head_dim * head_num * 3) # query, key, value
self.project = fc_block(head_dim * head_num, output_dim)
[docs] def split(self, x: torch.Tensor, T: bool = False) -> List[torch.Tensor]:
r"""
Overview:
Split input to get multihead queries, keys, values
Arguments:
- x (:obj:`torch.Tensor`): query or key or value
- T (:obj:`bool`): whether to transpose output
Returns:
- x (:obj:`List[torch.Tensor]`): list of output tensors for each head
"""
B, N = x.shape[:2]
x = x.view(B, N, self.head_num, self.head_dim)
x = x.permute(0, 2, 1, 3).contiguous() # B, head_num, N, head_dim
if T:
x = x.permute(0, 1, 3, 2).contiguous()
return x
[docs] def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
r"""
Overview:
Compute attention
Arguments:
- x (:obj:`torch.Tensor`): input tensor
- mask (:obj:`Optional[torch.Tensor]`): mask out invalid entries
Returns:
- attention (:obj:`torch.Tensor`): attention tensor
"""
assert (len(x.shape) == 3)
B, N = x.shape[:2]
x = self.attention_pre(x)
query, key, value = torch.chunk(x, 3, dim=2)
query, key, value = self.split(query), self.split(key, T=True), self.split(value)
score = torch.matmul(query, key) # B, head_num, N, N
score /= math.sqrt(self.head_dim)
if mask is not None:
score.masked_fill(~mask, value=-1e9)
score = F.softmax(score, dim=-1)
score = self.dropout(score)
attention = torch.matmul(score, value) # B, head_num, N, head_dim
attention = attention.permute(0, 2, 1, 3).contiguous() # B, N, head_num, head_dim
attention = self.project(attention.view(B, N, -1)) # B, N, output_dim
return attention
[docs]class TransformerLayer(nn.Module):
r"""
Overview:
In transformer layer, first computes entries's attention and applies a feedforward layer
"""
def __init__(
self, input_dim: int, head_dim: int, hidden_dim: int, output_dim: int, head_num: int, mlp_num: int,
dropout: nn.Module, activation: nn.Module
) -> None:
r"""
Overview:
Init transformer layer
Arguments:
- input_dim (:obj:`int`): dimension of input
- head_dim (:obj:`int`): dimension of each head
- hidden_dim (:obj:`int`): dimension of hidden layer in mlp
- output_dim (:obj:`int`): dimension of output
- head_num (:obj:`int`): number of heads for multihead attention
- mlp_num (:obj:`int`): number of mlp layers
- dropout (:obj:`nn.Module`): dropout layer
- activation (:obj:`nn.Module`): activation function
"""
super(TransformerLayer, self).__init__()
self.attention = Attention(input_dim, head_dim, output_dim, head_num, dropout)
self.layernorm1 = build_normalization('LN')(output_dim)
self.dropout = dropout
layers = []
dims = [output_dim] + [hidden_dim] * (mlp_num - 1) + [output_dim]
for i in range(mlp_num):
layers.append(fc_block(dims[i], dims[i + 1], activation=activation))
if i != mlp_num - 1:
layers.append(self.dropout)
layers.append(self.dropout)
self.mlp = nn.Sequential(*layers)
self.layernorm2 = build_normalization('LN')(output_dim)
[docs] def forward(self, inputs: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Overview:
Transformer layer forward
Arguments:
- inputs (:obj:`Tuple[torch.Tensor, torch.Tensor]`): x and mask
Returns:
- output (:obj:`Tuple[torch.Tensor, torch.Tensor]`): predict value and mask
"""
x, mask = inputs
a = self.dropout(self.attention(x, mask))
x = self.layernorm1(x + a)
m = self.dropout(self.mlp(x))
x = self.layernorm2(x + m)
return (x, mask)
[docs]class Transformer(nn.Module):
'''
Overview:
Transformer implementation
.. note::
For details refer to Attention is all you need: http://arxiv.org/abs/1706.03762
'''
def __init__(
self,
input_dim: int,
head_dim: int = 128,
hidden_dim: int = 1024,
output_dim: int = 256,
head_num: int = 2,
mlp_num: int = 2,
layer_num: int = 3,
dropout_ratio: float = 0.,
activation: nn.Module = nn.ReLU(),
):
r"""
Overview:
Init transformer
Arguments:
- input_dim (:obj:`int`): dimension of input
- head_dim (:obj:`int`): dimension of each head
- hidden_dim (:obj:`int`): dimension of hidden layer in mlp
- output_dim (:obj:`int`): dimension of output
- head_num (:obj:`int`): number of heads for multihead attention
- mlp_num (:obj:`int`): number of mlp layers
- layer_num (:obj:`int`): number of transformer layers
- dropout_ratio (:obj:`float`): dropout ratio
- activation (:obj:`nn.Module`): activation function
"""
super(Transformer, self).__init__()
self.embedding = fc_block(input_dim, output_dim, activation=activation)
self.act = activation
layers = []
dims = [output_dim] + [output_dim] * layer_num
self.dropout = nn.Dropout(dropout_ratio)
for i in range(layer_num):
layers.append(
TransformerLayer(dims[i], head_dim, hidden_dim, dims[i + 1], head_num, mlp_num, self.dropout, self.act)
)
self.main = nn.Sequential(*layers)
[docs] def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
r"""
Overview:
Transformer forward
Arguments:
- x (:obj:`torch.Tensor`): input tensor. Shape (B, N, C), B is batch size, \
N is number of entries, C is feature dimension
- mask (:obj:`Optional[torch.Tensor]`): bool tensor, can be used to mask out invalid entries in attention. \
Shape (B, N), B is batch size, N is number of entries
Returns:
- x (:obj:`torch.Tensor`): transformer output
"""
if mask is not None:
mask = mask.unsqueeze(dim=1).repeat(1, mask.shape[1], 1).unsqueeze(dim=1)
x = self.embedding(x)
x = self.dropout(x)
x, mask = self.main((x, mask))
return x
class ScaledDotProductAttention(nn.Module):
'''
Overview:
Implementation of dot product attentionn with scaling.
'''
def __init__(self, d_k: int, dropout: float = 0.0) -> None:
super(ScaledDotProductAttention, self).__init__()
self.d_k = d_k
self.dropout = nn.Dropout(dropout)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
attn = torch.matmul(q / (self.d_k ** 0.5), k.transpose(2, 3))
if mask is not None:
attn = attn.masked_fill(~mask, -1e9)
attn = self.dropout(F.softmax(attn, dim=-1))
output = torch.matmul(attn, v)
return output