0
点赞
收藏
分享

微信扫一扫

longformer代码结构解读

SPEIKE 2022-02-12 阅读 35

Longformer模型难点解读


longformer在延伸maxlen的同时,结构上也存在着很多的难点,这里逐步分析出来。

_sliding_chunks_query_key_matmul函数之中的结构变换

这里最难懂的是这样的几句

query_size = list(query.size())
query_size[1] = query_size[1]*2-1
query_stride = list(query.stride())
query_stride[1] = query_stride[1]//2
query = query.as_strided(size=query_size,stride=query_stride)

首先这里的query_size和key_size都是512的整数倍,因为longformer最长长度为4096,对于不是整数倍的数值,longformer会自动填充成为整数倍。
这里输入的query_size = (24,1,512,64),(24,2,512,64),…(24,n,512,64)等多种情况,其中batch_size = 2,第一个24=batch_size*num_heads,第二个1或者2为有几个512,然后后面两位一般固定为512和64,512为longformer一个周期固定的长度,64为size_per_head一个注意力头的大小。
所以这里本质上就是query_size[1]乘上2然后减去1,query_stride[1]//2之后得到的新的tensor内容
这里由于query和key的数值过大,无法直接看出变换后的tensor,可以采用化繁为简的方法,先看小的tensor的变换,进而发现规律。

import torch
import numpy as np
import random
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
# 设置随机数种子
setup_seed(20)
data = torch.rand(5,2,4,3)
print('data = ')
print(data)

data_size = list(data.size())
data_size[1] = data_size[1]*2-1
data_stride = list(data.stride())
data_stride[1] = data_stride[1]//2
data = data.as_strided(size=data_size,stride=data_stride)

print('data = ')
print(data)

1.当size[1] = 1的情况下,data数据保持不变

2.当size[1] = 2的情况下,data数据中间变为乘2加一,对应数据变化如下:

原先的数据:

data = 
tensor([[[[0.5615, 0.1774, 0.8147],
          [0.3295, 0.2319, 0.7832],
          [0.8544, 0.1012, 0.1877],
          [0.9310, 0.0899, 0.3156]],

         [[0.9423, 0.2536, 0.7388],
          [0.5404, 0.4356, 0.4430],
          [0.6257, 0.0379, 0.7130],
          [0.3229, 0.9631, 0.2284]]]])

现在的数据

print('data = ')

print(data)

data = 
tensor([[[[0.5615, 0.1774, 0.8147],
          [0.3295, 0.2319, 0.7832],
          [0.8544, 0.1012, 0.1877],
          [0.9310, 0.0899, 0.3156]],
          
         [[0.8544, 0.1012, 0.1877],
          [0.9310, 0.0899, 0.3156],
          [0.9423, 0.2536, 0.7388],
          [0.5404, 0.4356, 0.4430]],

         [[0.9423, 0.2536, 0.7388],
          [0.5404, 0.4356, 0.4430],
          [0.6257, 0.0379, 0.7130],
          [0.3229, 0.9631, 0.2284]]]])

可以看出来,由于第二维的4为偶数,所以中间的数据由上下两波数据拼接而成

[[0.8544, 0.1012, 0.1877],
 [0.9310, 0.0899, 0.3156],
 [0.9423, 0.2536, 0.7388],
 [0.5404, 0.4356, 0.4430]]

如果中间数据为奇数,试验一波,原装的数据为

data = 
tensor([[[[0.5615, 0.1774, 0.8147],
          [0.3295, 0.2319, 0.7832],
          [0.8544, 0.1012, 0.1877],
          [0.9310, 0.0899, 0.3156],
          [0.9423, 0.2536, 0.7388]],

         [[0.5404, 0.4356, 0.4430],
          [0.6257, 0.0379, 0.7130],
          [0.3229, 0.9631, 0.2284],
          [0.4489, 0.2113, 0.6839],
          [0.7478, 0.4627, 0.7742]]]])

经历过as_stride函数之后,新的数据为

