维度变换
Operation
View/reshape
>>> a=torch.rand(4,1,28,28)
>>> a.shape
torch.Size([4, 1, 28, 28])
>>> a.view(4,28*28)
tensor([[0.1881, 0.6594, 0.4549, ..., 0.5385, 0.8488, 0.4619],
[0.7678, 0.0459, 0.3048, ..., 0.0318, 0.7407, 0.7854],
[0.7252, 0.7116, 0.4051, ..., 0.0523, 0.1067, 0.4984],
[0.6943, 0.5723, 0.4737, ..., 0.5214, 0.5718, 0.7721]])
>>> a.view(4,28*28).shape
torch.Size([4, 784])
>>> a.view(4*28,28).shape
torch.Size([112, 28])
>>> a.view(4*1,28,28).shape
torch.Size([4, 28, 28])
Squeeze /unsqueeze
unsqueeze
>>> a.shape
torch.Size([4, 1, 28, 28])
>>> a.unsqueeze(0).shape
torch.Size([1, 4, 1, 28, 28])
>>> a.unsqueeze(-1).shape
torch.Size([4, 1, 28, 28, 1])
>>> a.unsqueeze(4).shape
torch.Size([4, 1, 28, 28, 1])
>>> a.unsqueeze(-4).shape
torch.Size([4, 1, 1, 28, 28])
>>> a.unsqueeze(-5).shape
torch.Size([1, 4, 1, 28, 28])
>>> a.unsqueeze(5).shape
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
IndexError: Dimension out of range (expected to be in range of [-5, 4], but got 5)
>>> a=torch.tensor([1.2,2.3])
>>> a.unsqueeze(-1)
tensor([[1.2000],
[2.3000]])
>>> a.unsqueeze(0)
tensor([[1.2000, 2.3000]])
Pos.Idx | 0 | 1 | 2 | 3 |
---|---|---|---|---|
4 | 3 | 28 | 28 | |
Neg.Idx | -4 | -3 | -2 | -1 |
For example
>>> b=torch.rand(32)
>>> f=torch.rand(4,32,14,14)
>>> b=b.unsqueeze(1).unsqueeze(2).unsqueeze(0)
>>> b.shape
torch.Size([1, 32, 1, 1])
>>> (b+f).shape
torch.Size([4, 32, 14, 14])
squeeze
>>> b.shape
torch.Size([1, 32, 1, 1])
>>> b.squeeze().shape
torch.Size([32])
>>> b.squeeze(0).shape
torch.Size([32, 1, 1])
>>> b.squeeze(-1).shape
torch.Size([1, 32, 1])
>>> b.squeeze(1).shape
torch.Size([1, 32, 1, 1])
>>> b.squeeze(-4).shape
torch.Size([32, 1, 1])
Expand / repeat
Expand: broadcasting(推荐使用)
Repeat: memory copied
Expand/expand_as
>>> b.shape
torch.Size([1, 32, 1, 1])
>>> b.expand(4,32,14,14).shape
torch.Size([4, 32, 14, 14])
>>> b.expand(-1,32,-1,-1).shape
torch.Size([1, 32, 1, 1])
>>> b.expand(-1,32,-1,-4).shape
torch.Size([1, 32, 1, -4])
Repeat
>>> b.shape
torch.Size([1, 32, 1, 1])
>>> b.repeat(4,32,1,1).shape
torch.Size([4, 1024, 1, 1])
>>> b.repeat(4,1,1,1).shape
torch.Size([4, 32, 1, 1])
>>> b.repeat(4,1,32,32).shape
torch.Size([4, 32, 32, 32])
Transpose/t/permute
.t
>>> b.t()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 4D
>>> a=torch.randn(3,4)
>>> a.t()
tensor([[-2.0308, 0.8653, -1.1071],
[ 1.3861, 0.6934, -0.7272],
[ 0.2027, -0.8399, -0.2866],
[ 0.9593, -1.1604, 0.0354]])
Transpose
>>> a.shape
torch.Size([4, 3, 32, 32])
>>> a1=a.transpose(1,3).view(4,3*32*32).view(4,3,32,32)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous sub
spaces). Use .reshape(...) instead.
>>> a1=a.transpose(1,3).contiguous().view(4,3*32*32).view(4,3,32,32)
>>> a2=a.transpose(1,3).contiguous().view(4,3*32*32).view(4,32,32,3).transpose(1,3)
>>> a1.shape,a2.shape
(torch.Size([4, 3, 32, 32]), torch.Size([4, 3, 32, 32]))
>>> torch.all(torch.eq(a,a1))
tensor(False)
>>> torch.all(torch.eq(a,a2))
tensor(True)
permute
>>> a=torch.rand(4,3,28,28)
>>> a.transpose(1,3).shape
torch.Size([4, 28, 28, 3])
>>> b=torch.rand(4,3,28,32)
>>> b.transpose(1,3).shape
torch.Size([4, 32, 28, 3])
>>> b.transpose(1,3).transpose(1,2).shape
torch.Size([4, 28, 32, 3])
>>> b.permute(0,2,3,1).shape
torch.Size([4, 28, 32, 3])