0
点赞
收藏
分享

微信扫一扫

pytorch attend操作 代码

時小白 2022-07-27 阅读 74


# seq: [batch,sequence_len,hidden_dim]
# cond: [batch,hidden_dim]
# lens: [batch]

# context: [batch,hidden_dim]
# scores: [batch,sequence_len]
def attend(seq, cond, lens):
"""
attend over the sequences `seq` using the condition `cond`.
"""
scores = cond.unsqueeze(1).expand_as(seq).mul(seq).sum(2)
max_len = max(lens)
for i, l in enumerate(lens):
if l < max_len:
scores.data[i, l:] = -np.inf
scores = F.softmax(scores, dim=1)
context = scores.unsqueeze(2).expand_as(seq).mul(seq).sum(1)
return context,

这个attend的输入是一个sequence和一个item的,
输出的​​​context​​​是[batch_size, hidden_dim]
输出的​​​score​​​是[batch_size, sequence_len] 不同于


举报

相关推荐

0 条评论