0
点赞
收藏
分享

微信扫一扫

LSTM内部结构-我彻底明白了

悄然丝语 2022-01-14 阅读 51

原创:李孟启

1、背景

LSTM(Long Short-term Memory,LSTM)长短期记忆[1],是一种用于处理序列数据的神经网络。相比一般的神经网络来说,他能够处理序列变化的数据。比如某个单词的意思会因为上文提到的内容不同而有不同的含义,LSTM就能够很好地解决这类问题。

本教程适合入门深度学习的小白,大神请绕行,你要是即将参加面试,那么这篇文章也会对你产生帮助。

2、LSTM结构

如图1所示,这是一个时刻的LSTM的内部结构图,多个时刻就是下图在横向上拼接组成。

图1  LSTM单元结构图

(1)从图中我们可以看到有三个门,即输⼊门(input gate)、遗忘门(forget gate)和输出门(output gate),以及记忆细胞(某些⽂献把记忆细胞当成⼀种特殊的隐藏状态),从⽽记录额外的信息。这里你可能不懂隐藏状态,但是没关系,你只需要理解下面的计算过程即可。一提到LSTM,我们就不得不面对下面的公式,这些公式也是面试中常被问及的地方,也有一定的重复性,突出两个字“好记“。

I t = σ ( X t W x i + H t − 1 W h i + b i ) \boldsymbol{I}_{t}=\sigma\left(\boldsymbol{X}_{t} \boldsymbol{W}_{x i}+\boldsymbol{H}_{t-1} \boldsymbol{W}_{h i}+\boldsymbol{b}_{i}\right) It=σ(XtWxi+Ht1Whi+bi) 式1

F t = σ ( X t W x f + H t − 1 W h f + b f ) \boldsymbol{F}_{t}=\sigma\left(\boldsymbol{X}_{t} \boldsymbol{W}_{x f}+\boldsymbol{H}_{t-1} \boldsymbol{W}_{h f}+\boldsymbol{b}_{f}\right) Ft=σ(XtWxf+Ht1Whf+bf) 式2

O t = σ ( X t W x o + H t − 1 W h o + b o ) \boldsymbol{O}_{t}=\sigma\left(\boldsymbol{X}_{t} \boldsymbol{W}_{x o}+\boldsymbol{H}_{t-1} \boldsymbol{W}_{h o}+\boldsymbol{b}_{o}\right) Ot=σ(XtWxo+Ht1Who+bo) 式3

这里从数学运算的角度上讲解,所以不会涉及太多深度学习的名词,我们先来明确公式中每个参数的含义,这些参数是以矩阵的形式存在的, X t X_{t} Xt t t t 时刻的输入,尺寸为 [ n , d ] [n,d] [n,d](可以理解为一个样本数量为n每个样本特征数为d的一个矩阵), H t − 1 H_{t-1} Ht1是上一个时刻传过来的状态信息,它的尺寸是 [ n , h ] [n,h] [n,h] W x i W_{xi} Wxi W x f W_{xf} Wxf W x o W_{xo} Wxo ∈ \in [ d , h ] [d,h] [d,h] W h i W_{hi} Whi W h f W_{hf} Whf W h o W_{ho} Who ∈ \in [ h , h ] [h,h] [h,h]是权重参数, b i , b f , b o \boldsymbol{b}_{i}, \boldsymbol{b}_{f}, \boldsymbol{b}_{o} bi,bf,bo的尺寸是 [ 1 , h ] [1,h] [1,h](其实就是偏差参数),公示中“+”是元素的对位相加,相乘是矩阵乘法。 σ \sigma σ s i g m o i d sigmoid sigmoid 激活函数如公式4所示,就是把括号中运算得到的矩阵每一个元素都带入 s i g m o i d sigmoid sigmoid 函数求值,运算前后矩阵的形状不变。

S ( x ) = 1 1 + e − x S(x)=\frac{1}{1+e^{-x}} S(x)=1+ex1 式4

