函数里也只能用 TF自己的算子 对tensor操作
import tensorflow as tf
batch_size = 4
hidden_size = 2
input_tensor = tf.random_normal([batch_size, hidden_size])
def true_function():
return 1.0
def false_function():
return 0.0
def map_function(each_value):
print(each_value) # a tensor
return tf.cond(each_value > 0.5, true_function, false_function)
input_tensor = tf.reshape(input_tensor, [-1]) # 转成1d的
input_tensor = tf.map_fn(map_function,
input_tensor)
input_tensor = tf.reshape(input_tensor, [batch_size, hidden_size])
sess = tf.Session()
print(sess.run(input_tensor))