0
点赞
收藏
分享

微信扫一扫

深度学习论文: CenterMask : Real-Time Anchor-Free Instance Segmentation及其PyTorch实现

耳一文 2022-08-08 阅读 64


CenterMask : Real-Time Anchor-Free Instance Segmentation
PDF:​​​https://arxiv.org/pdf/1911.06667.pdf​​​ PyTorch: ​​https://github.com/shanglianlm0525/PyTorch-Networks​​

相关文章链接:
VoVNet ​深度学习论文: An Energy and GPU-Computation Efficient Backbone Network for Object Detection及其PyTorch​

VoVNetV2 深度学习论文: CenterMask : Real-Time Anchor-Free Instance Segmentation及其PyTorch实现

1 概述

在FOCS检测的基础上通过添加一个SAG-Mask(spatial attention-guided mask)分支来实现像Mask RCNN中类似的分割效果。

2 CenterMask

深度学习论文: CenterMask : Real-Time Anchor-Free Instance Segmentation及其PyTorch实现_2d

2-1 Adaptive RoI Assignment Function (自适应的RoI分配机制)

Mask RCNN网络根据检测出来RoI的大小将其分布到不同FPN层上,之后再使用RoI Align进行特征对齐。其实现过程可以描述为如下的映射关系:

深度学习论文: CenterMask : Real-Time Anchor-Free Instance Segmentation及其PyTorch实现_2d_02


但是上面的公式并不是很适合不同输入尺寸的图片, 因此将RoI映射的关系调整为如下的形式

深度学习论文: CenterMask : Real-Time Anchor-Free Instance Segmentation及其PyTorch实现_2d_03

2-2 Spatial Attention-Guided Mask (空间注意力引导的Mask分割)

空间注意力特征(spatial attention map)帮助进行目标聚焦与抑制噪声。

深度学习论文: CenterMask : Real-Time Anchor-Free Instance Segmentation及其PyTorch实现_深度学习_04


PyTorch代码:

class SAG_Mask(nn.Module):
def __init__(self, in_channels, out_channels):
super(SAG_Mask, self).__init__()
mid_channels = in_channels

self.fisrt_convs = nn.Sequential(
Conv3x3BNReLU(in_channels=in_channels, out_channels=mid_channels, stride=1),
Conv3x3BNReLU(in_channels=mid_channels, out_channels=mid_channels, stride=1),
Conv3x3BNReLU(in_channels=mid_channels, out_channels=mid_channels, stride=1),
Conv3x3BNReLU(in_channels=mid_channels, out_channels=mid_channels, stride=1)
)

self.avg_pool = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
self.max_pool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)

self.conv3x3 = Conv3x3BNReLU(in_channels=mid_channels*2, out_channels=mid_channels, stride=1)
self.sigmoid = nn.Sigmoid()

self.deconv = nn.ConvTranspose2d(mid_channels,mid_channels,kernel_size=2, stride=2)
self.conv1x1 = Conv1x1BN(mid_channels,out_channels)

def forward(self, x):
residual = x = self.fisrt_convs(x)
aggregate = torch.cat([self.avg_pool(x), self.max_pool(x)], dim=1)
sag = self.sigmoid(self.conv3x3(aggregate))
sag_x = residual + sag * x
out = self.conv1x1(self.deconv(sag_x))
return out

if __name__=='__main__':
sag_mask = SAG_Mask(16,80)
print(sag_mask)
input = torch.randn(1, 16, 14, 14)
out = sag_mask(input)
print(out.shape)

2-3 VoVNetV2

对VoVNet进行了如下的有效改进:1)通过引入残差连接减轻更大VoVNet中的饱和问题;2)使用effective Squeeze-Excitation(eSE)解决原有SE中信息损失的问题;使用相同的ResNet101-FPN骨干网络情况下文章的方法性能达到了38.3% mask AP,超过了之前所有的网络,而且还快。

深度学习论文: CenterMask : Real-Time Anchor-Free Instance Segmentation及其PyTorch实现_2d_05


1)添加了输入到输出的残差连接网络,缓解了随着网络深度叠加带来的性能饱和与梯度问题;

2)在输出的内部添加了一个channel上的attention模块eSE。将原始的SE模块中两个FC替换为了一个FC。

PyTorch代码:

class eSE_Module(nn.Module):
def __init__(self, channel,ratio = 16):
super(eSE_Module, self).__init__()
self.squeeze = nn.AdaptiveAvgPool2d(1)
self.excitation = nn.Sequential(
nn.Conv2d(channel, channel, kernel_size=1, padding=0),
nn.ReLU(inplace=True),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.squeeze(x)
z = self.excitation(y)
return x * z.expand_as(x)

class OSAv2_module(nn.Module):
def __init__(self, in_channels,mid_channels, out_channels, block_nums=5):
super(OSAv2_module, self).__init__()

self._layers = nn.ModuleList()
self._layers.append(Conv3x3BNReLU(in_channels=in_channels, out_channels=mid_channels, stride=1))
for idx in range(block_nums-1):
self._layers.append(Conv3x3BNReLU(in_channels=mid_channels, out_channels=mid_channels, stride=1))


self.conv1x1 = Conv1x1BNReLU(in_channels+mid_channels*block_nums,out_channels)
self.ese = eSE_Module(out_channels)
self.pass_conv1x1 = Conv1x1BNReLU(in_channels, out_channels)

def forward(self, x):
residual = x
outputs = []
outputs.append(x)
for _layer in self._layers:
x = _layer(x)
outputs.append(x)
out = self.ese(self.conv1x1(torch.cat(outputs, dim=1)))
return out + self.pass_conv1x1(residual)

4 实验结果

深度学习论文: CenterMask : Real-Time Anchor-Free Instance Segmentation及其PyTorch实现_ide_06


举报

相关推荐

0 条评论