pytorch数据加载器类
import math
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
data_path=r'smsspamcollection\SMSSpamCollection'
#完成数据集类的定义 --继承自DataSet
class MyDataSet(Dataset):
def __init__(self):
self.lines=open(data_path,"rb").readlines()
def __getitem__(self, index): #必须重写Dataset的 getitem 和len方法
#返回对应索引的一个值 --重写 []
curr_line=self.lines[index].decode().strip()
label=curr_line[:4].strip()
context=curr_line[4:].strip()
return label,context
def __len__(self):
return len(self.lines)
mydataset=MyDataSet()
"""
1.批处理数据
2.打乱数据
3.多线程 并行加载数据
"""
#实例化数据加载器类:
data_loader=DataLoader(mydataset,batch_size=10,shuffle=True,num_workers=2) #num_workers 多线程数
if __name__ == '__main__':
# for i in data_loader:
# print(i)
#使用enumerate 遍历数据加载器
for index,(label,context) in enumerate(data_loader):
print(index,label,context)
print("*"*100)
print("注意dataset与dataloader的长度")
print("DataSet:",len(mydataset))
print("DataLoader:",len(data_loader))
print("上取整 【 DataSet//bacth_size 】",math.ceil(len(mydataset)/10))
数据加载器类:
from torch.utils.data import DataLoader
使用的数据集:垃圾短信数据集SMSSpamCollection