0
点赞
收藏
分享

微信扫一扫

paddlepaddle实现十二生肖的分类之数据的预处理(一)

Java架构领域 2022-03-11 阅读 73

数据集说明

数据集一共包含3个目录trainvalidtest,每个目录都包含了12生肖(类别)的图片,通过下面的链接可以直接下载数据集

数据下载地址:下载地址
项目地址:项目链接

数据分析

统计数据集中每个类别的数据分布情况

import os

def print_classes_info(mode="train",data_dir = "data/signs"):
    datasets_dir = os.path.join(data_dir,mode)
    classes_names = os.listdir(datasets_dir)
    #用来保存每个类别的数量信息
    classes_num_infos = dict()
    for class_name in classes_names:
        #获取类别的目录
        class_dir_path = os.path.join(datasets_dir,class_name)
        img_names = os.listdir(class_dir_path)
        #记录每个类别的图片数量
        classes_num_infos[class_name] = len(img_names)
    print("{}:{}".format(mode,classes_num_infos))

#打印数据的分布情况
print_classes_info("train")
print_classes_info("valid")
print_classes_info("test")

在训练集中每个类别包含600张图片,验证集中每个类别包含55张图片,测试集中每个类别包含55张图片,因为这里的数据都比较平衡,后面我们就不需要去考虑数据的平衡问题了。

数据加载器

基于paddlepaddle提供的paddle.io.Dataset类,封装一个十二生肖的数据加载器,用于后面的模型训练和评估,将图片的预处理也封装在里面

import os
import paddle
from paddle.vision import transforms
from PIL import Image
import numpy as np

class ZodiacDatasets(paddle.io.Dataset):
    """
    加载十二生肖数据
    """
    def __init__(self,mode="train",data_root="data/signs",img_size=(224,224)):
        self.data_root = data_root
        #判断mode是否正确
        if mode not in ["train","valid","test"]:
            assert("{} is illegal,mode need is one of train,valid,test")
        #获取数据集的目录
        self._data_dir_path = os.path.join(data_root,mode)
        #获取十二生肖的类别名称
        self._zodiac_names = sorted(os.listdir(self._data_dir_path))
        #用来保存图片的路径
        self._img_path_list = []
        for name in self._zodiac_names:
            img_dir_path = os.path.join(self._data_dir_path,name)
            img_name_list = os.listdir(img_dir_path)
            for img_name in img_name_list:
                img_path = os.path.join(img_dir_path,img_name)
                self._img_path_list.append(img_path)
        #定义图像的预处理函数
        if mode == "train":
            self._transform = transforms.Compose([
                transforms.RandomResizedCrop(img_size),   #缩放图片并随机裁剪图片为指定shape
                transforms.RandomHorizontalFlip(0.5),     #随机水平翻转图片的概率为0.5
                transforms.ToTensor(),                    #转换图片的格式由HWC ==> CHW
                transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])  #图片通道像素的标准化
            ])
        else:
            self._transform = transforms.Compose([
                transforms.Resize(256),
                transforms.RandomCrop(img_size),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
            ])
    def __getitem__(self,index):
        """根据index获取图片数据
        """
        #获取图片的路径
        img_path = self._img_path_list[index]
        #获取图片的标签
        img_label = img_path.split("/")[-2]
        #将生肖的标签名称转换为数字标签
        label_index = self._zodiac_names.index(img_label)
        #读取图片
        img = Image.open(img_path)
        if img.mode != "RGB":
            img = img.convert("RGB")
        #图片的预处理
        img = self._transform(img)
        return img,np.array(label_index,dtype=np.int64)

    def __len__(self):
        """获取数据集的大小
        """
        return len(self._img_path_list)


#加载训练集
train_datasets = ZodiacDatasets(mode="train")
#统计训练集的大小
print(len(train_datasets))
for img,img_label in train_datasets:
    print(img.shape,img_label)
    break
举报

相关推荐

0 条评论