溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶(hù)服務(wù)條款》

基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作

發(fā)布時(shí)間:2022-03-25 16:13:13 來(lái)源:億速云 閱讀:359 作者:iii 欄目:開(kāi)發(fā)技術(shù)

本文小編為大家詳細(xì)介紹“基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作”,內(nèi)容詳細(xì),步驟清晰,細(xì)節(jié)處理妥當(dāng),希望這篇“基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作”文章能幫助大家解決疑惑,下面跟著小編的思路慢慢深入,一起來(lái)學(xué)習(xí)新知識(shí)吧。

mian.py文件是該項(xiàng)目的總文件,也是訓(xùn)練網(wǎng)絡(luò)模型的運(yùn)行文件,文本的介紹流程是隨著該文件一 一對(duì)代碼進(jìn)行介紹。

main.py代碼如下所示:

from dataset import data_dataloader    #電腦本地寫(xiě)的讀取數(shù)據(jù)的函數(shù)
from torch import nn                   #導(dǎo)入pytorch的nn模塊
from torch import optim                #導(dǎo)入pytorch的optim模塊
from network import Res_net            #電腦本地寫(xiě)的網(wǎng)絡(luò)框架的函數(shù)
from train import train                #電腦本地寫(xiě)的訓(xùn)練函數(shù)

def main():
    # 以下是通過(guò)Data_dataloader函數(shù)輸入為:數(shù)據(jù)的路徑,數(shù)據(jù)模式,數(shù)據(jù)大小,batch的大小,有幾線并用 (把dataset和Dataloader功能合在了一起)
    train_loader = data_dataloader(data_path='./data', mode='train', size=64, batch_size=24, num_workers=4)
    val_loader = data_dataloader(data_path='./data', mode='val', size=64, batch_size=24, num_workers=2)
    test_loader = data_dataloader(data_path='./data', mode='test', size=64, batch_size=24, num_workers=2)
    # 以下是超參數(shù)的定義
    lr = 1e-4           #學(xué)習(xí)率
    epochs = 10         #訓(xùn)練輪次
    model = Res_net(2)  # resnet網(wǎng)絡(luò)
    optimizer = optim.Adam(model.parameters(), lr=lr)  # 優(yōu)化器
    loss_function = nn.CrossEntropyLoss()  # 損失函數(shù)
    # 訓(xùn)練以及驗(yàn)證測(cè)試函數(shù)
    train(model=model, optimizer=optimizer, loss_function=loss_function, train_data=train_loader, val_data=val_loader,test_data= test_loader, epochs=epochs)
if __name__ == '__main__':
    main()

main.py流程圖如圖1所示:

基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作

圖 1 main.py 代碼流程圖

1.dataset.py(先看代碼的總體流程再看介紹)

main.py()前五行分別是導(dǎo)入相應(yīng)的模塊,其中dataset,network以及train是本地編寫(xiě)的文件。在mian()函數(shù)中的前幾行代碼中,我們使用dataset.py文件中的Data_dataloader函數(shù)導(dǎo)入訓(xùn)練集、驗(yàn)證集和測(cè)試集。Dataset文件是導(dǎo)入我們自己的本地?cái)?shù)據(jù)庫(kù),其功能是得到所有的數(shù)據(jù),將其變成pytorch能夠識(shí)別的tensor數(shù)據(jù),然后得到圖片。

dataset.py文件代碼如下所示:

import torch
import os,glob
import random
import csv
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader

