0
点赞
收藏
分享

微信扫一扫

Pytorch学习-调整torchvision.models中模型输出类别数

高子歌 2024-05-15 阅读 6

假设你的类别只有10个,而torchvision.models中Vgg16的输出类别为1000,这时应该如何调整呢?

方法一,直接修改模型中类别的输出。

from torch.nn import Linear
import torchvision
import torch

Vgg16=torchvision.models.vgg16(pretrained=True)
Vgg16.classifier[6]=Linear(in_features=4096,out_features=10)
if torch.cuda.is_available():
    T=Vgg16.cuda()

方法二,再模型的最后增加全连接层,改变输出类别。

from torch.nn import Linear
import torchvision
import torch

res=torchvision.models.resnet101(pretrained=True,progress=True)
res.fc.add_module('linelayer',Linear(in_features=1000,out_features=10))
if torch.cuda.is_available():
    T=res.cuda()

 

举报

相关推荐

0 条评论