文章目录
数据加载
1、为何在模型中使用数据加载器
2、数据集类
2.1 Dataset基类介绍
class Dataset(object):
r"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.
.. note::
:class:`~torch.utils.data.DataLoader` by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
"""
def __getitem__(self, index):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
# in pytorch/torch/utils/data/sampler.py
2.2 案例
# -*- coding: utf-8 -*-
""" @Time : 2021/5/17 9:28
@Author : XXX
@Site :
@File : DemoGetData.py
@Software: PyCharm
"""
from torch.utils.data import Dataset
class MyDLoadData(Dataset):
def __init__(self):
self.file = open('../databases/test.txt').readlines()
def __getitem__(self, item):
return self.file[item]
def __len__(self):
return len(self.file)
if __name__ == '__main__':
data =MyDLoadData()
print("第一行的数据为:",data[1])
print("该数据集的长度为:",len(data))
运行结果
2.3 迭代数据集
# -*- coding: utf-8 -*-
""" @Time : 2021/5/17 9:28
@Author : XXX
@Site :
@File : DemoGetData.py
@Software: PyCharm
"""
from torch.utils.data import DataLoader, Dataset
class MyDLoadData(Dataset):
def __init__(self):
self.file = open('../databases/test.txt').readlines()
# 对数据进行处理返回
def __getitem__(self, item):
return self.file[item]
# 返回数据的长度
def __len__(self):
return len(self.file)
if __name__ == '__main__':
my_data = MyDLoadData()
data_load = DataLoader(dataset=my_data, batch_size=2, shuffle=True, num_workers=2)
"""
参数:
· dataset:提前定义的dataset的实例;
· batch_size:传入数据的batch的大小,常用128/256;
· shuffle:【bool】获取数据的时候是否打乱;
· num_workers:加载数据的线程数。
*** 注意:数据总长度的:batch_size*len(data_load)
"""
for i in data_load:
print(i)
break
3、自带数据集
3.1 torchversion.datasets
torchvision.datasets.MNIST(root='files', train, download, transform)
"""
参数:
· root:参数表示数据存放的位置;
· train:【bool】使用训练集的数据还是测试集的数据;
· download:【bool】是否需要下载数据到root目录;
· transform:实现的对图片的处理函数。
"""
下载
import torchvision
dataset = torchvision.datasets.MNIST(root='../databases', download=True)
4、实现手写数字识别
4.1 思路和流程分析
4.2 准备训练集和测试集
下载数据:
数据处理:
result = transform.ToTensor()(dataset[0][0])
4.3 构建模型
4.4 损失函数
# 第一种
criterion = nn.CrossEntropyLoss()
loss = criterion(input, targtet)
# 第二种
# 对输出值计算softmax和取对数
output = F.log_softmax(x, dim=-1)
# 使用带权损失
loss = F.nll_loss(output, target)
4.5 训练模型
4.6 模型保存和加载
模型保存:
模型加载