给定两个句子,每个句子由3个单词组成,每个单词的embedding尺寸为10,则可得到一个尺寸为[2,3,10]的张量
现在要获取第0个句子的第0个单词和第一个句子的第2个单词的词嵌入组成的张量
可通过如下代码实现
input = torch.randint(0, 10, (2, 3, 10))
index = torch.tensor([0, 2])
def select(src, index):
x, y, z = src.size()
index = index.view(x, 1)
mask = torch.full((x, y), False).scatter(1, index, True).unsqueeze(-1).repeat(1, 1, z)
res = torch.masked_select(src, mask).view(x, z)
return res
res = select(input, index)
print(input)
print(index)
print(res)