0
点赞
收藏
分享

微信扫一扫

TensorFlow 的KL divergence的实现

扬帆远航_df7c 2022-07-27 阅读 46


def kl_for_log_probs(log_p, log_q):
p = tf.exp(log_p)
neg_ent = tf.reduce_sum(p * log_p, axis=-1)
neg_cross_ent = tf.reduce_sum(p * log_q, axis=-1)
kl = neg_ent - neg_cross_ent
return

传​​log_probs = tf.nn.log_softmax(logits, axis=-1)​


举报

相关推荐

0 条评论