reshape的使用
1、reshape(-1)
print("===================test reshape(-1)==============================")
test_arr = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
print(test_arr)
print(test_arr.reshape(-1))
test_arr = torch.Tensor([[1, 4, 7, 10], [2, 5, 8, 11], [3, 6, 9, 12]])
print(test_arr)
print(test_arr.reshape(-1))
结果:按照行的顺序将数据拉长
===================test reshape(-1)==============================
tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]])
tensor([[ 1., 2., 3.],
[ 4., 5., 6.],
[ 7., 8., 9.],
[10., 11., 12.]])
[[ 1 4 7 10]
[ 2 5 8 11]
[ 3 6 9 12]]
[ 1 4 7 10 2 5 8 11 3 6 9 12]
2、reshape维度不变
print("===================test reshape==============================")
test_arr = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
print(test_arr)
print(test_arr.reshape(-1, 3))
test_arr = np.array([[1, 4, 7, 10], [2, 5, 8, 11], [3, 6, 9, 12]])
print(test_arr)
print(test_arr.reshape(-1, 3))
结果:按照行的顺序将数据重新排列
===================test reshape==============================
tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]])
tensor([[ 1., 2., 3.],
[ 4., 5., 6.],
[ 7., 8., 9.],
[10., 11., 12.]])
[[ 1 4 7 10]
[ 2 5 8 11]
[ 3 6 9 12]]
[[ 1 4 7]
[10 2 5]
[ 8 11 3]
[ 6 9 12]]
3、reshape增加一个维度
print("===================test reshape==============================")
test_arr = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
print(test_arr)
print(test_arr.reshape(-1, 2, 3, 2))
test_arr = np.array([[1, 4, 7, 10], [2, 5, 8, 11], [3, 6, 9, 12]])
print(test_arr)
print(test_arr.reshape((-1, 2, 3, 2)))
结果:按照行的顺序将数据重新排列
===================test reshape==============================
tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]])
tensor([[[[ 1., 2.],
[ 3., 4.],
[ 5., 6.]],
[[ 7., 8.],
[ 9., 10.],
[11., 12.]]]])
[[ 1 4 7 10]
[ 2 5 8 11]
[ 3 6 9 12]]
[[[[ 1 4]
[ 7 10]
[ 2 5]]
[[ 8 11]
[ 3 6]
[ 9 12]]]]