代码
from sklearn.utils.multiclass import unique_labels
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
import csv
import pandas as pd
def plot_cm(classes,cm,
normalize=False,
title=None,
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
if not title:
if normalize:
title = 'Connection matrix'
else:
title = 'Connection matrix'
# Compute confusion matrix
cm = cm
print(cm)
# Only use the labels that appear in the data
classes = classes
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
#print("Normalized confusion matrix")
else:
pass
#print('Confusion matrix, without normalization')
#print(cm)
plt.rcParams['font.size'] = 10
plt.rcParams['font.family'] = 'Times New Roman'
# plt.style.use('ggplot')
fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
ax.figure.colorbar(im, ax=ax)
# We want to show all ticks...
ax.set(xticks=np.arange(17),
yticks=np.arange(17),
# ... and label them with the respective list entries
xticklabels=classes, yticklabels=classes,
title=title)
ax.set_ylim(len(classes)-0.5, -0.5)
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], fmt),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
fig.tight_layout()
return ax
class_names = np.array(["a","b","c","d","e","f","g","h","i","j","k","l","m","n","o","p","q"
])
cm1=np.array([[0,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17],
[2,0,1,1,5,2,6,1,2,5,6,5,5,9,5,25,4],
[3,1,0,4,5,6,1,2,4,7,8,9,5,6,8,6,3],
[4,1,4,0,7,4,1,8,5,2,6,9,3,5,4,8,2],
[5,5,5,7,0,4,1,5,1,5,1,5,5,5,5,4,5],
[6,2,6,4,4,0,5,6,3,2,1,4,7,8,9,4,6],
[7,6,1,1,1,5,0,1,2,3,4,5,6,7,4,1,5],
[8,1,2,8,5,6,1,0,8,6,3,2,6,2,8,5,8],
[9,2,4,5,1,3,2,8,0,4,1,4,7,8,9,5,1],
[10,5,7,2,5,2,3,6,4,0,5,7,9,6,5,8,8],
[11,6,8,6,1,1,4,3,1,5,0,4,1,2,5,2,9],
[12,5,9,9,5,4,5,2,4,7,4,0,1,2,9,9,4],
[13,5,5,3,5,7,6,6,7,9,1,1,0,5,4,2,7],
[14,9,6,5,5,8,7,2,8,6,2,2,5,0,4,4,5],
[15,5,8,4,5,9,4,8,9,5,5,9,4,4,0,1,5],
[16,25,6,8,4,4,1,5,5,8,2,9,2,4,1,0,6],
[17,4,3,2,5,6,5,8,1,8,9,4,7,5,5,6,0]])
plot_cm(classes=class_names,cm=cm1,normalize=False)
plt.show()
结果