0
点赞
收藏
分享

微信扫一扫

6-1-torchvision、DataLoader的使用

'''torchvision'''
import torchvision
from torchvision.transforms import ToTensor
# ToTensor()的作用:
# 1.将输入转为tensor
# 2.图片格式转换为(通道,高,宽) #其他软件中常见的图片格式:(行,列,通道),一般是3通道,如(512,512,3)。但是在pytorch中经常为(通道,高,宽)
# 3.将像素值转换到(0,1)范围

'''dataset.MNIST内置数据--Dataset类型'''
# 通过MNIST下载内置数据,并创建为Dataset类型
train_ds = torchvision.datasets.MNIST('data',      # 下载后保存到文件夹data
                           train=True,             # 下载训练数据True,下载测试数据False,下载全部数据 不填。
                           transform=ToTensor(),   # 转换数据格式
                           download=True)          # 是否下载True,运行的时候就不用再次下载了。也可自行下载

test_ds = torchvision.datasets.MNIST('data', train=False, transform=ToTensor(), download=True)   # 下载测试数据False

'''DataLoader'''
from torch.utils.data import DataLoader   # 可对Dataset封装
# DataLoader的作用:
# 1.乱序 shuffle。如果不进行乱序,前面一直都是学习“猫”,后面一直学习“狗”,会导致模型训练的不准确,需要更多次的训练。
# 2.自动对数据分批次batch_size。如果不分批次,突然进入一个和模型不匹配的单个输入数据时,会引起模型的巨大震荡,模型受到单个数据的影响大。
# 3.num_workes 加速数据读取
# 4.设置批次处理函数 collate_fn

# 创建DataLoader,DataLoader本质上是可迭代对象
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)    # 把每批分成64个图片batch_size=64,顺序打乱shuffle=True
test_dl = DataLoader(test_ds, batch_size=64, shuffle=True)      # 每个批次包含两部分数据:图片、标签
print('dataset有多少条完整数据:', len(train_ds))
print('dataloader有多少条完整数据:', len(train_dl.dataset))
print('dataloader有多少个批次(6000/64=937.5):', len(train_dl))

'''读取DataLoader数据'''
# 取出DataLoader中一个批次(batch)的图片:从可迭代对象中拿数据
for i, (imgs, labels) in enumerate(train_dl):      # 逐个取出数据: imgs是一个批次中的所有图片;labels是这个批次图片的标签
    break                                          # imgs, labels = next(iter(train_dl))  可以用这句代替这两行循环

print(imgs.shape)      # 一个批次图片的格式img.shape为[图片数,通道数,行,列], [64, 1, 28, 28]
print(imgs.view(-1, 28*28).shape)
print(imgs.view(64, -1).shape)


# 绘图和输出imgs、labels的前10个数据
import matplotlib.pyplot as plt
import numpy as np
plt.figure(figsize=(10, 1))
for i, img in enumerate(imgs[0:10]):   # 取前十张绘图
    img_np = img.numpy()   # 转换为numpy数组
    img_np = np.squeeze(img_np)
    plt.subplot(1, 10, i+1)
    plt.imshow(img_np)
    plt.axis('off')   # 不显示坐标轴
plt.show()

print(labels[0:10])

输出
在这里插入图片描述

C:\Users\mayuhuaw\software\Anaconda3_2020_11\anaconda\envs\pytorch\python.exe C:/Users/mayuhuaw/Desktop/深度学习pytorch/Pytorch入门/6-多层感知器(多分类)/1-torchvision、DataLoader的使用.py
dataset有多少条完整数据: 60000
dataloader有多少条完整数据: 60000
dataloader有多少个批次(6000/64=937.5)938
torch.Size([64, 1, 28, 28])
torch.Size([64, 784])
torch.Size([64, 784])
tensor([6, 6, 3, 2, 9, 9, 7, 8, 1, 6])

进程已结束,退出代码为 0

举报

相关推荐

0 条评论