0
点赞
收藏
分享

微信扫一扫

pytorch,position embedding 代码实现

jjt二向箔 2022-07-27 阅读 28
编程语言


class PositionalEmbedding(nn.Module):
def __init__(self, demb):
super(PositionalEmbedding, self).__init__()

self.demb = demb

inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
self.register_buffer('inv_freq', inv_freq)

# pos_seq = pos_seq = torch.arange(seq_len-1, -1, -1.0)
def forward(self, pos_seq, batch_size=None):
sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)

if batch_size is not None:
return pos_emb[:,None,:].expand(-1, batch_size, -1)
else:
return pos_emb[:,None,:]


举报

相关推荐

0 条评论