0
点赞
收藏
分享

微信扫一扫

adaptive embedding 原理及pytorch代码

忆北文学摄影爱好员 2022-07-27 阅读 54
编程语言


就是高频词用比如1024或512维,低频词用256或64维。再用Linear层project到相同的维数:

class AdaptiveEmbedding(nn.Module):
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
sample_softmax=False):
super(AdaptiveEmbedding, self).__init__()

self.n_token = n_token # 793470
self.d_embed = d_embed # 1024

self.cutoffs = cutoffs + [n_token] # [60000, 100000, 640000, 793470]
self.div_val = div_val # 4
self.d_proj = d_proj # 1024

self.emb_scale = d_proj ** 0.5 # 32

self.cutoff_ends = [0] + self.cutoffs # [0, 60000, 100000, 640000, 793470]

self.emb_layers = nn.ModuleList()
self.emb_projs = nn.ParameterList()
if div_val == 1:
self.emb_layers.append(
nn.Embedding(n_token, d_embed, sparse=sample_softmax>0)
)
if d_proj != d_embed:
self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed)))
else:
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
d_emb_i = d_embed // (div_val ** i)
self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i))
self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_emb_i))) # project to same dim

def forward(self, inp):
if self.div_val == 1:
embed = self.emb_layers[0](inp)
if self.d_proj != self.d_embed:
embed = F.linear(embed, self.emb_projs[0])
else:
param = next(self.parameters())
inp_flat = inp.view(-1)
emb_flat = torch.zeros([inp_flat.size(0), self.d_proj],
dtype=param.dtype, device=param.device)
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]

mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
indices_i = mask_i.nonzero().squeeze()

if indices_i.numel() == 0:
continue

inp_i = inp_flat.index_select(0, indices_i) - l_idx
emb_i = self.emb_layers[i](inp_i)
emb_i = F.linear(emb_i, self.emb_projs[i]) # project to same dim

emb_flat.index_copy_(0, indices_i, emb_i)

embed = emb_flat.view(*inp.size(), self.d_proj)

embed.mul_(self.emb_scale)

return embed


举报

相关推荐

0 条评论