torch.gather(input, dim, index, out=None) → Tensor
我们直接用例子来说明好了
import torch
a = torch.Tensor([[1,2],[3,4]])
torch.gather(a,
0,
index=torch.LongTensor([[0,0],[1,0]]))
'''
tensor([[1., 2.],
[3., 2.]])
'''
这个怎么看呢
out[0][0] | a[index[0][0]] [0]] | a[0][0]=1 |
out[1][0] | a[index[1][0]] [0]] | a[1][0] =3 |
out[0][1] | a[index[0][1]] [1]] | a[0][1]=2 |
out[1][1] | a[index[1][1]] [1]] | a[0][1]=2 |