0
点赞
收藏
分享

微信扫一扫

tf.dynamic_stitch 和  tf.dynamic_partition

菜头粿子园 2023-01-13 阅读 123


import  tensorflow as tf
x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4])

#判断x里面的元素是否是1
condition_mask=tf.not_equal(x,tf.constant(-1.))

#[ True, False, True, True, False, True]

#将张量拆成两个,按照condition_mask的对应位置
partitioned_data = tf.dynamic_partition(
x, tf.cast(condition_mask, tf.int32) , 2)


#partitioned_data[0]=[-1., -1.]
#partitioned_data[1]=[2.1, 7.2, 6.3, 9.4]


partitioned_data[1] = partitioned_data[1] + 1.0



#这行代码是提取索引位置
condition_indices = tf.dynamic_partition(
tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2)


x = tf.dynamic_stitch(condition_indices, partitioned_data)
# Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain
# unchanged.


举报

相关推荐

0 条评论