# seq: [batch,sequence_len,hidden_dim]
# cond: [batch,hidden_dim]
# lens: [batch]
# context: [batch,hidden_dim]
# scores: [batch,sequence_len]
def attend(seq, cond, lens):
"""
attend over the sequences `seq` using the condition `cond`.
"""
scores = cond.unsqueeze(1).expand_as(seq).mul(seq).sum(2)
max_len = max(lens)
for i, l in enumerate(lens):
if l < max_len:
scores.data[i, l:] = -np.inf
scores = F.softmax(scores, dim=1)
context = scores.unsqueeze(2).expand_as(seq).mul(seq).sum(1)
return context,
这个attend的输入是一个sequence和一个item的,
输出的context
是[batch_size, hidden_dim]
输出的score
是[batch_size, sequence_len] 不同于