0
点赞
收藏
分享

微信扫一扫

tensorflow 自定义函数

秀儿2020 2022-07-27 阅读 86


函数里也只能用 TF自己的算子 对tensor操作

import tensorflow as tf


batch_size = 4
hidden_size = 2
input_tensor = tf.random_normal([batch_size, hidden_size])

def true_function():
return 1.0

def false_function():
return 0.0

def map_function(each_value):
print(each_value) # a tensor
return tf.cond(each_value > 0.5, true_function, false_function)

input_tensor = tf.reshape(input_tensor, [-1]) # 转成1d的

input_tensor = tf.map_fn(map_function,
input_tensor)

input_tensor = tf.reshape(input_tensor, [batch_size, hidden_size])

sess = tf.Session()

print(sess.run(input_tensor))


举报

相关推荐

0 条评论