0
点赞
收藏
分享

微信扫一扫

plt画连接矩阵带权重

菜菜捞捞 2022-05-03 阅读 70

代码

from sklearn.utils.multiclass import unique_labels
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np 
import csv
import pandas as pd

def plot_cm(classes,cm,
                          normalize=False,
                          title=None,
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if not title:
        if normalize:
            title = 'Connection matrix'
        else:
            title = 'Connection matrix'

    # Compute confusion matrix
    cm = cm
    print(cm)
    # Only use the labels that appear in the data
    classes = classes
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        #print("Normalized confusion matrix")
    else:
        pass
        #print('Confusion matrix, without normalization')

    #print(cm)
    plt.rcParams['font.size']      = 10
    plt.rcParams['font.family']    = 'Times New Roman'
    # plt.style.use('ggplot')

    fig, ax = plt.subplots()
    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
    ax.figure.colorbar(im, ax=ax)
    # We want to show all ticks...
    ax.set(xticks=np.arange(17),
           yticks=np.arange(17),
           # ... and label them with the respective list entries
           xticklabels=classes, yticklabels=classes,
           title=title)

    ax.set_ylim(len(classes)-0.5, -0.5)

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")



    # Loop over data dimensions and create text annotations.
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], fmt),
                    ha="center", va="center",
                    color="white" if cm[i, j] > thresh else "black")
    fig.tight_layout()
    return ax

class_names = np.array(["a","b","c","d","e","f","g","h","i","j","k","l","m","n","o","p","q"
]) 
cm1=np.array([[0,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17],
[2,0,1,1,5,2,6,1,2,5,6,5,5,9,5,25,4],
[3,1,0,4,5,6,1,2,4,7,8,9,5,6,8,6,3],
[4,1,4,0,7,4,1,8,5,2,6,9,3,5,4,8,2],
[5,5,5,7,0,4,1,5,1,5,1,5,5,5,5,4,5],
[6,2,6,4,4,0,5,6,3,2,1,4,7,8,9,4,6],
[7,6,1,1,1,5,0,1,2,3,4,5,6,7,4,1,5],
[8,1,2,8,5,6,1,0,8,6,3,2,6,2,8,5,8],
[9,2,4,5,1,3,2,8,0,4,1,4,7,8,9,5,1],
[10,5,7,2,5,2,3,6,4,0,5,7,9,6,5,8,8],
[11,6,8,6,1,1,4,3,1,5,0,4,1,2,5,2,9],
[12,5,9,9,5,4,5,2,4,7,4,0,1,2,9,9,4],
[13,5,5,3,5,7,6,6,7,9,1,1,0,5,4,2,7],
[14,9,6,5,5,8,7,2,8,6,2,2,5,0,4,4,5],
[15,5,8,4,5,9,4,8,9,5,5,9,4,4,0,1,5],
[16,25,6,8,4,4,1,5,5,8,2,9,2,4,1,0,6],
[17,4,3,2,5,6,5,8,1,8,9,4,7,5,5,6,0]])

plot_cm(classes=class_names,cm=cm1,normalize=False) 
plt.show()

结果 

 

举报

相关推荐

0 条评论