torch.max(input, dim)使用说明
pred = torch.max(input, dim)
返回每行(dim=1)或每列(dim=0)的最大值。
_, pred = torch.max(input, dim)
仅返回每行(dim=1)或每列(dim=0)中最大值所在位置。
示例:
import torch
# 构造一个5x3随机初始化的矩阵
x = torch.rand(5, 3)
print('input: ', x)
print('-'*10)
y1 = torch.max(x, 1)
print('max by row: ', y1)
print('-'*10)
y2 = torch.max(x, 0)
print('max by col: ', y2)
print('-'*10)
_, y3 = torch.max(x, 1)
print('max index by row: ', y3)
print('-'*10)
_, y4 = torch.max(x, 0)
print('max index by col: ', y4)
输出结果:
input: tensor([[0.5504, 0.3160, 0.2448],
[0.8694, 0.3295, 0.2085],
[0.5530, 0.9984, 0.3531],
[0.2874, 0.1025, 0.9419],
[0.0867, 0.4234, 0.8334]])
----------
max by row: torch.return_types.max(
values=tensor([0.5504, 0.8694, 0.9984, 0.9419, 0.8334]),
indices=tensor([0, 0, 1, 2, 2]))
----------
max by col: torch.return_types.max(
values=tensor([0.8694, 0.9984, 0.9419]),
indices=tensor([1, 2, 3]))
----------
max index by row: tensor([0, 0, 1, 2, 2])
----------
max index by col: tensor([1, 2, 3])