import os.path
import numpy as np
import torch
import cv2
from PIL import Image
from torch.utils.data import Dataset
import re
from functools import reduce
from torch.utils.tensorboard import SummaryWriter as Writer
class myDataSet(Dataset):
    
    def __init__(self,rootdir,labeldir):
        self.rootdir=rootdir
        self.labeldir=labeldir
        self.imagePaths=os.path.join(rootdir,labeldir)
    '''
    #item作为编号:opencv版本
    def __getitem__(self, item):
        imagePath=os.listdir(self.imagePaths)[item]
        imagePath=os.path.join(self.imagePaths,imagePath)
        img=cv2.imdecode(np.fromfile(imagePath,np.uint8),-1)
        #bgr转rgb
        img = img[:, :, ::-1]
        labelComopent =re.findall("(\d+)",self.labeldir)
        #如果在标签中取不出对应tag
        if len(labelComopent)==0:
            raise ValueError
        label=int(labelComopent[-1])
        return img,label
    '''
    
    def __getitem__(self, item):
        imagePath=os.listdir(self.imagePaths)[item]
        imagePath=os.path.join(self.imagePaths,imagePath)
        img=Image.open(imagePath)
        img = np.array(img)
        labelComopent =re.findall("(\d+)",self.labeldir)
        
        if len(labelComopent)==0:
            raise ValueError
        label=int(labelComopent[-1])
        return img,label
    def __len__(self):
        return len(self.imagePaths)
rootdir=r"D:\17flowers"
labelList=os.listdir(rootdir)
allDataSet=[]
for label in labelList:
    allDataSet.append(myDataSet(rootdir,label))
'''
reduce() 函数会对参数序列中元素进行累积。
函数将一个数据集合(链表,元组等)中的所有数据进行下列操作:
用传给 reduce 中的函数 function(有两个参数)先对集合中的第 1、2 个元素进行操作,得到的结果再与第三个数据用 function 函数运算,最后得到一个结果。
'''
trainDataSet=reduce(lambda x,y:x+y,allDataSet)
writer=Writer("./myBorderText")
for index,datas in enumerate(trainDataSet):
    
    writer.add_scalar("labelb标识",scalar_value=datas[1],global_step=index)
writer.close()