RNN使用示例
import torch
import torch.nn as nn
rnn = nn.RNN(5,6,1)
input = torch.randn(1,3,5)
h0 = torch.randn(1,3,6)
print(input)
print(h0)
output,hn = rnn(input,h0)
print(output)
print(hn)
结果
tensor([[[ 1.0300, 0.4837, 1.2938, 0.6715, 1.5796],
[-1.0450, 0.6950, -1.6691, -0.3455, 0.7442],
[-0.6285, -0.5371, 0.8405, 0.5266, -0.5502]]])
tensor([[[-0.7893, 1.0546, -0.9698, -0.2873, 0.3692, 1.4670],
[-1.6947, -1.3487, -0.7390, 1.0247, 1.6895, -1.7836],
[-1.2155, -1.1685, -0.2168, -0.1386, 0.9072, 2.5924]]])
tensor([[[-0.2580, -0.2601, 0.9075, 0.6198, 0.3910, 0.8222],
[ 0.5959, 0.7814, 0.4612, -0.3636, 0.6169, 0.1007],
[-0.1534, 0.3039, 0.1671, 0.7664, -0.2574, -0.6732]]],
grad_fn=<StackBackward0>)
tensor([[[-0.2580, -0.2601, 0.9075, 0.6198, 0.3910, 0.8222],
[ 0.5959, 0.7814, 0.4612, -0.3636, 0.6169, 0.1007],
[-0.1534, 0.3039, 0.1671, 0.7664, -0.2574, -0.6732]]],
grad_fn=<StackBackward0>)