0
点赞
收藏
分享

微信扫一扫

pytorch数据加载器类

pytorch数据加载器类

import math

from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader

data_path=r'smsspamcollection\SMSSpamCollection'

#完成数据集类的定义 --继承自DataSet
class MyDataSet(Dataset):
    def __init__(self):
        self.lines=open(data_path,"rb").readlines()

    def __getitem__(self, index):  #必须重写Dataset的 getitem 和len方法
        #返回对应索引的一个值   --重写 []
        curr_line=self.lines[index].decode().strip()
        label=curr_line[:4].strip()
        context=curr_line[4:].strip()
        return label,context

    def __len__(self):
        return len(self.lines)


mydataset=MyDataSet()

"""
1.批处理数据
2.打乱数据
3.多线程 并行加载数据
"""
#实例化数据加载器类:
data_loader=DataLoader(mydataset,batch_size=10,shuffle=True,num_workers=2) #num_workers 多线程数

if __name__ == '__main__':
    # for i in data_loader:
    #     print(i)

    #使用enumerate 遍历数据加载器
    for index,(label,context) in enumerate(data_loader):
        print(index,label,context)
        print("*"*100)

    print("注意dataset与dataloader的长度")
    print("DataSet:",len(mydataset))
    print("DataLoader:",len(data_loader))
    print("上取整 【 DataSet//bacth_size 】",math.ceil(len(mydataset)/10))


在这里插入图片描述

数据加载器类:

from torch.utils.data import DataLoader
在这里插入图片描述
在这里插入图片描述
使用的数据集:垃圾短信数据集SMSSpamCollection

举报

相关推荐

0 条评论