1 引言
在使用机器学习训练模型算法的过程中,为提高模型的泛化能力、防止过拟合等目的,需要将整体数据划分为训练集和测试集两部分,训练集用于模型训练,测试集用于模型的验证。此时,使用train_test_split函数可便捷高效的实现数据训练集与测试集的划分。
2 train_test_split介绍
train_test_split
函数来自scikit-learn
库(也称为sklearn),安装命令:
pip install sklearn
函数的导入:
from sklearn.model_selection import train_test_split
1.1 函数定义
def train_test_split(*arrays,test_size=None,train_size=None,random_state=None,
shuffle=True,stratify=None,):
1.2 参数说明
1.3 返回值
1.4 注意事项
3 train_test_split使用
3.1 使用train_test_split分割Iris数据
from sklearn import datasets
from sklearn.model_selection import train_test_split
# 加载Iris数据集
iris = datasets.load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=1)
print(X_train)
print(X_test)
结果展示:
X_train=[[6.5 2.8 4.6 1.5]
[6.7 2.5 5.8 1.8]
[6.8 3. 5.5 2.1]
[5.1 3.5 1.4 0.3]
[6. 2.2 5. 1.5]
......此处数据省略
[4.9 3.6 1.4 0.1]]
X_test=[[5.8 4. 1.2 0.2]
[5.1 2.5 3. 1.1]
[6.6 3. 4.4 1.4]
[5.4 3.9 1.3 0.4]
[7.9 3.8 6.4 2. ]
......此处数据省略
[5.2 3.4 1.4 0.2]]
3.2 使用train_test_split分割水果识别数据
在/opt/dataset下存放着水果图片的分类数据文件夹(文件夹名称为标签),每个文件夹下存储着多张对应标签的水果图片,如下所示:
以apple文件夹为例,图片内容如下:
数据加载和分割数据集的代码如下:
from torchvision.datasets import ImageFolder
from sklearn.model_selection import train_test_split
# 图像变换
transform = transforms.Compose([transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]
), ])
# 加载数据集
dataset = ImageFolder('/opt/dataset', transform=transform)
# 划分训练集与测试集
train_dataset, valid_dataset = train_test_split(dataset, test_size=0.2, random_state=10)
batch_size = 64
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=True, drop_last=True)