# 第一部分:通過(guò)三個(gè)步驟得到輸出的tensor類(lèi)型的數(shù)據(jù)
class Dataset_self(Dataset):                    #如果是nn.moduel 則是編寫(xiě)網(wǎng)絡(luò)模型框架,這里需要繼承的是dataset的數(shù)據(jù),所以括號(hào)中的是Dataset
    #第一步:初始化
    def __init__(self,root,mode,resize,):       #root是文件根目錄,mode是選擇什么樣的數(shù)據(jù)集,resize是圖像重新調(diào)整大小
        super(Dataset_self, self).__init__()
        self.resize = resize
        self.root = root
        self.name_label = {}       #創(chuàng)建一個(gè)字典來(lái)保存每個(gè)文件的標(biāo)簽
        #首先得到標(biāo)簽相對(duì)于的字典(標(biāo)簽和名稱(chēng)一一對(duì)應(yīng))
        for name in sorted(os.listdir(os.path.join(root))):     #排序并且用列表的形式打開(kāi)文件夾
            if not os.path.isdir(os.path.join(root,name)):      #不是文件夾就不需要讀取
                continue
            self.name_label[name] = len(self.name_label.keys())  #每個(gè)文件的名字為name_Label字典中有多少對(duì)鍵值對(duì)的個(gè)數(shù)
        #print(self.name_label)
        self.image,self.label = self.make_csv('images.csv')       #編寫(xiě)一共函數(shù)來(lái)讀取圖片和標(biāo)簽的路徑
        #在得到image和label的基礎(chǔ)上對(duì)圖片數(shù)據(jù)進(jìn)行一共劃分  (注意:如果需要交叉驗(yàn)證就不需要驗(yàn)證集,只劃分為訓(xùn)練集和測(cè)試集)
        if mode == 'train':
            self.image ,self.label= self.image[:int(0.6*len(self.image))],self.label[:int(0.6*len(self.label))]
        if mode == 'val':
            self.image ,self.label= self.image[int(0.6*len(self.image)):int(0.8*len(self.image))],self.label[int(0.6*len(self.label)):int(0.8*len(self.label))]
        if mode == 'test':
            self.image ,self.label= self.image[int(0.8*len(self.image)):],self.label[int(0.8*len(self.label)):]
    # 獲得圖片和標(biāo)簽的函數(shù)
    def make_csv(self,filename):
        if not os.path.exists(os.path.join(self.root,filename)):  #如果不存在匯總的目錄就新建一個(gè)
            images = []
            for image in self.name_label.keys():                            # 讓image到name_label中的每個(gè)文件中去讀取圖片
                images += glob.glob(os.path.join(self.root,image,'*jpg'))   #加* 貪婪搜索關(guān)于jpg的所有文件
            #print('長(zhǎng)度為:{},第二張圖片為:{}'.format(len(images),images[1]))
            random.shuffle(images)                                         #把images列表中的數(shù)據(jù)洗牌
            # images[0]: ./data\ants\382971067_0bfd33afe0.jpg
            with open(os.path.join(self.root,filename),mode='w',newline='') as f :  #創(chuàng)建文件
                writer = csv.writer(f)
                for image in images:
                    name = image.split(os.sep)[-2]  #得到與圖片相對(duì)應(yīng)的標(biāo)簽
                    label = self.name_label[name]
                    writer.writerow([image,label])  #寫(xiě)入文件  第一行:./data\ants\382971067_0bfd33afe0.jpg,0
        images,labels = [],[]
        with open(os.path.join(self.root,filename)) as f:   #讀取文件
            reader = csv.reader(f)
            for row in reader:
                image, label = row
                label = int(label)
                images.append(image)
                labels.append(label)
        assert len(images) == len(labels)   #類(lèi)似if語(yǔ)句,只有兩者長(zhǎng)度一致才繼續(xù)執(zhí)行,否則報(bào)錯(cuò)
        return images,labels                #返回所有??!是所有的圖片和標(biāo)簽(此處的圖片不是圖片數(shù)據(jù)本身,而是它的文件目錄)
    #第二步:得到圖片數(shù)據(jù)的長(zhǎng)度(標(biāo)簽數(shù)據(jù)長(zhǎng)度與圖片一致)
    def __len__(self):
        return len(self.image)
    #第三步:讀取圖片和標(biāo)簽,并輸出
    def __getitem__(self, item):   # 單張返回張量的圖像與標(biāo)簽
        image,label = self.image[item],self.label[item]      #得到單張圖片和相應(yīng)的標(biāo)簽(此處都是image都是文件目錄)
        image = Image.open(image).convert('RGB')             #得到圖片數(shù)據(jù)
        #使用transform對(duì)圖片進(jìn)行處理以及變成tensor類(lèi)型數(shù)據(jù)
        transf = transforms.Compose([transforms.Resize((int(self.resize),int(self.resize))),
                                     transforms.RandomRotation(15),
                                     transforms.CenterCrop(self.resize),
                                     transforms.ToTensor(),  #先變成tensor類(lèi)型數(shù)據(jù),然后在進(jìn)行下面的標(biāo)準(zhǔn)化
                                     ])
        image = transf(image)
        label = torch.tensor(label)   #把圖片標(biāo)簽也變成tensor類(lèi)型
        return image,label
