前言
Hi,我是GISerLiu🙂, 这篇文章是参加2025年5月datawhale学习赛的打卡文章!💡 本文深入剖析了基于自注意力机制的SAITS模型在时间序列缺失值插补中的应用。通过结合论文解读和PyPOTS实战案例,文章详细介绍了SAITS的双任务学习框架、自注意力机制原理及其实际应用,并提供了完整的代码实现与性能分析,帮助读者掌握这一先进的时序插补技术。
一、时间序列数据与缺失值问题
时间序列数据在现实世界中无处不在,从股票市场的价格波动到医疗监测数据,再到工业传感器的读数,都是时间序列数据的典型例子。然而,由于各种原因(如传感器故障、网络中断、人为错误等),这些数据通常存在缺失值问题,给数据分析和模型训练带来了巨大挑战。
1. 缺失值类型与影响
时间序列数据的缺失值通常可以分为三类:
缺失类型 | 描述 | 示例 |
---|---|---|
MCAR(完全随机缺失) | 数据缺失完全随机,与其他变量无关 | 传感器随机故障 |
MAR(随机缺失) | 数据缺失与其他观测变量相关 | 当温度过高时,湿度传感器失效 |
MNAR(非随机缺失) | 数据缺失与其缺失值本身相关 | 血压计无法记录过高的血压值 |
缺失值会导致:
- 统计分析偏差
- 模型训练效果下降
- 预测精度降低
2. PyPOTS的核心特性
- 统一的API:提供一致的接口,便于模型比较和集成
- 多种模型支持:包含传统统计方法和先进的深度学习方法
- GPU加速:通过PyTorch支持GPU训练,提高计算效率
- 可视化工具:内置TensorBoard支持,便于训练监控和结果分析
二、SAITS模型解析
SAITS (Self-Attention-based Imputation for Time Series) 是一种基于自注意力机制的时间序列缺失值插补模型,由Du, Cote和Liu在2023年提出,发表在《Expert systems with applications》期刊上。
1. SAITS的核心架构
SAITS模型结合了两个关键创新:
- 双任务学习框架:
- 观测重构任务(ORT):学习已知数据点的内在关系
- 掩码插补任务(MIT):直接学习预测缺失值
- 门控融合机制:根据不同任务的表现自适应调整权重
2. 自注意力机制在时序插补中的优势
传统RNN和LSTM模型在处理长序列时存在长期依赖问题,而自注意力机制能够直接建立序列中任意两个时间点之间的关联,更适合捕捉时间序列中的长期依赖关系。
模型类型 | 长距离依赖 | 并行计算 | 参数效率 |
---|---|---|---|
RNN/LSTM | 弱 | 差 | 中 |
CNN | 中 | 好 | 中 |
自注意力 | 强 | 好 | 好 |
三、实战:使用PyPOTS和SAITS进行时序插补
下面我们通过一个完整的示例,展示如何使用PyPOTS库中的SAITS模型进行时间序列插补。
1. 环境配置
PyPOTS支持多种安装方式:
# 从源码安装(最新特性)
pip install https://github.com/WenjieDu/PyPOTS/archive/refs/tags/v0.18.zip
# 从PyPI安装(稳定版)
pip install pypots==0.18
# 从conda-forge安装
conda install conda-forge::pypots=0.18
# 使用Docker
docker run -it --name pypots wenjiedu/pypots
我这里使用方法2:
2. 数据准备
我们使用PhysioNet2012数据集作为示例:
# 导入必要的库
from benchpots.datasets import preprocess_physionet2012
# 加载并预处理physionet2012数据集
physionet2012_dataset = preprocess_physionet2012(
subset="set-a",
pattern="point",
rate=0.1, # 设置缺失率为10%
)
# 构建训练、验证和测试集
train_set = {"X": physionet2012_dataset["train_X"]}
val_set = {
"X": physionet2012_dataset["val_X"],
"X_ori": physionet2012_dataset["val_X_ori"],
}
test_set = {
"X": physionet2012_dataset["test_X"],
"X_ori": physionet2012_dataset["test_X_ori"],
}
初始的时候需要时间去下载!
3. 模型训练与评估
# 导入SAITS模型
from pypots.imputation import SAITS
from pypots.optim import Adam
# 初始化SAITS模型
saits = SAITS(
n_steps=physionet2012_dataset['n_steps'],
n_features=physionet2012_dataset['n_features'],
n_layers=3,
d_model=64,
n_heads=4,
d_k=16,
d_v=16,
d_ffn=128,
dropout=0.1,
ORT_weight=1, # 观测重构任务权重
MIT_weight=1, # 掩码插补任务权重
batch_size=32,
epochs=10,
patience=3,
optimizer=Adam(lr=1e-3),
device='cpu',
saving_path="result_saving/imputation/saits",
model_saving_strategy="best",
)
# 训练模型
saits.fit(train_set, val_set)
# 使用模型进行预测
test_set_imputation_results = saits.predict(test_set)
# 评估模型性能
from pypots.nn.functional import calc_mse
test_MSE = calc_mse(
test_set_imputation_results["imputation"],
physionet2012_dataset["test_X_ori"],
physionet2012_dataset["test_X_indicating_mask"],
)
print(f"SAITS test_MSE: {test_MSE}")
方便起见,这里只进行了10轮循环;
4. 结果保存与后续分析
from pypots.data.saving import pickle_dump
# 对所有数据集进行插补
train_set_imputation = saits.impute(train_set)
val_set_imputation = saits.impute(val_set)
test_set_imputation = test_set_imputation_results["imputation"]
# 保存结果
dict_to_save={
'train_set_imputation': train_set_imputation,
'train_set_labels': physionet2012_dataset['train_y'],
'val_set_imputation': val_set_imputation,
'val_set_labels': physionet2012_dataset['val_y'],
'test_set_imputation': test_set_imputation,
'test_set_labels': physionet2012_dataset['test_y'],
}
pickle_dump(dict_to_save, "result_saving/imputed_physionet2012.pkl")
四、SAITS的工作原理深度解析
1. 双任务学习框架
SAITS的一个关键创新是采用了双任务学习框架,包括两个互补的任务:
- 观测重构任务(ORT):
- 目标:重建已知的观测值
- 作用:帮助模型学习数据的内在结构和时间依赖关系
- 计算方式:仅在有观测值的位置计算重建误差
- 掩码插补任务(MIT):
- 目标:直接预测缺失值
- 作用:让模型专注于缺失值的预测
- 计算方式:在训练过程中随机掩盖一部分已知值,让模型预测这些"伪缺失值"
结合这两个任务,SAITS可以更全面地学习时间序列的特性,提高插补准确性。
2. 自注意力机制原理
自注意力机制是SAITS的核心组件,它有效地捕捉时间序列中的长期依赖关系:
自注意力计算公式:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dkQKT)V
在SAITS中,通过多头自注意力机制,模型能够从不同角度学习时间序列中的依赖关系,更有效地捕捉复杂的时间模式。
3. 门控融合机制
SAITS的另一个创新是门控融合机制,用于自适应地结合ORT和MIT两个任务的预测结果:
X ^ = G ⊙ X ^ O R T + ( 1 − G ) ⊙ X ^ M I T \hat{X} = G \odot \hat{X}_{ORT} + (1-G) \odot \hat{X}_{MIT} X^=G⊙X^ORT+(1−G)⊙X^MIT
其中$ G $是一个门控参数,通过网络学习得到,根据不同数据点的特性自动调整两个任务的权重。这使得模型能够灵活地适应不同类型的缺失模式。
五、 实验与性能分析
1. 实验结果
SAITS在多个数据集上优于主流方法(BRITS、GRU-D、MRNN):
数据集 | SAITS | BRITS | GRU-D | MRNN |
---|---|---|---|---|
PhysioNet | 0.326 | 0.375 | 0.422 | 0.401 |
Air Quality | 0.187 | 0.221 | 0.267 | 0.258 |
KDD-CUP | 0.219 | 0.247 | 0.283 | 0.271 |
2. 核心优势
- 长期依赖捕捉:通过自注意力机制建模长序列依赖。
- 缺失模式鲁棒性:适应随机、连续、混合缺失场景。
- 高效计算:支持并行化,提升训练速度。
3. 应用场景
- 医疗:补全生命体征数据
- 金融:修复股价/交易量缺失
- 工业:传感器数据修复
- 环境:气象/空气质量数据补全
4. 调优策略
- 模型架构:调整注意力层数、头数、隐藏层维度。
- 训练参数:优化学习率、批大小、训练轮次。
- 任务权重:
- 规律性数据 → 增大ORT权重
- 高噪声数据 → 增大MIT权重
总结
PyPOTS库和SAITS模型为时间序列缺失值插补提供了一个强大的工具,通过结合自注意力机制和双任务学习框架,实现了高精度的缺失值预测。在实际应用中,应根据具体数据特性选择合适的模型参数和训练策略,以获得最佳效果。
SAITS的成功也启示我们,在时间序列分析中,捕捉长期依赖关系和设计专门针对缺失值的学习任务是提高模型性能的关键。未来的研究方向可能包括:
- 将SAITS与其他类型的神经网络结合
- 探索更复杂的融合机制
- 扩展模型以处理多变量之间的相互影响
文章参考
项目地址
如果觉得我的文章对您有帮助,三连+关注便是对我创作的最大鼓励!或者一个star🌟也可以😂.