import torch
import numpy as np
input = torch.from_numpy(np.array([[1,2,3,4],[5,6,7,8],[9,10,11,12]]))
length = [4,4,4] # lengths array has to be sorted in decreasing order
result = torch.nn.utils.rnn.pack_padded_sequence(input,lengths=length,batch_first=True)
print(result)
input = torch.randn(8,10,300)
length = [10,10,10,10,10,10,10,10]
perm = torch.LongTensor(range(8))
result = torch.nn.utils.rnn.pack_padded_sequence(input[perm],lengths=length,batch_first=True)
print(result)
print结果:
PackedSequence(data=
1
5
9
2
6
10
3
7
11
4
8
12
[torch.LongTensor of size 12]
, batch_sizes=[3, 3, 3, 3])
PackedSequence(data=
-1.0129e+00 -1.5844e+00 -4.0759e-02 ... -9.6837e-01 5.7004e-01 -1.6919e-01
5.4662e-01 8.6405e-01 7.8474e-01 ... 5.2483e-01 1.9581e-02 7.2974e-01
7.7569e-02 -7.1858e-03 -2.9401e-01 ... -2.5550e-01 6.6782e-01 5.6192e-01
... ⋱ ...
-6.7423e-01 -1.6357e+00 1.4011e+00 ... 6.4557e-02 9.4204e-01 6.1430e-01
-1.0300e+00 4.6429e-01 1.4219e+00 ... 2.9208e+00 1.5081e+00 7.6805e-02
-7.2723e-01 6.2770e-01 -6.2025e-01 ... -3.5286e-01 1.0199e+00 8.8412e-01
[torch.FloatTensor of size 80x300]
, batch_sizes=[8, 8, 8, 8, 8, 8, 8, 8, 8, 8])