0
点赞
收藏
分享

微信扫一扫

220315-PyTorch中target为浮点数float时的交叉熵loss计算


  • Pytorch默认的交叉熵函数使用​​loss=(pred=浮点数, target=整数)​​的形式

# Example of target with class indices
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
output.backward()
# Example of target with class probabilities
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5).softmax(dim=1)
output = loss(input, target)
output.backward()

  • 但是当​​target​​需要为浮点数的时候,没法使用​​loss = nn.CrossEntropyLoss()​​直接计算, 此处修改损失函数
  • 也可参考标签平滑损失 [3]

def cross_entropy(pred, soft_targets):
logsoftmax = nn.LogSoftmax()
return torch.mean(torch.sum(- soft_targets * logsoftmax(pred), 1))

220315-PyTorch中target为浮点数float时的交叉熵loss计算_html


举报

相关推荐

0 条评论