溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶服務(wù)條款》

pytorch 狀態(tài)字典:state_dict使用詳解

發(fā)布時(shí)間:2020-08-19 21:38:20 來源:腳本之家 閱讀:983 作者:wzg2016 欄目:開發(fā)技術(shù)

pytorch 中的 state_dict 是一個(gè)簡(jiǎn)單的python的字典對(duì)象,將每一層與它的對(duì)應(yīng)參數(shù)建立映射關(guān)系.(如model的每一層的weights及偏置等等)

(注意,只有那些參數(shù)可以訓(xùn)練的layer才會(huì)被保存到模型的state_dict中,如卷積層,線性層等等)

優(yōu)化器對(duì)象Optimizer也有一個(gè)state_dict,它包含了優(yōu)化器的狀態(tài)以及被使用的超參數(shù)(如lr, momentum,weight_decay等)

備注:

1) state_dict是在定義了model或optimizer之后pytorch自動(dòng)生成的,可以直接調(diào)用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"

torch.save(model.state_dict(), PATH)

2) load_state_dict 也是model或optimizer之后pytorch自動(dòng)具備的函數(shù),可以直接調(diào)用

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因?yàn)?只有在執(zhí)行該命令后,"dropout層"及"batch normalization層"才會(huì)進(jìn)入 evalution 模態(tài). 而在"訓(xùn)練(training)模態(tài)"與"評(píng)估(evalution)模態(tài)"下,這兩層有不同的表現(xiàn)形式.

模態(tài)字典(state_dict)的保存(model是一個(gè)網(wǎng)絡(luò)結(jié)構(gòu)類的對(duì)象)

1.1)僅保存學(xué)習(xí)到的參數(shù),用以下命令

 torch.save(model.state_dict(), PATH)

1.2)加載model.state_dict,用以下命令

 model = TheModelClass(*args, **kwargs)
 model.load_state_dict(torch.load(PATH))
 model.eval()

備注:model.load_state_dict的操作對(duì)象是 一個(gè)具體的對(duì)象,而不能是文件名

2.1)保存整個(gè)model的狀態(tài),用以下命令

torch.save(model,PATH)

2.2)加載整個(gè)model的狀態(tài),用以下命令:

   # Model class must be defined somewhere

 model = torch.load(PATH)

 model.eval()

state_dict 是一個(gè)python的字典格式,以字典的格式存儲(chǔ),然后以字典的格式被加載,而且只加載key匹配的項(xiàng)

如何僅加載某一層的訓(xùn)練的到的參數(shù)(某一層的state)

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']

加載模型參數(shù)后,如何設(shè)置某層某參數(shù)的"是否需要訓(xùn)練"(param.requires_grad)

for param in list(model.pretrained.parameters()):
 param.requires_grad = False

注意: requires_grad的操作對(duì)象是tensor.

疑問:能否直接對(duì)某個(gè)層直接之用requires_grad呢?例如:model.conv1.requires_grad=False

回答:經(jīng)測(cè)試,不可以.model.conv1 沒有requires_grad屬性.

全部測(cè)試代碼:

#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
 
 
 
# define model
class TheModelClass(nn.Module):
 def __init__(self):
  super(TheModelClass,self).__init__()
  self.conv1 = nn.Conv2d(3,6,5)
  self.pool = nn.MaxPool2d(2,2)
  self.conv2 = nn.Conv2d(6,16,5)
  self.fc1 = nn.Linear(16*5*5,120)
  self.fc2 = nn.Linear(120,84)
  self.fc3 = nn.Linear(84,10)
 
 def forward(self,x):
  x = self.pool(F.relu(self.conv1(x)))
  x = self.pool(F.relu(self.conv2(x)))
  x = x.view(-1,16*5*5)
  x = F.relu(self.fc1(x))
  x = F.relu(self.fc2(x))
  x = self.fc3(x)
  return x
 
# initial model
model = TheModelClass()
 
#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
 
# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
 print(param_tensor,'\t',model.state_dict()[param_tensor].size())
 
print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
 print(var_name,'\t',optimizer.state_dict()[var_name])
 
print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)
 
print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
# model_2 = TheModelClass()
# model_2.load_state_dict(torch.load('./model_state_dict'))
# model.eval()
# print('\n',model_2.conv1.weight)
# print((model_2.conv1.weight == model.conv1.weight).size())
## 僅僅加載某一層的參數(shù)
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)
 
model_2 = TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)

以上這篇pytorch 狀態(tài)字典:state_dict使用詳解就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持億速云。

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI