0
点赞
收藏
分享

微信扫一扫

SE模块及其代码

凉夜lrs 2022-04-29 阅读 54
SE模块首先使用全局平均池化层将全局空间信息压缩到通道域以实现空间信息的聚合。

z(c)就是通道c的全局平均池化的结果,uc(i,j)是通道c特征图在空间(i,j)处的值,H和W分别是特征图在行列方向的数量,然后使用ReLU和Sigmoid函数来获得通道之间的依赖关系,r是压缩比。


整体结构图如下:



nn.AdaptiveAvgPool2d():二维自适应平均池化

Pytorch 里 nn.AdaptiveAvgPool2d(output_size) 原理是什么? - 知乎

对于任何输入大小,输出大小为H x W。输出特征的数量等于输入平面的数量。

 nn.Linear()

 nn,ReLU()

 nn.sigmoid()

view()函数:可以理解为reshape功能,重构张量的维度。

代码实现:

import torch.nn as nn


class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)


if __name__ == '__main__':
    import torch

    x = torch.randn(2, 64, 7, 7)
    se = SELayer(64)
    y = se(x)
    print(y.shape)

 

举报

相关推荐

0 条评论