溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務(wù)條款》

pytorch分類模型繪制混淆矩陣及可視化的方法

發(fā)布時間:2022-04-07 13:39:14 來源:億速云 閱讀:973 作者:iii 欄目:開發(fā)技術(shù)

本文小編為大家詳細(xì)介紹“pytorch分類模型繪制混淆矩陣及可視化的方法”,內(nèi)容詳細(xì),步驟清晰,細(xì)節(jié)處理妥當(dāng),希望這篇“pytorch分類模型繪制混淆矩陣及可視化的方法”文章能幫助大家解決疑惑,下面跟著小編的思路慢慢深入,一起來學(xué)習(xí)新知識吧。

Step 1. 獲取混淆矩陣

#首先定義一個 分類數(shù)*分類數(shù) 的空混淆矩陣
 conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds)
 # 使用torch.no_grad()可以顯著降低測試用例的GPU占用
    with torch.no_grad():
        for step, (imgs, targets) in enumerate(test_loader):
            # imgs:     torch.Size([50, 3, 200, 200])   torch.FloatTensor
            # targets:  torch.Size([50, 1]),     torch.LongTensor  多了一維,所以我們要把其去掉
            targets = targets.squeeze()  # [50,1] ----->  [50]

            # 將變量轉(zhuǎn)為gpu
            targets = targets.cuda()
            imgs = imgs.cuda()
            # print(step,imgs.shape,imgs.type(),targets.shape,targets.type())
            
            out = model(imgs)
            #記錄混淆矩陣參數(shù)
            conf_matrix = confusion_matrix(out, targets, conf_matrix)
            conf_matrix=conf_matrix.cpu()

混淆矩陣的求取用到了confusion_matrix函數(shù),其定義如下:

def confusion_matrix(preds, labels, conf_matrix):
    preds = torch.argmax(preds, 1)
    for p, t in zip(preds, labels):
        conf_matrix[p, t] += 1
    return conf_matrix

在當(dāng)我們的程序執(zhí)行結(jié)束 test_loader 后,我們可以得到本次數(shù)據(jù)的 混淆矩陣,接下來就要計算其 識別正確的個數(shù)以及混淆矩陣可視化:

conf_matrix=np.array(conf_matrix.cpu())# 將混淆矩陣從gpu轉(zhuǎn)到cpu再轉(zhuǎn)到np
corrects=conf_matrix.diagonal(offset=0)#抽取對角線的每種分類的識別正確個數(shù)
per_kinds=conf_matrix.sum(axis=1)#抽取每個分類數(shù)據(jù)總的測試條數(shù)

 print("混淆矩陣總元素個數(shù):{0},測試集總個數(shù):{1}".format(int(np.sum(conf_matrix)),test_num))
 print(conf_matrix)

 # 獲取每種Emotion的識別準(zhǔn)確率
 print("每種情感總個數(shù):",per_kinds)
 print("每種情感預(yù)測正確的個數(shù):",corrects)
 print("每種情感的識別準(zhǔn)確率為:{0}".format([rate*100 for rate in corrects/per_kinds]))

執(zhí)行此步的輸出結(jié)果如下所示:

pytorch分類模型繪制混淆矩陣及可視化的方法

Step 2. 混淆矩陣可視化

對上邊求得的混淆矩陣可視化

# 繪制混淆矩陣
Emotion=8#這個數(shù)值是具體的分類數(shù),大家可以自行修改
labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']#每種類別的標(biāo)簽

# 顯示數(shù)據(jù)
plt.imshow(conf_matrix, cmap=plt.cm.Blues)

# 在圖中標(biāo)注數(shù)量/概率信息
thresh = conf_matrix.max() / 2	#數(shù)值顏色閾值,如果數(shù)值超過這個,就顏色加深。
for x in range(Emotion_kinds):
    for y in range(Emotion_kinds):
        # 注意這里的matrix[y, x]不是matrix[x, y]
        info = int(conf_matrix[y, x])
        plt.text(x, y, info,
                 verticalalignment='center',
                 horizontalalignment='center',
                 color="white" if info > thresh else "black")
                 
plt.tight_layout()#保證圖不重疊
plt.yticks(range(Emotion_kinds), labels)
plt.xticks(range(Emotion_kinds), labels,rotation=45)#X軸字體傾斜45°
plt.show()
plt.close()

好了,以下就是最終的可視化的混淆矩陣?yán)玻?/p>

pytorch分類模型繪制混淆矩陣及可視化的方法

其它分類指標(biāo)的獲取

例如 F1分?jǐn)?shù)、TP、TN、FP、FN、精確率、召回率 等指標(biāo), 待補充哈(因為暫時還沒用到)~

pytorch分類模型繪制混淆矩陣及可視化的方法

讀到這里,這篇“pytorch分類模型繪制混淆矩陣及可視化的方法”文章已經(jīng)介紹完畢,想要掌握這篇文章的知識點還需要大家自己動手實踐使用過才能領(lǐng)會,如果想了解更多相關(guān)內(nèi)容的文章,歡迎關(guān)注億速云行業(yè)資訊頻道。

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進行舉報,并提供相關(guān)證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI