0
点赞
收藏
分享

微信扫一扫

torch.masked_select实例

狐沐说 2022-05-06 阅读 107

给定两个句子,每个句子由3个单词组成,每个单词的embedding尺寸为10,则可得到一个尺寸为[2,3,10]的张量
现在要获取第0个句子的第0个单词和第一个句子的第2个单词的词嵌入组成的张量
可通过如下代码实现

input = torch.randint(0, 10, (2, 3, 10))
index = torch.tensor([0, 2])

def select(src, index):
    x, y, z = src.size()
    index = index.view(x, 1)
    mask = torch.full((x, y), False).scatter(1, index, True).unsqueeze(-1).repeat(1, 1, z)
    res = torch.masked_select(src, mask).view(x, z)
    return res
res = select(input, index)
print(input)
print(index)
print(res)
举报

相关推荐

0 条评论