0
点赞
收藏
分享

微信扫一扫

RNN使用示例

飞进科技 2022-02-25 阅读 91

RNN使用示例

import torch 
import torch.nn as nn
rnn = nn.RNN(5,6,1)
input = torch.randn(1,3,5)
h0 = torch.randn(1,3,6)
print(input)
print(h0)
output,hn = rnn(input,h0)
print(output)
print(hn)


结果

tensor([[[ 1.0300,  0.4837,  1.2938,  0.6715,  1.5796],
         [-1.0450,  0.6950, -1.6691, -0.3455,  0.7442],
         [-0.6285, -0.5371,  0.8405,  0.5266, -0.5502]]])
tensor([[[-0.7893,  1.0546, -0.9698, -0.2873,  0.3692,  1.4670],
         [-1.6947, -1.3487, -0.7390,  1.0247,  1.6895, -1.7836],
         [-1.2155, -1.1685, -0.2168, -0.1386,  0.9072,  2.5924]]])
tensor([[[-0.2580, -0.2601,  0.9075,  0.6198,  0.3910,  0.8222],
         [ 0.5959,  0.7814,  0.4612, -0.3636,  0.6169,  0.1007],
         [-0.1534,  0.3039,  0.1671,  0.7664, -0.2574, -0.6732]]],
       grad_fn=<StackBackward0>)
tensor([[[-0.2580, -0.2601,  0.9075,  0.6198,  0.3910,  0.8222],
         [ 0.5959,  0.7814,  0.4612, -0.3636,  0.6169,  0.1007],
         [-0.1534,  0.3039,  0.1671,  0.7664, -0.2574, -0.6732]]],
       grad_fn=<StackBackward0>)
举报

相关推荐

0 条评论