0
点赞
收藏
分享

微信扫一扫

torch.unsqueeze和torch.squeeze维度修改 #yyds干货盘点#

以前干嘛去了 2022-03-17 阅读 61

torch.squeeze

torch.squeeze(input, dim=None, out=None)

  • 将输入张量形状中的1 去除并返回。
  • 当给定dim时,那么挤压操作只在给定维度上。
    • 这个意思是只在给定维度上挤压张量中的1,如果给定维度上的数字不是1那也不会进行挤压。
import torch

# 随机生成一个多维整数张量
array = torch.rand(2, 1, 2, 1, 3) * 10
array = torch.ceil(array)

# 打印张量和性状
print(array.shape)
print(array)

# squeeze方法之后打印其形状
array = torch.squeeze(array)
print(array.shape)
print(array)

image.png

打印结果我放记事本里排一下版:

  • 原来是清晰的2×1×2×1×3
    image.png

  • 经过torch.squeeze之后,将其为1的维度都进行压缩。

    image.png

  • 现在给定压缩的维度:

    import torch
    
    array = torch.rand(2, 1, 2, 1, 3)
    
    print(array.shape)
    
    a1 = torch.squeeze(array,2)
    print(a1.shape)
    
    a2 = torch.squeeze(array,0)
    print(a2.shape)
    
    a3 = torch.squeeze(array,3)
    print(a3.shape)

    image.png

    我之前理解错了,我以为给定维度是多少就压缩多少维度

    • 比如(2, 1, 2, 1, 3)执行torch.squeeze(array,2)之后就会变成(1, 1, 3),这是错的嗷。
      实际上是在你给定的维度上,如果是1那就压缩,不是1就不进行压缩.

torch.unsqueeze

torch.unsqueeze(input, dim, out=None)

返回一个新的张量,对输入的制定位置插入维1

注意: 返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。

如果dim为正就是按照索引序号,如果为负就是从后倒着数。

官方文档写的:如果dim为负就是dim+input.dim()+1

image.png

按照索引来是第0个,...第三个。input.dim()+1之后就是第一个,...第四个。加上dim为复数,所以就是从后倒着数,负几就是倒数第几位。

import torch

array = torch.rand(2, 1, 2, 3)

print(array.shape)

a1 = torch.unsqueeze(array,2)
print(a1.shape)

a2 = torch.unsqueeze(array,0)
print(a2.shape)

a3 = torch.unsqueeze(array,4)
print(a3.shape)

a4 = torch.unsqueeze(array,-1)
print(a4.shape)

a5 = torch.unsqueeze(array,-3)
print(a5.shape)

image.png

总结

二者可以方便的用于维度修改,给张量的计算提供便利

举报

相关推荐

0 条评论