0
点赞
收藏
分享

微信扫一扫

【代码记录】pytorch矩阵取数据--避免for循环

路西法阁下 2022-02-23 阅读 18
import torch
input = [
    [2, 3, 4, 5, 0, 0],
    [1, 4, 3, 0, 0, 0],
    [4, 2, 2, 5, 7, 0],
    [1, 0, 0, 0, 0, 0]
]
torch.gather(torch.tensor(input),1,torch.tensor([[3],[2],[4],[0]]))

在这里插入图片描述

注意点 input 和index 需要转成tensor
1. input 和index 需要转成tensor

torch.gather 官网链接

举报

相关推荐

0 条评论