Transformer 多头注意力机制详解(示例详细推导)
通过一个简单的例子来详细说明 Transformer 的多头注意力机制是如何工作的。假设我们有一个非常短的句子,只包含3个单词,我们用这个例子来逐步计算。
1. 词嵌入(Embedding)
首先,我们需要将每个单词转换为向量。假设我们使用一个简单的嵌入矩阵,将每个单词映射到一个二维向量:
单词 | 嵌入向量 |
---|---|
Hello | [1, 2] |
world | [3, 4] |
! | [5, 6] |
所以,输入的嵌入矩阵XXX 是:
X=[123456] X = \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix} X=135246
2. 多头注意力(Multi-Head Attention)
假设我们只用一个头来简化计算。在实际的 Transformer 中,通常会有多个头(比如8个或16个),但原理是一样的。
步骤1:线性变换
在多头注意力中,输入XXX 会被分别投影到三个不同的矩阵:查询(Query)、键(Key)和值(Value)。假设我们随机初始化以下线性变换矩阵:
WQ=[0.10.20.30.4],WK=[0.50.60.70.8],WV=[0.91.01.11.2] W^Q = \begin{bmatrix} 0.1 & 0.2 \\ 0.3 & 0.4 \end{bmatrix}, \quad W^K = \begin{bmatrix} 0.5 & 0.6 \\ 0.7 & 0.8 \end{bmatrix}, \quad W^V = \begin{bmatrix} 0.9 & 1.0 \\ 1.1 & 1.2 \end{bmatrix} WQ=[0.10.30.20.4],WK=[0.50.70.60.8],WV=[0.91.11.01.2]
通过矩阵乘法,我们得到QQQ、KKK 和VVV:
Q=XWQ=[123456][0.10.20.30.4]=[0.71.01.52.22.33.4] Q = XW^Q = \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix} \begin{bmatrix} 0.1 & 0.2 \\ 0.3 & 0.4 \end{bmatrix} = \begin{bmatrix} 0.7 & 1.0 \\ 1.5 & 2.2 \\ 2.3 & 3.4 \end{bmatrix} Q=XWQ=135246[0.10.30.20.4]=0.71.52.31.02.23.4
K=XWK=[123456][0.50.60.70.8]=[1.92.24.35.06.77.8] K = XW^K = \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix} \begin{bmatrix} 0.5 & 0.6 \\ 0.7 & 0.8 \end{bmatrix} = \begin{bmatrix} 1.9 & 2.2 \\ 4.3 & 5.0 \\ 6.7 & 7.8 \end{bmatrix} K=XWK=135246[0.50.70.60.8]=1.94.36.72.25.07.8
V=XWV=[123456][0.91.01.11.2]=[3.13.46.97.610.711.8] V = XW^V = \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix} \begin{bmatrix} 0.9 & 1.0 \\ 1.1 & 1.2 \end{bmatrix} = \begin{bmatrix} 3.1 & 3.4 \\ 6.9 & 7.6 \\ 10.7 & 11.8 \end{bmatrix} V=XWV=135246[0.91.11.01.2]=3.16.910.73.47.611.8
步骤2:计算注意力分数
接下来,我们计算QQQ 和KKK 的点积,并除以dk\sqrt{d_k}dk(这里 dk=2d_k = 2dk=2):
QKT=[0.71.01.52.22.33.4][1.94.36.72.25.07.8]=[3.177.0310.895.7713.1320.498.3719.2330.09] QK^T = \begin{bmatrix} 0.7 & 1.0 \\ 1.5 & 2.2 \\ 2.3 & 3.4 \end{bmatrix} \begin{bmatrix} 1.9 & 4.3 & 6.7 \\ 2.2 & 5.0 & 7.8 \end{bmatrix} = \begin{bmatrix} 3.17 & 7.03 & 10.89 \\ 5.77 & 13.13 & 20.49 \\ 8.37 & 19.23 & 30.09 \end{bmatrix} QKT=0.71.52.31.02.23.4[1.92.24.35.06.77.8]=3.175.778.377.0313.1319.2310.8920.4930.09
除以2≈1.414\sqrt{2} \approx 1.4142≈1.414:
QKT2=[2.244.977.704.089.2914.495.9213.6021.28] \frac{QK^T}{\sqrt{2}} = \begin{bmatrix} 2.24 & 4.97 & 7.70 \\ 4.08 & 9.29 & 14.49 \\ 5.92 & 13.60 & 21.28 \end{bmatrix} 2QKT=2.244.085.924.979.2913.607.7014.4921.28
步骤3:应用 softmax
对每一行应用 softmax 函数,将分数转换为概率(权重):
softmax(QKT2)=[0.0020.0120.9860.0000.0010.9990.0000.0001.000] \text{softmax}\left(\frac{QK^T}{\sqrt{2}}\right) = \begin{bmatrix} 0.002 & 0.012 & 0.986 \\ 0.000 & 0.001 & 0.999 \\ 0.000 & 0.000 & 1.000 \end{bmatrix} softmax(2QKT)=0.0020.0000.0000.0120.0010.0000.9860.9991.000
步骤4:加权求和
最后,将这些权重与VVV 相乘,得到每个单词的加权表示:
Attention(Q,K,V)=[0.0020.0120.9860.0000.0010.9990.0000.0001.000][3.13.46.97.610.711.8]=[10.6511.7510.6911.7910.7011.80]
\text{Attention}(Q, K, V) = \begin{bmatrix}
0.002 & 0.012 & 0.986 \\
0.000 & 0.001 & 0.999 \\
0.000 & 0.000 & 1.000
\end{bmatrix}
\begin{bmatrix}
3.1 & 3.4 \\
6.9 & 7.6 \\
10.7 & 11.8
\end{bmatrix}
= \begin{bmatrix}
10.65 & 11.75 \\
10.69 & 11.79 \\
10.70 & 11.80
\end{bmatrix}
Attention(Q,K,V)=0.0020.0000.0000.0120.0010.0000.9860.9991.0003.16.910.73.47.611.8=10.6510.6910.7011.7511.7911.80
3. 多头注意力(如果有多个头)
如果有多个头,比如2个头,我们会重复上述过程,使用不同的WQ,WK,WVW^Q, W^K, W^VWQ,WK,WV。然后将所有头的输出拼接起来,并通过一个线性变换WOW^OWO:
MultiHead(Q,K,V)=Concat(head1,head2)WO \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2)W^O MultiHead(Q,K,V)=Concat(head1,head2)WO
4. 前馈神经网络(FFN)
假设我们对每个单词的向量通过一个简单的前馈网络:
FFN(x)=max(0,xW1+b1)W2+b2 \text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2 FFN(x)=max(0,xW1+b1)W2+b2
假设:
W1=[0.10.20.30.4],b1=[0.10.1],W2=[0.50.60.70.8],b2=[0.10.1] W_1 = \begin{bmatrix} 0.1 & 0.2 \\ 0.3 & 0.4 \end{bmatrix}, \quad b_1 = \begin{bmatrix} 0.1 \\ 0.1 \end{bmatrix}, \quad W_2 = \begin{bmatrix} 0.5 & 0.6 \\ 0.7 & 0.8 \end{bmatrix}, \quad b_2 = \begin{bmatrix} 0.1 \\ 0.1 \end{bmatrix} W1=[0.10.30.20.4],b1=[0.10.1],W2=[0.50.70.60.8],b2=[0.10.1]
以第一个单词为例:
x=[10.65,11.75] x = [10.65, 11.75] x=[10.65,11.75]
xW1+b1=[10.65,11.75][0.10.20.30.4]+[0.10.1]=[5.657.85] xW_1 + b_1 = [10.65, 11.75] \begin{bmatrix} 0.1 & 0.2 \\ 0.3 & 0.4 \end{bmatrix} +\begin{bmatrix} 0.1 \\ 0.1 \end{bmatrix} = \begin{bmatrix} 5.65 \\ 7.85 \end{bmatrix} xW1+b1=[10.65,11.75][0.10.30.20.4]+[0.10.1]=[5.657.85]
ReLU(xW1+b1)=[5.657.85] \text{ReLU}(xW_1 + b_1) = \begin{bmatrix} 5.65 \\ 7.85 \end{bmatrix} ReLU(xW1+b1)=[5.657.85]
FFN(x)=[5.657.85][0.50.60.70.8]+[0.10.1]=[8.4510.65] \text{FFN}(x) = \begin{bmatrix} 5.65 \\ 7.85 \end{bmatrix} \begin{bmatrix} 0.5 & 0.6 \\ 0.7 & 0.8 \end{bmatrix} +\begin{bmatrix} 0.1 \\ 0.1 \end{bmatrix} = \begin{bmatrix} 8.45 \\ 10.65 \end{bmatrix} FFN(x)=[5.657.85][0.50.70.60.8]+[0.10.1]=[8.4510.65]
通过这个例子,我们可以看到 Transformer 的核心操作:
- 词嵌入:将单词转换为向量。
- 多头注意力:通过线性变换、点积、softmax 和加权求和,捕捉单词之间的关系。
- 前馈神经网络:进一步处理每个单词的向量。
这个过程会重复多次(多层 Transformer),最终输出用于下游任务(如翻译、分类等)。