input_tensor = tf.reshape(input_tensor, [-1]) # 转成1d的input_tensor = tf.map_fn(lambda x: tf.cond(x > 0.0001, lambda: 1.0, lambda: 0.0), input_tensor)input_tensor = tf.reshape(input_tensor, [batch_size, the_origin_size]) # 再从1d转回来