data = 
tensor([[[[0.5615, 0.1774, 0.8147],
          [0.3295, 0.2319, 0.7832],
          [0.8544, 0.1012, 0.1877],
          [0.9310, 0.0899, 0.3156],
          [0.9423, 0.2536, 0.7388]],

         [[0.1012, 0.1877, 0.9310],
          [0.0899, 0.3156, 0.9423],
          [0.2536, 0.7388, 0.5404],
          [0.4356, 0.4430, 0.6257],
          [0.0379, 0.7130, 0.3229]],

         [[0.7388, 0.5404, 0.4356],
          [0.4430, 0.6257, 0.0379],
          [0.7130, 0.3229, 0.9631],
          [0.2284, 0.4489, 0.2113],
          [0.6839, 0.7478, 0.4627]]]])

可以看出,中间一波提取出来了数据

          0.1012, 0.1877],
 [0.9310, 0.0899, 0.3156],
 [0.9423, 0.2536, 0.7388]],

[[0.5404, 0.4356, 0.4430],
 [0.6257, 0.0379, 0.7130],
 [0.3229,

而最后一波提取出来了数据

                  0.7388]],
[[0.5404, 0.4356, 0.4430],
 [0.6257, 0.0379, 0.7130],
 [0.3229, 0.9631, 0.2284],
 [0.4489, 0.2113, 0.6839],
 [0.7478, 0.4627,

构成了新的tensor内容

3.当size[1]=3,对应数据变化如下:

data = 
tensor([[[[0.5615, 0.1774, 0.8147],
          [0.3295, 0.2319, 0.7832],
          [0.8544, 0.1012, 0.1877],
          [0.9310, 0.0899, 0.3156]],

         [[0.9423, 0.2536, 0.7388],
          [0.5404, 0.4356, 0.4430],
          [0.6257, 0.0379, 0.7130],
          [0.3229, 0.9631, 0.2284]],

         [[0.4489, 0.2113, 0.6839],
          [0.7478, 0.4627, 0.7742],
          [0.3861, 0.0727, 0.8736],
          [0.3510, 0.3279, 0.3254]]]])

变化之后的内容如下(与size[1]=2变化类似)

data = 
tensor([[[[0.5615, 0.1774, 0.8147],
          [0.3295, 0.2319, 0.7832],
          [0.8544, 0.1012, 0.1877],
          [0.9310, 0.0899, 0.3156]],

         [[0.8544, 0.1012, 0.1877],
          [0.9310, 0.0899, 0.3156],
          [0.9423, 0.2536, 0.7388],
          [0.5404, 0.4356, 0.4430]],

         [[0.9423, 0.2536, 0.7388],
          [0.5404, 0.4356, 0.4430],
          [0.6257, 0.0379, 0.7130],
          [0.3229, 0.9631, 0.2284]],

         [[0.6257, 0.0379, 0.7130],
          [0.3229, 0.9631, 0.2284],
          [0.4489, 0.2113, 0.6839],
          [0.7478, 0.4627, 0.7742]],

         [[0.4489, 0.2113, 0.6839],
          [0.7478, 0.4627, 0.7742],
          [0.3861, 0.0727, 0.8736],
          [0.3510, 0.3279, 0.3254]]]])

attention之中attention_scores的形状变化

这里

attentions = torch.matmul(query,key.transpose(-1,-2))

得到的

attentions = (24,5,512,512)
(传入的key = (24,5,512,512),value=(24,5,512,512))

接下来padding在最后一层铺满数值0

attention_scores = nn.functional.pad(
	attention_scores,(0,0,0,1)
)

得到attention_scores = (24,5,513,512)
然后调用view函数

attention_scores = attention_scores.view(*attention_sccores.size()[:-2],attention_scores.size(-1),attention_scores.size(-2))

这里原先每一层最后由(24,5,513,512)->(24,5,512,513),每一层最后的原先铺了512个零,(共铺了24*5个,跟513没有关系,513是多了一列),现在变换完形状之后仍然铺了512个零,所以每一个最后一列多了一个非零的数,最后是512个零

举报

相关推荐

0 条评论