0
点赞
收藏
分享

微信扫一扫

04 Transformer 中的位置编码的 Pytorch 实现

一脸伟人痣 2022-12-13 阅读 124

1:10 点赞

16:00


04 Transformer 中的位置编码的 Pytorch 实现_ss


我爱你

你爱我

04 Transformer 中的位置编码的 Pytorch 实现_ss_02



1401

class PositionalEncoding(nn.Module):

def __init__(self, dim, dropout, max_len=5000):
super(PositionalEncoding, self).__init__()

if dim % 2 != 0:
raise ValueError("Cannot use sin/cos positional encoding with "
"odd dim (got dim={:d})".format(dim))

"""
构建位置编码pe
pe公式为:
PE(pos,2i/2i+1) = sin/cos(pos/10000^{2i/d_{model}})
"""
pe = torch.zeros(max_len, dim) # max_len 是解码器生成句子的最长的长度,假设是 10
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *
-(math.log(10000.0) / dim)))


pe[:, 0::2] = torch.sin(position.float() * div_term)
pe[:, 1::2] = torch.cos(position.float() * div_term)
pe = pe.unsqueeze(1)
self.register_buffer('pe', pe)
self.drop_out = nn.Dropout(p=dropout)
self.dim = dim

def forward(self, emb, step=None):

emb = emb * math.sqrt(self.dim)

if step is None:
emb = emb + self.pe[:emb.size(0)]
else:
emb = emb + self.pe[step]
emb = self.drop_out(emb)
return emb



举报

相关推荐

0 条评论