原创:李孟启
1、背景
LSTM(Long Short-term Memory,LSTM)长短期记忆[1],是一种用于处理序列数据的神经网络。相比一般的神经网络来说,他能够处理序列变化的数据。比如某个单词的意思会因为上文提到的内容不同而有不同的含义,LSTM就能够很好地解决这类问题。
本教程适合入门深度学习的小白,大神请绕行,你要是即将参加面试,那么这篇文章也会对你产生帮助。
2、LSTM结构
如图1所示,这是一个时刻的LSTM的内部结构图,多个时刻就是下图在横向上拼接组成。
(1)从图中我们可以看到有三个门,即输⼊门(input gate)、遗忘门(forget gate)和输出门(output gate),以及记忆细胞(某些⽂献把记忆细胞当成⼀种特殊的隐藏状态),从⽽记录额外的信息。这里你可能不懂隐藏状态,但是没关系,你只需要理解下面的计算过程即可。一提到LSTM,我们就不得不面对下面的公式,这些公式也是面试中常被问及的地方,也有一定的重复性,突出两个字“好记“。
I t = σ ( X t W x i + H t − 1 W h i + b i ) \boldsymbol{I}_{t}=\sigma\left(\boldsymbol{X}_{t} \boldsymbol{W}_{x i}+\boldsymbol{H}_{t-1} \boldsymbol{W}_{h i}+\boldsymbol{b}_{i}\right) It=σ(XtWxi+Ht−1Whi+bi) 式1
F t = σ ( X t W x f + H t − 1 W h f + b f ) \boldsymbol{F}_{t}=\sigma\left(\boldsymbol{X}_{t} \boldsymbol{W}_{x f}+\boldsymbol{H}_{t-1} \boldsymbol{W}_{h f}+\boldsymbol{b}_{f}\right) Ft=σ(XtWxf+Ht−1Whf+bf) 式2
O t = σ ( X t W x o + H t − 1 W h o + b o ) \boldsymbol{O}_{t}=\sigma\left(\boldsymbol{X}_{t} \boldsymbol{W}_{x o}+\boldsymbol{H}_{t-1} \boldsymbol{W}_{h o}+\boldsymbol{b}_{o}\right) Ot=σ(XtWxo+Ht−1Who+bo) 式3
这里从数学运算的角度上讲解,所以不会涉及太多深度学习的名词,我们先来明确公式中每个参数的含义,这些参数是以矩阵的形式存在的, X t X_{t} Xt是 t t t 时刻的输入,尺寸为 [ n , d ] [n,d] [n,d](可以理解为一个样本数量为n每个样本特征数为d的一个矩阵), H t − 1 H_{t-1} Ht−1是上一个时刻传过来的状态信息,它的尺寸是 [ n , h ] [n,h] [n,h], W x i W_{xi} Wxi、 W x f W_{xf} Wxf、 W x o W_{xo} Wxo ∈ \in ∈ [ d , h ] [d,h] [d,h]和 W h i W_{hi} Whi、 W h f W_{hf} Whf、 W h o W_{ho} Who ∈ \in ∈ [ h , h ] [h,h] [h,h]是权重参数, b i , b f , b o \boldsymbol{b}_{i}, \boldsymbol{b}_{f}, \boldsymbol{b}_{o} bi,bf,bo的尺寸是 [ 1 , h ] [1,h] [1,h](其实就是偏差参数),公示中“+”是元素的对位相加,相乘是矩阵乘法。 σ \sigma σ 是 s i g m o i d sigmoid sigmoid 激活函数如公式4所示,就是把括号中运算得到的矩阵每一个元素都带入 s i g m o i d sigmoid sigmoid 函数求值,运算前后矩阵的形状不变。
S ( x ) = 1 1 + e − x S(x)=\frac{1}{1+e^{-x}} S(x)=1+e−x1 式4
最终我们可以得到 I t I_{t} It、 F t F_{t} Ft、 O t O_{t} Ot ∈ \in ∈ [ n , h ] [n,h] [n,h] 。
(2)下面是候选记忆细胞的计算公式:
C ~ t = tanh ( X t W x c + H t − 1 W h c + b c ) \tilde{\boldsymbol{C}}_{t}=\tanh \left(\boldsymbol{X}_{t} \boldsymbol{W}_{x c}+\boldsymbol{H}_{t-1} \boldsymbol{W}_{h c}+\boldsymbol{b}_{c}\right) C~t=tanh(XtWxc+Ht−1Whc+bc) 式5
其中 W x c ∈ [ d , h ] W_{xc}\in [d,h] Wxc∈[d,h]和 W h c ∈ [ h , h ] W_{hc}\in[h,h] Whc∈[h,h] 是权重参数, b c ∈ [ 1 , h ] b_{c}\in[1,h] bc∈[1,h] 偏差参数 ,这里括号里面的计算方式同上三个门的计算过程一致,只不过这里使用了 t a n h tanh tanh激活函数, t a n h tanh tanh如式6所示。最终 C ~ t ∈ [ n , h ] \tilde{\boldsymbol{C}}_{t}\in[n,h] C~t∈[n,h] 。
tanh x = sinh x cosh x = e x − e − x e x + e − x \tanh x=\frac{\sinh x}{\cosh x}=\frac{e^{x}-e^{-x}}{e^{x}+e^{-x}} tanhx=coshxsinhx=ex+e−xex−e−x 式6
那么,剩余的计算就简单多了,现在我门已知 I t 、 F t 、 O t ∈ [ n , h ] I_{t}、F_{t}、O_{t}\in[n,h] It、Ft、Ot∈[n,h]、 C ~ t ∈ [ n , h ] \tilde{\boldsymbol{C}}_{t}\in[n,h] C~t∈[n,h] ,那么 C t C_{t} Ct的计算如式7所示。
C t = F t ⊙ C t − 1 + I t ⊙ C ~ t \boldsymbol{C}_{t}=\boldsymbol{F}_{t} \odot \boldsymbol{C}_{t-1}+\boldsymbol{I}_{t} \odot \tilde{\boldsymbol{C}}_{t} Ct=Ft⊙Ct−1+It⊙C~t 式7
其中, C t − 1 C_{t-1} Ct−1与 C t C_{t} Ct的形状一致同为 [ n , h ] [n,h] [n,h],只不过是上一个时刻的输出, ⊙ \odot ⊙是两个矩阵对位元素的相乘,“ + + +”就是两个矩阵对位元素相加。
(3)有了记忆细胞以后,接下来我们还可以通过输出⻔来控制从记忆细胞到隐藏状态 H t H_{t} Ht的信息的
流动:
H t = O t ⊙ tanh ( C t ) \boldsymbol{H}_{t}=\boldsymbol{O}_{t} \odot \tanh \left(\boldsymbol{C}_{t}\right) Ht=Ot⊙tanh(Ct) 式8
这里的计算过程类比式7即可,就不再赘述。
上述过程就是LSTM一个单元细胞的数据流动。
3、pytorch实现LSTM
下面我们再从代码的角度理解LSTM内部结构。
(1)导入必要的包
import torch
import numpy as np
from torch import nn
(2)这里生成LSTM内部的权重参数
def get_params():
def _one(shape):
ts = torch.tensor(np.random.normal(0, 0.01, size=shape), device=device, dtype=torch.float32)
return torch.nn.Parameter(ts, requires_grad=True)
def _three():
return (_one((num_inputs, num_hiddens)),
_one((num_hiddens, num_hiddens)),
torch.nn.Parameter(torch.zeros(num_hiddens, device=device, dtype=torch.float32), requires_grad=True))
W_xi, W_hi, b_i = _three() # 输⼊⻔参数
W_xf, W_hf, b_f = _three() # 遗忘⻔参数
W_xo, W_ho, b_o = _three() # 输出⻔参数
W_xc, W_hc, b_c = _three() # 候选记忆细胞参数
# 输出层参数 (下面的两行代码新建的权重不属于LSTM,属于输出层;意思是LSTM每个时刻的输出又乘了一个矩阵)
W_hq = _one((num_hiddens, num_outputs))
b_q = torch.nn.Parameter(torch.zeros(num_outputs, device=device, dtype=torch.float32), requires_grad=True)
return nn.ParameterList([W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q])
(3)这里用代码实现LSTM内部的计算公式
def lstm(inputs, state, params):
# 这里获取上面代码初始化的参数
[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,b_c, W_hq, b_q] = params
(H, C) = state # 接受初始化的H和C,这里未给出相关的代码,大家知道即可
outputs = [] # 用来存储每个时间步的输出信息
for X in inputs: # 在这里inputs的形状是[t, n, d],通过for循环,可以逐一的拿出每个时刻的输入信息X(X的shape是[n,d])
# 输入门的计算,公式1
I = torch.sigmoid(torch.matmul(X, W_xi) + torch.matmul(H, W_hi) + b_i)
# 遗忘门计算,公式2
F = torch.sigmoid(torch.matmul(X, W_xf) + torch.matmul(H, W_hf) + b_f)
# 输出门计算,公式3
O = torch.sigmoid(torch.matmul(X, W_xo) + torch.matmul(H, W_ho) + b_o)
# 候选记忆细胞计算,公式5
C_tilda = torch.tanh(torch.matmul(X, W_xc) + torch.matmul(H, W_hc) + b_c)
# C的计算,公式7
C = F * C + I * C_tilda
# 输出隐藏状计算,公式8
H = O * C.tanh()
# 输出层的计算,该过程不包括在LSTM中
Y = torch.matmul(H, W_hq) + b_q
# 记录每个时刻的输出信息
outputs.append(Y) # 把每个时刻的输出Y追加到列表中
return outputs, (H, C) # 返回所有时刻的输出outputs,最终时刻的隐藏状态H、记忆细胞C
(4)生成初始时刻的需要输入的记忆细胞和隐藏状态
def init_lstm_state(batch_size, num_hiddens, device):
return (torch.zeros((batch_size, num_hiddens), device=device), torch.zeros((batch_size, num_hiddens), device=device))
(5)开始运行程序
if __name__ == '__main__':
device = torch.device('gpu' if torch.cuda.is_available() else 'cpu') # 指定运算使用的设备(gpu or cpu)
batch_size, num_inputs, num_hiddens, num_outputs = 32, 512, 256, 512 # 定义公式中参数的维度
print('将会使用:', device)
C, H = init_lstm_state(batch_size, num_hiddens, device) # 这里生成LSTM初始时刻需要的C和H
print('开始时刻记忆细胞的输入设置为全0', C)
print('开始时刻隐藏状态的输入设置为全0', H)
# 在这里随机生成每个时刻的输入inputs(inputs的shape是[t, n, d])
shape = (56, 32, 512) # t=56;n=32;d=512
inputs = torch.tensor(np.random.normal(0, 0.01, size=shape), device=device, dtype=torch.float32)
# 通过事先定义好的函数get_params获取LSTM中使用的参数
params = get_params()
state = (H, C)
outputs, (H, C) = lstm(inputs, state, params)
print(outputs[0].shape) # 取出第一个时刻的输出,它的shape是[32, 512]
4、总结
从代码中可以看到LSTM每个时刻的单元是共享同一组参数的。
这里列举了LSTM网络中涉及的公式和代码的详细实现,但是在实际应用中,我们只需要做好调包侠即可。
参考文献
[1] Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory. Neural computation,9(8), 1735-1780.