今天写代码遇到了torch.repeat_interleave
,去查了一下。在此记录。
torch.repeat_interleave
官方文档里边提示了这么一句话:This is different from torch.Tensor.repeat()
but similar to numpy.repeat
.
就是说它的功能和 torch.Tensor.repeat()
不太一样,更类似于numpy.repeat
,我也不怎么用numpy,所以这里就不解释写numpy的了。
torch.repeat_interleave(input, repeats, dim=None, *, output_size=None) → Tensor1
-
参数列表如下:
input
,就是你要执行repeat操作的张量。-
repeats
,你要重复的次数,可以是数字,也可以是张量。如果是张量的时候必须保证它能够进行广播(broadcast)。
-
可选参数
dim
就是要重复的维度了。如果不写会默认展开成一个向量,如果写了维度会在保持原来的维度基础上进行重复。
- 最后返回的是重复之后的张量。
代码示例:
- 先来看一下一个普通的向量的重复:
a = torch.tensor(1) res = a.repeat_interleave(5) print(res)
输出结果为:
tensor([1, 1, 1, 1, 1])
上边代码的写法等价于这个:
a = torch.tensor(1) res = torch.repeat_interleave(a,5) print(res)
- 看完向量看一下高维张量的重复:
a = torch.arange(6).reshape(2,1,3) res = torch.repeat_interleave(a,3,dim = 1) print(res) print(a.shape) print(res.shape)
输出结果为:
tensor([[[0, 1, 2],<br>
[0, 1, 2],<br>
[0, 1, 2]],<br>
[[3, 4, 5],<br>
[3, 4, 5],<br>
[3, 4, 5]]])<br>
torch.Size([2, 1, 3])<br>
torch.Size([2, 3, 3])我们可以看出在指定维度的时候会在相应的维度上展开。
a = torch.arange(6).reshape(2,1,3) res = torch.repeat_interleave(a,3) print(res) print(a.shape) print(res.shape)
输出结果:
tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5])<br>
torch.Size([2, 1, 3])<br>
torch.Size([18])反之,不指定结果的话就会将其复制之后展开成一个一维的向量。
- 最后看一下正常的二维向量的重复:
a = torch.arange(6).reshape(2,3) res = torch.repeat_interleave(a,3,dim=0) print(res) print(a.shape) print(res.shape)
输出结果为:
tensor([[0, 1, 2],<br>
[0, 1, 2],<br>
[0, 1, 2],<br>
[3, 4, 5],<br>
[3, 4, 5],<br>
[3, 4, 5]])<br>
torch.Size([2, 3])<br>
torch.Size([6, 3])这个结果应该没什么疑问,一个2*3的矩阵,对其第0维进行重复操作,重复三次,最后可以看到输出结果是对第0维度重复了三次。然后用下边这个代码看一下广播机制的重复操作:
a = torch.arange(6).reshape(2,3) res = torch.repeat_interleave(a,torch.tensor([1,3]),dim=0) res2 = torch.torch.repeat_interleave(a,torch.tensor([1,3,2]),dim=1) print(res) print(res2) print(a.shape) print(res.shape) print(res2.shape)
输出结果:
tensor([[0, 1, 2],<br>
[3, 4, 5],<br>
[3, 4, 5],<br>
[3, 4, 5]])<br>
tensor([[0, 1, 1, 1, 2, 2],<br>
[3, 4, 4, 4, 5, 5]])<br>
torch.Size([2, 3])<br>
torch.Size([4, 3])<br>
torch.Size([2, 6])我们使用了两次重复,先看一下a的形状是一个2*3的矩阵,也就是说我在某一维度上进行重复操作的时候,为了让它运行广播机制,那我张量传入的向量长度必须和矩阵某一维度上的长度一致。
比如
res
中我们是从第0维进行重复操作,a
的第0维是2,所以我们后边传入的张量长度是2。他们分别对应着:对第0维上第一个元素重复2次,对第0维上第二个元素重复三次。同理,
res
中我们是对第1维进行重复操作,a
的第0维是3,所以我们后边传入的张量长度是3。不重复解释了。
总结!
对于传入的参数可以是:
- Tensor,int
- Tensor,Tensor
- Tensor,int,dim*
- Tensor,Tensor,dim*
Tensor.repeat
上一个我们说是“This is different from torch.Tensor.repeat()
but similar to numpy.repeat
”。既然上一个更类似于numpy的repeat,那这个肯定不能类似于repeat了。这个和numpy.tile
更为相似。
这个比较简单,大致说一下子就OK。
torch.Tensor.repeat(*sizes) → Tensor
- 参数列表:只有一个参数,传入int或者需要重复操作的torch.Size。注意,参数必须大于等于要复制的张量。
- 结果返回一格张量
代码示例:
参数时候我们说“参数必须大于等于要复制的张量”,意思是说你二维的矩阵,传入的参数比如是两个及两个以上的torch.Size,不能传入一个int,只有一维的可以传入一个int。
-
二维矩阵:
a = torch.arange(6).reshape(2,3) res = a.repeat([2,3]) print(res)
输出结果:
tensor([[0, 1, 2, 0, 1, 2, 0, 1, 2],<br>
[3, 4, 5, 3, 4, 5, 3, 4, 5],<br>
[0, 1, 2, 0, 1, 2, 0, 1, 2],<br>
[3, 4, 5, 3, 4, 5, 3, 4, 5]])我们传入的是一个torch.Size,因此会在对应的维度上进行复制,不过和上一个函数的区别在于上一个函数是按元素复制,这个是按维度复制。(000111222和012012012的区别。)
-
一维向量:
a = torch.arange(6) res = a.repeat(2) print(res)
输出结果:
tensor([0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5])
在向里中传入int是可以使用的。你也可以试一下在二维矩阵复制时候传入int,会报错。
-
一维向量:
a = torch.arange(6) res = a.repeat([2,2]) print(res)
输出结果:
tensor([[0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5],<br>
[0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5]])只能传入等于或者大于原张量维度的torch.Size。