二、现有网络模型(torchvision.models.vgg16)的修改使用、保存加载
1.torchvision.models.vgg16
官方文档 : https://pytorch.org/vision/stable/models.html#id2
ImageNet数据集太大不好下载
2.pretrained设置不同时网络模型的差别
3.如何修改现有网络结构
修改vgg16_true网路结构,添加linear层
import torchvision
from torch.nn import Linear
vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)
print(vgg16_true)
# vgg16_true.add_module("add_linear",Linear(1000,10))
vgg16_true.classifier.add_module("add_linear",Linear(1000,10))
print(vgg16_true)
修改vgg16_false网路结构,更改分类器第6层为指定linear层
print(vgg16_false)
vgg16_false.classifier[6]=Linear(4096,10)
print(vgg16_false)
4.模型的保存、加载
vgg16_method1 结构+参数
import torch
import torchvision
vgg16 = torchvision.models.vgg16(pretrained=False)
# vgg16_method1 结构+参数
torch.save(vgg16, "vgg16_method1.pth")
# 模型加载(在另一个文件加载)
model = torch.load("vgg16_method1.pth")
print(model)
import torch
import torchvision
from torch import nn
class Qu(nn.Module):
def __init__(self):
super(Qu, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
def forward(self, x):
x = self.conv1(x)
return
qu = Qu()
torch.save(qu, "qu_method1.pth")
正确的调用格式需要复制原模型的类定义
class Qu(nn.Module):
def __init__(self):
super(Qu, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
def forward(self, x):
x = self.conv1(x)
return
model = torch.load("qu_method1.pth")
print(model)
或者用import
from model_save import *
model = torch.load("qu_method1.pth")
print(model)
vgg16_method2 参数(官方推荐)
import torch
import torchvision
from torch import nn
vgg16 = torchvision.models.vgg16(pretrained=False)
# vgg16_method2 参数(官方推荐)
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
# 模型加载(在另一个文件加载)
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16)