前言:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
数据预处理部分:
data_transforms = {
'train': transforms.Compose([transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选
transforms.CenterCrop(224),#从中心开始裁剪
transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率
transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差
]),
'valid': transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
图像数据集加载部分:
Dataset类
PyTorch读取图片,主要是通过Dataset类,所以先简单了解一下Dataset类。Dataset类作为所有的datasets的基类存在,所有的datasets都需要继承它,类似于C++中的虚基类。
源码如下:
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
这里重点看 getitem函数,getitem接收一个index,然后返回一个batch大小的图片数据和标签,其中这个index是一个列表,这个列表是由dataloader里的sampler采样器生成的。感兴趣的可以详细了解这里的数据集的加载Dataset和DataLoader原理。
如bitch_size的值是16,其在pycharm中的表示形式为:
一、自定义Dataset加载
要让PyTorch能读取自己的数据集,只需要两步:
- 制作图片数据的索引
- 构建Dataset子类
然而,如何制作这个list呢,通常的方法是将图片的路径和标签信息存储在一个txt中,然后从该txt中读取。
整个读取自己数据的基本流程就是:
- 制作存储了图片的路径和标签信息的txt
- 将这些信息转化为list,该list每一个元素对应一个样本
- 通过getitem函数,读取数据和标签,并返回数据和标签。
首先制作图片数据的索引
就是读取图片路径,标签,保存到txt文件中。
1)一堆相同类别的图片已经在一个文件夹下了,可以用下面这种方法产生一个txt文件。
参考:如何用python生成带图片名称和标签的.txt文件(代码)
2)标签和图片标号都在csv文件里,可以用以下方法。
pytorch 自定义数据集载入(标签在csv文件里)
然后构建Dataset子类
from PIL import Image
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, txt_path, transform = None, target_transform = None):
fh = open(txt_path, 'r') #读取 制作好的txt文件的 图片路径和标签到imgs里
imgs = []
for line in fh:
line = line.rstrip()
words = line.split()
imgs.append((words[0], int(words[1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
fn, label = self.imgs[index] #self.imgs是一个list,self.imgs的一个元素是一个str,包含图片路径,图片标签,这些信息是在init函数中从txt文件中读取的
# fn是一个图片路径
img = Image.open(fn).convert('RGB') #利用Image.open对图片进行读取,img类型为 Image ,mode=‘RGB’
if self.transform is not None:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.imgs)
- 注意到Dataset类里的初始化中还会初始化transform,transform是一个Compose类型,里边有一个list,list中就会定义了各种对图像进行处理的操作,可以设置减均值,除标准差,随机裁剪,旋转,翻转,仿射变换等操作。
- 在这里我们要知道,一张图片读取进来之后,会经过数据处理(数据增强),最终变成模型的输入数据。这里就有一点需要注意,PyTorch的数据增强是将原始图片进行了处理,并不会生成新的一份图片,而是“覆盖”原图,当采用randomcrop之类的随机操作时,每个epoch输入进来的图片几乎不会是一模一样的,这达到了样本多样性的功能。
最后DataLoader加载即可
- 当自定义Dataset构建好,剩下的操作就交给DataLoader了。在DataLoader中,会触发Mydataset中的getiterm函数读取一个batch大小的图片的数据和标签,并返回,(清晰的底层逻辑见该博客)作为模型真正的输入。
- 最后像下面这样,处理好了前面说的两步之后,得到data,交给DataLoader就很简单了。
train_data = MyDataset(txt='../gender/train1.txt',type = "train", transform=transform_train)
train_loader = torch.utils.data.DataLoader(train_data,
batch_size=batch_size,
sampler=train_sampler)
二、用torchvision里的ImageFolder图像分类数据集的加载
仍然先制作数据源
举个例子,做的另一个项目,花的类别分类。他的数据集如下图。即同一种花都在一个文件夹中,文件夹的名称即为标签类别。
然后利用torchvision里的ImageFolder类
如下面的代码,ImageFolder已经写好datasets了。就像手写数字的datasets一样该datasets里面init,getitem,len魔法函数已实现了,只要保存数据集的格式符合要求,就可以直接使用。
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'valid']}
上面的代码是花分类的项目的train和valid两个数据集
关于datasets官方给的答案是:
All datasets are subclasses of torch.utils.data.Dataset i.e, they have getitem and len methods implemented(都已实现了getitem和len,不需要像第一种自定义方法自己写dataset类了). Hence, they can all be passed to a torch.utils.data.DataLoader which can load multiple samples in parallel using torch.multiprocessing workers.
这里All datasets还有很多,比如ImageNet等,具体可以去pytorch官网查看。
最后再dataloader加载
使用如下。
imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=4,
shuffle=True,
num_workers=args.nThreads)
三、torch自带图像分类数据集的处理和加载
像手写数字之类等,都可以在官网查看具体还有哪些数据集,都自带相应的dataset。
import torch
from torchvision import datasets, transforms
import helper
import matplotlib.pyplot as plt
import numpy
# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))],)
# Download and load the training data
trainset = datasets.FashionMNIST('F_MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
# Download and load the test data
testset = datasets.FashionMNIST('F_MNIST_data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)