import os.path
import numpy as np
import torch
import cv2
from PIL import Image
from torch.utils.data import Dataset
import re
from functools import reduce
from torch.utils.tensorboard import SummaryWriter as Writer
from torchvision import transforms
class myDataSet(Dataset):
def __init__(self,rootdir,labeldir):
self.rootdir=rootdir
self.labeldir=labeldir
self.imagePaths=os.path.join(rootdir,labeldir)
def __getitem__(self, item):
imagePath=os.listdir(self.imagePaths)[item]
imagePath=os.path.join(self.imagePaths,imagePath)
img=cv2.imdecode(np.fromfile(imagePath,np.uint8),-1)
img = img[:, :, ::-1].copy()
labelComopent =re.findall("(\d+)",self.labeldir)
if len(labelComopent)==0:
raise ValueError
label=int(labelComopent[-1])
return img,label
def __len__(self):
return len(self.imagePaths)
rootdir=r"D:\17flowers"
labelList=os.listdir(rootdir)
allDataSet=[]
for label in labelList:
allDataSet.append(myDataSet(rootdir,label))
trainDataSet=reduce(lambda x,y:x+y,allDataSet)
t1=transforms.ToTensor()
t2=transforms.CenterCrop(300)
t3=transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
t4=transforms.Resize((500,500))
t=transforms.Compose([t1,t2,t3,t4])
writer=Writer("./myBorderText")
for index,datas in enumerate(trainDataSet):
if index>10:
break
writer.add_image("图片未处理",img_tensor=datas[0],global_step=index,dataformats="HWC")
writer.add_image("图片中心裁剪处理", img_tensor=t(datas[0]), global_step=index)
writer.close()