'''torchvision'''
import torchvision
from torchvision.transforms import ToTensor
# ToTensor()的作用:
# 1.将输入转为tensor
# 2.图片格式转换为(通道,高,宽) #其他软件中常见的图片格式:(行,列,通道),一般是3通道,如(512,512,3)。但是在pytorch中经常为(通道,高,宽)
# 3.将像素值转换到(0,1)范围
'''dataset.MNIST内置数据--Dataset类型'''
# 通过MNIST下载内置数据,并创建为Dataset类型
train_ds = torchvision.datasets.MNIST('data', # 下载后保存到文件夹data
train=True, # 下载训练数据True,下载测试数据False,下载全部数据 不填。
transform=ToTensor(), # 转换数据格式
download=True) # 是否下载True,运行的时候就不用再次下载了。也可自行下载
test_ds = torchvision.datasets.MNIST('data', train=False, transform=ToTensor(), download=True) # 下载测试数据False
'''DataLoader'''
from torch.utils.data import DataLoader # 可对Dataset封装
# DataLoader的作用:
# 1.乱序 shuffle。如果不进行乱序,前面一直都是学习“猫”,后面一直学习“狗”,会导致模型训练的不准确,需要更多次的训练。
# 2.自动对数据分批次batch_size。如果不分批次,突然进入一个和模型不匹配的单个输入数据时,会引起模型的巨大震荡,模型受到单个数据的影响大。
# 3.num_workes 加速数据读取
# 4.设置批次处理函数 collate_fn
# 创建DataLoader,DataLoader本质上是可迭代对象
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True) # 把每批分成64个图片batch_size=64,顺序打乱shuffle=True
test_dl = DataLoader(test_ds, batch_size=64, shuffle=True) # 每个批次包含两部分数据:图片、标签
print('dataset有多少条完整数据:', len(train_ds))
print('dataloader有多少条完整数据:', len(train_dl.dataset))
print('dataloader有多少个批次(6000/64=937.5):', len(train_dl))
'''读取DataLoader数据'''
# 取出DataLoader中一个批次(batch)的图片:从可迭代对象中拿数据
for i, (imgs, labels) in enumerate(train_dl): # 逐个取出数据: imgs是一个批次中的所有图片;labels是这个批次图片的标签
break # imgs, labels = next(iter(train_dl)) 可以用这句代替这两行循环
print(imgs.shape) # 一个批次图片的格式img.shape为[图片数,通道数,行,列], [64, 1, 28, 28]
print(imgs.view(-1, 28*28).shape)
print(imgs.view(64, -1).shape)
# 绘图和输出imgs、labels的前10个数据
import matplotlib.pyplot as plt
import numpy as np
plt.figure(figsize=(10, 1))
for i, img in enumerate(imgs[0:10]): # 取前十张绘图
img_np = img.numpy() # 转换为numpy数组
img_np = np.squeeze(img_np)
plt.subplot(1, 10, i+1)
plt.imshow(img_np)
plt.axis('off') # 不显示坐标轴
plt.show()
print(labels[0:10])
输出
C:\Users\mayuhuaw\software\Anaconda3_2020_11\anaconda\envs\pytorch\python.exe C:/Users/mayuhuaw/Desktop/深度学习pytorch/Pytorch入门/6-多层感知器(多分类)/1-torchvision、DataLoader的使用.py
dataset有多少条完整数据: 60000
dataloader有多少条完整数据: 60000
dataloader有多少个批次(6000/64=937.5): 938
torch.Size([64, 1, 28, 28])
torch.Size([64, 784])
torch.Size([64, 784])
tensor([6, 6, 3, 2, 9, 9, 7, 8, 1, 6])
进程已结束,退出代码为 0