0
点赞
收藏
分享

微信扫一扫

【PyTorch】两种不同分类层的设计方法


问题

涉及到图像分类的网络的最后一层分类层,有两种实现方法,如下所示,你更偏向于哪种方法呢?

方法

方法1

import torch
from torch import nn


'''
测试池化和卷积组合的分类层
'''
class MyNet(nn.Module):
def __init__(self) -> None:
super().__init__()

self.conv = nn.Conv2d(3, 32, 3, padding=1)

self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(32, 2)


def forward(self, x):
x = self.conv(x)
x = self.avg_pool(x)
x = x.view(x.size(0), -1) # 展开所有元素
out = self.classifier(x)

return out

if __name__ == '__main__':

from torchsummary import summary

device = 'cuda' if torch.cuda.is_available() else 'cpu'

x = torch.rand(size=(1, 3, 7, 7)).to(device)
net = MyNet().to(device)

summary(net, (3, 7, 7))

方法2

import torch
from torch import nn


'''
测试池化和卷积组合的分类层
'''
class MyNet(nn.Module):
def __init__(self) -> None:
super().__init__()

self.conv = nn.Conv2d(3, 32, 3, padding=1)

self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(32, 2)
)


def forward(self, x):
x = self.conv(x)
out = self.classifier(x)

return out

if __name__ == '__main__':

from torchsummary import summary

device = 'cuda' if torch.cuda.is_available() else 'cpu'

x = torch.rand(size=(1, 3, 7, 7)).to(device)
net = MyNet().to(device)
out = net(x)

summary(net, (3, 7, 7))

结语

从扩展性、可读性的角度来说,更偏向于方法2的设计。


举报

相关推荐

0 条评论