最终我们可以得到 I t I_{t} It F t F_{t} Ft O t O_{t} Ot ∈ \in [ n , h ] [n,h] [n,h]

(2)下面是候选记忆细胞的计算公式:

C ~ t = tanh ⁡ ( X t W x c + H t − 1 W h c + b c ) \tilde{\boldsymbol{C}}_{t}=\tanh \left(\boldsymbol{X}_{t} \boldsymbol{W}_{x c}+\boldsymbol{H}_{t-1} \boldsymbol{W}_{h c}+\boldsymbol{b}_{c}\right) C~t=tanh(XtWxc+Ht1Whc+bc) 式5

其中 W x c ∈ [ d , h ] W_{xc}\in [d,h] Wxc[d,h] W h c ∈ [ h , h ] W_{hc}\in[h,h] Whc[h,h] 是权重参数, b c ∈ [ 1 , h ] b_{c}\in[1,h] bc[1,h] 偏差参数 ,这里括号里面的计算方式同上三个门的计算过程一致,只不过这里使用了 t a n h tanh tanh激活函数, t a n h tanh tanh如式6所示。最终 C ~ t ∈ [ n , h ] \tilde{\boldsymbol{C}}_{t}\in[n,h] C~t[n,h]

tanh ⁡ x = sinh ⁡ x cosh ⁡ x = e x − e − x e x + e − x \tanh x=\frac{\sinh x}{\cosh x}=\frac{e^{x}-e^{-x}}{e^{x}+e^{-x}} tanhx=coshxsinhx=ex+exexex 式6

那么,剩余的计算就简单多了,现在我门已知 I t 、 F t 、 O t ∈ [ n , h ] I_{t}、F_{t}、O_{t}\in[n,h] ItFtOt[n,h] C ~ t ∈ [ n , h ] \tilde{\boldsymbol{C}}_{t}\in[n,h] C~t[n,h] ,那么 C t C_{t} Ct的计算如式7所示。

C t = F t ⊙ C t − 1 + I t ⊙ C ~ t \boldsymbol{C}_{t}=\boldsymbol{F}_{t} \odot \boldsymbol{C}_{t-1}+\boldsymbol{I}_{t} \odot \tilde{\boldsymbol{C}}_{t} Ct=FtCt1+ItC~t 式7

其中, C t − 1 C_{t-1} Ct1 C t C_{t} Ct的形状一致同为 [ n , h ] [n,h] [n,h],只不过是上一个时刻的输出, ⊙ \odot 是两个矩阵对位元素的相乘,“ + + +”就是两个矩阵对位元素相加。

(3)有了记忆细胞以后,接下来我们还可以通过输出⻔来控制从记忆细胞到隐藏状态 H t H_{t} Ht的信息的

流动:

H t = O t ⊙ tanh ⁡ ( C t ) \boldsymbol{H}_{t}=\boldsymbol{O}_{t} \odot \tanh \left(\boldsymbol{C}_{t}\right) Ht=Ottanh(Ct) 式8

这里的计算过程类比式7即可,就不再赘述。

上述过程就是LSTM一个单元细胞的数据流动。

3、pytorch实现LSTM

下面我们再从代码的角度理解LSTM内部结构。

(1)导入必要的包

import torch
import numpy as np
from torch import nn

(2)这里生成LSTM内部的权重参数

