0
点赞
收藏
分享

微信扫一扫

用于生成测试及训练集的代码

左手梦圆 2022-02-16 阅读 75

coding: utf-8

In[1]:

import os
import random
import shutil
import numpy as np

In[2]:

数据集路径,原始数据集以每分类一个文件夹进行存放:如"D:/17flowers"下有若干个类别文件夹。

DATASET_DIR = “D:/17flowers”

数据切分后存放路径,注意,两路径不能重合

NEW_DIR = “D:/17flowersTrainDatas/data”

测试集占比

num_test = 0.2

In[3]:

打乱所有种类数据,并分割训练集和测试集

def shuffle_all_files(dataset_dir, new_dir, num_test):
# 先删除已有new_dir文件夹
if not os.path.exists(new_dir):
pass
else:
# 递归删除文件夹
shutil.rmtree(new_dir)
# 重新创建new_dir文件夹
os.makedirs(new_dir)
# 在new_dir文件夹目录下创建train文件夹
train_dir = os.path.join(new_dir, ‘train’)
os.makedirs(train_dir)
# 在new_dir文件夹目录下创建test文件夹
test_dir = os.path.join(new_dir, ‘test’)
os.makedirs(test_dir)
# 原始数据类别列表
directories = []
# 新训练集类别列表
train_directories = []
# 新测试集类别列表
test_directories = []
# 类别名称列表
class_names = []
# 循环所有类别
for filename in os.listdir(dataset_dir):
# 原始数据类别路径
path = os.path.join(dataset_dir, filename)
# 新训练集类别路径
train_path = os.path.join(train_dir, filename)
# 新测试集类别路径
test_path = os.path.join(test_dir, filename)
# 判断该路径是否为文件夹
if os.path.isdir(path):
# 加入原始数据类别列表
directories.append(path)
# 加入新训练集类别列表
train_directories.append(train_path)
# 新建类别文件夹
os.makedirs(train_path)
# 加入新测试集类别列表
test_directories.append(test_path)
# 新建类别文件夹
os.makedirs(test_path)
# 加入类别名称列表
class_names.append(filename)
print(‘类别列表:’,class_names)

# 循环每个分类的文件夹
for i in range(len(directories)):
    # 保存原始图片路径
    photo_filenames = []
    # 保存新训练集图片路径
    train_photo_filenames = []
    # 保存新测试集图片路径
    test_photo_filenames = []
    # 得到所有图片的路径
    for filename in os.listdir(directories[i]):
        # 原始图片路径
        path = os.path.join(directories[i], filename)
        # 训练图片路径
        train_path = os.path.join(train_directories[i], filename)
        # 测试集图片路径
        test_path = os.path.join(test_directories[i], filename)
        # 保存图片路径
        photo_filenames.append(path)
        train_photo_filenames.append(train_path)
        test_photo_filenames.append(test_path)
    # list转array
    photo_filenames = np.array(photo_filenames)
    train_photo_filenames = np.array(train_photo_filenames)
    test_photo_filenames = np.array(test_photo_filenames)
    # 打乱索引
    index = [i for i in range(len(photo_filenames))] 
    random.shuffle(index)
    # 对3个list进行相同的打乱,保证在3个list中索引一致
    photo_filenames = photo_filenames[index]
    train_photo_filenames = train_photo_filenames[index]
    test_photo_filenames = test_photo_filenames[index]
    # 计算测试集数据个数
    test_sample_index = int((1-num_test) * float(len(photo_filenames)))
    # 复制测试集图片
    for j in range(test_sample_index, len(photo_filenames)):
        # 复制图片
        shutil.copyfile(photo_filenames[j], test_photo_filenames[j])
    # 复制训练集图片
    for j in range(0, test_sample_index):
        # 复制图片
        shutil.copyfile(photo_filenames[j], train_photo_filenames[j])

In[4]:

打乱并切分数据集

shuffle_all_files(DATASET_DIR, NEW_DIR, num_test)

举报

相关推荐

0 条评论