0
点赞
收藏
分享

微信扫一扫

pytorch的tf.slice

向上的萝卜白菜 2022-07-27 阅读 66


import torch
A_idx = torch.LongTensor([0, 2]) # the index vector
B = torch.LongTensor([[1, 2, 3], [4, 5, 6]])
C = B.index_select(1, A_idx)
# 1 3
# 4 6


举报

相关推荐

0 条评论