#第二部分:使用pytorch自帶的DataLoader函數(shù)批量得到圖片數(shù)據(jù)
def data_dataloader(data_path,mode,size,batch_size,num_workers):   #用一個(gè)函數(shù)加載上訴的數(shù)據(jù),data_path、mode和size分別是以上定義的Dataset_self()中的參數(shù),batch_size是一次性輸出多少?gòu)垐D像,num_worker是同時(shí)處理幾張圖像
    dataset = Dataset_self(data_path,mode,size)
    dataloader = DataLoader(dataset,batch_size,num_workers)  #使用pytorch中的dataloader函數(shù)得到數(shù)據(jù)
    return dataloader
#測(cè)試
def main():
    test = Dataset_self('./data','train',64)
if __name__ == '__main__':
    main()

dataset.py流程圖2所示:

基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作

圖2 dataset.py流程圖

如以上代碼所示,使用pytorch加載自定義的數(shù)據(jù)集時(shí),需要定義一個(gè)dataset的對(duì)象,然后定義一個(gè)dataloaber的對(duì)象,最后對(duì)dataloaber反復(fù)得到訓(xùn)練數(shù)據(jù)和標(biāo)簽。所以本文件主要分為兩個(gè)部分:自定義的dataset部分和使用pytorch中dataloaber來(lái)得到訓(xùn)練數(shù)據(jù)的部分。

代碼首先是導(dǎo)入必要的python庫(kù),然后編寫(xiě)第一部分。第一部分主要是通過(guò)三個(gè)步驟來(lái)得到單張輸出的tensor類(lèi)型圖片和標(biāo)簽。

三個(gè)步驟分別是:初始化、獲得數(shù)據(jù)的長(zhǎng)度以及讀取數(shù)據(jù)和標(biāo)簽。其中初始化是為了得到一個(gè)文件,文件中保存所有圖片相對(duì)應(yīng)的目錄以及其標(biāo)簽,再將得到的文件讀出分為訓(xùn)練集、驗(yàn)證集和測(cè)試集。具體實(shí)現(xiàn)如上述代碼所示,首先在初始化的函數(shù)中定義變量resize、root和name_label,方便與后面的函數(shù)調(diào)用:

基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作

圖3 Dataset_self中參數(shù)的初始化

然后,我們編寫(xiě)代碼讀取根目錄,得到分類(lèi)名字及其相對(duì)應(yīng)的標(biāo)簽:

基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作

圖4 標(biāo)簽的獲得

代碼中,首先使用os庫(kù)來(lái)把根目錄內(nèi)的文件變成列表被讀取出來(lái),然后把根目錄內(nèi)所有文件名保存在name_label字典中,在分別依照存儲(chǔ)進(jìn)字典的個(gè)數(shù)來(lái)給標(biāo)簽數(shù)值化。(第一個(gè)讀取進(jìn)字典的標(biāo)簽就是0,第二個(gè)是1,其余文件以此類(lèi)推)

得到標(biāo)簽字典后,我們編寫(xiě)一個(gè)函數(shù)來(lái)獲得所有圖片的目錄,便于下面步驟的圖片讀?。?/p>

基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作

圖5 圖片和標(biāo)簽的讀取

編寫(xiě)make_csv函數(shù),來(lái)得到image和label(image是每張圖片的目錄,label是相對(duì)應(yīng)的標(biāo)簽)。

make_csv函數(shù)中,首先判斷是否以及存在我們需要的文件,如果存在則直接讀取,如果不存在就先生成一個(gè)存儲(chǔ)所有圖片目錄和標(biāo)簽的文件。

基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作

圖6 make_csv函數(shù)

當(dāng)文件不存在時(shí)(第一行語(yǔ)句的判斷),我們編寫(xiě)文件的思路是先編寫(xiě)一個(gè)列表來(lái)保存所有的圖片目錄,然后再創(chuàng)建文件使用csv庫(kù)把列表數(shù)據(jù)寫(xiě)入文件中。所以在判斷語(yǔ)句下面,我們得到一個(gè)空的images列表,然后遍歷name_label中的keys,對(duì)于name_label來(lái)說(shuō),它是一個(gè)key是文件名,value是標(biāo)簽(數(shù)值)的字典,因?yàn)槭怯胦s庫(kù)把文件讀取成為字典的,所以遍歷字典內(nèi)的key時(shí),是讀取的是相對(duì)應(yīng)的文件。所以上圖第四行代碼中是分別讀取文件中的圖片,然后使用glob庫(kù)分別把所有jpg文件存儲(chǔ)到images列表里面。在列表中images[0]是:./data\ants\382971067_0bfd33afe0.jpg

在得到圖片目錄列表后,首先將列表內(nèi)的數(shù)據(jù)隨機(jī)排列,然后創(chuàng)造一個(gè)文件,在列表images中的目錄得到標(biāo)簽名稱(chēng),用name_label得到標(biāo)簽名稱(chēng)相對(duì)應(yīng)的數(shù)值,最后寫(xiě)入文件中。文件第一行是:./data\ants\382971067_0bfd33afe0.jpg,0(圖片相對(duì)目錄和相對(duì)于的標(biāo)簽)

得到文件后,因?yàn)槲覀冃枰氖敲繌垐D片的目錄而不是文件(主要是為了后面反復(fù)調(diào)試,所以得到一個(gè)文件做中轉(zhuǎn)站),所以我們需要用兩個(gè)列表來(lái)得到圖片目錄和相對(duì)應(yīng)的標(biāo)簽值,最后分別把文件中的數(shù)據(jù)寫(xiě)入列表中,得到圖片和標(biāo)簽列表。

至此,我們就能通過(guò)函數(shù)make_csv來(lái)得到image和label。得到這兩個(gè)列表后,我們對(duì)其進(jìn)行切割,因?yàn)榱斜砝锩媸潜4娴乃詳?shù)據(jù),所以我們需要分割為訓(xùn)練集、驗(yàn)證集和測(cè)試集。代碼很簡(jiǎn)單,(如果需要交叉驗(yàn)證則只需要?jiǎng)澐殖鲇?xùn)練集和測(cè)試集即可)如下圖所示:

基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作

圖7 數(shù)據(jù)集的劃分

以上是第一步初始化的過(guò)程,第二步讀取圖像長(zhǎng)度:

基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作

圖8 讀取圖像長(zhǎng)度

很簡(jiǎn)單,一個(gè)len()函數(shù)就搞定,其主要功能是知道一共有多少數(shù)據(jù)。

第三步:讀取數(shù)據(jù)和標(biāo)簽,讀取數(shù)據(jù)是一張一張來(lái)讀取的,所以首先從image和label列表中得到單個(gè)數(shù)據(jù),因?yàn)閕mage列表中保存的是圖片的目錄,所以先讀取RGB格式的圖片,然后使用transform對(duì)圖片進(jìn)行相應(yīng)的處理(尺寸,圖片變化,變成tensor類(lèi)型等),最后也將label變成tensor類(lèi)型然后把圖片數(shù)據(jù)和標(biāo)簽數(shù)據(jù)返回即可,代碼如下圖所示:

基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作

圖8 讀取圖像和標(biāo)簽

第一部分是讀取圖片和圖片相對(duì)應(yīng)的標(biāo)簽,流程是三步:初始化、得到數(shù)據(jù)長(zhǎng)度和讀取單張數(shù)據(jù),對(duì)于pytorch的dataset處理都是基于這三步。其中算法邏輯并不復(fù)雜,主要是需要使用的語(yǔ)句有點(diǎn)多,需要仔細(xì)思考其中的邏輯。

第二部分相對(duì)于第一部分要簡(jiǎn)單很多,甚至可以把這部分放到main()函數(shù)中運(yùn)行。其主要內(nèi)容是通過(guò)第一部分得到的dataset_self來(lái)得到數(shù)據(jù),然后使用pytorch自帶的dataloader得到放入模型中訓(xùn)練的數(shù)據(jù)集,代碼如下圖所示:

基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作

圖9 數(shù)據(jù)集的獲取

Dataset部分其功能簡(jiǎn)單概括就是將本地?cái)?shù)據(jù)集中的圖片和標(biāo)簽變成tensor類(lèi)型數(shù)據(jù)讀取為需要使用的數(shù)據(jù)集。

2.network.py

main.py()中,我們定義了一些超參數(shù)等,分別有學(xué)習(xí)率,訓(xùn)練輪次,訓(xùn)練模型,優(yōu)化器以及損失函數(shù)。對(duì)于訓(xùn)練模型,本文使用的是本地編寫(xiě)的一個(gè)小型的Resnet模型。其代碼如下所示:

import torch
from torch import nn

# 先寫(xiě)好resnet的block塊
class Res_block(nn.Module):
    def __init__(self,in_num,out_num,stride):
        super(Res_block, self).__init__()
        self.cov1 = nn.Conv2d(in_num,out_num,(3,3),stride=stride,padding=1)    #(3,3)  padding=1 則圖像大小不變,stride為幾圖像就縮小幾倍,能極大減少參數(shù)
        self.bn1 = nn.BatchNorm2d(out_num)
        self.cov2 = nn.Conv2d(out_num,out_num,(3,3),padding=1)
        self.bn2 = nn.BatchNorm2d(out_num)
        self.extra = nn.Sequential(
                nn.Conv2d(in_num,out_num,(1,1),stride=stride),
                nn.BatchNorm2d(out_num)
            )   #使得輸入前后的圖像數(shù)據(jù)大小是一致的
        self.relu = nn.ReLU()
    def forward(self,x):
        out = self.relu(self.bn1(self.cov1(x)))
        out = self.relu(self.bn2(self.cov2(out)))
        out = self.extra(x) + out
        return out
class Res_net(nn.Module):
    def __init__(self,num_class):
        super(Res_net, self).__init__()
        self.init = nn.Sequential(
            nn.Conv2d(3,16,(3,3)),
            nn.BatchNorm2d(16)
        )   #預(yù)處理
        self.bn1 = Res_block(16,32,2)
        self.bn2 = Res_block(32,64,2)
        self.bn3 = Res_block(64,128,2)
        self.bn4 = Res_block(128,256,2)
        self.fl = nn.Flatten()
        self.linear1 = nn.Linear(8192,10)
        self.linear2 = nn.Linear(10,num_class)
        out = self.relu(self.init(x))
        #print('inint:',out.shape)
        out = self.bn1(out)
        #print('bn1:', out.shape)
        out = self.bn2(out)
        #print('bn2:', out.shape)
        out = self.bn3(out)
        #print('bn3:', out.shape)
        out = self.fl(out)
        #print('flatten:', out.shape)
        out = self.relu(self.linear1(out))
        #print('linear1:', out.shape)
        out = self.relu(self.linear2(out))
        #print('linear2:', out.shape)
#測(cè)試
def main():
    x = torch.randn(2,3,64,64)
    net = Res_net(2)
    out = net(x)
    print(out.shape)
if __name__ == '__main__':
    main()

network.py流程圖如圖10所示:

基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作

圖10 network.py流程圖

Resnet模型網(wǎng)絡(luò)主要是兩部分,首先編寫(xiě)resnet中的每個(gè)殘差塊,然后編寫(xiě)整個(gè)網(wǎng)絡(luò)。在開(kāi)始介紹代碼之前,首先用我的理解來(lái)介紹一下Resnet,也就是殘差網(wǎng)絡(luò)的思想與邏輯(具體可以搜索其他資料查看)。殘差網(wǎng)絡(luò)其主要的目的是能夠訓(xùn)練一個(gè)深層次的網(wǎng)絡(luò),希望是隨著網(wǎng)絡(luò)的加深,效果越來(lái)越好。但是因?yàn)榫W(wǎng)絡(luò)加深,很有可能一些參數(shù)會(huì)得不到訓(xùn)練(一次次的迭代,使得梯度消失),所有Resnet網(wǎng)絡(luò)巧妙的運(yùn)用了一個(gè)殘差塊來(lái)解決因?yàn)榫W(wǎng)絡(luò)模型太深而導(dǎo)致其梯度消失的問(wèn)題,如圖11所示:

基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作

圖11 殘差塊

簡(jiǎn)單來(lái)說(shuō)就是在x通過(guò)兩個(gè)層后,在和x本身相加,如此在反向傳播的過(guò)程中,f(x)+x求帶就變成如此就在回傳給x上面的隱藏層的時(shí)候就不會(huì)發(fā)生梯度消失(至少有個(gè)1)。如果在x輸入殘差塊前有n層,那么就算殘差快內(nèi)的隱藏層因?yàn)樘荻认У膯?wèn)題而沒(méi)有訓(xùn)練好,但是至少x輸入之前的n層是訓(xùn)練好了的,這樣只要?dú)埐羁熘械碾[藏層能訓(xùn)練好一部分,神經(jīng)網(wǎng)絡(luò)的準(zhǔn)確度就很有可能在原來(lái)基礎(chǔ)上增加。(還是得好好研究,這里Resnet的解釋可能并沒(méi)有那么準(zhǔn)確)

基于上述殘差塊的圖片,我們先定義好殘差塊,代碼如下圖12所示:

基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作

圖12 殘差塊的定義

其流程圖如圖13:

基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作

圖13 殘差塊定義流程圖

當(dāng)殘差塊寫(xiě)好后,就可以編寫(xiě)一個(gè)簡(jiǎn)單的Resnet網(wǎng)絡(luò),代碼如圖14所示:

基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作

圖14 簡(jiǎn)單Resnet網(wǎng)絡(luò)模型

上述代碼中,首先通過(guò)一層正常的卷積層后,再通過(guò)3個(gè)殘差塊,最后通過(guò)兩層線性層,代碼十分比較簡(jiǎn)單。在定義好殘差塊之后,調(diào)用pytorch本身自帶的函數(shù)即可完成。唯一需要注意的地方是參數(shù)的設(shè)置,該網(wǎng)絡(luò)一般來(lái)說(shuō)都是維度在慢慢增加,圖像的尺寸慢慢減少。

3.train.py

train.py是整個(gè)模型的訓(xùn)練過(guò)程,本文將其打包成為一個(gè)函數(shù),然后在mian.py中調(diào)用,因?yàn)榛旧暇W(wǎng)絡(luò)的訓(xùn)練過(guò)程都大同小異,一般都是用訓(xùn)練集訓(xùn)練,在驗(yàn)證集上得到最好的輪次,最后保存網(wǎng)絡(luò)參數(shù)并且在測(cè)試集上檢測(cè),所以這里直接將訓(xùn)練過(guò)程和驗(yàn)證過(guò)程打包成為函數(shù),便于以后項(xiàng)目的直接調(diào)用。

train.py代碼如下所示:

import torch
from torch import optim
from torch.utils.data import DataLoader
from dataset import Dataset_self
from network import Res_net
from torch import nn
from matplotlib import pyplot as plt
import numpy as np

def evaluate(model,loader):   #計(jì)算每次訓(xùn)練后的準(zhǔn)確率
    correct = 0
    total = len(loader.dataset)
    for x,y in loader:
        logits = model(x)
        pred = logits.argmax(dim=1)     #得到logits中分類(lèi)值(要么是[1,0]要么是[0,1]表示分成兩個(gè)類(lèi)別)
        correct += torch.eq(pred,y).sum().float().item()        #用logits和標(biāo)簽label想比較得到分類(lèi)正確的個(gè)數(shù)
    return correct/total
#把訓(xùn)練的過(guò)程定義為一個(gè)函數(shù)
def train(model,optimizer,loss_function,train_data,val_data,test_data,epochs):  #輸入:網(wǎng)絡(luò)架構(gòu),優(yōu)化器,損失函數(shù),訓(xùn)練集,驗(yàn)證集,測(cè)試集,輪次
    best_acc,best_epoch =0,0      #輸出驗(yàn)證集中準(zhǔn)確率最高的輪次和準(zhǔn)確率
    train_list,val_List = [],[]   # 創(chuàng)建列表保存每一次的acc,用來(lái)最后的畫(huà)圖
    for epoch in range(epochs):
            print('============第{}輪============'.format(epoch + 1))
            for steps,(x,y) in enumerate(train_data):   #  for x,y in train_data
                logits = model(x)                   #數(shù)據(jù)放入網(wǎng)絡(luò)中
                loss = loss_function(logits,y)      #得到損失值
                optimizer.zero_grad()               #優(yōu)化器先清零,不然會(huì)疊加上次的數(shù)值
                loss.backward()                     #后向傳播
                optimizer.step()
            train_acc =evaluate(model,train_data)
            train_list.append(train_acc)
            print('train_acc',train_acc)
            #if epoch % 1 == 2:   #這里可以設(shè)置每?jī)纱斡?xùn)練驗(yàn)證一次
            val_acc = evaluate(model,val_data)
            print('val_acc=',val_acc)
            val_List.append((val_acc))
            if val_acc > best_acc:  #判斷每次在驗(yàn)證集上的準(zhǔn)確率是否為最大
                best_epoch = epoch
                best_acc = val_acc
                torch.save(model.state_dict(),'best.mdl')   #保存驗(yàn)證集上最大的準(zhǔn)確率
    print('===========================分割線===========================')
    print('best acc:',best_acc,'best_epoch:',best_epoch)
    #在測(cè)試集上檢測(cè)訓(xùn)練好后模型的準(zhǔn)確率
    model.load_state_dict((torch.load('best.mdl')))
    print('detect the test data!')
    test_acc = evaluate(model,test_data)
    print('test_acc:',test_acc)
    train_list_file = np.array(train_list)
    np.save('train_list.npy',train_list_file)
    val_list_file = np.array(val_List)
    np.save('val_list.npy',val_list_file)
    #畫(huà)圖
    x_label = range(1,len(val_List)+1)
    plt.plot(x_label,train_list,'bo',label='train acc')
    plt.plot(x_label,val_List,'b',label='validation acc')
    plt.title('train and validation accuracy')
    plt.xlabel('epochs')
    plt.legend()
    plt.show()
