from data_provider.data_factory import data_provider #数据提供模块
from data_provider.m4 import M4Meta #M4数据集的元数据
from exp.exp_basic import Exp_Basic #实验基类
from utils.tools import EarlyStopping, adjust_learning_rate, visual #工具函数
from utils.losses import mape_loss, mase_loss, smape_loss #自定义损失函数
from utils.m4_summary import M4Summary #M4数据集的汇总统计
import torch
import torch.nn as nn
from torch import optim
import os
import time
import warnings
import numpy as np
import pandas
warnings.filterwarnings('ignore') #忽略警告信息
#短期预测类
class Exp_Short_Term_Forecast(Exp_Basic):
#构造函数
def __init__(self, args):
super(Exp_Short_Term_Forecast, self).__init__(args)
#创建模型
def _build_model(self):
if self.args.data == 'm4':
self.args.pred_len = M4Meta.horizons_map[self.args.seasonal_patterns] # Up to M4 config
self.args.seq_len = 2 * self.args.pred_len # input_len = 2*pred_len
self.args.label_len = self.args.pred_len
self.args.frequency_map = M4Meta.frequency_map[self.args.seasonal_patterns]
model = self.model_dict[self.args.model].Model(self.args).float()
#多gpu且gpu可用
if self.args.use_multi_gpu and self.args.use_gpu:
model = nn.DataParallel(model, device_ids=self.args.device_ids)
return model
#从data_provider函数获取数据集合和数据加载器
def _get_data(self, flag):
data_set, data_loader = data_provider(self.args, flag)
return data_set, data_loader
#选择优化器,该函数使用adam优化器,从传入的参数self 添加self.args.learning_rate学习率
def _select_optimizer(self):
model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
return model_optim
#选择损失函数,MSELoss(均方误差损失)
def _select_criterion(self, loss_name='MSE'):
if loss_name == 'MSE':
return nn.MSELoss()
elif loss_name == 'MAPE':
return mape_loss()
elif loss_name == 'MASE':
return mase_loss()
elif loss_name == 'SMAPE':
return smape_loss()
#训练方法
def train(self, setting):
#获取数据
train_data, train_loader = self._get_data(flag='train')
vali_data, vali_loader = self._get_data(flag='val')
#创建模型存储文件
path = os.path.join(self.args.checkpoints, setting)
if not os.path.exists(path):
os.makedirs(path)
#获取时间戳
time_now = time