一 背景
既然有了RNN,为何又需要LSTM呢?
循环神经网络RNN的网络结构使得它可以使用历史信息来帮助当前的决策。例如使用之前出现的单词来加强对当前文字的理解。可以解决传统神经网络模型不能充分利用上下文的信息增益的问题,但是同时,这也带来了更大的技术挑战 – 长期依赖问题(long-term dependencies)。
-
在某些场景中,模型需要短期内的信息来进行增强,比如模型尝试去预测短语“天空非常蓝”,并不需要记忆这个短语之前更长的上下文,因为相关的信息与待预测的词的空间非常接近。
-
在某些场景中,模型需要长期的信息来进行增强,比如当模型尝试去预测“xxx市建设了大量的化工厂,水污染非常严重,已经严重的影响了当地的生态环境,同时空气污染十分严重 … 天空都是灰色的”的最后一个单词灰色的时候,仅仅依赖短期的信息无法进行准确的预测,模型需要记忆更长的信息上下文来进行判断。
在这种长期上下文依赖的情况下,RNN的网路结构会丧失学习到如此长距离信息的能力。
那么如何解决呢?
解决方法我们可以联想我们的记忆能力,其实我们的记忆是有选择的,有些重要的事情印象特别深刻,有些非重要的事情就已经模糊,这样才可以在有限的脑力的情况下记忆尽量重要的事情,总结下来就是记忆+忘记,只有忘记才能记忆。
将这个思想运用到模型里面,就引出了本章要讲述的内容LSTM(long short term memory,LSTM)。
二 LSTM
长短时记忆网络(Long Short Term Memory Network)LSTM,是一种改进之后的循环神经网络,通过门控机制有选择的记忆重要的内容,可以解决RNN无法处理长距离的依赖的问题,目前比较流行。LSTM结构(图右)和普通RNN的主要输入输出区别如下所示。
相比RNN只有一个传输状态 h t h_t ht,LSTM有两个传输状态:
- 一个状态是 c t c_t ct(cell state)
- 一个状态是 h t h_t ht(hidden state)。
其中对于传输状态的 c t c_t ct改变很慢,而 h t h_t ht则在不同时刻下往往会有很大的区别。
LSTM是有三个gate,当外界某个neural的output想要被写到memory cell里面的时候,必须通过一个input Gate,那这个input Gate要被打开的时候,你才能把值写到memory cell里面去,如果把这个关起来的话,就没有办法把值写进去。至于input Gate是打开还是关起来,这个是neural network自己学的(它可以自己学说,它什么时候要把input Gate打开,什么时候要把input Gate关起来)。那么输出的地方也有一个output Gate,这个output Gate会决定说,外界其他的neural可不可以从这个memory里面把值读出来(把output Gate关闭的时候是没有办法把值读出来,output Gate打开的时候,才可以把值读出来)。那跟input Gate一样,output Gate什么时候打开什么时候关闭,network是自己学到的。那第三个gate叫做forget Gate,forget Gate决定说:什么时候memory cell要把过去记得的东西忘掉。这个forget Gate什么时候会把存在memory的值忘掉,什么时候会把存在memory里面的值继续保留下来),这也是network自己学到的。
同时相对于RNN的neuron,LSTM的neuron要做的事情其实就是将原来简单的neuron换成LSTM。 i n p u t [ h t , x t ] input[h_t, x_t ] input[ht,xt]会乘以不同的weight当做LSTM不同的输入(假设我们这个hidden layer只有两个neuron,但实际上是有很多的neuron)。
- i n p u t [ h t , x t ] input[h_t, x_t ] input[ht,xt]乘以不同的weight当做forget gate。
- i n p u t [ h t , x t ] input[h_t, x_t ] input[ht,xt]乘以不同的weight操控input gate。
- i n p u t [ h t , x t ] input[h_t, x_t ] input[ht,xt]乘以不同的weight会去操控output gate。
- i n p u t [ h t , x t ] input[h_t, x_t ] input[ht,xt]乘以不同的weight当做底下的input。
第二个LSTM也是一样的。所以LSTM是有四个input(一个是想要被存在memory cell的值(但它不一定存的进去)还有操控input Gate的讯号,操控output Gate的讯号,操控forget Gate的讯号,有着四个input)跟一个output,所以LSTM需要的参数量(假设你现在用的neural的数目跟LSTM是一样的)是一般neural network的四倍。这个跟Recurrent Neural Network 的关系是什么,这个看起来好像不一样,所以我们要画另外一张图来表示。
2.1 Sigmoid函数
Sigmoid函数是深度学习神经网络常用的激活函数,他的特点值在0和1之间。用于LSTM的门槛,那么离0越近表示忘记,离1越近表示保留。
2.2 Forget Gate
遗忘门主要是对上一个时刻传进来的输入进行选择性忘记。简单来说就是会 “忘记不重要的,记住重要的”,是针对长期依赖的不重要的信息的遗忘。
算法流程:
- 时刻t的输入 X t X_t Xt与上一个时刻 t − 1 t-1 t−1的隐藏层状态 h t − 1 h_{t-1} ht−1进行向量拼接。
- 拼接后的向量经过 S i g m o d Sigmod Sigmod函数(门控),进行忘记门控的计算。
- 该门控后续会和上个时刻的状态 C t − 1 C_{t-1} Ct−1进行权重控制(矩阵星乘),决定哪些信息信息需要忘记,哪些信息需要保留。
2.3 Input Gate
输入门主要是将这个时刻的输入有选择性地进行记忆。通过输入门控,对于当前时刻的输入信息进行控制,更多的记录下重要的信息,对于不重要的信息则少记一些。
算法流程:
- 时刻t的输入 x t x_t xt与上一个时刻 t − 1 t-1 t−1的隐藏层状态 h t − 1 h_{t-1} ht−1进行向量拼接。
- 拼接后的向量分别进入两个全连接网络
- 门控网络:单层全连接,通过激活函数 S i g m o i d Sigmoid Sigmoid计算门控值得到 i t i_t it。
- 输入网络:单层全连接,通过激活函数 t a n h tanh tanh计算得到值candidate。
- 经过 S i g m o d Sigmod Sigmod函数(门控),进行忘记门控的计算。
- 该门控会和当前时刻的状态 C o n d i d a t e Condidate Condidate进行权重控制(矩阵星乘),首先,我们将之前的隐藏状态 h t − 1 h_{t-1} ht−1和输入 x t x_t xt,经过全连接网络使用激活函数sigmoid。它通过将值转换为0到1之间来决定哪些值将被更新。0表示不重要,1表示重要。同时还将之前的隐藏状态 h t − 1 h_{t-1} ht−1和输入 x t x_t xt,经过全连接网络传使用激活函数tanh,以压缩-1和1之间的值,以帮助调节网络。然后将tanh输出和sigmoid输出相乘。sigmoid输出将决定哪些信息是重要的,而不是tanh输出。生成向量 i t ∗ c i i_t * c_i it∗ci,如下图
2.4 Cell State
下面我们看下当前时刻的 c t c_t ct是如何计算的,状态 c t c_t ct作为LSTM循环体中传输的状态之一,它的状态的改变对于整个LSTM的中长期记忆与遗忘起到关键的作用。
算法流程:
- 首先,时刻t的遗忘门值 f t f_t ft与上一个时刻 t − 1 t-1 t−1的状态 c t − 1 c_{t-1} ct−1进行向量星乘,生成 f t ∗ c t − 1 f_t * c_{t-1} ft∗ct−1。
- 然后,将上面输入门计算的结果 i t ∗ c t i_t * c_t it∗ct与上一步生成的结果 f t ∗ c t − 1 f_t * c_{t-1} ft∗ct−1进行矩阵加法。
- 最后,更新当前时刻t的状态 c t = i t ∗ c t + f t ∗ c t − 1 c_t = i_t * c_t + f_t * c_{t-1} ct=it∗ct+ft∗ct−1,同时作为下一个时刻的状态输入。
2.5 Output Gate
最后介绍输出门,输出门有两个作用,一个是LSTM循环体的输出(也可以理解为RNN循环体的y值),另外一个就是下一个时刻的输入(同状态 c t c_t ct一样, h t h_t ht, c t 和 h t c_t和h_t ct和ht作为LSTM循环体的双输入)。下面我们看下当前时刻的 o t o_t ot是如何计算的,输出门 o t o_t ot作为LSTM循环体中传输的状态之一,它的状态的改变对于整个LSTM的中短期记忆与遗忘起到关键的作用。
算法流程:
-
时刻t的输入 x t x_t xt与上一个时刻 t − 1 t-1 t−1的隐藏层状态 h t − 1 h_{t-1} ht−1进行向量拼接。
-
拼接后的向量分别进入单层全连接,通过激活函数 S i g m o i d Sigmoid Sigmoid计算门控值得到 o t o_t ot。
-
当前时刻的状态 c t c_t ct通过tanh函数后与当前的门控 o t o_t ot进行矩阵乘法,生成最后的输出值 h t h_t ht
2.6 多层LSTM
在实践中经常会使用多层LSTM,所以这里也简单介绍下,主要看下面的图就好。
三 伪码
通过上面的讲解与分析,相信大家对LSTM应该已经有了全面的理解,下面附上代码,大家可以通过代码再进行下加深理解。
四 参考链接
- https://towardsdatascience.com/illustrated-guide-to-lstms-and-gru-s-a-step-by-step-explanation-44e9eb85bf21
- http://www.wildml.com/2015/10/recurrent-neural-network-tutorial-part-4-implementing-a-grulstm-rnn-with-python-and-theano/
- http://colah.github.io/posts/2015-08-Understanding-LSTMs/
- https://www.youtube.com/watch?v=WCUNPb-5EYI
五 番外篇
个人介绍:杜宝坤,隐私计算行业从业者,从0到1带领团队构建了京东的联邦学习解决方案9N-FL,同时主导了联邦学习框架与联邦开门红业务。
框架层面:实现了电商营销领域支持超大规模的工业化联邦学习解决方案,支持超大规模样本PSI隐私对齐、安全的树模型与神经网络模型等众多模型支持。
业务层面:实现了业务侧的开门红业务落地,开创了新的业务增长点,产生了显著的业务经济效益。
个人比较喜欢学习新东西,乐于钻研技术。基于从全链路思考与决策技术规划的考量,研究的领域比较多,从工程架构、大数据到机器学习算法与算法框架均有涉及。欢迎喜欢技术的同学和我交流,邮箱:baokun06@163.com
六 公众号导读
自己撰写博客已经很长一段时间了,由于个人涉猎的技术领域比较多,所以对高并发与高性能、分布式、传统机器学习算法与框架、深度学习算法与框架、密码安全、隐私计算、联邦学习、大数据等都有涉及。主导过多个大项目包括零售的联邦学习,社区做过多次分享,另外自己坚持写原创博客,多篇文章有过万的阅读。公众号大家可以按照话题进行连续阅读,里面的章节我都做过按照学习路线的排序,话题就是公众号里面下面的标红的这个,大家点击去就可以看本话题下的多篇文章了,比如下图(话题分为:一、隐私计算 二、联邦学习 三、机器学习框架 四、机器学习算法 五、高性能计算 六、广告算法 七、程序人生),知乎号同理关注专利即可。
一切有为法,如梦幻泡影,如露亦如电,应作如是观。