0
点赞
收藏
分享

微信扫一扫

python_验证曲线_查看不通超参数对模型性能的影响

奔跑的酆 2022-07-18 阅读 80


python_验证曲线_查看不通超参数对模型性能的影响

# 可视化超参数值的效果
# 了解不通超参数对模型性能的影响
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import load_digits
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import validation_curve

# load data
digits = load_digits()

# create feature matrix and target vector
features, target = digits.data, digits.target

# create range of values for parameter
# 创建超参数的变化范围
param_range = np.arange(1, 250, 2)

# calculate accuracy on training and test set using range of parameter values
# 用不同参数值分别计算模型在训练集和测试集上的准确率
train_scores, test_scores = validation_curve(
# 分类器
RandomForestClassifier(),
# 特征矩阵
features,
# 目标向量
target,
# 要查看的超参数
param_name="n_estimators",
# 超参数的范围值
param_range=param_range,
cv=3,
# 性能指标
scoring="accuracy",
n_jobs=-1)

# calculate mean and standard deviation for training set scores
# 计算模型在训练集上得分的平均值和标准差
train_mean = np.mean(train_scores, axis=1)
train_std = np.std(train_scores, axis=1)

# calculate mean and standard deviation for test set scores
# 计算模型在测试集上得分的平均值和标准差
test_mean = np.mean(test_scores, axis=1)
test_std = np.std(test_scores, axis=1)

# plot mean accuracy score for training and test sets
# 画出训练集和测试集上准确率的平均值
plt.plot(param_range, train_mean, label="Training score", color="black")
plt.plot(param_range, test_mean, label="Cross-validation score", color="dimgrey")

# plt accuracy bands for training and test sets
# 画带状图
plt.fill_between(param_range, train_mean - train_std, train_mean+train_std, color="gray")
plt.fill_between(param_range, test_mean - test_std, test_mean+test_std, color="gainsboro")

# create plot
# 画图
plt.title("Validation Curve With Random Fores")
plt.xlabel("Number Of Trees")
plt.ylabel("Accuracy Score")
plt.tight_layout()
plt.legend(loc="best")
plt.show()

python_验证曲线_查看不通超参数对模型性能的影响_git


举报

相关推荐

0 条评论