def get_params():
    def _one(shape):
        ts = torch.tensor(np.random.normal(0, 0.01, size=shape), device=device, dtype=torch.float32)
        return torch.nn.Parameter(ts, requires_grad=True)
    def _three():
        return (_one((num_inputs, num_hiddens)),
                _one((num_hiddens, num_hiddens)),
                torch.nn.Parameter(torch.zeros(num_hiddens, device=device, dtype=torch.float32), requires_grad=True))
    W_xi, W_hi, b_i = _three()  # 输⼊⻔参数
    W_xf, W_hf, b_f = _three()  # 遗忘⻔参数
    W_xo, W_ho, b_o = _three()  # 输出⻔参数
    W_xc, W_hc, b_c = _three()  # 候选记忆细胞参数
    # 输出层参数 (下面的两行代码新建的权重不属于LSTM,属于输出层;意思是LSTM每个时刻的输出又乘了一个矩阵)
    W_hq = _one((num_hiddens, num_outputs))
    b_q = torch.nn.Parameter(torch.zeros(num_outputs, device=device, dtype=torch.float32), requires_grad=True)
    return nn.ParameterList([W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q])

(3)这里用代码实现LSTM内部的计算公式

def lstm(inputs, state, params):
    # 这里获取上面代码初始化的参数
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,b_c, W_hq, b_q] = params
    (H, C) = state  # 接受初始化的H和C,这里未给出相关的代码,大家知道即可
    outputs = []  # 用来存储每个时间步的输出信息
    for X in inputs:  # 在这里inputs的形状是[t, n, d],通过for循环,可以逐一的拿出每个时刻的输入信息X(X的shape是[n,d])
        # 输入门的计算,公式1
        I = torch.sigmoid(torch.matmul(X, W_xi) + torch.matmul(H, W_hi) + b_i)
        # 遗忘门计算,公式2
        F = torch.sigmoid(torch.matmul(X, W_xf) + torch.matmul(H, W_hf) + b_f)
        # 输出门计算,公式3
        O = torch.sigmoid(torch.matmul(X, W_xo) + torch.matmul(H, W_ho) + b_o)
        # 候选记忆细胞计算,公式5
        C_tilda = torch.tanh(torch.matmul(X, W_xc) + torch.matmul(H, W_hc) + b_c)
        # C的计算,公式7
        C = F * C + I * C_tilda
        # 输出隐藏状计算,公式8
        H = O * C.tanh()
        # 输出层的计算,该过程不包括在LSTM中
        Y = torch.matmul(H, W_hq) + b_q
        # 记录每个时刻的输出信息
        outputs.append(Y)  # 把每个时刻的输出Y追加到列表中
    return outputs, (H, C)  # 返回所有时刻的输出outputs,最终时刻的隐藏状态H、记忆细胞C

(4)生成初始时刻的需要输入的记忆细胞和隐藏状态

def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device), torch.zeros((batch_size, num_hiddens), device=device))

(5)开始运行程序

if __name__ == '__main__':
    device = torch.device('gpu' if torch.cuda.is_available() else 'cpu')  # 指定运算使用的设备(gpu or cpu)
    batch_size, num_inputs, num_hiddens, num_outputs = 32, 512, 256, 512  # 定义公式中参数的维度
    print('将会使用:', device)
    C, H = init_lstm_state(batch_size, num_hiddens, device)  # 这里生成LSTM初始时刻需要的C和H
    print('开始时刻记忆细胞的输入设置为全0', C)
    print('开始时刻隐藏状态的输入设置为全0', H)
    # 在这里随机生成每个时刻的输入inputs(inputs的shape是[t, n, d])
    shape = (56, 32, 512)  # t=56;n=32;d=512
    inputs = torch.tensor(np.random.normal(0, 0.01, size=shape), device=device, dtype=torch.float32)
    # 通过事先定义好的函数get_params获取LSTM中使用的参数
    params = get_params()
    state = (H, C)
    outputs, (H, C) = lstm(inputs, state, params)
    print(outputs[0].shape)  # 取出第一个时刻的输出,它的shape是[32, 512]

4、总结

从代码中可以看到LSTM每个时刻的单元是共享同一组参数的。

这里列举了LSTM网络中涉及的公式和代码的详细实现,但是在实际应用中,我们只需要做好调包侠即可。

参考文献

[1] Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory. Neural computation,9(8), 1735-1780.

举报

相关推荐

0 条评论