torch.squeeze
torch.squeeze(input, dim=None, out=None)
- 将输入张量形状中的
1
去除并返回。 - 当给定
dim
时,那么挤压操作只在给定维度上。- 这个意思是只在给定维度上挤压张量中的1,如果给定维度上的数字不是1那也不会进行挤压。
import torch
# 随机生成一个多维整数张量
array = torch.rand(2, 1, 2, 1, 3) * 10
array = torch.ceil(array)
# 打印张量和性状
print(array.shape)
print(array)
# squeeze方法之后打印其形状
array = torch.squeeze(array)
print(array.shape)
print(array)
打印结果我放记事本里排一下版:
-
原来是清晰的2×1×2×1×3
-
经过
torch.squeeze
之后,将其为1
的维度都进行压缩。 -
现在给定压缩的维度:
import torch array = torch.rand(2, 1, 2, 1, 3) print(array.shape) a1 = torch.squeeze(array,2) print(a1.shape) a2 = torch.squeeze(array,0) print(a2.shape) a3 = torch.squeeze(array,3) print(a3.shape)
我之前理解错了,我以为给定维度是多少就压缩多少维度
比如(2, 1, 2, 1, 3)执行torch.squeeze(array,2)之后就会变成(1, 1, 3),这是错的嗷。
实际上是在你给定的维度上,如果是1那就压缩,不是1就不进行压缩.
torch.unsqueeze
torch.unsqueeze(input, dim, out=None)
返回一个新的张量,对输入的制定位置插入维1
。
注意: 返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。
如果dim
为正就是按照索引序号,如果为负就是从后倒着数。
官方文档写的:如果dim为负就是dim+input.dim()+1
按照索引来是第0个,...第三个。input.dim()+1之后就是第一个,...第四个。加上dim为复数,所以就是从后倒着数,负几就是倒数第几位。
import torch
array = torch.rand(2, 1, 2, 3)
print(array.shape)
a1 = torch.unsqueeze(array,2)
print(a1.shape)
a2 = torch.unsqueeze(array,0)
print(a2.shape)
a3 = torch.unsqueeze(array,4)
print(a3.shape)
a4 = torch.unsqueeze(array,-1)
print(a4.shape)
a5 = torch.unsqueeze(array,-3)
print(a5.shape)
总结
二者可以方便的用于维度修改,给张量的计算提供便利