0
点赞
收藏
分享

微信扫一扫

机器学习笔记 - 基于Torch Hub的DCGAN图像生成 + 调用自定义网络模型

一、Torch Hub概述

        Pytorch Hub 是一个预训练模型存储库,旨在促进研究的可重复性。Torch Hub 允许您发布预先训练的模型,以帮助促进研究共享和可重复性。

        Torch Hub 在其官方展示中总共纳入了48个研究模型(目前为止)。音频模型8个、生成式模型2个、自然语言处理 (NLP)3个、可编写脚本20个、视觉模型等等。这些模型还在基准数据集(例如Kinetics 400和COCO 2017)上进行了训练。

        在您的项目中使用这些模型很容易,使用torch.hub.load功能就可以调用。让我们看一个例子来说明它是如何工作的。

二、示例:使用DCGAN生成图像

        首先进行模型的加载,如果下载失败,可以看失败的log手动下载,手动下载之后。修改source='local',默认是'github'。

# USAGE
# python inference.py
# import the necessary packages
import matplotlib.pyplot as plt
import torchvision
import argparse
import torch
# construct the argument parser and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-n", "--num-images", type=int, default=64, help="# of images you want the DCGAN to generate")
args = vars(ap.parse_args())
# check if gpu is available for use
useGpu = True if torch.cuda.is_available() else False
# load the DCGAN model
model = torch.hub.load("facebookresearch/pytorch_GAN_zoo-hub", "DCGAN", pretrained=True, useGPU=useGpu, source='local')

        生成图片

# 生成随机噪声输入到生成器
(noise, _) = model.buildNoiseData(args["num_images"])
# 关闭 autograd 并将输入噪声提供给模型
with torch.no_grad():
	generatedImages = model.test(noise)
# 重新配置图像通道顺序并显示输出
output = torchvision.utils.make_grid(generatedImages).permute(1, 2, 0).cpu().numpy()
plt.imshow(output)
plt.show()

        显示结果如下

三、训练自定义模型后使用Torch Hub加载

1、创建并训练神经网络

        这里使用之前的简单神经网络进行测试,就不重复描述了,见下面链接。

机器学习笔记 - win10安装Pytorch-GPU版本并训练第一个神经网络_bashendixie5的博客-CSDN博客一、环境准备和Pytorch安装1、基础环境window10系统全新的conda的python3.7.11隔离环境 cuda11.2cudnn8.12、安装Pytorch GPU版本根据官方网站查看GPU版本安装命令,需要提前安装好CUDA和CUDNNStart Locally | PyTorch https://pytorch.org/get-started/locally/ ...https://skydance.blog.csdn.net/article/details/122350126

2、使用Torch Hub加载

        模型训练完成后,下一步是修改配置文件中的 repo 以使我们的模型可以通过 Torch Hub 访问。

(1)配置hubconf.py脚本

import torch
# 我们的模型
import mlp

# 定义入口点/可调用函数来初始化和返回模型
def custom_model():

    # 初始化模型
	model = mlp.get_training_model()
    # 加载权重
	model.load_state_dict(torch.load("model_wt.pth"))
    # 返回模型
	return model

(2)调用模型

        创建一个python文件。

# import the necessary packages
from train import next_batch
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_blobs
import torch.nn as nn
import argparse
import torch


ap = argparse.ArgumentParser()
ap.add_argument("-b", "--batch-size", type=int, default=64, help="input batch size")
args = vars(ap.parse_args())

# 使用torch hub加载模型
print("[INFO] loading the model using torch hub...")
model = torch.hub.load("torch_hub/test", "custom_model")
# 生成具有 1000 个数据点的 3 类分类问题,其中每个数据点是一个 4D 特征向量
print("[INFO] preparing data...")
(X, Y) = make_blobs(n_samples=1000, n_features=4, centers=3, cluster_std=2.5, random_state=95)
# 创建训练和测试拆分,并将它们转换为 PyTorch 张量
(trainX, testX, trainY, testY) = train_test_split(X, Y, test_size=0.15, random_state=95)
trainX = torch.from_numpy(trainX).float()
testX = torch.from_numpy(testX).float()
trainY = torch.from_numpy(trainY).float()
testY = torch.from_numpy(testY).float()

(3)模型设置为评估模式

        初始化交叉熵损失函数、初始化评估指标。模型设置为评估模式,并抓取一批数据以供模型评估。

# 初始化损失函数
lossFunc = nn.CrossEntropyLoss()
# 如果可用,将设备设置为 cuda 并初始化测试损失和准确性
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
testLoss = 0
testAcc = 0
samples = 0

# 将模型设置为 eval 并获取一批数据
print("[INFO] setting the model in evaluation mode...")
model.eval()
(batchX, batchY) = next(next_batch(testX, testY, args["batch_size"]))

(4)进行评估

        关闭自动渐变,我们将这批数据加载到设备并将其提供给模型。使用损失函数计算损失。

# initialize a no-gradient context
with torch.no_grad():
	# load the data into device
	(batchX, batchY) = (batchX.to(DEVICE), batchY.to(DEVICE))
	# pass the data through the model to get the output and calculate
	# loss
	predictions = model(batchX)
	loss = lossFunc(predictions, batchY.long())
	# update test loss, accuracy, and the number of
	# samples visited
	testLoss += loss.item() * batchY.size(0)
	testAcc += (predictions.max(1)[1] == batchY).sum().item()
	samples += batchY.size(0)
	print("[INFO] test loss: {:.3f}".format(testLoss / samples))
	print("[INFO] test accuracy: {:.3f}".format(testAcc / samples))

举报

相关推荐

0 条评论