0
点赞
收藏
分享

微信扫一扫

pytorch使用布尔索引获取指定维度元素


对于一些任务,我们想从tensor中提取符合指定要求的数值,那么一般我们有两种方法,第一种是采用​​布尔索引​​​,第二种是使用​​masked_select()​​方法来实现。

其实还有一种方法​​torch.where()​​​,但是这个与上述两个方法不同,上述两个方法会把我们需要的数值挑出来形成一个​​一维张量​​​,对于where我们会得到与​​原来形状一样​​的tensor,所以本文只介绍上面两种方法。

方法一:采用布尔索引

该方法我们采用布尔索引进行提取,首先获取一个​​布尔矩阵mask​​,标记每个位置是否符合我们要求,符合则为True,不符则为False,然后我们会把True的位置提取出来。

a = torch.randn(3, 4)
print(a)

mask = a > 0
print(mask)

print(a[mask])

tensor([[ 0.5748,  1.4601,  1.8610, -0.8904],
[-1.5891, -1.2431, 0.1356, -0.6111],
[-0.5736, -0.7268, -0.2200, 0.4816]])
tensor([[ True, True, True, False],
[False, False, True, False],
[False, False, False, True]])
tensor([0.5748, 1.4601, 1.8610, 0.1356, 0.4816])

方法二:masked_select()

使用masked_select()方法同样可以实现,只需要将条件传入。

print(a.masked_select(a > 0))

tensor([0.5748, 1.4601, 1.8610, 0.1356, 0.4816])


举报

相关推荐

0 条评论