使用scaled dot-product attention替换Additive attention
scaled dot-product attention常用于Transformer模型中,其计算复杂度与普通的点积注意力相同,但由于在缩放时将点积除以sqrt(d_k),其中d_k为键向量的维度,可以缩小相似度的范围,有助于训练稳定的模型。
下面是一个使用scaled dot-product attention的简单示例代码:
import torch.nn as nn import torch.nn.functional as F
class TransformerBlock(nn.Module): def init(self, d_model, n_head, d_hidden, dropout): super(TransformerBlock, self).init() self.multihead_attn = nn.MultiheadAttention(d_model, n_head, dropout=dropout) self.feedforward = nn.Sequential( nn.Linear(d_model, d_hidden), nn.ReLU(inplace=True), nn.Linear(d_hidden, d_model), nn.Dropout(dropout) ) self.layernorm1 = nn.LayerNorm(d_model) self.layernorm2 = nn.LayerNorm(d_model)
def forward(self, x):
# Multi-head attention
attn_output, _ = self.multihead_attn(x, x, x)
x = x + F.dropout(attn_output, p=0.2)
x = self.layernorm1(x)
# Feedforward
ff_output = self.feedforward(x)
x = x + F.dropout(ff_output, p=0.2)
x = self.layernorm2(x)
return x
在上面的代码中,multihead_attn模块使用了scaled dot-product attention,可以直接应用于序列数据的编码器中。通过使用TransformerBlock模块堆叠多个transformer block,可以构建一个完整的Transformer编码器。