文章目录
1. torch.take
torch.take(input, index) → Tensor
返回一个新的张量,其输入元素为给定指标。输入张量被看成是一维张量。结果与指标的形状相同。
分两步:
- 将输入input展开成一个一维张量
- 根据index序号进行索引input里面的值
import torch
input = torch.tensor([[4, 3, 5], [6, 7, 8]])
index = torch.tensor([0, 2, 5])
output = torch.take(input,index)
print(f"input={input}")
# input=tensor([[4, 3, 5],
# [6, 7, 8]])
print(f"index={index}")
# index=tensor([0, 2, 5])
print(f"output={output}")
# output=tensor([4, 5, 8])