#列出所有环境
#conda info --env
#conda activate mypytorch
#安装opencv环境
#pip3 install opencv-python
```python
```python
import os.path
import numpy as np
import torch
import cv2
from PIL import Image
from torch.utils.data import Dataset
import re
from functools import reduce
#正则表达式匹配出最后的数字:12
#print(re.findall("(\d+)","flower")[-1])
#创建自定义DataSet类
class myDataSet(Dataset):
#每个分类的子文件夹独立成一个标签数据集,标签例如flower0
def __init__(self,rootdir,labeldir):
self.rootdir=rootdir
self.labeldir=labeldir
self.imagePaths=os.path.join(rootdir,labeldir)
'''
#item作为编号:opencv版本
def __getitem__(self, item):
imagePath=os.listdir(self.imagePaths)[item]
imagePath=os.path.join(self.imagePaths,imagePath)
img=cv2.imdecode(np.fromfile(imagePath,np.uint8),-1)
#bgr转rgb
img = img[:, :, ::-1]
labelComopent =re.findall("(\d+)",self.labeldir)
#如果在标签中取不出对应tag
if len(labelComopent)==0:
raise ValueError
label=int(labelComopent[-1])
return img,label
'''
#item作为编号:opencv版本
def __getitem__(self, item):
imagePath=os.listdir(self.imagePaths)[item]
imagePath=os.path.join(self.imagePaths,imagePath)
img=Image.open(imagePath)
img = np.array(img)
labelComopent =re.findall("(\d+)",self.labeldir)
#如果在标签中取不出对应tag
if len(labelComopent)==0:
raise ValueError
label=int(labelComopent[-1])
return img,label
def __len__(self):
return len(self.imagePaths)
#使用r标识路径防止转义:
rootdir=r"D:\17flowers"
labelList=os.listdir(rootdir)
allDataSet=[]
#生成各子数据集
for label in labelList:
allDataSet.append(myDataSet(rootdir,label))
'''
reduce() 函数会对参数序列中元素进行累积。
函数将一个数据集合(链表,元组等)中的所有数据进行下列操作:
用传给 reduce 中的函数 function(有两个参数)先对集合中的第 1、2 个元素进行操作,得到的结果再与第三个数据用 function 函数运算,最后得到一个结果。
'''
trainDataSet=reduce(lambda x,y:x+y,allDataSet)
#347
print(trainDataSet.__len__())
#(500, 689, 3)
print(trainDataSet[0][0].shape)
print(trainDataSet[0][1])