#測(cè)試
def main():
    train_dataset = Dataset_self('./data', 'train', 64)
    vali_dataset = Dataset_self('./data', 'val', 64)
    test_dataset = Dataset_self('./data', 'test', 64)
    train_loaber = DataLoader(train_dataset, 24, num_workers=4)
    val_loaber = DataLoader(vali_dataset, 24, num_workers=2)
    test_loaber = DataLoader(test_dataset, 24, num_workers=2)
    lr = 1e-4
    epochs = 5
    model = Res_net(2)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.CrossEntropyLoss()
    train(model,optimizer,criteon,train_loaber,val_loaber,test_loaber,epochs)
if __name__ == '__main__':
    main()

  train.py流程圖如圖15所示:

基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作

圖15 train.py流程圖

上述代碼中,第一個(gè)函數(shù)的定義是為了得到一次訓(xùn)練(或者驗(yàn)證或者測(cè)試)后的準(zhǔn)確率,也就是跑完一次所有訓(xùn)練集后,模型的準(zhǔn)確率是多少。其代碼內(nèi)容并不復(fù)雜,先得到經(jīng)過(guò)模型logits中的分類(lèi)標(biāo)簽(是[1,0]還是[0,1],表示分成兩類(lèi))pred,然后用logits與標(biāo)簽進(jìn)行比較,從而得到一個(gè)batch_size中分類(lèi)正確的個(gè)數(shù),然后累加起來(lái),得到一次訓(xùn)練中網(wǎng)絡(luò)對(duì)數(shù)據(jù)集分類(lèi)正確的個(gè)數(shù)(correct),最后讓其除以數(shù)據(jù)集的個(gè)數(shù)從而得到準(zhǔn)確率并且返回其數(shù)值。

對(duì)于第二個(gè)函數(shù),train的函數(shù)的定義,其主要內(nèi)容是在訓(xùn)練集上訓(xùn)練,每一輪次訓(xùn)練好之后放在驗(yàn)證集上驗(yàn)證(可以是每?jī)纱位蛘呷危?,?zhí)行完所有輪次后,保存在驗(yàn)證集上最好的一次的網(wǎng)絡(luò)參數(shù)與輪次,最后加載保存的網(wǎng)絡(luò)參數(shù)對(duì)測(cè)試集進(jìn)行檢測(cè)。

train函數(shù)內(nèi)部首先定義驗(yàn)證集中最好的準(zhǔn)確率和最好的輪次,然后創(chuàng)建兩個(gè)列表來(lái)保存每一次的訓(xùn)練集和驗(yàn)證集的準(zhǔn)確率(用來(lái)畫(huà)圖查看),然后就是進(jìn)行epochs次訓(xùn)練。

基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作

圖16 trian函數(shù)內(nèi)參數(shù)的定義

訓(xùn)練中,如果直接是用x,y來(lái)獲得數(shù)據(jù)的圖片和標(biāo)簽則可以使用標(biāo)注里面的代碼,而使用enumerate函數(shù),其主要是為了給每次得到的數(shù)據(jù)(x,y)標(biāo)上一個(gè)索引,這個(gè)索引是steps,從0開(kāi)始(這里沒(méi)有使用到steps參數(shù))。在每次執(zhí)行中,圖片數(shù)據(jù)x會(huì)被放入網(wǎng)絡(luò)模型model中被處理,然后使用定義的loss_function函數(shù)得到預(yù)測(cè)和正確標(biāo)簽之間的損失值。優(yōu)化器先清零(不然會(huì)有數(shù)值疊加),然后讓損失值loss執(zhí)行反向傳播操作(鏈?zhǔn)角髮?dǎo)),最后優(yōu)化器執(zhí)行優(yōu)化功能,如此便實(shí)現(xiàn)了模型的一次訓(xùn)練與參數(shù)更新。

