python 多头注意力机制 时间序列
时间: 2025-07-07 16:12:14 浏览: 8
在处理时间序列数据时,多头注意力机制(Multi-Head Attention Mechanism)能够有效地捕捉时间序列中的长期依赖关系和局部模式。该机制通过并行计算多个注意力头来增强模型对不同特征子空间的表达能力,并最终将这些头的结果进行拼接与线性变换以获得更强的表示。
以下是一个使用 PyTorch 实现的多头注意力机制用于时间序列数据的示例:
### 多头注意力机制实现
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, embed_size, num_heads, dropout=0.1):
"""
初始化多头注意力机制
:param embed_size: 输入嵌入维度
:param num_heads: 注意力头的数量
:param dropout: Dropout 比例
"""
super(MultiHeadAttention, self).__init__()
self.embed_size = embed_size
self.num_heads = num_heads
self.head_dim = embed_size // num_heads
assert (
self.head_dim * num_heads == embed_size
), "Embedding size needs to be divisible by num_heads"
# 线性投影层
self.values = nn.Linear(self.head_dim, embed_size, bias=False)
self.keys = nn.Linear(self.head_dim, embed_size, bias=False)
self.queries = nn.Linear(self.head_dim, embed_size, bias=False)
self.fc_out = nn.Linear(embed_size, embed_size)
self.dropout = nn.Dropout(dropout)
def forward(self, values, keys, queries, mask=None):
"""
前向传播
:param values: 值张量 (batch_size, seq_len, embed_size)
:param keys: 键张量 (batch_size, seq_len, embed_size)
:param queries: 查询张量 (batch_size, seq_len, embed_size)
:param mask: 掩码张量 (batch_size, 1, 1, seq_len)
:return: 输出张量和注意力权重
"""
batch_size = queries.shape[0]
value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]
# 将输入分割为多个头
values = values.reshape(batch_size, value_len, self.num_heads, self.head_dim)
keys = keys.reshape(batch_size, key_len, self.num_heads, self.head_dim)
queries = queries.reshape(batch_size, query_len, self.num_heads, self.head_dim)
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(queries)
# 计算缩放点积注意力
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) # (batch_size, heads, query_len, key_len)
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3) # (batch_size, heads, query_len, key_len)
attention = self.dropout(attention)
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
batch_size, query_len, self.embed_size
) # (batch_size, query_len, embed_size)
out = self.fc_out(out)
return out, attention
```
### 使用示例
假设我们有一个时间序列数据集,每个样本包含 `seq_len` 个时间步,每个时间步有 `input_dim` 维特征。我们可以将数据送入模型中进行处理。
```python
# 示例参数
batch_size = 32
seq_len = 50
embed_size = 64
num_heads = 8
# 随机生成一个时间序列输入
x = torch.randn(batch_size, seq_len, embed_size)
# 初始化多头注意力模块
mha = MultiHeadAttention(embed_size=embed_size, num_heads=num_heads)
# 前向传播
output, attn_weights = mha(x, x, x)
print(f"Output shape: {output.shape}") # (batch_size, seq_len, embed_size)
print(f"Attention weights shape: {attn_weights.shape}") # (batch_size, num_heads, seq_len, seq_len)
```
### 时间序列数据适配建议
对于时间序列任务,通常会结合位置编码(Positional Encoding)以保留时间顺序信息。可以使用正弦和余弦函数生成的位置编码或学习得到的位置嵌入。
```python
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return x
```
### 总结
上述实现展示了如何使用 PyTorch 构建适用于时间序列数据的多头注意力机制。该方法能够有效捕捉时间序列中的复杂依赖关系,提升模型在预测、分类等任务上的性能。此外,为了进一步优化时间序列建模效果,还可以结合 Transformer 的其他组件如前馈神经网络(Feed-Forward Network)、层归一化(Layer Normalization)等[^1]。
---
阅读全文
相关推荐



















