0
点赞
收藏
分享

微信扫一扫

深度学习 | pytorch数据集加载

文章目录

数据加载

1、为何在模型中使用数据加载器

2、数据集类

2.1 Dataset基类介绍
class Dataset(object):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py
2.2 案例
# -*- coding: utf-8 -*-
""" @Time : 2021/5/17 9:28
 @Author : XXX
 @Site : 
 @File : DemoGetData.py
 @Software: PyCharm 
"""
from torch.utils.data import Dataset

class MyDLoadData(Dataset):

    def __init__(self):
        self.file = open('../databases/test.txt').readlines()

    def __getitem__(self, item):

        return self.file[item]

    def __len__(self):

        return len(self.file)


if __name__ == '__main__':
    data =MyDLoadData()
    print("第一行的数据为:",data[1])
    print("该数据集的长度为:",len(data))

运行结果
在这里插入图片描述

2.3 迭代数据集
# -*- coding: utf-8 -*-
""" @Time : 2021/5/17 9:28
 @Author : XXX
 @Site : 
 @File : DemoGetData.py
 @Software: PyCharm 
"""
from torch.utils.data import DataLoader, Dataset


class MyDLoadData(Dataset):

    def __init__(self):
        self.file = open('../databases/test.txt').readlines()

    # 对数据进行处理返回
    def __getitem__(self, item):

        return self.file[item]

    # 返回数据的长度
    def __len__(self):

        return len(self.file)



if __name__ == '__main__':

    my_data = MyDLoadData()
    data_load = DataLoader(dataset=my_data, batch_size=2, shuffle=True, num_workers=2)
    """
    参数:
        · dataset:提前定义的dataset的实例;
        · batch_size:传入数据的batch的大小,常用128/256;
        · shuffle:【bool】获取数据的时候是否打乱;
        · num_workers:加载数据的线程数。 
    *** 注意:数据总长度的:batch_size*len(data_load)
    """
    for i in data_load:
        print(i)
        break

在这里插入图片描述

3、自带数据集

3.1 torchversion.datasets
torchvision.datasets.MNIST(root='files', train, download, transform)
"""
参数:
	· root:参数表示数据存放的位置;
	· train:【bool】使用训练集的数据还是测试集的数据;
	· download:【bool】是否需要下载数据到root目录;
	· transform:实现的对图片的处理函数。
"""

下载

import torchvision

dataset = torchvision.datasets.MNIST(root='../databases', download=True)

在这里插入图片描述

4、实现手写数字识别

4.1 思路和流程分析
4.2 准备训练集和测试集

下载数据:

数据处理:

result = transform.ToTensor()(dataset[0][0])
4.3 构建模型
4.4 损失函数

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

# 第一种
criterion = nn.CrossEntropyLoss()
loss = criterion(input, targtet)

# 第二种
# 对输出值计算softmax和取对数
output = F.log_softmax(x, dim=-1)
# 使用带权损失
loss = F.nll_loss(output, target)

4.5 训练模型
4.6 模型保存和加载

模型保存

模型加载

4.7 模型评估
举报

相关推荐

0 条评论