import tensorflow as tf
batch_size = 3
seq_len = 5
sequence_tensor = tf.constant([[[1,1],[2,2],[3,3],[4,4],[5,5]],[[6,6],[7,7],[8,8],[9,9],[10,10]],[[11,11],[12,12],[13,13],[14,14],[15,15]]])
batch_index = tf.cast(tf.range(0, batch_size), tf.int32)
sequence_index = tf.cast(tf.ones([batch_size]),tf.int32)
index = tf.stack([batch_index, sequence_index], -1)
result = tf.gather_nd(params=sequence_tensor,indices=index)
sess = tf.Session()
print(sess.run(sequence_tensor))
print(sess.run(index))
print(sess.run(result))
不误导,想找标准的阅读理解任务代码还是 https://github.com/google-research/bert/blob/master/run_squad.py