import tensorflow as tf
batch_size = 4
a = tf.one_hot(tf.range(batch_size), batch_size)
sess = tf.Session()
print(sess.run(a))
print结果:
[[1. 0. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]]
微信扫一扫
import tensorflow as tf
batch_size = 4
a = tf.one_hot(tf.range(batch_size), batch_size)
sess = tf.Session()
print(sess.run(a))
print结果:
[[1. 0. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]]
相关推荐