溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點(diǎn)擊 登錄注冊 即表示同意《億速云用戶服務(wù)條款》

Pytorch訓(xùn)練模型得到輸出后如何計(jì)算F1-Score和AUC

發(fā)布時間:2022-02-25 10:10:16 來源:億速云 閱讀:356 作者:小新 欄目:開發(fā)技術(shù)

這篇文章主要介紹了Pytorch訓(xùn)練模型得到輸出后如何計(jì)算F1-Score和AUC,具有一定借鑒價(jià)值,感興趣的朋友可以參考下,希望大家閱讀完這篇文章之后大有收獲,下面讓小編帶著大家一起了解一下。

1、計(jì)算F1-Score

對于二分類來說,假設(shè)batch size 大小為64的話,那么模型一個batch的輸出應(yīng)該是torch.size([64,2]),所以首先做的是得到這個二維矩陣的每一行的最大索引值,然后添加到一個列表中,同時把標(biāo)簽也添加到一個列表中,最后使用sklearn中計(jì)算F1的工具包進(jìn)行計(jì)算,代碼如下

import numpy as np
import sklearn.metrics import f1_score
prob_all = []
lable_all = []
for i, (data,label) in tqdm(train_data_loader):
    prob = model(data) #表示模型的預(yù)測輸出
    prob = prob.cpu().numpy() #先把prob轉(zhuǎn)到CPU上,然后再轉(zhuǎn)成numpy,如果本身在CPU上訓(xùn)練的話就不用先轉(zhuǎn)成CPU了
    prob_all.extend(np.argmax(prob,axis=1)) #求每一行的最大值索引
    label_all.extend(label)
print("F1-Score:{:.4f}".format(f1_score(label_all,prob_all)))

2、計(jì)算AUC

計(jì)算AUC的時候,本次使用的是sklearn中的roc_auc_score () 方法

輸入?yún)?shù):

y_true:真實(shí)的標(biāo)簽。形狀 (n_samples,) 或 (n_samples, n_classes)。二分類的形狀 (n_samples,1),而多標(biāo)簽情況的形狀 (n_samples, n_classes)。

y_score:目標(biāo)分?jǐn)?shù)。形狀 (n_samples,) 或 (n_samples, n_classes)。二分類情況形狀 (n_samples,1),“分?jǐn)?shù)必須是具有較大標(biāo)簽的類的分?jǐn)?shù)”,通俗點(diǎn)理解:模型打分的第二列。舉個例子:模型輸入的得分是一個數(shù)組 [0.98361117 0.01638886],索引是其類別,這里 “較大標(biāo)簽類的分?jǐn)?shù)”,指的是索引為 1 的分?jǐn)?shù):0.01638886,也就是正例的預(yù)測得分。

average='macro':二分類時,該參數(shù)可以忽略。用于多分類,' micro ':將標(biāo)簽指標(biāo)矩陣的每個元素看作一個標(biāo)簽,計(jì)算全局的指標(biāo)。' macro ':計(jì)算每個標(biāo)簽的指標(biāo),并找到它們的未加權(quán)平均值。這并沒有考慮標(biāo)簽的不平衡。' weighted ':計(jì)算每個標(biāo)簽的指標(biāo),并找到它們的平均值,根據(jù)支持度 (每個標(biāo)簽的真實(shí)實(shí)例的數(shù)量) 進(jìn)行加權(quán)。

sample_weight=None:樣本權(quán)重。形狀 (n_samples,),默認(rèn) = 無。

max_fpr=None

multi_class='raise':(多分類的問題在下一篇文章中解釋)

labels=None

輸出:

auc:是一個 float 的值。

import numpy as np
import sklearn.metrics import roc_auc_score
prob_all = []
lable_all = []
for i, (data,label) in tqdm(train_data_loader):
    prob = model(data) #表示模型的預(yù)測輸出
    prob_all.extend(prob[:,1].cpu().numpy()) #prob[:,1]返回每一行第二列的數(shù),根據(jù)該函數(shù)的參數(shù)可知,y_score表示的較大標(biāo)簽類的分?jǐn)?shù),因此就是最大索引對應(yīng)的那個值,而不是最大索引值
    label_all.extend(label)
print("AUC:{:.4f}".format(roc_auc_score(label_all,prob_all)))

補(bǔ)充:pytorch訓(xùn)練模型的一些坑

1. 圖像讀取

opencv的python和c++讀取的圖像結(jié)果不一致,是因?yàn)閜ython和c++采用的opencv版本不一樣,從而使用的解碼庫不同,導(dǎo)致讀取的結(jié)果不同。

2. 圖像變換

PIL和pytorch的圖像resize操作,與opencv的resize結(jié)果不一樣,這樣會導(dǎo)致訓(xùn)練采用PIL,預(yù)測時采用opencv,結(jié)果差別很大,尤其是在檢測和分割任務(wù)中比較明顯。

3. 數(shù)值計(jì)算

pytorch的torch.exp與c++的exp計(jì)算,10e-6的數(shù)值時候會有10e-3的誤差,對于高精度計(jì)算需要特別注意,比如

兩個輸入5.601597, 5.601601, 經(jīng)過exp計(jì)算后變成270.85862343143174, 270.85970686809225

感謝你能夠認(rèn)真閱讀完這篇文章,希望小編分享的“Pytorch訓(xùn)練模型得到輸出后如何計(jì)算F1-Score和AUC”這篇文章對大家有幫助,同時也希望大家多多支持億速云,關(guān)注億速云行業(yè)資訊頻道,更多相關(guān)知識等著你來學(xué)習(xí)!

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI