class PositionalEmbedding(nn.Module):
def __init__(self, demb):
super(PositionalEmbedding, self).__init__()
self.demb = demb
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
self.register_buffer('inv_freq', inv_freq)
# pos_seq = pos_seq = torch.arange(seq_len-1, -1, -1.0)
def forward(self, pos_seq, batch_size=None):
sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
if batch_size is not None:
return pos_emb[:,None,:].expand(-1, batch_size, -1)
else:
return pos_emb[:,None,:]