背景
进行训练和测试时,有时难以保证输入文本长度的一致性,因此常常需要截断操作(即将超过预设长度的文本截断)和pad操作(即对不足预设长度的文本进行补0)。
在Pytorch中的torch.nn.utils.rnn,提供了pad和pack,pack_padded_sequence和pad_packed_sequence四种方法实现这一操作。
pad和pack
举一个简单的例子:
from torch.nn.utils.rnn import pack_sequence, pad_sequence,pad_packed_sequence, pack_padded_sequence,
text1 = torch.tensor([1,2,3,4]) # 可视为有4个文字的样本
text2 = torch.tensor([5,6,7]) # 可视为有3个文字的样本
text3 = torch.tensor([8,9]) # 可视为有2个文字的样本
sequences = [text1, text2, text3] # 三个文本序列
pack操作将原来的二维数据(batch*sequence)进行了压缩,但其排列是按照列(即sequence的顺序)进行排列,每个时间步一次性输出batch上的所有样本,即:
pack后的返回值包括两数据。一类为data,即压缩后的数据;而另一类batch_sizes表示每个时间步,batch中包含的样本量。
[Input] pack_sequence(sequences)
[Output] PackedSequence(data=tensor([1, 5, 8, 2, 6, 9, 3, 7, 4]), batch_sizes=tensor([3, 3, 2, 1]))
pad操作即是将不同长度的文本序列进行补0。需要注意的是,这个没有dim参数,替代的是batch_first参数,即第一个维度是否是batch,在默认情况下,参数batch_first=False,这是rnn网络的推荐用法。其返回值的第一个维度将变成sequence,而第二个维度才为batch。如果是仅仅使用这种方法对数据进行补齐或截断,可以通过设置batch_first=True,使得返回值的第一个维度为batch,从而保持与输入值的一致性。
[Input] pad_sequence(sequences)
[Output] tensor([[1, 5, 8],
[2, 6, 9],
[3, 7, 0],
[4, 0, 0]])
[Input] pad_sequence(sequences, batch_first=True)
[Output] tensor([[1, 2, 3, 4],
[5, 6, 7, 0],
[8, 9, 0, 0]])
pack_padded_sequence和pad_packed_sequence
因为pytorch中的RNN网络可以接受的是PackedSequence类型数据(通过pack操作实现),而pad操作又可以实现不等长文本的填充对齐,所以自然会想到将两个操作联合起来,这就是pytorch提供的pack_padded_sequence和pad_packed_sequence功能。
pack_padded_sequence就是将经pad后的文本序列在做pack,从而实现对文本缺失位置的填0和维度压缩:
[Input] pack_padded_sequence(pad_sequence(sequences,batch_first=True),lengths=[4,3,3], batch_first=True)
[Output] PackedSequence(data=tensor([1, 5, 8, 2, 6, 9, 3, 7, 0, 4, 0]), batch_sizes=tensor([3, 3, 3, 2]))
pack_padded_sequence函数接收一个padded_sequence数据;根据batch_first参数明确该数据的布局(默认为batch_first=False);根据lengths参数明确batch内各样本的时间步长,选择数据;将上述数据按照时间维度进行压缩,得到目标的PackedSequence类型数据。
pad_packed_sequence函数即为pack_padded_sequence的逆操作,其在参数设定时也通过batch_first控制返回值的维度顺序,同时可通过设置total_lengths来控制pad后的总步长(该值必须不小于输入PackedSequence数据的步长数):
[Input] pad_packed_sequence(pack_sequence(sequences),total_length=5,batch_first=True)
[Output] (tensor([[1, 2, 3, 4, 0],
[5, 6, 7, 0, 0],
[8, 9, 0, 0, 0]]), tensor([4, 3, 2]))