0
点赞
收藏
分享

微信扫一扫

tensorflow 1.15 KL loss 代码

Yaphets_巍 2022-07-27 阅读 53


from tensorflow.python.keras.utils import losses_utils

kl = tf.keras.losses.KLDivergence(
reduction = losses_utils.ReductionV2.NONE,
name = 'kullback_leibler_divergence')

kl_loss = tf.reduce_mean(kl(logit1, logit2))


举报

相关推荐

0 条评论