文中内容仅限技术学习与代码实践参考,市场存在不确定性,技术分析需谨慎验证,不构成任何投资建议。
Darts 是一个 Python 库,用于对时间序列进行用户友好型预测和异常检测。它包含多种模型,从 ARIMA 等经典模型到深度神经网络。所有预测模型都能以类似 scikit-learn 的方式使用 fit()
和 predict()
函数。该库还可以轻松地对模型进行回溯测试,将多个模型的预测结果结合起来,并将外部数据考虑在内。Darts 支持单变量和多变量时间序列和模型。基于 ML 的模型可以在包含多个时间序列的潜在大型数据集上进行训练,其中一些模型还为概率预测提供了丰富的支持。
Darts 中的超参数优化
Hyperparameter Optimization in Darts
Darts 在超参数优化方面并无特殊之处。 需要重点关注的主要是:对于基于深度学习的 TorchForecastingModel,可以利用 PyTorch Lightning 的 callback 实现早停(early stopping)和实验剪枝(pruning)。下文分别以 Optuna 和 Ray Tune 为例,展示如何在 Darts 中进行超参数优化。
使用 Optuna 进行超参数优化
Optuna 是与 Darts 配合使用的优秀超参数优化工具。下面给出一个最小示例,利用 PyTorch Lightning 的 callback 实现实验剪枝。示例中,我们在单条时间序列上训练 TCNModel
,并通过最小化验证集上的预测误差来优化(可能过拟合的)超参数。
更完整的示例可参考此 notebook。
注意(2023-19-02):Optuna 的 PyTorchLightningPruningCallback
在 pytorch-lightning>=1.8 时会报错。在问题解决前,可参考此处的 workaround。
import numpy as np
import optuna
import torch
from optuna.integration import PyTorchLightningPruningCallback
from pytorch_lightning.callbacks import Callback, EarlyStopping
from sklearn.preprocessing import MaxAbsScaler
from darts.dataprocessing.transformers import Scaler
from darts.datasets import AirPassengersDataset
from darts.metrics import smape
from darts.models import TCNModel
from darts.utils.likelihood_models.torch import GaussianLikelihood
# 加载数据
series = AirPassengersDataset().load().astype(np.float32)
# 划分训练 / 验证(注意:实际应用中还需额外划分测试集)
VAL_LEN = 36
train, val = series[:-VAL_LEN], series[-VAL_LEN:]
# 缩放
scaler = Scaler(MaxAbsScaler())
train = scaler.fit_transform(train)
val = scaler.transform(val)
# workaround 来源:https://github.com/Lightning-AI/pytorch-lightning/issues/17485
# 避免同时导入 lightning 和 pytorch_lightning
class PatchedPruningCallback(optuna.integration.PyTorchLightningPruningCallback, Callback):
pass
# 定义目标函数
def objective(trial):
# 选择输入与输出 chunk 长度
in_len = trial.suggest_int("in_len", 12, 36)
out_len = trial.suggest_int("out_len", 1, in_len - 1)
# 其余超参数
kernel_size = trial.suggest_int("kernel_size", 2, 5)
num_filters = trial.suggest_int("num_filters", 1, 5)
weight_norm = trial.suggest_categorical("weight_norm", [False, True])
dilation_base = trial.suggest_int("dilation_base", 2, 4)
dropout = trial.suggest_float("dropout", 0.0, 0.4)
lr = trial.suggest_float("lr", 5e-5, 1e-3, log=True)
include_year = trial.suggest_categorical("year", [False, True])
# 训练过程中,通过验证损失进行剪枝与早停
pruner = PatchedPruningCallback(trial, monitor="val_loss")
early_stopper = EarlyStopping("val_loss", min_delta=0.001, patience=3, verbose=True)
callbacks = [pruner, early_stopper]
# 检测是否可用 GPU
if torch.cuda.is_available():
num_workers = 4
else:
num_workers = 0
pl_trainer_kwargs = {
"accelerator": "auto",
"callbacks": callbacks,
}
# 可选:将(缩放后的)年份作为过去协变量
if include_year:
encoders = {"datetime_attribute": {"past": ["year"]},
"transformer": Scaler()}
else:
encoders = None
# 可复现性
torch.manual_seed(42)
# 构建 TCN 模型
model = TCNModel(
input_chunk_length=in_len,
output_chunk_length=out_len,
batch_size=32,
n_epochs=100,
nr_epochs_val_period=1,
kernel_size=kernel_size,
num_filters=num_filters,
weight_norm=weight_norm,
dilation_base=dilation_base,
dropout=dropout,
optimizer_kwargs={"lr": lr},
add_encoders=encoders,
likelihood=GaussianLikelihood(),
pl_trainer_kwargs=pl_trainer_kwargs,
model_name="tcn_model",
force_reset=True,
save_checkpoints=True,
)
# 验证时,可额外包含 input_chunk_length 长度的历史数据
model_val_set = scaler.transform(series[-(VAL_LEN + in_len):])
# 训练模型
model.fit(
series=train,
val_series=model_val_set,
)
# 加载训练过程中的最佳模型
model = TCNModel.load_from_checkpoint("tcn_model")
# 在验证集上评估模型,使用 sMAPE
preds = model.predict(series=train, n=VAL_LEN)
smapes = smape(val, preds, n_jobs=-1, verbose=True)
smape_val = np.mean(smapes)
return smape_val if smape_val != np.nan else float("inf")
# 为方便查看,打印试验信息
def print_callback(study, trial):
print(f"Current value: {trial.value}, Current params: {trial.params}")
print(f"Best value: {study.best_value}, Best params: {study.best_trial.params}")
# 通过最小化验证集 sMAPE 优化超参数
if __name__ == "__main__":
study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=100, callbacks=[print_callback])
使用 Ray Tune 进行超参数优化
Ray Tune 是另一种支持自动剪枝的超参数优化方案。
以下示例展示了如何将 Ray Tune 与 NBEATSModel
结合,并使用 Asynchronous Hyperband 调度器。示例在 ray==2.32.0
下测试通过。
import numpy as np
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from ray import tune
from ray.train import RunConfig
from ray.tune import CLIReporter
from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback
from ray.tune.schedulers import ASHAScheduler
from ray.tune.tuner import Tuner
from torchmetrics import (
MeanAbsoluteError,
MeanAbsolutePercentageError,
MetricCollection,
)
from darts.dataprocessing.transformers import Scaler
from darts.datasets import AirPassengersDataset
from darts.models import NBEATSModel
def train_model(model_args, callbacks, train, val):
torch_metrics = MetricCollection(
[MeanAbsolutePercentageError(), MeanAbsoluteError()]
)
# 使用 Ray Tune 提供的 model_args 创建模型
model = NBEATSModel(
input_chunk_length=24,
output_chunk_length=12,
n_epochs=100,
torch_metrics=torch_metrics,
pl_trainer_kwargs={"callbacks": callbacks, "enable_progress_bar": False},
**model_args,
)
model.fit(
series=train,
val_series=val,
)
# 读取数据
series = AirPassengersDataset().load().astype(np.float32)
# 创建训练与验证集
train, val = series.split_after(pd.Timestamp(year=1957, month=12, day=1))
# 标准化时间序列(注意:避免在验证集上拟合 transformer)
transformer = Scaler()
transformer.fit(train)
train = transformer.transform(train)
val = transformer.transform(val)
# 早停 callback
my_stopper = EarlyStopping(
monitor="val_MeanAbsolutePercentageError",
patience=5,
min_delta=0.05,
mode="min",
)
# 设置 ray tune callback
class TuneReportCallback(TuneReportCheckpointCallback, pl.Callback):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
tune_callback = TuneReportCallback(
{
"loss": "val_loss",
"MAPE": "val_MeanAbsolutePercentageError",
},
on="validation_end",
)
# 定义 Ray Tune 将调优的 trainable 函数
train_fn_with_parameters = tune.with_parameters(
train_model,
callbacks=[tune_callback, my_stopper],
train=train,
val=val,
)
# 设置每个试验使用的资源(无 GPU 时禁用)
resources_per_trial = {"cpu": 8, "gpu": 1}
# 定义超参数空间
config = {
"batch_size": tune.choice([16, 32, 64, 128]),
"num_blocks": tune.choice([1, 2, 3, 4, 5]),
"num_stacks": tune.choice([32, 64, 128]),
"dropout": tune.uniform(0, 0.2),
}
# 尝试的组合数量
num_samples = 10
# 配置 ASHA 调度器
scheduler = ASHAScheduler(max_t=1000, grace_period=3, reduction_factor=2)
# 配置 CLI 报告器以显示进度
reporter = CLIReporter(
parameter_columns=list(config.keys()),
metric_columns=["loss", "MAPE", "training_iteration"],
)
# 创建 Tuner 对象并运行超参数搜索
tuner = Tuner(
trainable=tune.with_resources(
train_fn_with_parameters, resources=resources_per_trial
),
param_space=config,
tune_config=tune.TuneConfig(
metric="MAPE", mode="min", num_samples=num_samples, scheduler=scheduler
),
run_config=RunConfig(name="tune_darts", progress_reporter=reporter),
)
results = tuner.fit()
# 打印找到的最佳超参数
print("Best hyperparameters found were: ", results.get_best_result().config)
使用 gridsearch()
进行超参数优化
Darts 中的每个预测模型均提供 gridsearch()
方法,用于最基本的超参数搜索。此方法仅适用于极简单场景,超参数数量极少,且只能处理单一时间序列。
常见问题
Frequently Asked Questions
-
Darts 只是对其他库的封装吗?
不是。在合适的情况下,我们会复用已有的实现(例如来自 statsforecasts),但我们也经常自行实现(例如神经网络)。此外,Darts 中的模型通常比原有版本具备更多功能。例如,与原始版本不同,我们的 N-BEATS 实现支持多元时间序列、过去协变量以及概率预测。
-
Darts 看起来很棒,我可以贡献吗?
当然可以!我们始终欢迎社区贡献。若您参与贡献,您将被列入“荣誉墙”(即变更日志)!贡献不限于代码,也可以是文档等。此外,我们也乐于在 GitHub 上以 issue 的形式接收建议。贡献者最佳起点是贡献指南。
-
我想贡献一个新模型到 Darts,可以吗?
一般而言可以,我们欢迎新的参考实现。不过,我们会进行松散筛选,仅保留经典模型,或已在论文或其他形式证据中被证明在某方面达到 SOTA 的模型。
-
如何让 Darts 在 Google Colab 上运行?
Colab 可能对最新版 pyyaml 存在兼容问题。在安装 Darts 前先安装 pyyaml 5.4.1 可解决:
!pip install pyyaml==5.4.1
-
我在预测中得到了 NaN,该怎么办?
通常意味着以下两种情况之一:
- 训练用的
TimeSeries
(目标或协变量)中存在 NaN。这是最常见的情况,大多数模型会因此始终预测 NaN。请注意,将pd.DataFrame
转换为TimeSeries
时,若存在缺失日期且freq
参数设置错误,也可能引入 NaN。 - 训练过程出现数值发散。若使用神经网络,请确保数据已正确缩放;若问题依旧,可尝试降低学习率。
- 训练用的
-
我的预测模型效果不佳,能帮忙吗?
获得良好预测不仅仅是调用
fit()
/predict()
那么简单,还需进行数据科学工作以理解哪些方法适用。我们无法给出通用答案;若您面临重要预测问题或需将预测工业化,Unit8 提供技术咨询,欢迎联系我们。
风险提示与免责声明
本文内容基于公开信息研究整理,不构成任何形式的投资建议。历史表现不应作为未来收益保证,市场存在不可预见的波动风险。投资者需结合自身财务状况及风险承受能力独立决策,并自行承担交易结果。作者及发布方不对任何依据本文操作导致的损失承担法律责任。市场有风险,投资须谨慎。