0
点赞
收藏
分享

微信扫一扫

pytorch,nonzero 实例 使用

心存浪漫 2022-07-27 阅读 72


import torch
input_tensor = torch.tensor([1,2,3,4,5])
mask = input_tensor>3
print(mask)
indexes = mask.nonzero().squeeze()
print(indexes)

tensor([0, 0, 0, 1, 1], dtype=torch.uint8)
tensor([3, 4])


举报

相关推荐

0 条评论