def custom_net(cell, inputs, init_state, timesteps, time_major=False, scope='custom_net_0'):
# convert to time major format
if not time_major:
inputs_tm = tf.transpose(inputs, [1, 0, -1],name="input_time_major")
# collection of states and outputs
states, outputs = [init_state], []
with tf.variable_scope(scope):
for i in range(timesteps):
if i > 0:
tf.get_variable_scope().reuse_variables()
output, state = cell(inputs_tm[i], states[-1])
outputs.append(output)
states.append(state)
return tf.stack(outputs), tf.stack(states[1:])
代码摘自https://github.com/ai-guild/r-net/blob/master/lib/recurrence.py