- Pytorch默认的交叉熵函数使用
loss=(pred=浮点数, target=整数)
的形式
# Example of target with class indices
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
output.backward()
# Example of target with class probabilities
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5).softmax(dim=1)
output = loss(input, target)
output.backward()
- 但是当
target
需要为浮点数的时候,没法使用loss = nn.CrossEntropyLoss()
直接计算, 此处修改损失函数 - 也可参考标签平滑损失 [3]
def cross_entropy(pred, soft_targets):
logsoftmax = nn.LogSoftmax()
return torch.mean(torch.sum(- soft_targets * logsoftmax(pred), 1))