0
点赞
收藏
分享

微信扫一扫

关于pytorch语义分割二分类问题的两种做法

讨论了二分类语义分割使用单通道和多通道输出,对应的训练和预测方法。


形式1:输出为单通道

分析

即网络的输出 ​​output​​​ 为 [batch_size, 1, height, width] 形状。其中 ​​batch_szie​​​ 为批量大小,​​1​​​ 表示输出一个通道,​​height​​​ 和 ​​width​​ 与输入图像的高和宽保持一致。

在训练时,输出通道数是 1,网络得到的 ​​output​​​ 包含的数值是任意的数。给定的 ​​target​​​ ,是一个单通道标签图,数值只有 0 和 1 这两种。为了让网络输出 ​​output​​​ 不断逼近这个标签,首先会让 ​​output​​ 经过一个sigmoid 函数,使其数值归一化到[0, 1],得到 ​​output1​​​ ,然后让这个 ​​output1​​​ 与 ​​target​​​ 进行交叉熵计算,得到损失值,反向传播更新网络权重。最终,网络经过学习,会使得 ​​output1​​​ 逼近​​target​​。

训练结束后,网络已经具备让输出的 ​​output​​​ 经过转换从而逼近 ​​target​​​ 的能力。首先将输出的 ​​output​​ 通过sigmoid 函数,然后取一个阈值(一般设置为0.5),大于阈值则取1反之则取0,从而得到预测图 ​​predict​​。后续则是一些评估相关的计算。

代码实现

在这个过程中,训练的损失函数为二进制交叉熵损失函数,然后根据输出是否用到了sigmoid有两种可选的pytorch实现方式:

output = net(input)  # net的最后一层没有使用sigmoid
loss_func1 = torch.nn.BCEWithLogitsLoss()
loss = loss_func1(output, target)

当网络最后一层没有使用sigmoid时,需要使用 ​​torch.nn.BCEWithLogitsLoss()​​ ,顾名思义,在这个函数中,拿到output首先会做一个sigmoid操作,再进行二进制交叉熵计算。上面的操作等价于

output = net(input)  # net的最后一层没有使用sigmoid
output = F.sigmoid(output)
loss_func1 = torch.nn.BCEWithLoss()
loss = loss_func1(output, target)

当然,你也可以在网络最后一层加上sigmoid操作。从而省去第二行的代码(在预测时也可以省去)。

在预测试时,可用下面的代码实现预测图的生成

output = net(input)  # net的最后一层没有使用sigmoid
output = F.sigmoid(output)
predict = torch.where(output>0.5,torch.ones_like(output),torch.zeros_like(output))
...

即大于0.5的记为1,小于0.5记为0。

形式2:输出为多通道

分析

即网络的输出 ​​output​​​ 为 [batch_size, num_class, height, width] 形状。其中 ​​batch_szie​​​ 为批量大小,​​num_class​​​ 表示输出的通道数与分类数量一致,​​height​​​ 和 ​​width​​ 与输入图像的高和宽保持一致。

在训练时,输出通道数是 ​​num_class​​​(这里取2),网络得到的 ​​output​​​ 包含的数值是任意的数。给定的 ​​target​​​ ,是一个单通道标签图,数值只有 0 和 1 这两种。为了让网络输出 ​​output​​​ 不断逼近这个标签,首先会让 ​​output​​ 经过一个 softmax 函数,使其数值归一化到[0, 1],得到 ​​output1​​​ ,在各通道中,这个数值加起来会等于1。对于​​target​​​ 他是一个单通道图,首先使用​​onehot​​​编码,转换成 ​​num_class​​​个通道的图像,每个通道中的取值是根据单通道中的取值计算出来的,例如单通道中的第一个像素取值为1(0<= 1 <=num_class-1,这里num_class=2),那么​​onehot​​​编码后,在第一个像素的位置上,两个通道的取值分别为0,1。也就是说像素的取值决定了对应序号的通道取1,其他的通道取0,这个非常关键。上面的操作执行完后得到​​target1​​​,让这个 ​​output1​​​ 与 ​​target1​​​ 进行交叉熵计算,得到损失值,反向传播更新网路权重。最终,网络经过学习,会使得 ​​output1​​​ 逼近​​target1​​(在各通道层面上)。

训练结束后,网络已经具备让输出的 ​​output​​​ 经过转换从而逼近 ​​target​​​ 的能力。计算 ​​output​​​ 中各通道每一个像素位置上,取值最大的那个对应的通道序号,从而得到预测图 ​​predict​​。后续则是一些评估相关的计算。

代码实现

在这个过程中,则可以使用交叉熵损失函数:

output = net(input)  # net的最后一层没有使用sigmoid
loss_func = torch.nn.CrossEntropyLoss()
loss = loss_func(output, target)

根据前面的分析,我们知道,正常的​​output​​​是 [batch_size, num_class, height, width]形状的,而​​target​​是[batch_size, height, width]形状的,需要按照上面的分析进行转换才可以计算交叉熵,而在pytorch中,我们不需要进一步做这个处理,直接使用就可以了。

在预测试时,使用下面的代码实现预测图的生成

output = net(input)  # net的最后一层没有使用sigmoid
predict = output.argmax(dim=1)
...

即得到输出后,在通道方向上找出最大值所在的索引号。

小结

总的来说,我觉得第二种方式更值得推广,一方面不用考虑阈值的选取问题;另一方面,该方法同样适用于多类别的语义分割任务,通用性更强。

参考资料

[2]​​https://cuijiahua.com/blog/2020/03/dl-16.html​​

未经作者授权,禁止转载 THE END


举报

相关推荐

0 条评论