0
点赞
收藏
分享

微信扫一扫

ResNet残差网络Pytorch实现——对花的种类进行批量数据预测


ResNet残差网络Pytorch实现——对花的种类进行批量数据预测

上一篇:​​【对花的种类进行单数据预测】​​ ✌✌✌✌ ​​【目录】​​ ✌✌✌✌ 下一篇:​​【cifar10数据集训练】​​

大学生一枚,最近在学习神经网络,写这篇文章只是记录自己的学习历程,本文参考了​​Github上fengdu78老师的文章​​进行学习

✌ 使用ResNet进行对花的种类进行批量数据预测

import os
import json

import torch
from torchvision import transforms

from PIL import Image

# 加载运算设备
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# 数据处理
data_transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

# 因为是同时预测多个图片,所以要将路径添加到列表中
img_path_list=['13290033_ebd7c7abba_n.jpg','302782756_d35cb3e468.jpg','40411019_526f3fc8d9_m.jpg']
img_list=[]

# 将转化后的数据添加到list中
for img_path in img_path_list:
img=Image.open(img_path)
img=data_transform(img)
img_list.append(img)

# 将列表转化成tensor,因为列表中已经存在3个样本,所以直接就升维了,不用unsqueeze
batch_img=torch.stack(img_list,dim=0)

# 加载预测结果与真实分类的映射
json_path='./class_indices.json'
json_file=open(json_path,'r')
class_indict=json.load(json_file)

# 构建网络
model=resnet34(num_classes=5).to(device)

# 加载训练模型参数
weights_path='./resNet34.pth'
model.load_state_dict(torch.load(weights_path,map_location=device))

# 开启验证模式
model.eval()
# 不需要进行求导
with torch.no_grad():
output=model(batch_img.to(device)).cpu()
# 这里的dim维度就是1,因为此时预测的输出结果为二维,dim=1代表,每一行所有列进行计算
# 输出维度(3,5)
predict=torch.softmax(output,dim=1)
probs,classes=torch.max(predict,dim=1)

for idx, (pro, cla) in enumerate(zip(probs, classes)):
print("image: {} class: {} prob: {:.3}".format(img_path_list[idx],
class_indict[str(cla.item())],
pro.item()))


举报

相关推荐

Pytorch 残差网络 ResNet

ResNet残差网络

0 条评论