https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/
# https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html
class TransformerEncoderLayer(nn.Module):
def __init__(
self,
d_model: int = 512,
nhead: int = 8,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: callable[[Tensor], Tensor] = F.relu,
layer_norm_eps: float = 1e-5,
batch_first: bool = True, # (default: `False`), Why is the default of `batch_first` False?
) -> None:
super(TransformerEncoderLayer, self).__init__()
# (1) Multi-head Attention
self.self_attn = nn.MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=batch_first
)
# (2) Feedforward
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(
self, src: Tensor, src_mask: None | Tensor, src_key_padding_mask: None | Tensor
):
"""
`N`: batch size
`S`: sequence length
`E`: embeddins feature dim
src: `(N, S, E)` if batch_first=True
"""
x = src
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
x = self.norm2(x + self._ff_block(x))
# self-attention block
def _sa_block(
self, x: Tensor, attn_mask: None | Tensor, key_padding_mask: None | Tensor
) -> Tensor:
# Input: (N, S, E)
x = self.self_attn(x, x, x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
)[0] # -> (Tensor(N, S, E), Tensor(N, S, S))
# attn's output[1] is self-attention matrix
# q_scaled = q / math.sqrt(E)
# attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
return self.dropout1(x) # (N, S, 512)
# feed-forward block
def _ff_block(self, x: Tensor) -> Tensor:
# Input: (N, S, 512)
x = self.dropout(self.activation(self.linear1(x))) # (N, S, 2048)
return self.dropout2(self.linear2(x)) # (N, S, 512)
# https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html
class TransformerDecoderLayer(nn.Module):
def __init__(
self,
d_model: int = 512,
nhead: int = 8,
dim_feedforward: int = 2048,
activation: callable[[Tensor], Tensor] = F.relu,
dropout: float = 0.1,
layer_norm_eps: float = 1e-5,
batch_first: bool = True, # default: False
) -> None:
super(TransformerDecoderLayer, self).__init__()
self.self_attn = nn.MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=batch_first
)
self.multihead_attn = nn.MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=batch_first
)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = activation
def forward(
self,
tgt: Tensor,
memory: Tensor,
tgt_mask: Tensor | None = None,
memory_mask: Tensor | None = None,
tgt_key_padding_mask: Tensor | None = None,
memory_key_padding_mask: Tensor | None = None,
) -> Tensor:
"""
`N`: batch size
`S`: memory sequence length (output from the last layer of the encoder)
`T`: target sequence length
tgt: `(N, T, 512)` if batch_first=True
memory: `(N, S, 512)` if batch_first=True
"""
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
x = tgt
x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask))
x = self.norm2(
x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask)
)
x = self.norm3(x + self._ff_block(x))
return x
# self-attention block
def _sa_block(
self, x: Tensor, attn_mask: Tensor | None, key_padding_mask: Tensor | None
) -> Tensor:
x = self.self_attn(x, x, x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)[0]
return self.dropout1(x) # (N, T, 512)
# Different from Encoder
# multihead attention block
def _mha_block(
self,
x: Tensor,
mem: Tensor, # from the encoder
attn_mask: Tensor | None,
key_padding_mask: Tensor | None,
) -> Tensor:
x = self.multihead_attn(x, mem, mem,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)[0] # ((N, T, 512)/sqrt(512))@(N, S, 512).T@(N, S, 512) -> (N, T, 512)
return self.dropout2(x) # (N, T, 512)
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.dropout(self.activation(self.linear1(x))) # (N, T, 2048)
return self.dropout3(self.linear2(x)) # (N, T, 512)
From ChatGPT
Why batch_first
is False?
The batch_first
argument is False
, which means that the sequence (time) dimension is listed first in the input tensor.
The reason for this choice is historical, as it reflects the convention used in the original Transformer paper by Vaswani et al. (2017), which has since become a standard in the field of NLP. By using the same convention, researchers and practitioners can easily compare and replicate the results from existing literature, as well as leverage existing code and resources.
But, batch_first is more memory-efficient!
However, it is worth noting that the choice of batch_first can affect the memory usage and performance of a model, as the batch dimension is typically the largest dimension, and it is more memory-efficient to have the batch dimension first.
On the other hand, if the sequence dimension was first, then each batch example would be stored in a separate block of memory, leading to more frequent cache misses and slower memory access.
But, It is calculated in the form of (squence length, batch_size, features)
in nn.MultiheadAttention
.