Attention Is All You Need 从零实现:Transformer 论文详解与代码

2026-06-03阅读 0热度 0
其他

2017年,Google那篇《Attention Is All You Need》横空出世,直接把Transformer推到了前台。这篇文章做了一个在当时相当激进的决定——彻底抛弃RNN和LSTM的递归结构,全靠一种叫Attention的机制来捕捉语义关系。结果我们都看到了:GPT、BERT、LLaMA……几乎整个大模型时代都是从这里开始的。

这篇文章会带着你,用一份干净的、大约350行的纯PyTorch代码,把Transformer的每一个组件拆开来看。不是泛泛地讲概念,是切切实实从零搭到完整模型,代码可以直接跑,非常适合拿来理解原理、调试或者自己二次修改。

1. 组件一览

在深入代码之前,我们先快速扫一眼整个架构的核心模块,知道每个部分干的活和大概的计算量,心里有个底。

组件功能复杂度
Scaled Dot-Product AttentionQ/K/V 相似度计算与聚合O(n·dₖ)
Multi-Head Attention多个表示空间并行注意力O(n·d_model)
Position-wise FFN每个位置非线性变换O(d_model·d_ff)
Positional Encoding引入位置信息O(max_len·d_model)
Layer Norm维度规范化O(d_model)
Encoder Layer自注意力 + FFNO(n²·d_model)
Decoder Layer带掩码 + 交叉注意力O(n²·d_model)

2. Scaled Dot-Product Attention

公式与直观

Scaled Dot-Product Attention是Transformer里最核心的计算单元。它的本质说白了就一句话:用“查询”(Query)和“键”(Key)的相似度,来决定怎么去加权聚合“值”(Value)。

Attention(Q,K,V)=softmax(QKdk)V

为什么要除以根号dk?这个问题其实很重要。当dk比较大的时候,点积的结果会随着维度增加而变大,这会把softmax推到非常极端的区域,梯度基本就消失了。缩放一下,让方差保持稳定,训练才能稳得住。

代码解析

class ScaledDotProductAttention(nn.Module):
    """Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V"""
    def __init__(self, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, Q, K, V, mask=None):
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        output = torch.matmul(attn_weights, V)
        return output, attn_weights

这里有几个关键点值得留意:

  • masked_fill会把padding位置的值变成负无穷,softmax之后权重就变成0了,不会影响输出结果。
  • 计算完attention权重之后才做dropout,这是一个很重要的正则化手段。
  • 函数里同时返回了attn_weights,主要就是为了可视化调试用的。

这个模块的计算瓶颈在QK转置的矩阵乘法上,复杂度是O(n²·dk),这里的n就是序列长度。这也是整个Transformer最吃性能的地方。

3. Multi-Head Attention

从单头到多头

单头注意力有一个天然的局限:它只能在一种表示空间里看问题。Multi-Head Attention的做法是把Q、K、V分别投影到h个不同的表示空间里,各自独立算一遍注意力,最后把结果拼起来,再投影回到原来的维度。

MultiHead(Q,K,V)=Concat(head1,...,headh)WO

代码实现里有一个很聪明的设计决策:先一次性投影,再拆成多个头。数学上定义h个独立的线性层是等价的,但代码里只需要4个线性层而不是3h个,效率高得多。

代码解析

class MultiHeadAttention(nn.Module):
    """MultiHead(Q, K, V) = Concat(head_1, ..., head_h) @ W_O
    其中 head_i = Attention(Q @ W_Q_i, K @ W_K_i, V @ W_V_i)"""
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0, "d_model 必须能被 n_heads 整除"
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)
        self.attention = ScaledDotProductAttention(dropout)

    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        # 1) 线性投影 → (batch, seq_len, d_model)
        Q = self.W_Q(Q)
        K = self.W_K(K)
        V = self.W_V(V)
        # 2) 拆成多头 → (batch, n_heads, seq_len, d_k)
        Q = Q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        # 3) Scaled Dot-Product Attention
        attn_output, attn_weights = self.attention(Q, K, V, mask)
        # 4) 拼接多头 → (batch, seq_len, d_model)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        # 5) 最终线性投影
        output = self.W_O(attn_output)
        return output

