GAT多头注意力讲解
时间: 2025-05-27 15:02:23 浏览: 25
### GAT 中多头注意力机制的工作原理及其实现
#### 工作原理
在图注意力网络 (GAT, Graph Attention Network) 中,多头注意力机制通过多个独立的注意力计算通道来增强模型的表现力和稳定性。具体来说,每个注意力头会单独学习节点之间的不同关系模式,并最终将这些结果组合起来以获得更加丰富的特征表示[^3]。
为了实现这一点,GAT 首先定义了一个共享权重矩阵 \( W \),用于线性变换输入节点特征向量。接着,在每一个注意力头上分别执行如下操作:
1. **计算注意力系数**
对于每一对相邻节点 \( i \) 和 \( j \),基于它们经过线性变换后的特征向量 \( h_iW \) 和 \( h_jW \),使用一个可训练参数向量 \( a \in R^{F'} \) 来衡量两者的重要性得分:
\[
e_{ij} = LeakyReLU(a^T [Wh_i || Wh_j])
\]
这里 \(||\) 表示拼接操作,\(LeakyReLU\) 是激活函数[^1]。
2. **标准化注意力系数**
使用 Softmax 函数对邻域内的所有节点进行归一化处理,得到最终的注意力权重:
\[
\alpha_{ij} = softmax(e_{ij}) = \frac{\exp(e_{ij})}{\sum_{k \in N(i)} \exp(e_{ik})}
\]
3. **加权求和**
将邻居节点的特征按照其对应的注意力权重聚合到中心节点上:
\[
h'_i = \sigma(\sum_{j \in N(i)} \alpha_{ij}(Wh_j))
\]
其中 \( N(i) \) 表示节点 \( i \) 的直接邻居集合,而 \( \sigma \) 则是一个非线性的激活函数。
#### 多头注意力的具体实现方式
由于单个注意力头可能无法捕捉所有的潜在关联信息,因此引入了多头注意力的概念。即并行运行若干组上述过程(通常称为 “head”),并将各 head 输出的结果串联或平均作为最后的输出。假设总共有 \( K \) 个 heads,则第 k 个 head 的输出可以写成:
\[ z_k = AGGREGATE(ATTENTION(k)(h_1,...,h_N)) \]
其中 ATTENTION(k) 表达的是第 k 种形式下的注意力建模方法;AGGREGATE 负责汇总来自各个 attention head 的贡献[^4]。
当采用 concat 方法时,
\[ h' = ||_{k=1}^K z_k \]
如果选择 average 方式则有,
\[ h' = mean(z_1 ... z_K). \]
以下是 Python 实现的一个简单例子:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttentionLayer(nn.Module):
def __init__(self, input_dim, output_dim, num_heads, leaky_relu_negative_slope=0.2):
super(MultiHeadAttentionLayer, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.num_heads = num_heads
# 定义线性变换层
self.W = nn.Parameter(torch.Tensor(input_dim, output_dim * num_heads))
self.a = nn.ParameterList([nn.Parameter(torch.Tensor(output_dim*2, 1)) for _ in range(num_heads)])
# 初始化参数
nn.init.xavier_uniform_(self.W.data)
for param in self.a:
nn.init.xavier_uniform_(param.data)
def forward(self, h, adj_matrix):
batch_size, node_num, feature_dim = h.size()
# 应用线性变换
Wh = torch.matmul(h, self.W).view(batch_size, node_num, self.num_heads, -1)
results = []
for idx_head in range(self.num_heads):
# 提取当前头部的数据
Wh_single_head = Wh[:, :, idx_head, :]
# 构建边缘特征
edge_features = self._build_edge_features_for_each_node_pair(Wh_single_head)
# 计算注意力分数
attention_scores = F.leaky_relu(
torch.matmul(edge_features, self.a[idx_head]), negative_slope=leaky_relu_negative_slope)
mask = ~adj_matrix.bool().unsqueeze(-1)
masked_attention_scores = attention_scores.masked_fill(mask, float('-inf'))
attentions = F.softmax(masked_attention_scores.squeeze(), dim=-1)
# 更新节点特征
new_h_per_head = torch.bmm(attentions.unsqueeze(dim=1), Wh_single_head).squeeze()
results.append(new_h_per_head)
final_result = torch.cat(results, dim=-1) if len(results)>1 else results[0]
return final_result
@staticmethod
def _build_edge_features_for_each_node_pair(node_embeddings):
src_nodes_emb = node_embeddings.unsqueeze(2) # shape: B x N x 1 x D'
dst_nodes_emb = node_embeddings.unsqueeze(1) # shape: B x 1 x N x D'
combined_embs = torch.cat((src_nodes_emb.expand(-1,-1,node_embeddings.shape[-2],-1),
dst_nodes_emb.expand(-1,node_embeddings.shape[-2],-1,-1)),dim=-1)
return combined_embs.reshape(combined_embs.shape[:3]+ (-1,))
```
阅读全文
相关推荐




















