0
点赞
收藏
分享

微信扫一扫

pytorch 笔记:gather 函数

飞鸟不急 2022-01-06 阅读 64
torch.gather(input, dim, index, out=None) → Tensor

我们直接用例子来说明好了

import torch
a = torch.Tensor([[1,2],[3,4]])
torch.gather(a,
            0,
            index=torch.LongTensor([[0,0],[1,0]]))
'''
tensor([[1., 2.],
        [3., 2.]])
'''

这个怎么看呢

out[0][0]a[index[0][0]]  [0]]a[0][0]=1
out[1][0]a[index[1][0]]  [0]]a[1][0] =3
out[0][1]a[index[0][1]]  [1]]   a[0][1]=2
out[1][1]a[index[1][1]]  [1]]a[0][1]=2
举报

相关推荐

0 条评论