0
点赞
收藏
分享

微信扫一扫

【Pytorch学习】-- torchvision.datasets&DataLoader使用

学习视频:https://www.bilibili.com/video/BV1hE411t7RN?p=1,内含环境搭建

torchvision.datasets使用

Pytorch有自带许多数据集可供学习使用,因此,本次挑选一个来进行学习,官方地址

CIFAR10

本次使用的是CIFAR10数据集。

import torchvision
train_set = torchvision.datasets.CIFAR10(root = "./CIFAR10_Dataset",train = True,download = True)
test_set = torchvision.datasets.CIFAR10(root = "./CIFAR10_Dataset",train = False,download = True)
# root:下载的根目录
# train:true为训练集,false为测试集
# download:是否从网上下载,若下载速度慢可以用迅雷下载:
# https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

查看一下数据集的数据是什么类型的

print(test_set[0])

由输出结果可以看出是一个PIL Image类型,"3"代表类别时三,CIFAR10有10个类别,具体类别参考官方文档

(<PIL.Image.Image image mode=RGB size=32x32 at 0x281591063D0>, 3)

转化为Tensor类型

在调用参数中多加一个自带的transform的参数:

test_data = torchvision.datasets.CIFAR10(root = "./CIFAR10_Dataset",
                                                    train = False,
                                                    transform = torchvision.transforms.ToTensor())
print(test_data[0])                                                                                                     

得到的就是tensor数据类型

(tensor([[[0.6196, 0.6235, 0.6471,  ..., 0.5373, 0.4941, 0.4549],
         [0.5961, 0.5922, 0.6235,  ..., 0.5333, 0.4902, 0.4667],
         [0.5922, 0.5922, 0.6196,  ..., 0.5451, 0.5098, 0.4706],
         ...,
         [0.2667, 0.1647, 0.1216,  ..., 0.1490, 0.0510, 0.1569],
         [0.2392, 0.1922, 0.1373,  ..., 0.1020, 0.1137, 0.0784],
         [0.2118, 0.2196, 0.1765,  ..., 0.0941, 0.1333, 0.0824]],

        [[0.4392, 0.4353, 0.4549,  ..., 0.3725, 0.3569, 0.3333],
         [0.4392, 0.4314, 0.4471,  ..., 0.3725, 0.3569, 0.3451],
         [0.4314, 0.4275, 0.4353,  ..., 0.3843, 0.3725, 0.3490],
         ...,
         [0.4863, 0.3922, 0.3451,  ..., 0.3804, 0.2510, 0.3333],
         [0.4549, 0.4000, 0.3333,  ..., 0.3216, 0.3216, 0.2510],
         [0.4196, 0.4118, 0.3490,  ..., 0.3020, 0.3294, 0.2627]],

        [[0.1922, 0.1843, 0.2000,  ..., 0.1412, 0.1412, 0.1294],
         [0.2000, 0.1569, 0.1765,  ..., 0.1216, 0.1255, 0.1333],
         [0.1843, 0.1294, 0.1412,  ..., 0.1333, 0.1333, 0.1294],
         ...,
         [0.6941, 0.5804, 0.5373,  ..., 0.5725, 0.4235, 0.4980],
         [0.6588, 0.5804, 0.5176,  ..., 0.5098, 0.4941, 0.4196],
         [0.6275, 0.5843, 0.5176,  ..., 0.4863, 0.5059, 0.4314]]]), 3)

DataLoader使用

Dataloader就是把数据整理成适合输入到神经网络形式的工具

from torch.utils.data import DataLoader

test_loader = DataLoader(dataset = test_data,
                        batch_size = 4,   	#  how many samples per batch to load 一次取多少
                        shuffle = True,  	# set to True to have the data reshuffled at every epoch 是否打乱
                        num_workers = 0, 	# how many subprocesses to use for data loading 多少个子进程进行读取
                        drop_last = False)  # set to True to drop the last incomplete batch,if the dataset size is not divisible by the batch size. 是否丢弃最后不能形成一个batch的数据

for data in test_loader:
    imgs,targets = data
    print(imgs.shape) # batch_size = 4:四张图,每张图3通道,像素为32*32
    print(targets)    # 四张图的所对应的类别

一个输出结果

torch.Size([4, 3, 32, 32])
tensor([0, 8, 5, 1])

借助TensorBoard更直观地查看:

from torch.utils.tensorboard import SummaryWriter

test_loader = DataLoader(dataset = test_data,
                        batch_size = 64,   
                        shuffle = True,  
                        num_workers = 0, 
                        drop_last = False)

writer = SummaryWriter("dataloader")
step = 0
for data in test_loader:
    imgs,targets = data
    writer.add_images("test_data",imgs,step)
    step = step + 1
    
writer.close()

每一步有64张图
请添加图片描述

举报

相关推荐

0 条评论