0
点赞
收藏
分享

微信扫一扫

Deep Coral loss

鱼满舱 2022-04-17 阅读 89
迁移学习
import torch


def CORAL(source, target):
    d = source.data.shape[1] #coral公式中的分母部分
    ns, nt = source.data.shape[0], target.data.shape[0]
    # source covariance
    xm = torch.mean(source, 0, keepdim=True) - source #对应着Cs的分子部分
    xc = xm.t() @ xm/(ns-1)  #对应着Cs的分子部分

    # target covariance
    xmt = torch.mean(target, 0, keepdim=True) - target#对应着Ct的分子部分
    xct = xmt.t() @ xmt/(nt-1)#对应着Ct的分子部分

    # frobenius norm between source and target
    loss = torch.mean(torch.mul((xc - xct), (xc - xct))) #Cs-Ct的点乘
    loss = loss/(4*d*d)

    return loss

Coral公式:

只做学习使用,作者也是看了别人进行了学习总结,希望能对你有所帮助。

举报

相关推荐

0 条评论