0
点赞
收藏
分享

微信扫一扫

CBAM: Convolutional Block Attention Module

在觉 2022-01-25 阅读 89

CBAM: Convolutional Block Attention Module

GitHub - Jongchan/attention-module: Official PyTorch code for "BAM: Bottleneck Attention Module (BMVC2018)" and "CBAM: Convolutional Block Attention Module (ECCV2018)"

Channel attention module

对于输入特征图\small F,分别使用avarage-poolingmax-pooling得到F_{avg}^{c}F_{max}^{c};然后接同一个网络,这个网络是含一个隐藏层的MLP,即两层全连接层,为了减少额外的参数开销,隐藏层的size设置为\mathbb{R}^{c/r\times 1\times 1}rreduction ratio,第二个FC再还原回去;接着通过element-wise summation的方式进行融合;最后再接一个sigmoid激活函数得到channel分支的结果M_{c}(F)\in \mathbb{R}^{c\times 1\times 1}。具体计算方法如下:

其中\sigma表示sigmoid函数,注意两个子分支的W_{0}W_{1}相同,且W_{0}后接激活函数ReLU

Spatial attention module

对于输入特征图\small F, 沿通道方向分别使用avarage-poolingmax-pooling得到F_{avg}^{s}\in \mathbb{R}^{1\times H\times W}F_{max}^{s}\in \mathbb{R}^{1\times H\times W};然后沿通道方向concatenate;然后接一个7×7的卷积;最后接一个sigmoid函数得到spatial分支的结果M_{s}(F)\in \mathbb{R}^{H\times W}。具体计算方法如下:

Arrangement of attention modules

作者通过实验确定了两个attention modulesequential的方式比parallel的方式效果好,通道attention module放在空间attention module前面效果更好。因此最终的结构如下所示:

Ablation studies

Channel attention

实验对比了通道注意力使用AvgPool、MaxPool、AvgPool&MaxPool的区别,结果表明两者结合起来使用效果最好。"We argue that max-pooled features which encode the degree of the most salient part can compensate the average-pooled features which encode global statistics softly."

Spatial attention

Arrangement of the channel and spatial attention

作者在该部分比较了三种不同的通道和空间分支融合方法,sequential channel-spatialsequential spatial-channelparallel,实验结果表明sequential channel-spatial的效果最好。

官方代码

import torch
import math
import torch.nn as nn
import torch.nn.functional as F


class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True,
                 bn=True, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
                              dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn:
            x = self.bn(x)
        if self.relu:
            x = self.relu(x)
        return x


class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)


class ChannelGate(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
        super(ChannelGate, self).__init__()
        self.gate_channels = gate_channels
        self.mlp = nn.Sequential(
            Flatten(),
            nn.Linear(gate_channels, gate_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(gate_channels // reduction_ratio, gate_channels)
        )
        self.pool_types = pool_types

    def forward(self, x):
        channel_att_sum = None
        for pool_type in self.pool_types:
            if pool_type == 'avg':
                avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp(avg_pool)
            elif pool_type == 'max':
                max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp(max_pool)

            if channel_att_sum is None:
                channel_att_sum = channel_att_raw
            else:
                channel_att_sum = channel_att_sum + channel_att_raw

        scale = F.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
        return x * scale


class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)


class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress = ChannelPool()
        self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False)

    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = F.sigmoid(x_out)  # broadcasting
        return x * scale


class CBAM(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
        super(CBAM, self).__init__()
        self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
        self.no_spatial = no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()

    def forward(self, x):
        x_out = self.ChannelGate(x)
        if not self.no_spatial:
            x_out = self.SpatialGate(x_out)
        return x_out

BAM的区别

  • BAMchannelspatialparallel模式,CBAMsequence模式
  • channel attention中,BAM只用了avg pool,而CBAM用了avg poolmax pool
举报

相关推荐

Coordinate attention,SE,CBAM

0 条评论