您好,登錄后才能下訂單哦!
小編給大家分享一下pytorch中優(yōu)化器optimizer.param_groups用法的示例分析,希望大家閱讀完這篇文章之后都有所收獲,下面讓我們一起去探討吧!
optimizer.param_groups
: 是長(zhǎng)度為2的list,其中的元素是2個(gè)字典;
optimizer.param_groups[0]
: 長(zhǎng)度為6的字典,包括[‘a(chǎn)msgrad', ‘params', ‘lr', ‘betas', ‘weight_decay', ‘eps']這6個(gè)參數(shù);
optimizer.param_groups[1]
: 好像是表示優(yōu)化器的狀態(tài)的一個(gè)字典;
import torch import torch.optim as optimh3 w1 = torch.randn(3, 3) w1.requires_grad = True w2 = torch.randn(3, 3) w2.requires_grad = True o = optim.Adam([w1]) print(o.param_groups)
[{'amsgrad': False, 'betas': (0.9, 0.999), 'eps': 1e-08, 'lr': 0.001, 'params': [tensor([[ 2.9064, -0.2141, -0.4037], [-0.5718, 1.0375, -0.6862], [-0.8372, 0.4380, -0.1572]])], 'weight_decay': 0}]
Per the docs, the add_param_group method accepts a param_group parameter that is a dict. Example of use:h3import torch import torch.optim as optimh3 w1 = torch.randn(3, 3) w1.requires_grad = True w2 = torch.randn(3, 3) w2.requires_grad = True o = optim.Adam([w1]) print(o.param_groups) givesh3[{'amsgrad': False, 'betas': (0.9, 0.999), 'eps': 1e-08, 'lr': 0.001, 'params': [tensor([[ 2.9064, -0.2141, -0.4037], [-0.5718, 1.0375, -0.6862], [-0.8372, 0.4380, -0.1572]])], 'weight_decay': 0}] nowh3o.add_param_group({'params': w2}) print(o.param_groups)
[{'amsgrad': False, 'betas': (0.9, 0.999), 'eps': 1e-08, 'lr': 0.001, 'params': [tensor([[ 2.9064, -0.2141, -0.4037], [-0.5718, 1.0375, -0.6862], [-0.8372, 0.4380, -0.1572]])], 'weight_decay': 0}, {'amsgrad': False, 'betas': (0.9, 0.999), 'eps': 1e-08, 'lr': 0.001, 'params': [tensor([[-0.0560, 0.4585, -0.7589], [-0.1994, 0.4557, 0.5648], [-0.1280, -0.0333, -1.1886]])], 'weight_decay': 0}]
# 動(dòng)態(tài)修改學(xué)習(xí)率 for param_group in optimizer.param_groups: param_group["lr"] = lr # 得到學(xué)習(xí)率optimizer.param_groups[0]["lr"] h3# print('查看optimizer.param_groups結(jié)構(gòu):') # i_list=[i for i in optimizer.param_groups[0].keys()] # print(i_list) ['amsgrad', 'params', 'lr', 'betas', 'weight_decay', 'eps']
補(bǔ)充:pytorch中的優(yōu)化器總結(jié)
# -*- coding: utf-8 -*- #@Time :2019/7/3 22:31 #@Author :XiaoMa from torch import nn as nn import torch as t from torch.autograd import Variable as V #定義一個(gè)LeNet網(wǎng)絡(luò) class Net(nn.Module): def __init__(self): super(Net,self).__init__() self.features=nn.Sequential( nn.Conv2d(3,6,5), nn.ReLU(), nn.MaxPool2d(2,2), nn.Conv2d(6,16,5), nn.ReLU(), nn.MaxPool2d(2,3) ) self.classifier=nn.Sequential(\ nn.Linear(16*5*5,120), nn.ReLU(), nn.Linear(120,84), nn.ReLU(), nn.Linear(84,10) ) def forward(self, x): x=self.features(x) x=x.view(-1,16*5*5) x=self.classifier(x) return x net=Net() from torch import optim #優(yōu)化器 optimizer=optim.SGD(params=net.parameters(),lr=1) optimizer.zero_grad() #梯度清零,相當(dāng)于net.zero_grad() input=V(t.randn(1,3,32,32)) output=net(input) output.backward(output) #fake backward optimizer.step() #執(zhí)行優(yōu)化 #為不同子網(wǎng)絡(luò)設(shè)置不同的學(xué)習(xí)率,在finetune中經(jīng)常用到 #如果對(duì)某個(gè)參數(shù)不指定學(xué)習(xí)率,就使用默認(rèn)學(xué)習(xí)率 optimizer=optim.SGD( [{'param':net.features.parameters()}, #學(xué)習(xí)率為1e-5 {'param':net.classifier.parameters(),'lr':1e-2}],lr=1e-5 ) #只為兩個(gè)全連接層設(shè)置較大的學(xué)習(xí)率,其余層的學(xué)習(xí)率較小 special_layers=nn.ModuleList([net.classifier[0],net.classifier[3]]) special_layers_params=list(map(id,special_layers.parameters())) base_params=filter(lambda p:id(p) not in special_layers_params,net.parameters()) optimizer=t.optim.SGD([ {'param':base_params}, {'param':special_layers.parameters(),'lr':0.01} ],lr=0.001)
一種是修改optimizer.param_groups中對(duì)應(yīng)的學(xué)習(xí)率,另一種是新建優(yōu)化器(更簡(jiǎn)單也是更推薦的做法),由于optimizer十分輕量級(jí),構(gòu)建開(kāi)銷很小,故可以構(gòu)建新的optimizer。
但是新建優(yōu)化器會(huì)重新初始化動(dòng)量等狀態(tài)信息,這對(duì)使用動(dòng)量的優(yōu)化器來(lái)說(shuō)(如自帶的momentum的sgd),可能會(huì)造成損失函數(shù)在收斂過(guò)程中出現(xiàn)震蕩。
如:
#調(diào)整學(xué)習(xí)率,新建一個(gè)optimizer old_lr=0.1 optimizer=optim.SGD([ {'param':net.features.parameters()}, {'param':net.classifiers.parameters(),'lr':old_lr*0.5}],lr=1e-5)
看完了這篇文章,相信你對(duì)“pytorch中優(yōu)化器optimizer.param_groups用法的示例分析”有了一定的了解,如果想了解更多相關(guān)知識(shí),歡迎關(guān)注億速云行業(yè)資訊頻道,感謝各位的閱讀!
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。