溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點(diǎn)擊 登錄注冊 即表示同意《億速云用戶服務(wù)條款》

pytorch測試時要加上model.eval()的原因

發(fā)布時間:2021-05-24 09:21:23 來源:億速云 閱讀:1095 作者:小新 欄目:開發(fā)技術(shù)

這篇文章將為大家詳細(xì)講解有關(guān)pytorch測試時要加上model.eval()的原因,小編覺得挺實(shí)用的,因此分享給大家做個參考,希望大家閱讀完這篇文章后可以有所收獲。

Do need to use model.eval() when I test?

Sure, Dropout works as a regularization for preventing overfitting during training.

It randomly zeros the elements of inputs in Dropout layer on forward call.

It should be disabled during testing since you may want to use full model (no element is masked)

使用PyTorch進(jìn)行訓(xùn)練和測試時一定注意要把實(shí)例化的model指定train/eval,eval()時,框架會自動把BN和DropOut固定住,不會取平均,而是用訓(xùn)練好的值,不然的話,一旦test的batch_size過小,很容易就會被BN層導(dǎo)致生成圖片顏色失真極大!?。。。?!

補(bǔ)充:pytorch中model eval和torch no grad()的區(qū)別

model.eval()和with torch.no_grad()的區(qū)別

在PyTorch中進(jìn)行validation時,會使用model.eval()切換到測試模式,在該模式下,

主要用于通知dropout層和batchnorm層在train和val模式間切換

在train模式下,dropout網(wǎng)絡(luò)層會按照設(shè)定的參數(shù)p設(shè)置保留激活單元的概率(保留概率=p); batchnorm層會繼續(xù)計(jì)算數(shù)據(jù)的mean和var等參數(shù)并更新。

在val模式下,dropout層會讓所有的激活單元都通過,而batchnorm層會停止計(jì)算和更新mean和var,直接使用在訓(xùn)練階段已經(jīng)學(xué)出的mean和var值。

該模式不會影響各層的gradient計(jì)算行為,即gradient計(jì)算和存儲與training模式一樣,只是不進(jìn)行反傳(backprobagation)

而with torch.no_grad()則主要是用于停止autograd模塊的工作,以起到加速和節(jié)省顯存的作用,具體行為就是停止gradient計(jì)算,從而節(jié)省了GPU算力和顯存,但是并不會影響dropout和batchnorm層的行為。

使用場景

如果不在意顯存大小和計(jì)算時間的話,僅僅使用model.eval()已足夠得到正確的validation的結(jié)果;而with torch.zero_grad()則是更進(jìn)一步加速和節(jié)省gpu空間(因?yàn)椴挥糜?jì)算和存儲gradient),從而可以更快計(jì)算,也可以跑更大的batch來測試。

補(bǔ)充:Pytorch的modle.train,model.eval,with torch.no_grad的個人理解

1. 最近在學(xué)習(xí)pytorch過程中遇到了幾個問題

不理解為什么在訓(xùn)練和測試函數(shù)中model.eval(),和model.train()的區(qū)別,經(jīng)查閱后做如下整理

一般情況下,我們訓(xùn)練過程如下:

1、拿到數(shù)據(jù)后進(jìn)行訓(xùn)練,在訓(xùn)練過程中,使用

model.train():告訴我們的網(wǎng)絡(luò),這個階段是用來訓(xùn)練的,可以更新參數(shù)。

2、訓(xùn)練完成后進(jìn)行預(yù)測,在預(yù)測過程中,使用

model.eval() : 告訴我們的網(wǎng)絡(luò),這個階段是用來測試的,于是模型的參數(shù)在該階段不進(jìn)行更新。

2. 但是為什么在eval()階段會使用with torch.no_grad()?

with torch.no_grad - disables tracking of gradients in autograd.

model.eval() changes the forward() behaviour of the module it is called upon

eg, it disables dropout and has batch norm use the entire population statistics

總結(jié)一下就是說,在eval階段了,即使不更新,但是在模型中所使用的dropout或者batch norm也就失效了,直接都會進(jìn)行預(yù)測,而使用no_grad則設(shè)置讓梯度Autograd設(shè)置為False(因?yàn)樵谟?xùn)練中我們默認(rèn)是True),這樣保證了反向過程為純粹的測試,而不變參數(shù)。

pytorch的優(yōu)點(diǎn)

1.PyTorch是相當(dāng)簡潔且高效快速的框架;2.設(shè)計(jì)追求最少的封裝;3.設(shè)計(jì)符合人類思維,它讓用戶盡可能地專注于實(shí)現(xiàn)自己的想法;4.與google的Tensorflow類似,F(xiàn)AIR的支持足以確保PyTorch獲得持續(xù)的開發(fā)更新;5.PyTorch作者親自維護(hù)的論壇 供用戶交流和求教問題6.入門簡單

關(guān)于“pytorch測試時要加上model.eval()的原因”這篇文章就分享到這里了,希望以上內(nèi)容可以對大家有一定的幫助,使各位可以學(xué)到更多知識,如果覺得文章不錯,請把它分享出去讓更多的人看到。

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報,并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI