注意力机制的具体实行代码
时间: 2025-05-20 21:41:03 浏览: 23
### 注意力机制的具体实现代码
以下是一些常见的注意力机制的实现代码示例,涵盖了多种类型的注意力机制,包括Bahdanau注意力、Luong注意力以及自注意力(Self-Attention)。
#### 1. Bahdanau Attention 实现
Bahdanau注意力是一种加性注意力机制,它通过计算隐藏状态和上下文向量之间的相似性来生成注意力权重。以下是基于PyTorch的一个简单实现:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class BahdanauAttention(nn.Module):
def __init__(self, hidden_size):
super(BahdanauAttention, self).__init__()
self.Wa = nn.Linear(hidden_size, hidden_size)
self.Ua = nn.Linear(hidden_size, hidden_size)
self.Va = nn.Linear(hidden_size, 1)
def forward(self, query, keys):
# query: (batch_size, hidden_size)
# keys: (batch_size, seq_len, hidden_size)
query = query.unsqueeze(1) # (batch_size, 1, hidden_size)
scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys))) # (batch_size, seq_len, 1)
weights = F.softmax(scores, dim=1) # (batch_size, seq_len, 1)
context = torch.sum(weights * keys, dim=1) # (batch_size, hidden_size)
return context, weights.squeeze(2)
```
---
#### 2. Luong Attention 实现
Luong注意力是一种乘法注意力机制,它的特点是更直观且易于实现。下面是一个简单的PyTorch实现:
```python
class LuongAttention(nn.Module):
def __init__(self, method='dot', hidden_size=256):
super(LuongAttention, self).__init__()
self.method = method
if self.method not in ['dot', 'general', 'concat']:
raise ValueError('Unsupported attention type')
if self.method == 'general':
self.attn = nn.Linear(hidden_size, hidden_size)
def dot_score(self, hidden, encoder_outputs):
return torch.sum(hidden * encoder_outputs, dim=2)
def general_score(self, hidden, encoder_outputs):
energy = self.attn(encoder_outputs)
return torch.sum(hidden * energy, dim=2)
def concat_score(self, hidden, encoder_outputs):
expand_hidden = hidden.expand(encoder_outputs.size(0), -1, -1)
energy = torch.tanh(self.attn(torch.cat((expand_hidden, encoder_outputs), 2)))
return torch.sum(self.v * energy, dim=2)
def forward(self, hidden, encoder_outputs):
if self.method == 'dot':
attn_energies = self.dot_score(hidden, encoder_outputs)
elif self.method == 'general':
attn_energies = self.general_score(hidden, encoder_outputs)
elif self.method == 'concat':
attn_energies = self.concat_score(hidden, encoder_outputs)
attn_weights = F.softmax(attn_energies, dim=1).unsqueeze(1)
context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
return context, attn_weights
```
---
#### 3. Self-Attention 实现
自注意力机制广泛应用于Transformer架构中,用于捕捉序列内部的关系。下面是基于TensorFlow/Keras的一个简单实现:
```python
import tensorflow as tf
from tensorflow.keras.layers import Layer, Dense
class MultiHeadSelfAttention(Layer):
def __init__(self, embed_dim, num_heads=8):
super(MultiHeadSelfAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
assert embed_dim % num_heads == 0
self.projection_dim = embed_dim // num_heads
self.query_dense = Dense(embed_dim)
self.key_dense = Dense(embed_dim)
self.value_dense = Dense(embed_dim)
self.combine_heads = Dense(embed_dim)
def call(self, inputs):
batch_size = tf.shape(inputs)[0]
query = self.query_dense(inputs) # (batch_size, seq_len, embed_dim)
key = self.key_dense(inputs) # (batch_size, seq_len, embed_dim)
value = self.value_dense(inputs) # (batch_size, seq_len, embed_dim)
query = self.split_heads(query, batch_size) # (batch_size, heads, seq_len_q, depth)
key = self.split_heads(key, batch_size) # (batch_size, heads, seq_len_k, depth)
value = self.split_heads(value, batch_size) # (batch_size, heads, seq_len_v, depth)
scaled_attention, _ = self.scaled_dot_product_attention(query, key, value)
scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, heads, depth)
concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.embed_dim)) # (batch_size, seq_len_q, embed_dim)
outputs = self.combine_heads(concat_attention) # (batch_size, seq_len_q, embed_dim)
return outputs
def split_heads(self, x, batch_size):
"""Split the last dimension into (num_heads, projection_dim)."""
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.projection_dim))
return tf.transpose(x, perm=[0, 2, 1, 3])
@staticmethod
def scaled_dot_product_attention(q, k, v):
matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k)
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k)
output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v)
return output, attention_weights
```
---
#### 4. 轴注意(Axial Attention)
轴注意旨在高效处理高维数据,例如图像或视频。以下是从`axial-attention`库中简化的一个实现片段[^2]:
```python
class AxialAttentionBlock(nn.Module):
def __init__(self, in_channels, out_channels, groups=8):
super(AxialAttentionBlock, self).__init__()
self.conv_h = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 3), padding=(0, 1), groups=groups)
self.conv_w = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 1), padding=(1, 0), groups=groups)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
identity = x
n, c, h, w = x.size()
x_h = self.conv_h(x)
x_w = self.conv_w(x.permute(0, 1, 3, 2)).permute(0, 1, 3, 2)
out = self.bn(x_h + x_w)
out += identity
out = self.relu(out)
return out
```
---
###
阅读全文
相关推荐


















