0
点赞
收藏
分享

微信扫一扫

PoolNet

code:​​https://github.com/backseason/PoolNet​​

paper: ​​https://arxiv.org/abs/1904.09569​​


文章目录

  • ​​摘要​​
  • ​​网络实现​​
  • ​​Global Guidance Module (GGM)​​
  • ​​Feature Aggregation Module (FAM)​​
  • ​​整体结构​​
  • ​​结果​​
  • ​​Pytorch PoolNet​​

摘要

最先进的突出对象检测结果之一是PoolNet模型,它在不同的骨干模型(如VGG或ResNet)下表现良好。该模型的关键操作是池化操作,池化操作对深层特征和浅层特征都有较好的表示。在5个基准数据集的显著性目标检测问题上,该算法的结果名列前茅。它的网络中有两个主要模块。全局指导模块(Global Guidance Module,GGM)和特征聚合模块(Feature Aggregation Module,FAM)。两者都有助于对显著区域有更好的特征表示。

网络实现

PoolNet模型在其网络中有两个主要模块。一种是针对深层特征提取更好的表示,另一种是负责浅层特征,其中有许多关于显著区域的细节。

Global Guidance Module (GGM)

对于缺乏精密的高层语义信息特征图谱的自上而下的途径,他们引入一个全局指导模块包含一个改良版的金字塔池模块(PPM)和一系列全局指导流动(ggf)显式地在每个级别特征图的位置突出对象的注意。

Feature Aggregation Module (FAM)

利用我们的GGM可以将全局指导信息交付到不同金字塔级别的特征图上。然而,一个值得探讨的新问题是如何使GGM的粗层次特征图与金字塔的不同尺度的特征图无缝融合。
PoolNet_边缘检测

整体结构

如果你看一下这个网络的可视化,你会发现,GGM位于编码器和解码器之间,它与浅层的解码器部分有几个连接。
PoolNet_池化_02
它还进行了与边缘检测的联合训练,即每个块的前3张图像。这种改进边缘,使它们精致的方式。

结果

PoolNet_池化_03
PoolNet_2d_04

Pytorch PoolNet

正如我们已经说过的,该网络有几个变种,包括ResNet的主干、VGG和边缘检测。

class ConvertLayer(nn.Module):
def __init__(self, list_k):
super(ConvertLayer, self).__init__()
up = []
for i in range(len(list_k[0])):
up.append(nn.Sequential(nn.Conv2d(list_k[0][i], list_k[1][i], 1, 1, bias=False), nn.ReLU(inplace=True)))
self.convert0 = nn.ModuleList(up)

def forward(self, list_x):
resl = []
for i in range(len(list_x)):
resl.append(self.convert0[i](list_x[i]))
return resl

class DeepPoolLayer(nn.Module):
def __init__(self, k, k_out, need_x2, need_fuse):
super(DeepPoolLayer, self).__init__()
self.pools_sizes = [2,4,8]
self.need_x2 = need_x2
self.need_fuse = need_fuse
pools, convs = [],[]
for i in self.pools_sizes:
pools.append(nn.AvgPool2d(kernel_size=i, stride=i))
convs.append(nn.Conv2d(k, k, 3, 1, 1, bias=False))
self.pools = nn.ModuleList(pools)
self.convs = nn.ModuleList(convs)
self.relu = nn.ReLU()
self.conv_sum = nn.Conv2d(k, k_out, 3, 1, 1, bias=False)
if self.need_fuse:
self.conv_sum_c = nn.Conv2d(k_out, k_out, 3, 1, 1, bias=False)

def forward(self, x, x2=None, x3=None):
x_size = x.size()
resl = x
for i in range(len(self.pools_sizes)):
y = self.convs[i](self.pools[i](x))
resl = torch.add(resl, F.interpolate(y, x_size[2:], mode='bilinear', align_corners=True))
resl = self.relu(resl)
if self.need_x2:
resl = F.interpolate(resl, x2.size()[2:], mode='bilinear', align_corners=True)
resl = self.conv_sum(resl)
if self.need_fuse:
resl = self.conv_sum_c(torch.add(torch.add(resl, x2), x3))
return resl

class PoolNet(nn.Module):
def __init__(self, base_model_cfg, base, convert_layers, deep_pool_layers, score_layers):
super(PoolNet, self).__init__()
self.base_model_cfg = base_model_cfg
self.base = base
self.deep_pool = nn.ModuleList(deep_pool_layers)
self.score = score_layers
if self.base_model_cfg == 'resnet':
self.convert = convert_layers

def forward(self, x):
x_size = x.size()
conv2merge, infos = self.base(x)
if self.base_model_cfg == 'resnet':
conv2merge = self.convert(conv2merge)
conv2merge = conv2merge[::-1]

edge_merge = []
merge = self.deep_pool[0](conv2merge[0], conv2merge[1], infos[0])
for k in range(1, len(conv2merge)-1):
merge = self.deep_pool[k](merge, conv2merge[k+1], infos[k])

merge = self.deep_pool[-1](merge)
merge = self.score(merge, x_size)
return merge


举报
0 条评论