0
点赞
收藏
分享

微信扫一扫

LSTM GRU 得到所有的state 而不是最后一个state

孟祥忠诗歌 2022-07-27 阅读 56


def custom_net(cell, inputs, init_state, timesteps, time_major=False, scope='custom_net_0'):
# convert to time major format
if not time_major:
inputs_tm = tf.transpose(inputs, [1, 0, -1],name="input_time_major")
# collection of states and outputs
states, outputs = [init_state], []

with tf.variable_scope(scope):

for i in range(timesteps):
if i > 0:
tf.get_variable_scope().reuse_variables()
output, state = cell(inputs_tm[i], states[-1])
outputs.append(output)
states.append(state)

return tf.stack(outputs), tf.stack(states[1:])

代码摘自​​https://github.com/ai-guild/r-net/blob/master/lib/recurrence.py​​


举报

相关推荐

0 条评论