0
点赞
收藏
分享

微信扫一扫

ShuffleNet V1代码和总结

流计算Alink 2022-04-27 阅读 60
pytorch

针对1x1卷积数目增多后,导致有较大计算量的问题,采用1x1分组卷积,然而1x1卷积本身就只考虑了通道信息,若直接使用分组卷积会导致部分通道信息被忽视的问题,因此,需要在1x1分组卷积的基础上加入通道混洗再输入3x3卷积中获得相应特征信息。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

代码:

import torch
import torch.nn as nn

class Channel_Shuffle(nn.Module):
    def __init__(self,groups):
        super(Channel_Shuffle, self).__init__()
        self.groups = groups

    def forward(self,x):
        batch_size, channels, height, width = x.size()
        channels_per_group = channels // self.groups
        x = x.view(batch_size,self.groups,channels_per_group,height,width)
        x = x.transpose(1,2).contiguous()
        x = x.view(batch_size,-1,height,width)
        return x


class BLOCK(nn.Module):
    def __init__(self,inchannels,outchannels, stride,group):
        super(BLOCK, self).__init__()
        hidden_channels = outchannels//2
        self.shortcut = nn.Sequential()
        self.cat = True
        if stride == 1:
            self.conv = nn.Sequential(
                nn.Conv2d(inchannels,hidden_channels,1,1,groups = group),
                nn.BatchNorm2d(hidden_channels),
                nn.ReLU(inplace=True),
                Channel_Shuffle(group),
                nn.Conv2d(hidden_channels,hidden_channels,3,stride,1,groups=hidden_channels),
                nn.BatchNorm2d(hidden_channels),
                nn.Conv2d(hidden_channels,outchannels,1,1,groups=group),
                nn.BatchNorm2d(outchannels)
            )
            self.cat = False
        elif stride == 2:
            self.conv = nn.Sequential(
                nn.Conv2d(inchannels, hidden_channels, 1, 1, groups=group),
                nn.BatchNorm2d(hidden_channels),
                nn.ReLU(inplace=True),
                Channel_Shuffle(group),
                nn.Conv2d(hidden_channels, hidden_channels, 3, stride, 1, groups=hidden_channels),
                nn.BatchNorm2d(hidden_channels),
                nn.Conv2d(hidden_channels, outchannels-inchannels, 1, 1, groups=group),
                nn.BatchNorm2d(outchannels-inchannels)
            )
            self.shortcut = nn.Sequential(
                nn.AvgPool2d(kernel_size=3,stride=2,padding = 1)
            )
        self.relu = nn.ReLU(inplace=True)


    def forward(self,x):
        out = self.conv(x)
        x = self.shortcut(x)
        if self.cat:
            x = torch.cat([out,x],1)
        else:
            x = out+x
        return self.relu(x)


class Shuffle_v1(nn.Module):
    def __init__(self, classes,group = 1):
        super(Shuffle_v1, self).__init__()
        setting = {1:[3,24,144,288,576],
                   2:[3,24,200,400,800],
                   3:[3,24,240,480,960],
                   4:[3,24,272,544,1088],
                   8:[3,24,384,768,1536]}
        repeat = [3,7,3]
        channels = setting[group]

        self.conv1 = nn.Sequential(
            nn.Conv2d(channels[0],channels[1],3,2,1),
            nn.BatchNorm2d(channels[1]),
            nn.ReLU(inplace=True)
        )
        self.pool1 = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)

        self.block = BLOCK
        self.stages = nn.ModuleList([])

        for i,j in enumerate(repeat):
            self.stages.append(self.block(channels[1+i],channels[2+i],stride=2, group = group))
            for _ in range(j):
                self.stages.append(self.block(channels[2 + i], channels[2 + i], stride=1, group=group))

        self.pool2 = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(channels[-1],classes)
        )

        for m in self.modules():
            if isinstance(m,nn.Conv2d):
                nn.init.kaiming_normal_(m.weight,mode = 'fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m,nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m,nn.Linear):
                nn.init.normal_(m.weight,0,0.01)
                nn.init.zeros_(m.bias)

    def forward(self,x):
        x = self.conv1(x)
        x = self.pool1(x)
        for stage in self.stages:
            x = stage(x)
        x = self.pool2(x)
        x = x.view(x.size(0),-1)
        x = self.fc(x)
        return x

if __name__ == '__main__':
    input = torch.empty((1,3,224,224))
    m = Shuffle_v1(10,8)
    out = m(input)
    print(out)


举报

相关推荐

0 条评论