基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作

圖17 模型的訓(xùn)練步驟

而后面的代碼,每訓(xùn)練一次網(wǎng)絡(luò)模型,就把驗(yàn)證集放入網(wǎng)絡(luò)模型中,測(cè)試網(wǎng)絡(luò)模型訓(xùn)練得怎么樣,然后保存下epochs次數(shù)中最好準(zhǔn)確率的網(wǎng)絡(luò)模型參數(shù)與輪次。最后加載保存下的網(wǎng)絡(luò)模型參數(shù),在測(cè)試集上檢測(cè)準(zhǔn)確率如何。

基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作

圖18 模型參數(shù)的保存與測(cè)試

最后幾句代碼是將保存下來(lái)的準(zhǔn)確率做圖,有一點(diǎn)需要注意,因?yàn)檫@里是每次訓(xùn)練后都在驗(yàn)證集上檢測(cè)過(guò),所以坐標(biāo)軸的長(zhǎng)度就用訓(xùn)練集準(zhǔn)確率的長(zhǎng)度來(lái)表示兩個(gè)不同數(shù)據(jù)的長(zhǎng)度。

基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作

圖19 做圖

4.結(jié)果與總結(jié)

本文項(xiàng)目是使用Resnet模型來(lái)識(shí)別螞蟻和蜜蜂,其一共有三百九十六張的數(shù)據(jù),訓(xùn)練集只有兩百多張(數(shù)據(jù)集很小),運(yùn)行十輪后,分別對(duì)訓(xùn)練集和測(cè)試集在每一輪的準(zhǔn)確率如圖所示:

基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作

圖20 train and validation accuracy

測(cè)試集的準(zhǔn)確率如圖所示:

基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作

圖21 測(cè)試集準(zhǔn)確率

最后得到的效果不理想,很大可能是數(shù)據(jù)集太少導(dǎo)致導(dǎo)致模型泛化能力變?nèi)酰P桶延?xùn)練集都記下來(lái)了),對(duì)于這樣的問(wèn)題可以嘗試通過(guò)交叉驗(yàn)證(效果可能有一定程度的提升)或者增加數(shù)據(jù)集的方法來(lái)增強(qiáng)模型的泛化能力。對(duì)精度的提升,會(huì)在后續(xù)的文章中進(jìn)行討論。

在得到模型參數(shù)后,我隨便在網(wǎng)上找了兩張螞蟻的圖片放進(jìn)模型檢測(cè)看效果如何:

基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作

圖22 第一次測(cè)試

基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作

圖23 第二次測(cè)試

第一次測(cè)試識(shí)別出來(lái)了是螞蟻,但是第二次就失敗了,有可能是模型沒(méi)有看到過(guò)黑色的蜜蜂所以把黑色的都當(dāng)成了螞蟻吧,總之改模型還有很多需要改進(jìn)的地方。

附上單張檢測(cè)的代碼:

from network import Res_net
import torch
from PIL import Image
import torchvision

#導(dǎo)入圖片
img = '1.jpg'
img =Image.open(img)
tf = torchvision.transforms.Compose([torchvision.transforms.Resize((64,64)),torchvision.transforms.ToTensor()])
img = tf(img)
image = torch.reshape(img,(1,3,64,64))
#加載模型
net = Res_net(2)
net.load_state_dict(torch.load('best.mdl'))
with torch.no_grad():
    out = net(image)
#確定分類(lèi)
class_cl =out.argmax(dim=1)
class_num = class_cl.numpy()
if class_num == 0:
    print('這張照片是螞蟻')
else:
    print('這張照片是蜜蜂')

讀到這里,這篇“基于pytorch怎么實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集操作”文章已經(jīng)介紹完畢,想要掌握這篇文章的知識(shí)點(diǎn)還需要大家自己動(dòng)手實(shí)踐使用過(guò)才能領(lǐng)會(huì),如果想了解更多相關(guān)內(nèi)容的文章,歡迎關(guān)注億速云行業(yè)資訊頻道。

向AI問(wèn)一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI