0
点赞
收藏
分享

微信扫一扫

语义分割单通道和多通道输出交叉熵损失函数的计算问题

语义分割单通道和多通道输出交叉熵损失函数的计算问题_二分类本文验证了语义分割任务下,单通道输出和多通道输出时,使用交叉熵计算损失值的细节问题。对比验证了使用简单的函数和自带损失函数的结果,通过验证,进一步加强了对交叉熵的理解。


摘要

本文验证了语义分割任务下,单通道输出和多通道输出时,使用交叉熵计算损失值的细节问题。对比验证了使用简单的函数和自带损失函数的结果,通过验证,进一步加强了对交叉熵的理解。

交叉熵损失函数

交叉熵损失函数的原理和推导过程,可以参考​​这篇博文​​,交叉熵的计算公式如下:

\[CE(p,q) = -p*log(q) \]

其中 \(q\) 为预测的概率,\(q∈[0,1]\), \(p\) 为标签,\(p∈\{0,1\}\)。

而交叉熵损失函数则是利用上式计算每一个分类的交叉熵之和。对于概率,所有分类的概率 \(q\) 之和满足相加等于1,而对于标签,则需要进行one-hot编码,使得有且只有一个分类的 \(p\) 为1,其余的分类为0。

单通道输出时的交叉熵损失计算

语义分割单通道和多通道输出交叉熵损失函数的计算问题_损失函数_02

首先,假设我们研究的是一个二分类语义分割问题。

网络的输入是一个 2×2 的图像,设置 batch_size 为 2,网络输出单通道特征图。网络的标签也是一个 2 ×2 的二进制掩模图(即只有0和1的单通道图像)。

我们在 ​​pytorch​​ 中将其定义:

import torch

# 假设输出一个 [batch_size=2, channel=1, height=2, width=2] 格式的张量 x1
x1 = torch.tensor(
[[[[ 0.43, -0.25],
[-0.32, 0.69]]],

[[[-0.29, 0.37],
[0.54, -0.72]]]])

# 假设标签图像为与 x1 同型的张量 y1
y1 = torch.tensor(
[[[[0., 0.],
[0., 1.]]],

[[[0., 0.],
[1., 1.]]]])

在进行交叉熵前,首先需要做一个 ​​sigmoid​​ 操作,将数值压缩到0到1之间:

# 根据二进制交叉熵的计算过程
# 首先进行sigmoid计算,然后与标签图像进行二进制交叉熵计算,最后取平均值,即为损失值

# 1. sigmoid
s1 = torch.sigmoid(x1)
s1

'''
out:
tensor([[[[0.6059, 0.4378],
[0.4207, 0.6660]]],


[[[0.4280, 0.5915],
[0.6318, 0.3274]]]]
'''

然后进行交叉熵计算,由于计算的是每个像素的损失值,所以要取个平均值:

# 2.交叉熵计算
loss_cal = -1*(y1*torch.log(s1)+(1-y1)*torch.log(1-s1)) # 此处相当于一个one-hot编码
loss_cal_mean = torch.mean()
loss_cal_mean

'''
out:
tensor(0.6861)
'''

为了验证结果,我们使用 ​​pytorch​​ 自带的二进制交叉熵损失函数计算:

# 使用torch自带的二进制交叉熵计算
loss_bce = torch.nn.BCELoss()(s1,y1)
loss_bce

'''
out:
tensor(0.6861)
'''

当计算损失值前没有进行 ​​sigmoid​​​ 操作时,​​pytorch​​ 还提供了包含这个操作的二进制交叉熵损失函数:

# 使用带sigmoid的二进制交叉熵计算
loss_bce2 = torch.nn.BCEWithLogitsLoss()(x1,y1)
loss_bce2

'''
out:
tensor(0.6861)
'''

可以看到,我们使用了三种方式,计算了交叉熵损失,结果一致。

多通道输出时的交叉熵损失计算

语义分割单通道和多通道输出交叉熵损失函数的计算问题_多通道_03

首先,假设我们研究的是一个二分类语义分割问题。

网络的输入是一个 2×2 的图像,设置 ​​batch_size​​ 为 2,网络输出多(二)通道特征图。网络的标签也是一个 2 ×2 的二进制掩模图(即只有0和1的单通道图像)。

我们在 ​​pytorch​​ 中将其定义:

# 假设输出一个[batch_size=2, channel=2, height=2, width=2]格式的张量 x1
x1 = torch.tensor([[[[ 0.3164, -0.1922],
[ 0.4326, -1.2193]],

[[ 0.6873, 0.6838],
[ 0.2244, 0.5615]]],


[[[-0.2516, -0.8875],
[-0.6289, -0.1796]],

[[ 0.0411, -1.7851],
[-0.3069, -1.0379]]]])

# 假设标签图像为与x1同型,然后去掉channel的张量 y1 (注意两点,channel没了,格式为LongTensor)
y1 = torch.LongTensor([[[0., 1.],
[1., 0.]],

[[1., 1.],
[0., 1.]]])

在进行交叉熵前,首先需要做一个 ​​softmax​​ 操作,将数值压缩到0到1之间,且使得各通道之间的数值之和为1:

# 1.softmax
s1 = torch.softmax(x1,dim=1)
s1

'''
out:
tensor([[[[0.4083, 0.2940],
[0.5519, 0.1442]],

[[0.5917, 0.7060],
[0.4481, 0.8558]]],


[[[0.4273, 0.7105],
[0.4202, 0.7023]],

[[0.5727, 0.2895],
[0.5798, 0.2977]]]])
'''

对于标签图,由于其张量的形状与网络输出张量不一样,因此需要做一个one-hot转换,什么是one-hot?请看​​这篇博文​​:

# 2.one-hot
y1_one_hot = torch.zeros_like(x1).scatter_(dim=1,index=y1.unsqueeze(dim=1),src=torch.ones_like(x1))
y1_one_hot

'''
out:
tensor([[[[1., 0.],
[0., 1.]],

[[0., 1.],
[1., 0.]]],


[[[0., 0.],
[1., 0.]],

[[1., 1.],
[0., 1.]]]])
'''

这里需要重点理解这个​​scatter_​​​函数,他起到的作用十分关键,​​one-hot​​​ 转换时,其实可以理解为将一个同型的全1矩阵中的元素,有选择性的复制到全0矩阵中的过程,这里的选择依据就是我们的标签图,它决定了哪个位置和通道上的元素取值为 1 。在​​scatter_​​ 函数中,dim 决定了用于确定我们在哪个维度上开始定位要建立联系的元素,index是我们选择的依据。

按照交叉熵定义,继续计算:

# 交叉熵计算
loss_cal = -1 *(y1_one_hot * torch.log(s1))
loss_cal_mean = loss_cal.sum(dim=1).mean() # 在batch维度下计算每个样本的交叉熵
loss_cal_mean

'''
out:
tensor(0.9823)
'''

我们也可以使用 ​​pytorch​​ 自带的交叉熵损失函数计算:

loss_ce = torch.nn.CrossEntropyLoss()(x1,y1)
loss_ce

'''
tensor(0.9823)
'''

可以看到,两种方式结果一样。

结论

  • 交叉熵本质上将一群对象择其一进行研究,自然就变成一个二进制问题,即是这个对象或不是这个对象,然后将标签与概率融进公式中,计算损失值。对于每一个对象都可以计算一个损失值,求个平均值就是最后这个群体的损失值了。
  • 不论是sigmoid或者softmax,我们都是在有目的将数据规整到0到1之间,从而形成一个概率值,sigmoid针对的是二分类问题,因此算出一个概率,另一个用一减去就到了。多分类问题,由于最后会输出对应数量的值,softmax 能够将这些值转换到0到1,并满足加起来等于1,这样的话,当我们只研究其中一个类的概率时,其他类的概率自然就是用1减去这个类的概率了,又回到了二分类问题。
  • 对于二分类语义分割问题,其实采用上述两种方式都是可以的。

参考资料

[1] ​​pytorch中的 scatter_()函数使用和详解​​

[2] ​​pytorch交叉熵使用方法​​

[3] ​​pytorch损失函数之nn.BCELoss()(为什么用交叉熵作为损失函数)​​

[4] ​​pytorch损失函数之nn.CrossEntropyLoss()、nn.NLLLoss()​​

[5] ​​PyTorch中名不符实的损失函数​​

[6] ​​Pytorch中Softmax、Log_Softmax、NLLLoss以及CrossEntropyLoss的关系与区别详解​​

[7] ​​二分类问题,应该选择sigmoid还是softmax?​​

未经作者授权,禁止转载 THE END

举报

相关推荐

0 条评论