溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務(wù)條款》

如何進行PyTorch對象識別

發(fā)布時間:2021-12-04 18:25:47 來源:億速云 閱讀:178 作者:柒染 欄目:互聯(lián)網(wǎng)科技

這期內(nèi)容當(dāng)中小編將會給大家?guī)碛嘘P(guān)如何進行PyTorch對象識別,文章內(nèi)容豐富且以專業(yè)的角度為大家分析和敘述,閱讀完這篇文章希望大家可以有所收獲。

Keras是一個很棒的庫,它提供了一個簡單的API來構(gòu)建神經(jīng)網(wǎng)絡(luò),但最近對PyTorch的興奮感最終讓我對探索這個庫產(chǎn)生了興趣。雖然我是一個"盲目追隨炒作"的人,但是研究人員的采用和fast.ai的推崇使我確信在這個深度學(xué)習(xí)的新入口中必定有新的東西值得我去探尋。

由于學(xué)習(xí)新技術(shù)的最佳方法是使用它來解決問題,所以我學(xué)習(xí)PyTorch的工作始于一個簡單的項目:使用預(yù)先訓(xùn)練的卷積神經(jīng)網(wǎng)絡(luò)進行對象識別任務(wù)。我們將看到如何使用PyTorch來實現(xiàn)這一目標,并在此過程中學(xué)習(xí)一些關(guān)于庫和遷移學(xué)習(xí)的重要概念。

雖然PyTorch可能不適合所有人,但在這一點上,很難說出哪個深度學(xué)習(xí)庫會脫穎而出,而能夠快速學(xué)習(xí)和使用不同的工具對于成為數(shù)據(jù)科學(xué)家來說至關(guān)重要。

該項目的完整代碼在GitHub上以Jupyter Notebook的形式提供(https://github.com/WillKoehrsen/pytorch_challenge/blob/master/Transfer%20Learning%20in%20PyTorch.ipynb)。這個項目源于我參加Udacity PyTorch獎學(xué)金挑戰(zhàn)(https://www.udacity.com/facebook-pytorch-scholarship)。

如何進行PyTorch對象識別

從受過訓(xùn)練的網(wǎng)絡(luò)預(yù)測

遷移學(xué)習(xí)法

我們的任務(wù)是訓(xùn)練可以識別圖像中物體的卷積神經(jīng)網(wǎng)絡(luò)(CNN)。我們將使用Caltech 101數(shù)據(jù)集(http://www.vision.caltech.edu/Image_Datasets/Caltech201/),該數(shù)據(jù)集包含101個類別的圖像。大多數(shù)類別只有50個圖像,這些圖像通常不足以讓神經(jīng)網(wǎng)絡(luò)學(xué)會高精度。因此,我們將使用預(yù)先構(gòu)建和預(yù)先訓(xùn)練的模型來應(yīng)用遷移學(xué)習(xí),而不是從頭開始構(gòu)建和訓(xùn)練CNN。

遷移學(xué)習(xí)的基本前提很簡單:采用在大型數(shù)據(jù)集上訓(xùn)練的模型,并將其轉(zhuǎn)移到較小的數(shù)據(jù)集上。對于使用CNN的對象識別,我們凍結(jié)網(wǎng)絡(luò)的早期卷積層,并且僅訓(xùn)練進行預(yù)測的最后幾層。這個想法是卷積層提取適用于圖像的一般,低級特征(例如邊緣、圖案、漸變)后面的圖層識別圖像中的特定特征,如眼睛或車輪。

因此,我們可以使用在大規(guī)模數(shù)據(jù)集(通常是Imagenet)中訓(xùn)練不相關(guān)類別的網(wǎng)絡(luò),并將其應(yīng)用于我們自己的問題中,因為圖像之間共享通用的低級特征。Caltech 101數(shù)據(jù)集中的圖像與Imagenet數(shù)據(jù)集中的圖像非常相似,模型在Imagenet上學(xué)習(xí)的知識應(yīng)該很容易轉(zhuǎn)移到此任務(wù)中。(http://www.image-net.org/)

如何進行PyTorch對象識別

遷移學(xué)習(xí)背后的理念

以下是物體識別的遷移學(xué)習(xí)的概要:

  1. 加載在大型數(shù)據(jù)集上訓(xùn)練的預(yù)訓(xùn)練CNN模型

  2. 凍結(jié)模型的下卷積層中的參數(shù)(權(quán)重)

  3. 添加具有多層可訓(xùn)練參數(shù)的自定義分類器以進行建模

  4. 訓(xùn)練可用于任務(wù)的訓(xùn)練數(shù)據(jù)的分類器層

  5. 根據(jù)需要微調(diào)超參數(shù)并解凍更多層

事實證明,這種方法適用于廣泛的領(lǐng)域。這是一個很好的工具,通常是面對新的圖像識別問題時應(yīng)該嘗試的第一種方法。

數(shù)據(jù)設(shè)置

對于所有數(shù)據(jù)科學(xué)問題,正確格式化數(shù)據(jù)將決定項目的成功或失敗。幸運的是,Caltech 101數(shù)據(jù)集圖像清晰,并以正確的格式存儲。如果我們正確設(shè)置數(shù)據(jù)目錄,PyTorch可以很容易地將正確的標簽與每個類關(guān)聯(lián)起來。我將數(shù)據(jù)分為訓(xùn)練,驗證和測試集,分別為50%,25%,25%,然后按如下方式構(gòu)建目錄:

如何進行PyTorch對象識別

按類別劃分的訓(xùn)練圖像數(shù)量(我可以互換地使用術(shù)語類別和類別):

如何進行PyTorch對象識別

按類別劃分的訓(xùn)練圖像數(shù)量

我們希望模型在具有更多示例的類上做得更好,因為它可以更好地學(xué)習(xí)將特性映射到標簽。為了處理有限數(shù)量的訓(xùn)練樣例,我們將在訓(xùn)練期間使用數(shù)據(jù)增加。

作為另一項數(shù)據(jù)探索,我們還可以查看大小分布。

如何進行PyTorch對象識別

按類別分布平均圖像大小(以像素為單位)

Imagenet模型需要224 x 224的輸入大小,因此其中一個預(yù)處理步驟將是調(diào)整圖像大小。預(yù)處理也是我們?yōu)橛?xùn)練數(shù)據(jù)實施數(shù)據(jù)增強的地方。

數(shù)據(jù)增強

數(shù)據(jù)增強的想法是通過對圖像應(yīng)用隨機變換來人為地增加模型看到的訓(xùn)練圖像的數(shù)量。例如,我們可以隨機旋轉(zhuǎn)或裁剪圖像或水平翻轉(zhuǎn)它們。我們希望我們的模型能夠區(qū)分對象,而不管方向如何,數(shù)據(jù)增強也可以使模型對輸入數(shù)據(jù)的轉(zhuǎn)換不變。

無論大象朝哪個方向走,大象仍然是大象!

如何進行PyTorch對象識別

訓(xùn)練數(shù)據(jù)的圖像變換

通常僅在訓(xùn)練期間進行增強(盡管在fast.ai庫中可以進行測試時間增加)。每個時期 - 通過所有訓(xùn)練圖像的一次迭代 - 對每個訓(xùn)練圖像應(yīng)用不同的隨機變換。這意味著如果我們迭代數(shù)據(jù)20次,我們的模型將看到每個圖像的20個略有不同的版本。整體結(jié)果應(yīng)該是一個模型,它可以學(xué)習(xí)對象本身,而不是如何呈現(xiàn)它們或圖像中的工件。

圖像預(yù)處理

這是處理圖像數(shù)據(jù)最重要的一步。在圖像預(yù)處理期間,我們同時為網(wǎng)絡(luò)準備圖像并將數(shù)據(jù)增強應(yīng)用于訓(xùn)練集。每個模型都有不同的輸入要求,但如果我們讀完Imagenet所需的內(nèi)容,我們就會發(fā)現(xiàn)我們的圖像需要為224x224并標準化為一個范圍。

要在PyTorch中處理圖像,我們使用遷移,即應(yīng)用于數(shù)組的簡單操作。驗證(和測試)遷移如下:

  • 調(diào)整

  • 中心裁剪為224 x 224

  • 遷移為張量

  • 用均值和標準差標準化

通過這些遷移的最終結(jié)果是可以進入我們網(wǎng)絡(luò)的張量。訓(xùn)練變換是相似的,但增加了隨機增強。

首先,我們定義訓(xùn)練和驗證轉(zhuǎn)換:

如何進行PyTorch對象識別

如何進行PyTorch對象識別

然后,我們創(chuàng)建數(shù)據(jù)集和數(shù)據(jù)閱讀器。ImageFolder創(chuàng)建數(shù)據(jù)集,PyTorch將自動將圖像與正確的標簽關(guān)聯(lián),前提是我們的目錄設(shè)置如上述。然后將數(shù)據(jù)集傳遞給DataLoader,這是一個產(chǎn)生批量圖像和標簽的迭代器。

如何進行PyTorch對象識別

我們可以使用以下方法查看DataLoader的迭代行為:

如何進行PyTorch對象識別

批處理的形狀是(batch_size,color_channels,height,width)。在訓(xùn)練、驗證和最終測試期間,我們將遍歷DataLoaders,一次通過包含一個時期的完整數(shù)據(jù)集。每個時期,訓(xùn)練DataLoader將對圖像應(yīng)用稍微不同的隨機變換以進行訓(xùn)練數(shù)據(jù)增強。

用于圖像識別的預(yù)訓(xùn)練模型

隨著我們的數(shù)據(jù)的成形,我們接下來將注意力轉(zhuǎn)向模型。為此,我們將使用預(yù)先訓(xùn)練的卷積神經(jīng)網(wǎng)絡(luò)。PyTorch有許多模型已經(jīng)在Imagenet的1000個類中訓(xùn)練了數(shù)百萬個圖像。完整的模型列表可以在這里看到(https://pytorch.org/docs/stable/torchvision/models.html)。這些模型在Imagenet上的性能如下所示:

如何進行PyTorch對象識別

PyTorch中的預(yù)訓(xùn)練模型和Imagenet上的性能

對于此實現(xiàn),我們將使用VGG-16。雖然它沒有記錄最低的錯誤,但我發(fā)現(xiàn)它適用于任務(wù),并且比其他模型訓(xùn)練得更快。使用預(yù)訓(xùn)練模型的過程已經(jīng)建立:

  1. 從在大型數(shù)據(jù)集上訓(xùn)練的網(wǎng)絡(luò)加載預(yù)訓(xùn)練的權(quán)重

  2. 凍結(jié)較低(卷積)圖層中的所有權(quán)重:根據(jù)新任務(wù)與原始數(shù)據(jù)集的相似性調(diào)整要凍結(jié)的圖層

  3. 用自定義分類器替換網(wǎng)絡(luò)的上層:輸出數(shù)必須設(shè)置為等于類的數(shù)量

  4. 僅為任務(wù)訓(xùn)練自定義分類器層,從而優(yōu)化較小數(shù)據(jù)集的模型

在PyTorch中加載預(yù)先訓(xùn)練的模型很簡單:

如何進行PyTorch對象識別

這個模型有超過1.3億個參數(shù),但我們只訓(xùn)練最后幾個完全連接的層。首先,我們凍結(jié)所有模型的權(quán)重:

如何進行PyTorch對象識別

然后,我們使用以下圖層添加我們自己的自定義分類器:

  • 與ReLU激活完全連接,shape =(n_inputs,256)

  • Dropout有40%的可能性下降

  • 與log softmax輸出完全連接,shape =(256,n_classes)

如何進行PyTorch對象識別

將額外圖層添加到模型時,默認情況下將它們設(shè)置為可訓(xùn)練(require_grad = True)。對于VGG-16,我們只改變最后一個原始的全連接層。卷積層和前5個完全連接層中的所有權(quán)重都是不可訓(xùn)練的。

如何進行PyTorch對象識別

網(wǎng)絡(luò)的最終輸出是我們數(shù)據(jù)集中100個類中每個類的對數(shù)概率。 該模型共有1.35億個參數(shù),其中只有100多萬個將被訓(xùn)練。

如何進行PyTorch對象識別

將模型移動到GPU(s)

PyTorch的最佳方面之一是可以輕松地將模型的不同部分移動到一個或多個gpus(https://pytorch.org/docs/stable/notes/cuda.html),以便你可以充分利用你的硬件。由于我使用2 gpus進行訓(xùn)練,我首先將模型移動到cuda,然后創(chuàng)建一個分布在gpus上的DataParallel模型:

如何進行PyTorch對象識別

(這個筆記本應(yīng)該在一個gpu上運行,以便在合理的時間內(nèi)完成。對CPU的加速可以輕松達到10倍或更多。)

訓(xùn)練損失和優(yōu)化

訓(xùn)練損失(預(yù)測和真值之間的誤差或差異)是負對數(shù)似然(NLL:https://ljvmiranda921.github.io/notebook/2017/08/13/softmax-and-the-negative-log-likelihood/)。(PyTorch中的NLL損失需要對數(shù)概率,因此我們從模型的最后一層傳遞原始輸出。)PyTorch使用自動微分,這意味著張量不僅跟蹤它們的值,而且還跟蹤每個操作(乘法,加法,激活等)。這意味著我們可以針對任何先前張量計算網(wǎng)絡(luò)中任何張量的梯度。

這在實踐中意味著損失不僅跟蹤誤差,而且跟蹤模型中每個權(quán)重和偏差對誤差的貢獻。在我們計算損失后,我們可以找到相對于每個模型參數(shù)的損失梯度,這個過程稱為反向傳播。一旦我們獲得了梯度,我們就會使用它們來更新參數(shù)和優(yōu)化器。

優(yōu)化器是Adam(https://machinelearningmastery.com/adam-optimization-algorithm-for-deep-learning/),梯度下降的有效變體,通常不需要手動調(diào)整學(xué)習(xí)速率。在訓(xùn)練期間,優(yōu)化器使用損失的梯度來嘗試通過調(diào)整參數(shù)來減少模型輸出的誤差("優(yōu)化")。只會優(yōu)化我們在自定義分類器中添加的參數(shù)。

損失和優(yōu)化器初始化如下:

如何進行PyTorch對象識別

通過預(yù)先訓(xùn)練的模型,自定義分類器,損失,優(yōu)化器以及最重要的數(shù)據(jù),我們已準備好進行訓(xùn)練。

訓(xùn)練

PyTorch中的模型訓(xùn)練比Keras中的實際操作多一些,因為我們必須自己進行反向傳播和參數(shù)更新步驟。主循環(huán)迭代多個時期,并且在每個時期迭代通過DataLoader。 DataLoader生成一批我們通過模型的數(shù)據(jù)和目標。在每個訓(xùn)練批次之后,我們計算損失,相對于模型參數(shù)反向傳播損失的梯度,然后用優(yōu)化器更新參數(shù)。

我建議你查看筆記本上的完整訓(xùn)練詳細信息(https://github.com/WillKoehrsen/pytorch_challenge/blob/master/Transfer%20Learning%20in%20PyTorch.ipynb),但基本的偽代碼如下:

如何進行PyTorch對象識別

我們可以繼續(xù)迭代數(shù)據(jù),直到達到給定數(shù)量的時期。然而,這種方法的一個問題是,我們的模型最終將過度擬合訓(xùn)練數(shù)據(jù)。為了防止這種情況,我們使用驗證數(shù)據(jù)并早期停止。

早期停止

早期停止(https://en.wikipedia.org/wiki/Early_stopping)意味著當(dāng)驗證損失在許多時期沒有減少時停止訓(xùn)練。在我們繼續(xù)訓(xùn)練時,訓(xùn)練損失只會減少,但驗證損失最終會達到最低限度并達到穩(wěn)定水平或開始增加。理想情況下,當(dāng)驗證損失最小時,我們希望停止訓(xùn)練,希望此模型能夠最好地推廣到測試數(shù)據(jù)。當(dāng)使用早期停止時,驗證損失減少的每個時期,我們保存參數(shù),以便我們以后可以檢索具有最佳驗證性能的那些。

我們通過在每個訓(xùn)練時期結(jié)束時迭代驗證DataLoader來實現(xiàn)早期停止。我們計算驗證損失并將其與最低驗證損失進行比較。如果到目前為止損失最小,我們保存模型。如果在一定數(shù)量的時期內(nèi)損失沒有改善,我們停止訓(xùn)練并返回已保存到磁盤的最佳模型。

同樣,完整的代碼在筆記本中,但偽代碼是:

如何進行PyTorch對象識別

如何進行PyTorch對象識別

如何進行PyTorch對象識別

如何進行PyTorch對象識別

要了解早期停止的好處,我們可以查看顯示訓(xùn)練和驗證損失和準確性的訓(xùn)練曲線:

如何進行PyTorch對象識別

如何進行PyTorch對象識別

負對數(shù)似然和準確性訓(xùn)練曲線

正如預(yù)期的那樣,隨著進一步的訓(xùn)練,訓(xùn)練損失只會繼續(xù)減少。另一方面,驗證損失達到最低和穩(wěn)定的狀態(tài)。在某一時期,進一步訓(xùn)練是沒有回報的(甚至是負回報)。我們的模型將僅開始記憶訓(xùn)練數(shù)據(jù),并且無法推廣到測試數(shù)據(jù)。

如果沒有早期停止,我們的模型將訓(xùn)練超過必要的時間并且將過度訓(xùn)練數(shù)據(jù)。

我們從訓(xùn)練曲線中可以看到的另一點是我們的模型并沒有過度擬合??偸谴嬖谝恍┻^度擬合,但是在第一個可訓(xùn)練的完全連接層之后的退出可以防止訓(xùn)練和驗證損失過多。

做出預(yù)測:推論

在筆記本中我處理了一些無聊但必要的保存和加載PyTorch模型的細節(jié),但在這里我們將移動到最佳部分:對新圖像進行預(yù)測。我們知道我們的模型在訓(xùn)練甚至驗證數(shù)據(jù)方面做得很好,但最終的測試是它如何在一個前所未見的保持測試集上的執(zhí)行。我們保存了25%的數(shù)據(jù),以確定我們的模型是否可以推廣到新數(shù)據(jù)。

使用訓(xùn)練過的模型進行預(yù)測非常簡單。我們使用與訓(xùn)練和驗證相同的語法:

如何進行PyTorch對象識別

我們概率的形狀是(batch_size,n_classes),因為我們有每個類的概率。我們可以通過找出每個示例的最高概率來找到準確性,并將它們與標簽進行比較:

如何進行PyTorch對象識別

在診斷用于對象識別的網(wǎng)絡(luò)時(https://www.coursera.org/lecture/machine-learning/model-selection-and-train-validation-test-sets-QGKbr),查看測試集的整體性能和單個預(yù)測會很有幫助。

模型結(jié)果

以下是模型的兩個預(yù)測:

如何進行PyTorch對象識別

如何進行PyTorch對象識別

這些都很簡單,所以我很高興模型沒有問題!

我們不僅僅想關(guān)注正確的預(yù)測,我們還將很快就會看到一些錯誤的輸出?,F(xiàn)在讓我們評估整個測試集的性能。為此,我們希望迭代測試DataLoader并計算每個示例的損失和準確性。

用于對象識別的卷積神經(jīng)網(wǎng)絡(luò)通常根據(jù)topk精度(https://stats.stackexchange.com/questions/95391/what-is-the-definition-of-top-n-accuracy)來測量。這是指真實的類是否屬于k最可能預(yù)測的類中。例如,前5個準確度是5個最高概率預(yù)測中正確等級的百分比。你可以從PyTorch張量中獲取topk最可能的概率和類,如下所示:

如何進行PyTorch對象識別

在整個測試集上評估模型,我們計算指標:

如何進行PyTorch對象識別

這些與驗證數(shù)據(jù)中接近90%的top1精度相比是有利的。總的來說,我們得出結(jié)論,我們的預(yù)訓(xùn)練模型能夠成功地將其知識從Imagenet轉(zhuǎn)移到我們較小的數(shù)據(jù)集。

模型調(diào)查

盡管該模型表現(xiàn)良好,但仍有可能采取一些步驟可以使其變得更好。通常,弄清楚如何改進模型的最佳方法是調(diào)查其錯誤(注意:這也是一種有效的自我改進方法。)

我們的模型不太適合識別鱷魚,所以我們來看看這個類別的一些測試預(yù)測:

如何進行PyTorch對象識別

如何進行PyTorch對象識別

如何進行PyTorch對象識別

考慮到鱷魚和鱷魚頭之間的微妙區(qū)別,以及第二張圖像的難度,我會說我們的模型在這些預(yù)測中并非完全不合理。圖像識別的最終目標是超越人類的能力,我們的模型幾乎已經(jīng)接近了!

最后,我們希望模型在具有更多圖像的類別上表現(xiàn)更好,因此我們可以查看給定類別中的準確度圖表與該類別中的訓(xùn)練圖像數(shù)量:

如何進行PyTorch對象識別

在訓(xùn)練圖像的數(shù)量和前一個測試精度之間似乎存在正相關(guān)關(guān)系。這表明更多的訓(xùn)練數(shù)據(jù)增加是有所幫助的,或者我們應(yīng)該對測試時間進行增加。我們還可以嘗試不同的預(yù)訓(xùn)練模型,或者構(gòu)建另一個自定義分類器。目前,深度學(xué)習(xí)仍然是一個經(jīng)驗領(lǐng)域,這意味著經(jīng)常需要實驗!

結(jié)論

雖然有更容易使用的深度學(xué)習(xí)庫,但PyTorch的優(yōu)點是速度快,對模型架構(gòu)/訓(xùn)練的各個方面的控制好,能使張量自動區(qū)分的反向傳播,以及由于PyTorch圖的動態(tài)特性而易于調(diào)試的代碼。對于生產(chǎn)代碼或你自己的項目,我不確定使用PyTorch而不是具有更溫和學(xué)習(xí)曲線的庫(例如Keras)還存在令人信服的論據(jù),但知道如何使用不同選項會很有幫助。

通過這個項目,我們能夠看到使用PyTorch的基礎(chǔ)知識以及遷移學(xué)習(xí)的概念,這是一種有效的對象識別方法。我們可以使用已在大型數(shù)據(jù)集上進行過訓(xùn)練的現(xiàn)有體系結(jié)構(gòu),然后根據(jù)我們的任務(wù)調(diào)整它們,而不是從頭開始訓(xùn)練模型。這無疑減少了訓(xùn)練的時間并且通常導(dǎo)致更好的整體性能。這個項目的成果是對遷移學(xué)習(xí)和PyTorch一些知識的應(yīng)用,我們可以構(gòu)建它來構(gòu)建更復(fù)雜的應(yīng)用程序。

我們確實生活在一個令人難以置信的深度學(xué)習(xí)時代,任何人都可以利用輕松可用的資源建立深度學(xué)習(xí)模型!現(xiàn)在是時候,通過構(gòu)建自己的項目來更好的利用這些資源了。

上述就是小編為大家分享的如何進行PyTorch對象識別了,如果剛好有類似的疑惑,不妨參照上述分析進行理解。如果想知道更多相關(guān)知識,歡迎關(guān)注億速云行業(yè)資訊頻道。

向AI問一下細節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進行舉報,并提供相關(guān)證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI