0
点赞
收藏
分享

微信扫一扫

linear self attention 的pytorch实现 和使用

捌柒陆壹 2022-07-27 阅读 77
编程语言


# For summarizing a set of vectors into a single vector
class LinearSelfAttn(nn.Module):
"""Self attention over a sequence:
* o_i = softmax(Wx_i) for x_i in X.
"""
def __init__(self, input_size):
super(LinearSelfAttn, self).__init__()
self.linear = nn.Linear(input_size, 1)

def forward(self, x, x_mask):
"""
x = [batch, len, hdim]
x_mask = [batch, len]
"""
x = dropout(x, p=my_dropout_p, training=self.training)

x_flat = x.contiguous().view(-1, x.size(-1))
scores = self.linear(x_flat).view(x.size(0), x.size(1))
scores.data.masked_fill_(x_mask.data, -float('inf'))
alpha = F.softmax(scores, dim=1)
return alpha # [batch, len]

# bmm: batch matrix multiplication
# unsqueeze: add singleton dimension
# squeeze: remove singleton dimension
def weighted_avg(x, weights):
""" x = [batch, len, d]
weights = [batch, len]
"""
return weights.unsqueeze(1).bmm(x).squeeze(1)

使用:

# [batch,sentence_len,hidden_dim] -> [batch,sentence_len]
sentence_weights = linear_self_attn(sentence_hiddens, sentence_mask)

# [batch,hidden_dim]
sentence_avg_hidden = weighted_avg(sentence_hiddens, sentence_weights)


举报

相关推荐

0 条评论