溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務(wù)條款》

怎么在python中利用PyTorch實現(xiàn)預(yù)訓(xùn)練

發(fā)布時間:2021-04-09 15:54:47 來源:億速云 閱讀:166 作者:Leah 欄目:開發(fā)技術(shù)

本篇文章給大家分享的是有關(guān)怎么在python中利用PyTorch實現(xiàn)預(yù)訓(xùn)練,小編覺得挺實用的,因此分享給大家學(xué)習(xí),希望大家閱讀完這篇文章后可以有所收獲,話不多說,跟著小編一起來看看吧。

直接加載預(yù)訓(xùn)練模型

如果我們使用的模型和原模型完全一樣,那么我們可以直接加載別人訓(xùn)練好的模型:

my_resnet = MyResNet(*args, **kwargs)
my_resnet.load_state_dict(torch.load("my_resnet.pth"))

當然這樣的加載方法是基于PyTorch推薦的存儲模型的方法:

torch.save(my_resnet.state_dict(), "my_resnet.pth")

還有第二種加載方法:

my_resnet = torch.load("my_resnet.pth")

加載部分預(yù)訓(xùn)練模型

其實大多數(shù)時候我們需要根據(jù)我們的任務(wù)調(diào)節(jié)我們的模型,所以很難保證模型和公開的模型完全一樣,但是預(yù)訓(xùn)練模型的參數(shù)確實有助于提高訓(xùn)練的準確率,為了結(jié)合二者的優(yōu)點,就需要我們加載部分預(yù)訓(xùn)練模型。

pretrained_dict = model_zoo.load_url(model_urls['resnet152'])
model_dict = model.state_dict()
# 將pretrained_dict里不屬于model_dict的鍵剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新現(xiàn)有的model_dict
model_dict.update(pretrained_dict)
# 加載我們真正需要的state_dict
model.load_state_dict(model_dict)

因為需要剔除原模型中不匹配的鍵,也就是層的名字,所以我們的新模型改變了的層需要和原模型對應(yīng)層的名字不一樣,比如:resnet最后一層的名字是fc(PyTorch中),那么我們修改過的resnet的最后一層就不能取這個名字,可以叫fc_

微改基礎(chǔ)模型預(yù)訓(xùn)練

對于改動比較大的模型,我們可能需要自己實現(xiàn)一下再加載別人的預(yù)訓(xùn)練參數(shù)。但是,對于一些基本模型PyTorch中已經(jīng)有了,而且我只想進行一些小的改動那么怎么辦呢?難道我又去實現(xiàn)一遍嗎?當然不是。

我們首先看看怎么進行微改模型。

微改基礎(chǔ)模型

PyTorch中的torchvision里已經(jīng)有很多常用的模型了,可以直接調(diào)用:

  1. AlexNet

  2. VGG

  3. ResNet

  4. SqueezeNet

  5. DenseNet

import torchvision.models as models

resnet18 = models.resnet18()
alexnet = models.alexnet()
squeezenet = models.squeezenet1_0()
densenet = models.densenet_161()

但是對于我們的任務(wù)而言有些層并不是直接能用,需要我們微微改一下,比如,resnet最后的全連接層是分1000類,而我們只有21類;又比如,resnet第一層卷積接收的通道是3, 我們可能輸入圖片的通道是4,那么可以通過以下方法修改:

resnet.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)
resnet.fc = nn.Linear(2048, 21)

簡單預(yù)訓(xùn)練

模型已經(jīng)改完了,接下來我們就進行簡單預(yù)訓(xùn)練吧。

我們先從torchvision中調(diào)用基本模型,加載預(yù)訓(xùn)練模型,然后,重點來了,將其中的層直接替換為我們需要的層即可:

resnet = torchvision.models.resnet152(pretrained=True)
# 原本為1000類,改為10類
resnet.fc = torch.nn.Linear(2048, 10)

以上就是怎么在python中利用PyTorch實現(xiàn)預(yù)訓(xùn)練,小編相信有部分知識點可能是我們?nèi)粘9ぷ鲿姷交蛴玫降摹OM隳芡ㄟ^這篇文章學(xué)到更多知識。更多詳情敬請關(guān)注億速云行業(yè)資訊頻道。

向AI問一下細節(jié)

免責聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進行舉報,并提供相關(guān)證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI