Longformer模型难点解读
longformer在延伸maxlen的同时,结构上也存在着很多的难点,这里逐步分析出来。
_sliding_chunks_query_key_matmul函数之中的结构变换
这里最难懂的是这样的几句
query_size = list(query.size())
query_size[1] = query_size[1]*2-1
query_stride = list(query.stride())
query_stride[1] = query_stride[1]//2
query = query.as_strided(size=query_size,stride=query_stride)
首先这里的query_size和key_size都是512的整数倍,因为longformer最长长度为4096,对于不是整数倍的数值,longformer会自动填充成为整数倍。
这里输入的query_size = (24,1,512,64),(24,2,512,64),…(24,n,512,64)等多种情况,其中batch_size = 2,第一个24=batch_size*num_heads,第二个1或者2为有几个512,然后后面两位一般固定为512和64,512为longformer一个周期固定的长度,64为size_per_head一个注意力头的大小。
所以这里本质上就是query_size[1]乘上2然后减去1,query_stride[1]//2之后得到的新的tensor内容
这里由于query和key的数值过大,无法直接看出变换后的tensor,可以采用化繁为简的方法,先看小的tensor的变换,进而发现规律。
import torch
import numpy as np
import random
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
# 设置随机数种子
setup_seed(20)
data = torch.rand(5,2,4,3)
print('data = ')
print(data)
data_size = list(data.size())
data_size[1] = data_size[1]*2-1
data_stride = list(data.stride())
data_stride[1] = data_stride[1]//2
data = data.as_strided(size=data_size,stride=data_stride)
print('data = ')
print(data)
1.当size[1] = 1的情况下,data数据保持不变
2.当size[1] = 2的情况下,data数据中间变为乘2加一,对应数据变化如下:
原先的数据:
data =
tensor([[[[0.5615, 0.1774, 0.8147],
[0.3295, 0.2319, 0.7832],
[0.8544, 0.1012, 0.1877],
[0.9310, 0.0899, 0.3156]],
[[0.9423, 0.2536, 0.7388],
[0.5404, 0.4356, 0.4430],
[0.6257, 0.0379, 0.7130],
[0.3229, 0.9631, 0.2284]]]])
现在的数据
print('data = ')
print(data)
data =
tensor([[[[0.5615, 0.1774, 0.8147],
[0.3295, 0.2319, 0.7832],
[0.8544, 0.1012, 0.1877],
[0.9310, 0.0899, 0.3156]],
[[0.8544, 0.1012, 0.1877],
[0.9310, 0.0899, 0.3156],
[0.9423, 0.2536, 0.7388],
[0.5404, 0.4356, 0.4430]],
[[0.9423, 0.2536, 0.7388],
[0.5404, 0.4356, 0.4430],
[0.6257, 0.0379, 0.7130],
[0.3229, 0.9631, 0.2284]]]])
可以看出来,由于第二维的4为偶数,所以中间的数据由上下两波数据拼接而成
[[0.8544, 0.1012, 0.1877],
[0.9310, 0.0899, 0.3156],
[0.9423, 0.2536, 0.7388],
[0.5404, 0.4356, 0.4430]]
如果中间数据为奇数,试验一波,原装的数据为
data =
tensor([[[[0.5615, 0.1774, 0.8147],
[0.3295, 0.2319, 0.7832],
[0.8544, 0.1012, 0.1877],
[0.9310, 0.0899, 0.3156],
[0.9423, 0.2536, 0.7388]],
[[0.5404, 0.4356, 0.4430],
[0.6257, 0.0379, 0.7130],
[0.3229, 0.9631, 0.2284],
[0.4489, 0.2113, 0.6839],
[0.7478, 0.4627, 0.7742]]]])
经历过as_stride函数之后,新的数据为
data =
tensor([[[[0.5615, 0.1774, 0.8147],
[0.3295, 0.2319, 0.7832],
[0.8544, 0.1012, 0.1877],
[0.9310, 0.0899, 0.3156],
[0.9423, 0.2536, 0.7388]],
[[0.1012, 0.1877, 0.9310],
[0.0899, 0.3156, 0.9423],
[0.2536, 0.7388, 0.5404],
[0.4356, 0.4430, 0.6257],
[0.0379, 0.7130, 0.3229]],
[[0.7388, 0.5404, 0.4356],
[0.4430, 0.6257, 0.0379],
[0.7130, 0.3229, 0.9631],
[0.2284, 0.4489, 0.2113],
[0.6839, 0.7478, 0.4627]]]])
可以看出,中间一波提取出来了数据
0.1012, 0.1877],
[0.9310, 0.0899, 0.3156],
[0.9423, 0.2536, 0.7388]],
[[0.5404, 0.4356, 0.4430],
[0.6257, 0.0379, 0.7130],
[0.3229,
而最后一波提取出来了数据
0.7388]],
[[0.5404, 0.4356, 0.4430],
[0.6257, 0.0379, 0.7130],
[0.3229, 0.9631, 0.2284],
[0.4489, 0.2113, 0.6839],
[0.7478, 0.4627,
构成了新的tensor内容
3.当size[1]=3,对应数据变化如下:
data =
tensor([[[[0.5615, 0.1774, 0.8147],
[0.3295, 0.2319, 0.7832],
[0.8544, 0.1012, 0.1877],
[0.9310, 0.0899, 0.3156]],
[[0.9423, 0.2536, 0.7388],
[0.5404, 0.4356, 0.4430],
[0.6257, 0.0379, 0.7130],
[0.3229, 0.9631, 0.2284]],
[[0.4489, 0.2113, 0.6839],
[0.7478, 0.4627, 0.7742],
[0.3861, 0.0727, 0.8736],
[0.3510, 0.3279, 0.3254]]]])
变化之后的内容如下(与size[1]=2变化类似)
data =
tensor([[[[0.5615, 0.1774, 0.8147],
[0.3295, 0.2319, 0.7832],
[0.8544, 0.1012, 0.1877],
[0.9310, 0.0899, 0.3156]],
[[0.8544, 0.1012, 0.1877],
[0.9310, 0.0899, 0.3156],
[0.9423, 0.2536, 0.7388],
[0.5404, 0.4356, 0.4430]],
[[0.9423, 0.2536, 0.7388],
[0.5404, 0.4356, 0.4430],
[0.6257, 0.0379, 0.7130],
[0.3229, 0.9631, 0.2284]],
[[0.6257, 0.0379, 0.7130],
[0.3229, 0.9631, 0.2284],
[0.4489, 0.2113, 0.6839],
[0.7478, 0.4627, 0.7742]],
[[0.4489, 0.2113, 0.6839],
[0.7478, 0.4627, 0.7742],
[0.3861, 0.0727, 0.8736],
[0.3510, 0.3279, 0.3254]]]])
attention之中attention_scores的形状变化
这里
attentions = torch.matmul(query,key.transpose(-1,-2))
得到的
attentions = (24,5,512,512)
(传入的key = (24,5,512,512),value=(24,5,512,512))
接下来padding在最后一层铺满数值0
attention_scores = nn.functional.pad(
attention_scores,(0,0,0,1)
)
得到attention_scores = (24,5,513,512)
然后调用view函数
attention_scores = attention_scores.view(*attention_sccores.size()[:-2],attention_scores.size(-1),attention_scores.size(-2))
这里原先每一层最后由(24,5,513,512)->(24,5,512,513),每一层最后的原先铺了512个零,(共铺了24*5个,跟513没有关系,513是多了一列),现在变换完形状之后仍然铺了512个零,所以每一个最后一列多了一个非零的数,最后是512个零