0
点赞
收藏
分享

微信扫一扫

学习笔记31-混淆矩阵可视化分类模型预测结果(附代码)

Greatiga 2022-04-06 阅读 59

混淆矩阵是机器学习中总结分类模型预测结果的情形分析表,以矩阵形式将数据集中的记录按照真实的类别与分类模型预测的类别判断两个标准进行汇总。其中矩阵的行表示真实值,矩阵的列表示预测值。

sklearn.metrics.confusion_matrix(y_true, y_pred, labels=None, sample_weight=None)

  • y_true真实标签值
  • y_pred预测标签值
  • labels=None类别,可以手动设置,也可自动生成。
  • sample_weight 是样本权重

定义/绘制混淆矩阵完整代码

from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix
#定义和绘制混淆矩阵
def get_confusion_matrix(trues, preds):
  labels = [0,1,2,3] #类别
  conf_matrix = confusion_matrix(trues, preds, labels=labels, sample_weight=None)
  return conf_matrix

def plot_confusion_matrix(conf_matrix):
  plt.imshow(conf_matrix, cmap=plt.cm.Greens)#可以改变颜色
  indices = range(conf_matrix.shape[0])
  labels = [0,1,2,3]
  plt.xticks(indices, labels)
  plt.yticks(indices, labels)
  plt.colorbar()
  plt.xlabel('y_true')
  plt.ylabel('y_pred')
  # 显示数据
  for first_index in range(conf_matrix.shape[0]):#trues
    for second_index in range(conf_matrix.shape[1]):#preds
      plt.text(first_index, second_index, conf_matrix[first_index, second_index])
  plt.savefig('heatmap_confusion_matrix.jpg')
  plt.show()

在这里插入图片描述
在这里插入图片描述
plt.imshow(conf_matrix, cmap=plt.cm.YlGnBu)
利用Matplotlib的imshow()函数,颜色映射cmap的可取值如下:
Accent, Accent_r, Blues, Blues_r, BrBG, BrBG_r, BuGn, BuGn_r, BuPu, BuPu_r, CMRmap, CMRmap_r, Dark2, Dark2_r, GnBu, GnBu_r, Greens, Greens_r, Greys, Greys_r, OrRd, OrRd_r, Oranges, Oranges_r, PRGn, PRGn_r, Paired, Paired_r, Pastel1, Pastel1_r, Pastel2, Pastel2_r, PiYG, PiYG_r, PuBu, PuBuGn, PuBuGn_r, PuBu_r, PuOr, PuOr_r, PuRd, PuRd_r, Purples, Purples_r, RdBu, RdBu_r, RdGy, RdGy_r, RdPu, RdPu_r, RdYlBu, RdYlBu_r, RdYlGn, RdYlGn_r, Reds, Reds_r, Set1, Set1_r, Set2, Set2_r, Set3, Set3_r, Spectral, Spectral_r, Wistia, Wistia_r, YlGn, YlGnBu, YlGnBu_r, YlGn_r, YlOrBr, YlOrBr_r, YlOrRd, YlOrRd_r, afmhot, afmhot_r, autumn, autumn_r, binary, binary_r, bone, bone_r, brg, brg_r, bwr, bwr_r, cividis, cividis_r, cool, cool_r, coolwarm, coolwarm_r, copper, copper_r, cubehelix, cubehelix_r, flag, flag_r, gist_earth, gist_earth_r, gist_gray, gist_gray_r, gist_heat, gist_heat_r, gist_ncar, gist_ncar_r, gist_rainbow, gist_rainbow_r, gist_stern, gist_stern_r, gist_yarg, gist_yarg_r, gnuplot, gnuplot2, gnuplot2_r, gnuplot_r, gray, gray_r, hot, hot_r, hsv, hsv_r, inferno, inferno_r, jet, jet_r, magma, magma_r, nipy_spectral, nipy_spectral_r, ocean, ocean_r, pink, pink_r, plasma, plasma_r, prism, prism_r, rainbow, rainbow_r, seismic, seismic_r, spring, spring_r, summer, summer_r, tab10, tab10_r, tab20, tab20_r, tab20b, tab20b_r, tab20c, tab20c_r, terrain, terrain_r, viridis, viridis_r, winter, winter_r
在这里插入图片描述
想要利用混淆矩阵来评价模型时,可将代码1和2写在predict.py代码中直接调用即可。1是定义混淆矩阵,2是绘制混淆矩阵。

  1. conf_matrix = get_confusion_matrix(test_trues, test_preds)
  2. plot_confusion_matrix(conf_matrix)
举报

相关推荐

0 条评论