0
点赞
收藏
分享

微信扫一扫

tf.reverse_sequence

颜路在路上 2023-01-16 阅读 94


tf.reverse_sequence 只翻转前n个数据,seq_lengths = [7, 2, 3, 5]就是inptuts 第一行的前7个,第二行的前2个…

import tensorflow as tf
seq_lengths = [7, 2, 3, 5]
inputs = [[1, 2, 3, 4, 5, 0, 0, 0], [1, 2, 0, 0, 0, 0, 0, 0],
[1, 2, 3, 4, 0, 0, 0, 0], [1, 2, 3, 4, 5, 6, 7, 8]]
output = tf.reverse_sequence(inputs, seq_lengths, seq_axis=1, batch_axis=0)
output

<tf.Tensor: shape=(4, 8), dtype=int32, numpy=
array([[0, 0, 5, 4, 3, 2, 1, 0],
[2, 1, 0, 0, 0, 0, 0, 0],
[3, 2, 1, 4, 0, 0, 0, 0],
[5, 4, 3, 2, 1, 6, 7, 8]], dtype=int32)>


举报

相关推荐

0 条评论