code:https://github.com/backseason/PoolNet
paper: https://arxiv.org/abs/1904.09569
文章目录
- 摘要
- 网络实现
- Global Guidance Module (GGM)
- Feature Aggregation Module (FAM)
- 整体结构
- 结果
- Pytorch PoolNet
摘要
最先进的突出对象检测结果之一是PoolNet模型,它在不同的骨干模型(如VGG或ResNet)下表现良好。该模型的关键操作是池化操作,池化操作对深层特征和浅层特征都有较好的表示。在5个基准数据集的显著性目标检测问题上,该算法的结果名列前茅。它的网络中有两个主要模块。全局指导模块(Global Guidance Module,GGM)和特征聚合模块(Feature Aggregation Module,FAM)。两者都有助于对显著区域有更好的特征表示。
网络实现
PoolNet模型在其网络中有两个主要模块。一种是针对深层特征提取更好的表示,另一种是负责浅层特征,其中有许多关于显著区域的细节。
Global Guidance Module (GGM)
对于缺乏精密的高层语义信息特征图谱的自上而下的途径,他们引入一个全局指导模块包含一个改良版的金字塔池模块(PPM)和一系列全局指导流动(ggf)显式地在每个级别特征图的位置突出对象的注意。
Feature Aggregation Module (FAM)
利用我们的GGM可以将全局指导信息交付到不同金字塔级别的特征图上。然而,一个值得探讨的新问题是如何使GGM的粗层次特征图与金字塔的不同尺度的特征图无缝融合。
整体结构
如果你看一下这个网络的可视化,你会发现,GGM位于编码器和解码器之间,它与浅层的解码器部分有几个连接。
它还进行了与边缘检测的联合训练,即每个块的前3张图像。这种改进边缘,使它们精致的方式。
结果
Pytorch PoolNet
正如我们已经说过的,该网络有几个变种,包括ResNet的主干、VGG和边缘检测。
class ConvertLayer(nn.Module):
def __init__(self, list_k):
super(ConvertLayer, self).__init__()
up = []
for i in range(len(list_k[0])):
up.append(nn.Sequential(nn.Conv2d(list_k[0][i], list_k[1][i], 1, 1, bias=False), nn.ReLU(inplace=True)))
self.convert0 = nn.ModuleList(up)
def forward(self, list_x):
resl = []
for i in range(len(list_x)):
resl.append(self.convert0[i](list_x[i]))
return resl
class DeepPoolLayer(nn.Module):
def __init__(self, k, k_out, need_x2, need_fuse):
super(DeepPoolLayer, self).__init__()
self.pools_sizes = [2,4,8]
self.need_x2 = need_x2
self.need_fuse = need_fuse
pools, convs = [],[]
for i in self.pools_sizes:
pools.append(nn.AvgPool2d(kernel_size=i, stride=i))
convs.append(nn.Conv2d(k, k, 3, 1, 1, bias=False))
self.pools = nn.ModuleList(pools)
self.convs = nn.ModuleList(convs)
self.relu = nn.ReLU()
self.conv_sum = nn.Conv2d(k, k_out, 3, 1, 1, bias=False)
if self.need_fuse:
self.conv_sum_c = nn.Conv2d(k_out, k_out, 3, 1, 1, bias=False)
def forward(self, x, x2=None, x3=None):
x_size = x.size()
resl = x
for i in range(len(self.pools_sizes)):
y = self.convs[i](self.pools[i](x))
resl = torch.add(resl, F.interpolate(y, x_size[2:], mode='bilinear', align_corners=True))
resl = self.relu(resl)
if self.need_x2:
resl = F.interpolate(resl, x2.size()[2:], mode='bilinear', align_corners=True)
resl = self.conv_sum(resl)
if self.need_fuse:
resl = self.conv_sum_c(torch.add(torch.add(resl, x2), x3))
return resl
class PoolNet(nn.Module):
def __init__(self, base_model_cfg, base, convert_layers, deep_pool_layers, score_layers):
super(PoolNet, self).__init__()
self.base_model_cfg = base_model_cfg
self.base = base
self.deep_pool = nn.ModuleList(deep_pool_layers)
self.score = score_layers
if self.base_model_cfg == 'resnet':
self.convert = convert_layers
def forward(self, x):
x_size = x.size()
conv2merge, infos = self.base(x)
if self.base_model_cfg == 'resnet':
conv2merge = self.convert(conv2merge)
conv2merge = conv2merge[::-1]
edge_merge = []
merge = self.deep_pool[0](conv2merge[0], conv2merge[1], infos[0])
for k in range(1, len(conv2merge)-1):
merge = self.deep_pool[k](merge, conv2merge[k+1], infos[k])
merge = self.deep_pool[-1](merge)
merge = self.score(merge, x_size)
return merge