0
点赞
收藏
分享

微信扫一扫

torch.topk和tensor.topk

沪钢木子 2022-05-04 阅读 136
pytorch

https://pytorch.org/docs/stable/generated/torch.topk.html#torch.topk

1. 定义

按照指定的维度进行数值大小的排序,返回top-k个数值。

2. 二者的区别

torch.topk()tensor.topk()本质上是一个函数,只不过调用方式不同。

3. 具体用法(以tensor.topk为例)

tensor.topk(k, dim=None, largest=True, sorted=True, *, out=None)

参数:

  1. k -> top-k
  2. dim:按照哪个维度进行排序
  3. largest=True: 是否按照从大到小的顺序排序
  4. sorted:控制是否按排序顺序返回元素(如果为False,则只返回k个值,但这k个值并不会排序)

返回值:

  1. 返回top-k排序后的数值
  2. 返回top-k排序后的索引

3.1 例子1(正常使用):

>>> x = torch.arange(1, 10)
>>> x
tensor([1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> a, b = x.topk(k=5, dim=0)
>>> a
tensor([9, 8, 7, 6, 5])
>>> b
tensor([8, 7, 6, 5, 4])
>>> x.topk(k=5, dim=0)
torch.return_types.topk(
values=tensor([9, 8, 7, 6, 5]),
indices=tensor([8, 7, 6, 5, 4]))

3.2 例子2(从小到大排序)

>>> x = torch.arange(1, 11)
>>> x
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10])
>>> x.topk(k=5, dim=0, largest=False, sorted=True)
torch.return_types.topk(
values=tensor([1, 2, 3, 4, 5]),
indices=tensor([0, 1, 2, 3, 4]))

3.3 例子3(只挑出topk个值,但不进行排序)

>>> x = torch.arange(1, 11)
>>> x
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10])
>>> x.topk(k=5, dim=0, sorted=False)
torch.return_types.topk(
values=tensor([ 9, 10,  8,  7,  6]),
indices=tensor([8, 9, 7, 6, 5]))
举报

相关推荐

0 条评论