几个值得注意的点:

  • d_model % n_heads == 0这个断言保证了每个head能分到整数维度。
  • view + transpose这个操作就像"重新排片"一样,先把维度拆开再转置,让head维度排到前面。
  • .contiguous()这一步不能漏,transpose只是改了视图的内存布局,不调contiguous的话,后续的view会报错。
  • 最后的W_O是拼接后的最终投影,它负责把多头信息融合回到d_model维度。

mask处理这里有个小技巧:mask用(batch, 1, 1, seq_len)的格式,可以直接和scores的(batch, n_heads, seq_len, seq_len)做广播,省掉了额外的unsqueeze操作。

4. Position-wise Feed-Forward Network

非线性变换与容量

每个位置的表示在经过注意力层之后,还要过一个两层的全连接网络。这个FFN是position-wise的,意思是它对序列里的每个位置独立应用同样的参数,效果相当于卷积核大小为1的卷积。

FFN(x)=ReLU(xW1+b1)W2+b2

代码解析

class PositionWiseFeedForward(nn.Module):
    """FFN(x) = ReLU(x @ W_1 + b_1) @ W_2 + b_2
    内部维度从 d_model → d_ff → d_model"""
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

这里的设计思路也很清晰:

  • 内部的d_ff通常比d_model大得多,论文里用的是512到2048的扩展,这一下子就给了非线性变换充足的容量。
  • dropout放在ReLU之后、第二次线性投影之前,这是目前比较主流的做法。
  • 原始论文用的是ReLU,后来像GPT这些工作更多用GELU,这是一个值得留意的演进细节。

为什么偏偏是两层?论文里的实验表明,一层表达能力确实不够,但三层以上的收益又微乎其微。两层就是性能和资源之间最优的那个平衡点。

5. Positional Encoding

为序列引入位置信息

Self-Attention有一个特性一定会让你意外——它是对位置完全不敏感的。不管你把序列里的元素怎么打乱,输出结果都一样。为了引入位置信息,原始论文用的是正余弦编码。

PE(pos,2i)=sin(pos100002i/dmodel)

PE(pos,2i+1)=cos(pos100002i/dmodel)

为什么用正余弦而不用可学习的位置嵌入?这里有三个很实在的理由:

  1. 它可以处理比训练时更长的序列,也就是有外推能力。
  2. 不需要额外参数,省内存。
  3. 相对位置信息通过线性变换就能表达出来——因为sin(α+Δ) = sinα·cosΔ + cosα·sinΔ,这里天然存在线性关系。

代码解析

class PositionalEncoding(nn.Module):
    """PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
    PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))"""
    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)  # 偶数维度
        pe[:, 1::2] = torch.cos(position * div_term)  # 奇数维度
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

代码里有几个细节值得拿出来说说:

  • div_term用了指数形式而不是直接算10000的2i/d_model次方,这是为了数值稳定性。
  • register_buffer这一步让pe可以跟着模型一起移动到CPU或GPU上,但不会作为参数被优化器更新。
  • forward里直接把pe和输入相加,通过broadcast机制对齐维度,这是最经典的做法了。

6. Layer Normalization

维度规范化

Layer Normalization做的事情是:对每个样本的所有维度做一次变换——减去均值,除以标准差,然后再做一个可学习的线性变换。跟Batch Norm不同,LN不依赖batch的大小,处理变长序列的时候要稳定得多。

LayerNorm(x)=γxμσ2+ϵ+β

代码解析

class LayerNorm(nn.Module):
    """LayerNorm(x) = gamma * (x - mean) / sqrt(var + eps) + beta
    手写版,方便理解;实际可直接用 nn.LayerNorm"""
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True, unbiased=False)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

这段代码虽然手写了一遍,但实际用的时候直接用nn.LayerNorm就行。不过这个手写版能让你看得更清楚:unbiased=False用的是样本标准差而不是无偏估计,这和原始论文的做法一致。eps取1e-6是为了防止除零。

7. Encoder Layer

网络中的网络单元

Encoder层是Transformer的基本构建块。每层包含两个子层:多头自注意力和FFN,每个子层后面紧跟着一个残差连接和层规范化。

x → MultiHead Self-Attention → Add & Norm → FFN → Add & Norm

代码解析

