您好,登錄后才能下訂單哦!
pytorch 中的 state_dict 是一個(gè)簡(jiǎn)單的python的字典對(duì)象,將每一層與它的對(duì)應(yīng)參數(shù)建立映射關(guān)系.(如model的每一層的weights及偏置等等)
(注意,只有那些參數(shù)可以訓(xùn)練的layer才會(huì)被保存到模型的state_dict中,如卷積層,線性層等等)
優(yōu)化器對(duì)象Optimizer也有一個(gè)state_dict,它包含了優(yōu)化器的狀態(tài)以及被使用的超參數(shù)(如lr, momentum,weight_decay等)
備注:
1) state_dict是在定義了model或optimizer之后pytorch自動(dòng)生成的,可以直接調(diào)用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"
torch.save(model.state_dict(), PATH)
2) load_state_dict 也是model或optimizer之后pytorch自動(dòng)具備的函數(shù),可以直接調(diào)用
model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval()
注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因?yàn)?只有在執(zhí)行該命令后,"dropout層"及"batch normalization層"才會(huì)進(jìn)入 evalution 模態(tài). 而在"訓(xùn)練(training)模態(tài)"與"評(píng)估(evalution)模態(tài)"下,這兩層有不同的表現(xiàn)形式.
模態(tài)字典(state_dict)的保存(model是一個(gè)網(wǎng)絡(luò)結(jié)構(gòu)類的對(duì)象)
1.1)僅保存學(xué)習(xí)到的參數(shù),用以下命令
torch.save(model.state_dict(), PATH)
1.2)加載model.state_dict,用以下命令
model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval()
備注:model.load_state_dict的操作對(duì)象是 一個(gè)具體的對(duì)象,而不能是文件名
2.1)保存整個(gè)model的狀態(tài),用以下命令
torch.save(model,PATH)
2.2)加載整個(gè)model的狀態(tài),用以下命令:
# Model class must be defined somewhere model = torch.load(PATH) model.eval()
state_dict 是一個(gè)python的字典格式,以字典的格式存儲(chǔ),然后以字典的格式被加載,而且只加載key匹配的項(xiàng)
如何僅加載某一層的訓(xùn)練的到的參數(shù)(某一層的state)
If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
加載模型參數(shù)后,如何設(shè)置某層某參數(shù)的"是否需要訓(xùn)練"(param.requires_grad)
for param in list(model.pretrained.parameters()): param.requires_grad = False
注意: requires_grad的操作對(duì)象是tensor.
疑問:能否直接對(duì)某個(gè)層直接之用requires_grad呢?例如:model.conv1.requires_grad=False
回答:經(jīng)測(cè)試,不可以.model.conv1 沒有requires_grad屬性.
全部測(cè)試代碼:
#-*-coding:utf-8-*- import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim # define model class TheModelClass(nn.Module): def __init__(self): super(TheModelClass,self).__init__() self.conv1 = nn.Conv2d(3,6,5) self.pool = nn.MaxPool2d(2,2) self.conv2 = nn.Conv2d(6,16,5) self.fc1 = nn.Linear(16*5*5,120) self.fc2 = nn.Linear(120,84) self.fc3 = nn.Linear(84,10) def forward(self,x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1,16*5*5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x # initial model model = TheModelClass() #initialize the optimizer optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9) # print the model's state_dict print("model's state_dict:") for param_tensor in model.state_dict(): print(param_tensor,'\t',model.state_dict()[param_tensor].size()) print("\noptimizer's state_dict") for var_name in optimizer.state_dict(): print(var_name,'\t',optimizer.state_dict()[var_name]) print("\nprint particular param") print('\n',model.conv1.weight.size()) print('\n',model.conv1.weight) print("------------------------------------") torch.save(model.state_dict(),'./model_state_dict.pt') # model_2 = TheModelClass() # model_2.load_state_dict(torch.load('./model_state_dict')) # model.eval() # print('\n',model_2.conv1.weight) # print((model_2.conv1.weight == model.conv1.weight).size()) ## 僅僅加載某一層的參數(shù) conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight'] print(conv1_weight_state==model.conv1.weight) model_2 = TheModelClass() model_2.load_state_dict(torch.load('./model_state_dict.pt')) model_2.conv1.requires_grad=False print(model_2.conv1.requires_grad) print(model_2.conv1.bias.requires_grad)
以上這篇pytorch 狀態(tài)字典:state_dict使用詳解就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持億速云。
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。