Transformer

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/

Attention is All you Need (Vaswani, et al., 2017)#

attn(Q,K,V)=softmax(QKdk)VMultiHeadAttn(Xq,Xk,Xv)=[head1;;headh]Wowhere headi=Attention(XqWiq,XkWik,XvWiv)\begin{aligned}
\text{attn}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) &= \text{softmax}(\frac{\mathbf{Q} {\mathbf{K}}^\top}{\sqrt{d_k}})\mathbf{V}\\
\text{MultiHeadAttn}(\mathbf{X}_q, \mathbf{X}_k, \mathbf{X}_v) &= [\text{head}_1; \dots; \text{head}_h] \mathbf{W}^o \\ 
\text{where head}_i &= \text{Attention}(\mathbf{X}_q\mathbf{W}^q_i, \mathbf{X}_k\mathbf{W}^k_i, \mathbf{X}_v\mathbf{W}^v_i)
\end{aligned}

Encoder Layer#

# https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html

class TransformerEncoderLayer(nn.Module):
    def __init__(
        self,
        d_model: int = 512,
        nhead: int = 8,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        activation: callable[[Tensor], Tensor] = F.relu,
        layer_norm_eps: float = 1e-5,
        batch_first: bool = True,  # (default: `False`), Why is the default of `batch_first` False?
    ) -> None:
        super(TransformerEncoderLayer, self).__init__()

        # (1) Multi-head Attention
        self.self_attn = nn.MultiheadAttention(
            d_model, nhead, dropout=dropout, batch_first=batch_first
        )

        # (2) Feedforward
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(
        self, src: Tensor, src_mask: None | Tensor, src_key_padding_mask: None | Tensor
    ):
        """
        `N`: batch size
        `S`: sequence length
        `E`: embeddins feature dim
        src: `(N, S, E)` if batch_first=True
        """
        x = src
        x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
        x = self.norm2(x + self._ff_block(x))

    # self-attention block
    def _sa_block(
        self, x: Tensor, attn_mask: None | Tensor, key_padding_mask: None | Tensor
    ) -> Tensor:
        # Input: (N, S, E)
        x = self.self_attn(x, x, x,
            attn_mask=attn_mask,
            key_padding_mask=key_padding_mask,
        )[0]  # -> (Tensor(N, S, E), Tensor(N, S, S))
        # attn's output[1] is self-attention matrix
        # q_scaled = q / math.sqrt(E)
        # attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
        return self.dropout1(x)  # (N, S, 512)

    # feed-forward block
    def _ff_block(self, x: Tensor) -> Tensor:
        # Input: (N, S, 512)
        x = self.dropout(self.activation(self.linear1(x)))  # (N, S, 2048)
        return self.dropout2(self.linear2(x))  # (N, S, 512)

Decoder Layer#

# https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html

class TransformerDecoderLayer(nn.Module):
    def __init__(
        self,
        d_model: int = 512,
        nhead: int = 8,
        dim_feedforward: int = 2048,
        activation: callable[[Tensor], Tensor] = F.relu,
        dropout: float = 0.1,
        layer_norm_eps: float = 1e-5,
        batch_first: bool = True,  # default: False
    ) -> None:
        super(TransformerDecoderLayer, self).__init__()

        self.self_attn = nn.MultiheadAttention(
            d_model, nhead, dropout=dropout, batch_first=batch_first
        )

        self.multihead_attn = nn.MultiheadAttention(
            d_model, nhead, dropout=dropout, batch_first=batch_first
        )

        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.activation = activation

    def forward(
        self,
        tgt: Tensor,
        memory: Tensor,
        tgt_mask: Tensor | None = None,
        memory_mask: Tensor | None = None,
        tgt_key_padding_mask: Tensor | None = None,
        memory_key_padding_mask: Tensor | None = None,
    ) -> Tensor:
        """
        `N`: batch size
        `S`: memory sequence length (output from the last layer of the encoder)
        `T`: target sequence length
        tgt: `(N, T, 512)` if batch_first=True
        memory: `(N, S, 512)` if batch_first=True
        """
        # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf

        x = tgt
        x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask))
        x = self.norm2(
            x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask)
        )
        x = self.norm3(x + self._ff_block(x))

        return x

    # self-attention block
    def _sa_block(
        self, x: Tensor, attn_mask: Tensor | None, key_padding_mask: Tensor | None
    ) -> Tensor:
        x = self.self_attn(x, x, x,
            attn_mask=attn_mask,
            key_padding_mask=key_padding_mask,
            need_weights=False,
        )[0]
        return self.dropout1(x)  # (N, T, 512)

    # Different from Encoder
    # multihead attention block
    def _mha_block(
        self,
        x: Tensor,
        mem: Tensor,  # from the encoder
        attn_mask: Tensor | None,
        key_padding_mask: Tensor | None,
    ) -> Tensor:
        x = self.multihead_attn(x, mem, mem,
            attn_mask=attn_mask,
            key_padding_mask=key_padding_mask,
            need_weights=False,
        )[0]  # ((N, T, 512)/sqrt(512))@(N, S, 512).T@(N, S, 512) -> (N, T, 512)
        return self.dropout2(x)  # (N, T, 512)

    # feed forward block
    def _ff_block(self, x: Tensor) -> Tensor:
        x = self.dropout(self.activation(self.linear1(x)))  # (N, T, 2048)
        return self.dropout3(self.linear2(x))  # (N, T, 512)

From ChatGPT

  • Why batch_first is False?

    The batch_first argument is False, which means that the sequence (time) dimension is listed first in the input tensor.

    The reason for this choice is historical, as it reflects the convention used in the original Transformer paper by Vaswani et al. (2017), which has since become a standard in the field of NLP. By using the same convention, researchers and practitioners can easily compare and replicate the results from existing literature, as well as leverage existing code and resources.

  • But, batch_first is more memory-efficient!

    However, it is worth noting that the choice of batch_first can affect the memory usage and performance of a model, as the batch dimension is typically the largest dimension, and it is more memory-efficient to have the batch dimension first.

    On the other hand, if the sequence dimension was first, then each batch example would be stored in a separate block of memory, leading to more frequent cache misses and slower memory access.

  • But, It is calculated in the form of (squence length, batch_size, features) in nn.MultiheadAttention.