1. ask bing(Temporal Convolutional Network)
问:“我对CNN、RNN、TCN等神经网络没有任何基础,你能直观地给我讲一下TCN的结构、输入输出和原理吗?”
bing对TCN的解释如下:
TCN是一种用于处理序列数据的神经网络,它可以看作是一种特殊的CNN。CNN是一种用于处理图像数据的神经网络,它可以看作是一种特殊的卷积运算。卷积运算就是将一个小的滑动窗口(称为卷积核或过滤器)在一个大的数据(称为输入或特征图)上滑动,并对每个窗口内的数据进行加权求和,得到一个新的数据(称为输出或激活图)。这样就可以提取出输入中的局部特征,并减少参数数量。
TCN与CNN不同之处在于,它只在序列数据的时间维度上进行卷积运算,而不是在空间维度上。这样就可以保持序列数据的时序信息,并且使得输出只依赖于当前及之前的输入,而不会受到未来输入的影响。这种卷积运算称为因果卷积。
TCN还使用了空洞卷积,这是一种在卷积核中插入空白位置(称为膨胀因子)的技术,使得卷积核可以覆盖更长范围的输入,而不增加参数数量。这样就可以增加感受野,即输出能够感知到输入中多远距离的信息。
TCN还使用了残差连接,这是一种将每层输入直接与输出相加(称为跳跃连接)的技术,使得网络可以更容易地学习恒等映射,并防止梯度消失或爆炸。这样就可以构建更深层次的网络,并提高性能。
TCN通常由多个重复模块组成,每个模块包含两个因果空洞卷积层和一个残差连接。每个模块都有自己独立的膨胀因子和过滤器数量。最后一个模块后面还有一个全连接层或者softmax层来输出预测结果。
TCN接收一个多维数组作为输入,其中第一维表示序列长度(即时间步数),第二维表示特征数量(即信号通道数)。例如,在生物信号处理中,如果我们有1000个时间步长和8个信号通道,则我们可以将其表示为一个1000×8 的数组作为TCN 的输入。
TCN输出一个多维数组作为预测结果,其中第一维表示序列长度(即时间步数),第二维表示类别数量(即预测目标数)。例如,在生物信号分类中,如果我们有1000个时间步长和4个类别,则我们可以将其表示为一个1000×4 的数组作为TCN 的输出。
总结一下:
- TCN是在CNN基础上衍生的一种神经网络;CNN主要用于处理数字图像,本质是一种特殊的卷积(滤波)运算,且是在空间维度进行的;而FCN使用的是因果卷积,即不是在空间维度进行卷积运算,而是对序列(sequence)数据在时间维度进行卷积运算。
- TCN在算法优化上采用了:膨胀因果卷积、空洞卷积(通过在卷积核中插入空白位置/膨胀因子,使卷积核覆盖更长范围的输入而不增加参数数量,实现增加感受野)、残差连接(可以构建更深层次的网络、提高性能)。
- TCN的结构上由多个重复模块组成,每个模块包含两个因果空洞卷积层+一个残差连接,每个模块的膨胀因子、过滤器数量都不一样;最后一个模块后再加一个全连接层/softmax层来输出预测结果。
- TCN输入输出数据:输入数据是一个多维数组,比如处理生物信号的话,第一维度表示信号序列长度/时间步数(取决于采样频率),第二维度表示特征数量/信号通道数(取决于采集了几个通道和特征处理的方法);输出数据也是多维数组,第一维度还是信号序列长度/时间步数,第二维度表示类别数量/预测目标数(比如根据多个通道的生物信号预测关节角度和时间的关系)。
2. 边跑边学
2.1 先跑起来
The Adding Problem with various T (we evaluated on T=200, 400, 600)
Copying Memory Task with various T (we evaluated on T=500, 1000, 2000)
Sequential MNIST digit classification
Permuted Sequential MNIST (based on Seq. MNIST, but more challenging)
JSB Chorales polyphonic music
Nottingham polyphonic music
PennTreebank [SMALL] word-level language modeling (LM)
Wikitext-103 [LARGE] word-level LM
LAMBADA [LARGE] word-level LM and textual understanding
PennTreebank [MEDIUM] char-level LM
text8 [LARGE] char-level LM
我这里选择了:JSB Chorales polyphonic music和Nottingham polyphonic music,对应poly_music文件夹,因为处理预测声波数据看起来和我要应用的处理生物信号数据比较相近。当然我们可以根据自己的需求选择其他的案例跑。
README中强调了,对应每个案例跑模型时只需要运行[TASK_NAME]_test.py
,比如打开music_test.py
,先让他运行着。
2.2 学习原理
2.2.1 TCN网络结构直观了解(参考:机器学习进阶之 时域/时间卷积网络 TCN 概念+由来+原理+代码实现)
这部分内容对应tcn.py
中的class Chomp1d
、class TemporalBlock
。
TCN网络结构左边一大串主要包括四个部分:
膨胀因果卷积(Dilated Causal Conv
):膨胀就是说卷积时的输入存在间隔采样;因果指每层某时刻的数据只依赖于之前层当前时刻及之前时刻的数据,与未来时刻的数据无关;卷积就是CNN的卷积(卷积核在数据上进行的一种滑动运算的操作)。
权重归一化(WeightNorm
):通过重写深度网络的权重来进行加速。代码tcn.py
中从torch.nn.utils
中调用weight_norm
使用。
激活函数(ReLU
):挺有名的,从torch.nn
中调用。
Dropout
:指在深度学习网络的训练过程中,对于神经网络单元,按照一定的概率将其暂时从网络中丢弃。意义是防止过拟合,提高模型的运算速度。
TCN网络结构右边是残差连接:
残差连接:
1*1的卷积块儿,作者说:不仅可以使网络拥有跨层传递信息的功能,而且可以保证输入输出的一致性。
2.2.2 TCN结构关系详解(参考:时间卷积网络(TCN):结构+pytorch代码,有对代码非常详细的标注)
TCN与LSTM的区别:LSTM是通过引入卷积操作使其能够处理图像信息,卷积只对一个时刻的输入图像进行操作;而TCN是利用卷积进行跨时间步提取特征。
TCN的实现——1-D FCN结构;
TCN的实现——因果卷积、膨胀因果卷积(对比膨胀非因果卷积)、残差块结构(参考ResNet,使TCN结构更具有泛化能力)
2.3测试结果
Namespace(clip=0.2, cuda=True, data='Nott', dropout=0.25, epochs=100, ksize=5, levels=4, log_interval=100, lr=0.001, nhid=150, optim='Adam', seed=1111)
loading Nott data...
Epoch 1 | lr 0.00100 | loss 24.20483
Epoch 1 | lr 0.00100 | loss 12.54757
Epoch 1 | lr 0.00100 | loss 10.34167
Epoch 1 | lr 0.00100 | loss 7.46519
Epoch 1 | lr 0.00100 | loss 6.22717
Epoch 1 | lr 0.00100 | loss 5.83443
Validation loss: 5.29395
Test loss: 5.40475
Saved model!
Epoch 2 | lr 0.00100 | loss 5.49445
Epoch 2 | lr 0.00100 | loss 5.08582
Epoch 2 | lr 0.00100 | loss 5.28078
Epoch 2 | lr 0.00100 | loss 5.21557
Epoch 2 | lr 0.00100 | loss 5.09972
Epoch 2 | lr 0.00100 | loss 4.87487
Validation loss: 4.88671
Test loss: 4.91229
Saved model!
Epoch 3 | lr 0.00100 | loss 4.55071
Epoch 3 | lr 0.00100 | loss 4.64663
Epoch 3 | lr 0.00100 | loss 4.62720
Epoch 3 | lr 0.00100 | loss 4.40581
Epoch 3 | lr 0.00100 | loss 4.54712
Epoch 3 | lr 0.00100 | loss 4.48989
Validation loss: 4.23855
Test loss: 4.27290
Saved model!
Epoch 4 | lr 0.00100 | loss 4.15868
Epoch 4 | lr 0.00100 | loss 4.19424
Epoch 4 | lr 0.00100 | loss 3.93361
Epoch 4 | lr 0.00100 | loss 3.87698
Epoch 4 | lr 0.00100 | loss 4.26776
Epoch 4 | lr 0.00100 | loss 4.32880
Validation loss: 4.00616
Test loss: 4.04496
Saved model!
Epoch 5 | lr 0.00100 | loss 4.02574
Epoch 5 | lr 0.00100 | loss 4.59837
Epoch 5 | lr 0.00100 | loss 4.17430
Epoch 5 | lr 0.00100 | loss 3.96050
Epoch 5 | lr 0.00100 | loss 4.04181
Epoch 5 | lr 0.00100 | loss 3.97466
Validation loss: 3.78924
Test loss: 3.84233
Saved model!
Epoch 6 | lr 0.00100 | loss 3.77337
Epoch 6 | lr 0.00100 | loss 3.78759
Epoch 6 | lr 0.00100 | loss 4.05782
Epoch 6 | lr 0.00100 | loss 3.61807
Epoch 6 | lr 0.00100 | loss 3.66880
Epoch 6 | lr 0.00100 | loss 3.68237
Validation loss: 3.69415
Test loss: 3.71941
Saved model!
Epoch 7 | lr 0.00100 | loss 3.71184
Epoch 7 | lr 0.00100 | loss 3.65575
Epoch 7 | lr 0.00100 | loss 3.50422
Epoch 7 | lr 0.00100 | loss 3.69709
Epoch 7 | lr 0.00100 | loss 3.39189
Epoch 7 | lr 0.00100 | loss 3.60912
Validation loss: 3.50421
Test loss: 3.52113
Saved model!
Epoch 8 | lr 0.00100 | loss 3.39342
Epoch 8 | lr 0.00100 | loss 3.45223
Epoch 8 | lr 0.00100 | loss 3.47272
Epoch 8 | lr 0.00100 | loss 3.47585
Epoch 8 | lr 0.00100 | loss 3.88333
Epoch 8 | lr 0.00100 | loss 3.51368
Validation loss: 3.45557
Test loss: 3.46655
Saved model!
Epoch 9 | lr 0.00100 | loss 3.38787
Epoch 9 | lr 0.00100 | loss 3.49427
Epoch 9 | lr 0.00100 | loss 3.42003
Epoch 9 | lr 0.00100 | loss 3.45465
Epoch 9 | lr 0.00100 | loss 3.44894
Epoch 9 | lr 0.00100 | loss 3.35138
Validation loss: 3.39177
Test loss: 3.39885
Saved model!
Epoch 10 | lr 0.00100 | loss 3.33982
Epoch 10 | lr 0.00100 | loss 3.33333
Epoch 10 | lr 0.00100 | loss 3.27813
Epoch 10 | lr 0.00100 | loss 3.39872
Epoch 10 | lr 0.00100 | loss 3.31045
Epoch 10 | lr 0.00100 | loss 3.47179
Validation loss: 3.35350
Test loss: 3.35821
Saved model!
Epoch 11 | lr 0.00100 | loss 3.24939
Epoch 11 | lr 0.00100 | loss 3.28225
Epoch 11 | lr 0.00100 | loss 3.31755
Epoch 11 | lr 0.00100 | loss 3.31538
Epoch 11 | lr 0.00100 | loss 3.34717
Epoch 11 | lr 0.00100 | loss 3.47794
Validation loss: 3.27830
Test loss: 3.28621
Saved model!
Epoch 12 | lr 0.00100 | loss 3.24459
Epoch 12 | lr 0.00100 | loss 3.26871
Epoch 12 | lr 0.00100 | loss 2.83995
Epoch 12 | lr 0.00100 | loss 3.24781
Epoch 12 | lr 0.00100 | loss 3.25777
Epoch 12 | lr 0.00100 | loss 3.09675
Validation loss: 3.25199
Test loss: 3.25987
Saved model!
Epoch 13 | lr 0.00100 | loss 3.18712
Epoch 13 | lr 0.00100 | loss 3.15744
Epoch 13 | lr 0.00100 | loss 3.08412
Epoch 13 | lr 0.00100 | loss 2.98677
Epoch 13 | lr 0.00100 | loss 3.23000
Epoch 13 | lr 0.00100 | loss 3.12484
Validation loss: 3.22609
Test loss: 3.22669
Saved model!
Epoch 14 | lr 0.00100 | loss 2.86843
Epoch 14 | lr 0.00100 | loss 3.05798
Epoch 14 | lr 0.00100 | loss 3.11845
Epoch 14 | lr 0.00100 | loss 3.14372
Epoch 14 | lr 0.00100 | loss 3.19728
Epoch 14 | lr 0.00100 | loss 3.12642
Validation loss: 3.20776
Test loss: 3.20785
Saved model!
Epoch 15 | lr 0.00100 | loss 3.09529
Epoch 15 | lr 0.00100 | loss 3.05085
Epoch 15 | lr 0.00100 | loss 3.12605
Epoch 15 | lr 0.00100 | loss 3.14538
Epoch 15 | lr 0.00100 | loss 3.09047
Epoch 15 | lr 0.00100 | loss 3.14403
Validation loss: 3.19157
Test loss: 3.19761
Saved model!
Epoch 16 | lr 0.00100 | loss 3.08435
Epoch 16 | lr 0.00100 | loss 3.06446
Epoch 16 | lr 0.00100 | loss 3.07964
Epoch 16 | lr 0.00100 | loss 2.92217
Epoch