class EncoderLayer(nn.Module):
    """一个 Encoder 层:x → MultiHead Self-Attention → Add & Norm → FFN → Add & Norm"""
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Self-Attention + Add & Norm
        attn_output = self.self_attn(x, x, x, mask)
        x = x + self.dropout1(attn_output)
        x = self.norm1(x)
        # FFN + Add & Norm
        ffn_output = self.ffn(x)
        x = x + self.dropout2(ffn_output)
        x = self.norm2(x)
        return x

几个要点:

  • 残差连接(x + sublayer(x))是解决深层网络梯度消失的关键设计——梯度可以直接通过这一条shortcut回传。
  • 这里用的是Post-LN模式,先做残差再做规范化,和原始论文的做法保持一致。
  • self_attn的三个输入参数全都是x,这正好对应了"自注意力"的概念——Q、K、V都来自同一个序列。

8. Decoder Layer

带掩码的自注意力与交叉注意力

Decoder层比Encoder多了一个子层——Cross-Attention。在这个子层里,Encoder的输出作为K和V,Decoder的输入作为Q。同时,自注意力层需要用下三角mask来遮住后面的位置,防止信息泄露。

x → Masked Self-Attention → Add & Norm → Cross-Attention → Add & Norm → FFN → Add & Norm

代码解析

class DecoderLayer(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.norm3 = LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        # Self-Attention(带 look-ahead mask)
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = x + self.dropout1(attn_output)
        x = self.norm1(x)
        # Cross-Attention: Q 来自 Decoder, K/V 来自 Encoder
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = x + self.dropout2(attn_output)
        x = self.norm2(x)
        # FFN
        ffn_output = self.ffn(x)
        x = x + self.dropout3(ffn_output)
        x = self.norm3(x)
        return x

这里有几个明显的区别:

  • Self-Attention用了tgt_mask,也就是下三角mask来遮住未来的位置;Cross-Attention则用src_mask来过滤掉Encoder那边padding的位置。
  • Cross-Attention的K和V来自Encoder,Q来自Decoder——这是一种"引导"机制,Decoder每走一步都能看到输入序列的全部信息。
  • Decoder比Encoder多了整整一套残差连接和LayerNorm,总共3个。

9. 完整 Transformer

拼装成网络

最后一步就是把N层Encoder和N层Decoder叠起来,再加上嵌入层、位置编码和最终的分类头,一个完整的Transformer就拼装完成了。

src → Embedding → Positional Encoding → N × EncoderLayer
                                ↓
tgt → Embedding → Positional Encoding → N × DecoderLayer → Linear → output

代码解析

class Transformer(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, d_model=512, n_heads=8, d_ff=2048, n_layers=6, dropout=0.1, max_len=5000):
        super().__init__()
        self.encoder_embed = nn.Embedding(src_vocab, d_model)
        self.decoder_embed = nn.Embedding(tgt_vocab, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
        self.fc_out = nn.Linear(d_model, tgt_vocab)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        # Encoder
        src_emb = self.pos_encoding(self.encoder_embed(src))
        for layer in self.encoder_layers:
            src_emb = layer(src_emb, src_mask)
        # Decoder
        tgt_emb = self.pos_encoding(self.decoder_embed(tgt))
        for layer in self.decoder_layers:
            tgt_emb = layer(tgt_emb, src_emb, src_mask, tgt_mask)
        return self.fc_out(tgt_emb)

最后这几个设计点值得记住:

  • nn.ModuleList保证了每一层的参数都能被正确注册。
  • Encoder和Decoder各自有独立的嵌入层和位置编码,互不干扰。
  • fc_out把d_model投影到词表的大小,负责输出下一个token的概率分布。

10. 总结

这份从零开始的实现,覆盖了Transformer从Scaled Dot-Product Attention到完整Encoder-Decoder架构的所有核心组件。每一行代码背后都有明确的动机和设计思考。

把这些基础组件吃透之后,再去看GPT系列那种只用Decoder的做法,或者BERT系列只用Encoder的方案,以及LLaMA这些现代变体,就能很快抓住它们各自的设计决策——知道它们在哪一步做了什么取舍,为什么那么做。

免责声明

本网站新闻资讯均来自公开渠道,力求准确但不保证绝对无误,内容观点仅代表作者本人,与本站无关。若涉及侵权,请联系我们处理。本站保留对声明的修改权,最终解释权归本站所有。

相关阅读

更多
欢迎回来 登录或注册后,可保存提示词和历史记录
登录后可同步收藏、历史记录和常用模板
注册即表示同意服务条款与隐私政策