溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶服務(wù)條款》

python PyTorch參數(shù)初始化和Finetune

發(fā)布時(shí)間:2020-09-16 06:06:15 來源:腳本之家 閱讀:146 作者:ycszen 欄目:開發(fā)技術(shù)

前言

這篇文章算是論壇PyTorch Forums關(guān)于參數(shù)初始化和finetune的總結(jié),也是我在寫代碼中用的算是“最佳實(shí)踐”吧。最后希望大家沒事多逛逛論壇,有很多高質(zhì)量的回答。

參數(shù)初始化

參數(shù)的初始化其實(shí)就是對(duì)參數(shù)賦值。而我們需要學(xué)習(xí)的參數(shù)其實(shí)都是Variable,它其實(shí)是對(duì)Tensor的封裝,同時(shí)提供了data,grad等借口,這就意味著我們可以直接對(duì)這些參數(shù)進(jìn)行操作賦值了。這就是PyTorch簡(jiǎn)潔高效所在。

python PyTorch參數(shù)初始化和Finetune

所以我們可以進(jìn)行如下操作進(jìn)行初始化,當(dāng)然其實(shí)有其他的方法,但是這種方法是PyTorch作者所推崇的:

def weight_init(m):
# 使用isinstance來判斷m屬于什么類型
  if isinstance(m, nn.Conv2d):
    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    m.weight.data.normal_(0, math.sqrt(2. / n))
  elif isinstance(m, nn.BatchNorm2d):
# m中的weight,bias其實(shí)都是Variable,為了能學(xué)習(xí)參數(shù)以及后向傳播
    m.weight.data.fill_(1)
    m.bias.data.zero_()

Finetune

往往在加載了預(yù)訓(xùn)練模型的參數(shù)之后,我們需要finetune模型,可以使用不同的方式finetune。

局部微調(diào)

有時(shí)候我們加載了訓(xùn)練模型后,只想調(diào)節(jié)最后的幾層,其他層不訓(xùn)練。其實(shí)不訓(xùn)練也就意味著不進(jìn)行梯度計(jì)算,PyTorch中提供的requires_grad使得對(duì)訓(xùn)練的控制變得非常簡(jiǎn)單。

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
  param.requires_grad = False
# 替換最后的全連接層, 改為訓(xùn)練100類
# 新構(gòu)造的模塊的參數(shù)默認(rèn)requires_grad為True
model.fc = nn.Linear(512, 100)

# 只優(yōu)化最后的分類層
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

全局微調(diào)

有時(shí)候我們需要對(duì)全局都進(jìn)行finetune,只不過我們希望改換過的層和其他層的學(xué)習(xí)速率不一樣,這時(shí)候我們可以把其他層和新層在optimizer中單獨(dú)賦予不同的學(xué)習(xí)速率。比如:

ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params,
           model.parameters())

optimizer = torch.optim.SGD([
      {'params': base_params},
      {'params': model.fc.parameters(), 'lr': 1e-3}
      ], lr=1e-2, momentum=0.9)

其中base_params使用1e-3來訓(xùn)練,model.fc.parameters使用1e-2來訓(xùn)練,momentum是二者共有的。

以上就是本文的全部?jī)?nèi)容,希望對(duì)大家的學(xué)習(xí)有所幫助,也希望大家多多支持億速云。

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI