溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點(diǎn)擊 登錄注冊 即表示同意《億速云用戶服務(wù)條款》

pytorch中BatchNorm2d函數(shù)的參數(shù)怎么使用

發(fā)布時間:2022-12-15 09:54:29 來源:億速云 閱讀:154 作者:iii 欄目:開發(fā)技術(shù)

本篇內(nèi)容主要講解“pytorch中BatchNorm2d函數(shù)的參數(shù)怎么使用”,感興趣的朋友不妨來看看。本文介紹的方法操作簡單快捷,實(shí)用性強(qiáng)。下面就讓小編來帶大家學(xué)習(xí)“pytorch中BatchNorm2d函數(shù)的參數(shù)怎么使用”吧!

BN原理、作用

pytorch中BatchNorm2d函數(shù)的參數(shù)怎么使用

pytorch中BatchNorm2d函數(shù)的參數(shù)怎么使用

pytorch中BatchNorm2d函數(shù)的參數(shù)怎么使用

pytorch中BatchNorm2d函數(shù)的參數(shù)怎么使用

函數(shù)參數(shù)講解

BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  • 1.num_features:一般輸入?yún)?shù)的shape為batch_size*num_features*height*width,即為其中特征的數(shù)量,即為輸入BN層的通道數(shù);

  • 2.eps:分母中添加的一個值,目的是為了計(jì)算的穩(wěn)定性,默認(rèn)為:1e-5,避免分母為0;

  • 3.momentum:一個用于運(yùn)行過程中均值和方差的一個估計(jì)參數(shù)(我的理解是一個穩(wěn)定系數(shù),類似于SGD中的momentum的系數(shù));

  • 4.affine:當(dāng)設(shè)為true時,會給定可以學(xué)習(xí)的系數(shù)矩陣gamma和beta

一般來說pytorch中的模型都是繼承nn.Module類的,都有一個屬性trainning指定是否是訓(xùn)練狀態(tài),訓(xùn)練狀態(tài)與否將會影響到某些層的參數(shù)是否是固定的,比如BN層或者Dropout層。

通常用model.train()指定當(dāng)前模型model為訓(xùn)練狀態(tài),model.eval()指定當(dāng)前模型為測試狀態(tài)。

同時,BN的API中有幾個參數(shù)需要比較關(guān)心的,一個是affine指定是否需要仿射,還有個是track_running_stats指定是否跟蹤當(dāng)前batch的統(tǒng)計(jì)特性。

容易出現(xiàn)問題也正好是這三個參數(shù):trainning,affine,track_running_stats。

其中的affine指定是否需要仿射,也就是是否需要上面算式的第四個,如果affine=False則γ=1,β=0,并且不能學(xué)習(xí)被更新。一般都會設(shè)置成affine=True。

trainning和track_running_stats,track_running_stats=True表示跟蹤整個訓(xùn)練過程中的batch的統(tǒng)計(jì)特性,得到方差和均值,而不只是僅僅依賴與當(dāng)前輸入的batch的統(tǒng)計(jì)特性。

相反的,如果track_running_stats=False那么就只是計(jì)算當(dāng)前輸入的batch的統(tǒng)計(jì)特性中的均值和方差了。

當(dāng)在推理階段的時候,如果track_running_stats=False,此時如果batch_size比較小,那么其統(tǒng)計(jì)特性就會和全局統(tǒng)計(jì)特性有著較大偏差,可能導(dǎo)致糟糕的效果。

如果BatchNorm2d的參數(shù)track_running_stats設(shè)置False,那么加載預(yù)訓(xùn)練后每次模型測試測試集的結(jié)果時都不一樣;track_running_stats設(shè)置為True時,每次得到的結(jié)果都一樣。

running_mean和running_var參數(shù)是根據(jù)輸入的batch的統(tǒng)計(jì)特性計(jì)算的,嚴(yán)格來說不算是“學(xué)習(xí)”到的參數(shù),不過對于整個計(jì)算是很重要的。

BN層中的running_mean和running_var的更新是在forward操作中進(jìn)行的,而不是在optimizer.step()中進(jìn)行的,因此如果處于訓(xùn)練中泰,就算不進(jìn)行手動step(),BN的統(tǒng)計(jì)特性也會變化。

model.train() #處于訓(xùn)練狀態(tài)
for data , label in self.dataloader:
    pred =model(data)  #在這里會更新model中的BN統(tǒng)計(jì)特性參數(shù),running_mean,running_var
    loss=self.loss(pred,label)
    #就算不進(jìn)行下列三行,BN的統(tǒng)計(jì)特性參數(shù)也會變化
    opt.zero_grad()
    loss.backward()
    opt.step()

這個時候,要用model.eval()轉(zhuǎn)到測試階段,才能固定住running_mean和running_var,有時候如果是先預(yù)訓(xùn)練模型然后加載模型,重新跑測試數(shù)據(jù)的時候,結(jié)果不同,有一點(diǎn)性能上的損失,這個時候基本上是training和track_running_stats設(shè)置的不對。

如果使用兩個模型進(jìn)行聯(lián)合訓(xùn)練,為了收斂更容易控制,先預(yù)訓(xùn)練好模型model_A,并且model_A內(nèi)還有若干BN層,后續(xù)需要將model_A作為一個inference推理模型和model_B聯(lián)合訓(xùn)練,此時希望model_A中的BN的統(tǒng)計(jì)特性量running_mean和running_var不會亂變化,因此就需要將model_A.eval()設(shè)置到測試模型,否則在trainning模式下,就算是不去更新模型的參數(shù),其BN都會變化,這將導(dǎo)致和預(yù)期不同的結(jié)果。

到此,相信大家對“pytorch中BatchNorm2d函數(shù)的參數(shù)怎么使用”有了更深的了解,不妨來實(shí)際操作一番吧!這里是億速云網(wǎng)站,更多相關(guān)內(nèi)容可以進(jìn)入相關(guān)頻道進(jìn)行查詢,關(guān)注我們,繼續(xù)學(xué)習(xí)!

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI