0
点赞
收藏
分享

微信扫一扫

Java中的垃圾回收机制

晚熟的猫 2024-05-27 阅读 10

目录

词嵌入层

RNN网络层

RNN 网络原理


自然语言处理(Nature language Processing, NLP)研究的主要是通过计算机算法来理解自然语言。对于自然语言来说,处理的数据主要就是人类的语言,我们在进行文本数据处理时,需要将文本进行数据值化,然后进行后续的训练工作。

词嵌入层

词嵌入层首先会根据输入的词的数量构建一个词向量矩阵,例如: 我们有 100 个词,每个词希望转换成 128 维度的向量,那么构建的矩阵形状即为: 100*128,输入的每个词都对应了一个该矩阵中的一个向量。在 PyTorch 中,我们可以使用 nn.Embedding 词嵌入层来实现输入词的向量化。

  1. 定义词汇表大小和词嵌入维度 
  2. 初始化词嵌入层

  3. 将词索引转换为词向量

nn.Embedding 对象构建时,最主要有两个参数:

  1. num_embeddings 表示词的数量
  2. embedding_dim 表示用多少维的向量来表示每个词
import torch
import torch.nn as nn
import jieba


if __name__ == '__main__':

    text = '北京冬奥的进度条已经过半,不少外国运动员在完成自己的比赛后踏上归途。'

    
    words = jieba.lcut(text)

    # 构建词表
    index_to_word = {}
    word_to_index = {}

    # 分词去重
    unique_words = list(set(words))
    for idx, word in enumerate(unique_words):
        index_to_word[idx] = word
        word_to_index[word] = idx

    # 词嵌入层
    embed = nn.Embedding(num_embeddings=len(index_to_word), embedding_dim=4)

    # 文本转换为词向量表示
    for word in words:
        # 获得词对应的索引
        idx = word_to_index[word]
        # 获得词嵌入向量
        word_vec = embed(torch.tensor(idx))
        print('%3s\t' % word, word_vec)

我们的词嵌入层是随机初始化的,现在还不能表示出文本的真正含义,当一个词输入进来之后,会使用随机产生的向量来表示该词,我们将该词向量参与到下游任务的计算,经过下游任务计算之后,会和目标结果进行对比产生损失,通过反向传播更新所有的网络参数,这里的参数就包括了 nn.Embedding 中的词向量表示。

RNN网络层

我们的文本数据是具有序列特性的,上面的词嵌入层将文本数据映射为数值向量,进而能够送入到网络进行计算,比如‘我爱你’是有顺序的,如果颠倒了顺序,那么可能就会表达不同的意思。

为了能够表示出数据的序列关系我们需要使用循环神经网络对数据进行建模,RNN 是一个具有记忆功能的网络,它作用于处理带有序列特点的样本数据。

RNN 网络原理

RNN的计算过程:

h 表示隐藏状态, 每一次的输入都会有包含两个值: 上一个时间步的隐藏状态、当前状态的输入值,最后输出当前时间步的隐藏状态。 

上面的神经元实际上只有一个神经元,只是时间不同的三个状态,"我爱你" 三个字是重复输入到同一个神经元中。

循环网络网络可以有多个神经元,多少个神经元就相当于是输出多少维度的词。

每个神经元内部还是使用像以往神经网络的激活函数和w、b权重。

RNN 层输入的数据为三个维度: (seq_len, batch_size,input_size)

seq_len是一次处理的句子中有几个词,batch_size是几个句子;

import torch
import torch.nn as nn

def func():

    # 输入数据维度 128, 输出维度 256
    rnn = nn.RNN(input_size=128, hidden_size=256)

    inputs = torch.randn(1, 32, 128)
    hn = torch.zeros(1, 32, 256)  # 初始隐藏层为0

    output, hn = rnn(inputs, hn)
    print(output.shape)
    print(hn.shape)

举报

相关推荐

0 条评论