0
点赞
收藏
分享

微信扫一扫

tensorflow set contain


如果知道tensor的长度,比较简单

import tensorflow as tf

one_vector = tf.constant([0,111,222,333,0])
tmp_list = []
for tmp_index in range(0, 3):
tmp_list.append(tf.cast(tf.math.equal(one_vector[tmp_index:tmp_index+3],
tf.constant([111,222,333])),tf.int32))

total = tf.reduce_sum(tmp_list)

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

print(sess.run(tmp_list))
print(sess.run(total))

print结果:
​​​[array([0, 0, 0], dtype=int32), array([1, 1, 1], dtype=int32), array([0, 0, 0], dtype=int32)]​​​​3​


举报

相关推荐

0 条评论