溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶服務(wù)條款》

從圖像超分辨率快速入門pytorch

發(fā)布時(shí)間:2020-07-04 07:05:22 來源:網(wǎng)絡(luò) 閱讀:952 作者:nineteens 欄目:編程語言

  前言

  最近又開始把pytorch拾起來,學(xué)習(xí)了github上一些項(xiàng)目之后,發(fā)現(xiàn)每個(gè)人都會(huì)用不同的方式來寫深度學(xué)習(xí)的訓(xùn)練代碼,而這些代碼對(duì)于初學(xué)者來說是難以閱讀的,因?yàn)殛P(guān)鍵和非關(guān)鍵代碼糅雜在一起,讓那些需要快速將代碼跑起來的初學(xué)者摸不著頭腦。

  所以,本文打算從最基本的出發(fā),只寫關(guān)鍵代碼,將完成一次深度學(xué)習(xí)訓(xùn)練需要哪些要素展現(xiàn)給各位初學(xué)者,以便你們能夠快速上手。等到能夠?qū)⒆约旱南敕ㄓ米詈?jiǎn)潔的方式寫出來并運(yùn)行起來之后,再對(duì)自己的代碼進(jìn)行重構(gòu)、擴(kuò)展。我認(rèn)為這種學(xué)習(xí)方式是較好的循序漸進(jìn)的學(xué)習(xí)方式。

  本文選擇超分辨率作為入門案例,一是因?yàn)橥ㄟ^結(jié)合案例能夠?qū)τ?xùn)練中涉及到的東西有較好的體會(huì),二是超分辨率是較為簡(jiǎn)單的任務(wù),我們本次教程的目的是教會(huì)大家如何使用pytorch,所以不應(yīng)該將難度設(shè)置在任務(wù)本身上。下面開始正文。。。

  正文

  單一圖像超分辨率(SISR)

  簡(jiǎn)單介紹一下圖像超分辨率這一任務(wù):超分辨率的任務(wù)就是將一張圖像的尺寸放大并且要求失真越小越好,舉例來說,我們需要將一張256*500的圖像放大2倍,那么放大后的圖像尺寸就應(yīng)該是512*1000。用深度學(xué)習(xí)的方法,我們通常會(huì)先將圖像縮小成原來的1/2,然后以原始圖像作為標(biāo)簽,進(jìn)行訓(xùn)練。訓(xùn)練的目標(biāo)是讓縮小后的圖像放大2倍后與原圖越近越好。所以通常會(huì)用L1或者L2作為損失函數(shù)。

  訓(xùn)練4要素

  一次訓(xùn)練要想完成,需要的要素我總結(jié)為4點(diǎn):

  網(wǎng)絡(luò)模型

  數(shù)據(jù)

  損失函數(shù)

  優(yōu)化器

  這4個(gè)對(duì)象都是一次訓(xùn)練必不可少的,通常情況下,需要我們自定義的是前兩個(gè):網(wǎng)絡(luò)模型和數(shù)據(jù),而后面兩個(gè)較為統(tǒng)一,而且pytorch也提供了非常全面的實(shí)現(xiàn)供我們使用,它們分別在torch.nn包和torch.optim包下面,使用的時(shí)候可以到pytorch官網(wǎng)進(jìn)行查看,后面我們用到的時(shí)候還會(huì)再次說明。

  網(wǎng)絡(luò)模型

  在網(wǎng)絡(luò)模型和數(shù)據(jù)兩個(gè)當(dāng)中,網(wǎng)絡(luò)模型是比較簡(jiǎn)單的,數(shù)據(jù)加載稍微麻煩些。我們先來看網(wǎng)絡(luò)模型的定義。自定義的網(wǎng)絡(luò)模型都必須繼承torch.nn.Module這個(gè)類,里面有兩個(gè)方法需要重寫:初始化方法__init__(self)和forward(self, *input)方法。在初始化方法中一般要寫我們需要哪些層(卷積層、全連接層等),而在forward方法中我們需要寫這些層的連接方式。舉一個(gè)通俗的例子,搭積木需要一個(gè)個(gè)的積木塊,這些積木塊放在__init__方法中,而規(guī)定將這些積木塊如何連接起來則是靠forward方法中的內(nèi)容。

  import torch.nn as nn

  import torch.nn.functional as F

  class VDSR(nn.Module):

  def __init__(self):

  super(VDSR, self).__init__()

  self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=True)

  self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)

  self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)

  self.conv4 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)

  self.conv5 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)

  self.conv6 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)

  self.conv7 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)

  self.conv8 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)

  self.conv9 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)

  self.conv10 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)

  self.conv11 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)

  self.conv12 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)

  self.conv13 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)

  self.conv14 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)

  self.conv15 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)

  self.conv16 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)

  self.conv17 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)

  self.conv18 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)

  self.conv19 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)

  self.conv20 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1, bias=True)

  def forward(self, x):

  ori = x

  x = F.relu(self.conv1(x))

  x = F.relu(self.conv2(x))

  x = F.relu(self.conv3(x))

  x = F.relu(self.conv4(x))

  x = F.relu(self.conv5(x))

  x = F.relu(self.conv6(x))

  x = F.relu(self.conv7(x))

  x = F.relu(self.conv8(x))

  x = F.relu(self.conv9(x))

  x = F.relu(self.conv10(x))

  x = F.relu(self.conv11(x))

  x = F.relu(self.conv12(x))

  x = F.relu(self.conv13(x))

  x = F.relu(self.conv14(x))

  x = F.relu(self.conv15(x))

  x = F.relu(self.conv16(x))

  x = F.relu(self.conv17(x))

  x = F.relu(self.conv18(x))

  x = F.relu(self.conv19(x))

  x = self.conv20(x)

  return x + ori

  上面代碼中展示的是我們要用到的模型VDSR,這個(gè)模型很簡(jiǎn)單,就是連續(xù)的20層卷積,外加一個(gè)跳線連接。結(jié)構(gòu)圖如下:

  在寫網(wǎng)絡(luò)模型時(shí),用到的各個(gè)層都在torch.nn這個(gè)包中,在寫自定義的網(wǎng)絡(luò)結(jié)構(gòu)時(shí)可以自行到pytorch官網(wǎng)的文檔中進(jìn)行查看。

  數(shù)據(jù)

  定義了網(wǎng)絡(luò)模型之后,我們?cè)賮砜础皵?shù)據(jù)”?!皵?shù)據(jù)”主要涉及到Dataset和DataLoader兩個(gè)概念。

  Dataset是數(shù)據(jù)加載的基礎(chǔ),我們一般在加載自己的數(shù)據(jù)集時(shí)都需要自定義一個(gè)Dataset,自定義的Dataset都需要繼承torch.utils.data.Dataset這個(gè)類,當(dāng)實(shí)現(xiàn)了__getitem__()和__len__()這兩個(gè)方法后,我們就自定義了一個(gè)Map-style datasets,Dataset是一個(gè)可迭代對(duì)象,通過下標(biāo)訪問的方式就能夠調(diào)用__getitem__()方法來實(shí)現(xiàn)數(shù)據(jù)加載。

  這里面最關(guān)鍵的就算是__getitem__()如何來寫了,我們需要讓__getitem__()的返回值是一對(duì),包括圖像和它的label,這里我們的任務(wù)是超分辨率,那么圖像和label分別是經(jīng)過下采樣的圖像和與其對(duì)應(yīng)的原始圖像。所以我們Dataset的__getitem__()方法返回值就應(yīng)該是兩個(gè)3D Tensor,分別表示兩種圖像。

  這里需要重點(diǎn)說明一下__getitem__()方法的返回值為什么應(yīng)該是3D Tensor。根據(jù)pytorch官網(wǎng)的說法,二維卷積層只接受4D Tensor,它的每一維表示的內(nèi)容分別是nSamples x nChannels x Height x Width,我們最后需要用批量的方式將數(shù)據(jù)送到網(wǎng)絡(luò)中,所以__getitem__()方法的返回值就應(yīng)該是后面三維的內(nèi)容,即便是我們的通道數(shù)為1,也必須有這一維的存在,否則就會(huì)報(bào)錯(cuò)。后面代碼中用到的unsqueeze(0)方法的作用就是如此。前面是說了為什么應(yīng)該是3D的,為什么應(yīng)該是Tensor呢?Tensor是跟NumPy中ndarray類似的東西,只是它能夠被用于GPU中來加速計(jì)算。

  下面來看一下我們的代碼:

  import os

  import random

  import cv2

  import torch

  from torch.utils.data import Dataset

  patch_size = 64

  def getPatch(y):

  h, w = y.shape

  randh = random.randrange(0, h - patch_size + 1)

  randw = random.randrange(0, w - patch_size + 1)

  lab = y[randh:randh + patch_size, randw:randw + patch_size]

  resized = cv2.resize(lab, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_CUBIC)

  rresized = cv2.resize(resized, None, fx=2, fy=2, interpolation=cv2.INTER_CUBIC)

  return rresized, lab

  class MyDateSet(Dataset):

  def __init__(self, imageFolder):

  self.imageFolder = imageFolder

  self.images = os.listdir(imageFolder)

  def __len__(self):

  return len(self.images)

  def __getitem__(self, index):

  name = self.images[index]

  name = os.path.join(self.imageFolder, name)

  imread = cv2.imread(name)

  # 轉(zhuǎn)換顏色空間

  ycrcb = cv2.cvtColor(imread, cv2.COLOR_RGB2YCR_CB)

  # 提取y通道

  y = ycrcb[:, :, 0]

  # 裁剪成小塊

  img, lab = getPatch(y)

  # 轉(zhuǎn)為3D Tensor鄭州婦科醫(yī)院 http://www.sptdfk.com/

  return torch.from_numpy(img).unsqueeze(0), torch.from_numpy(lab).unsqueeze(0)

  其中MyDateSet的內(nèi)容也不長(zhǎng),包括了初始化方法、__getitem__()和__len__()兩個(gè)方法。__getitem__()有一個(gè)輸入值是下標(biāo)值,我們根據(jù)下標(biāo),利用OpenCV,讀取了圖像,并將其轉(zhuǎn)換顏色空間,超分訓(xùn)練的時(shí)候我們只用了其中的y通道。還對(duì)圖形進(jìn)行了裁剪,最后返回了兩個(gè)3D Tensor。

  在寫自定義數(shù)據(jù)集的時(shí)候,我們最需要關(guān)注的點(diǎn)就是__getitem__()方法的返回值是不是符合要求,能不能夠被送到網(wǎng)絡(luò)中去。至于中間該怎么操作,其實(shí)跟pytorch框架也沒什么關(guān)系,根據(jù)需要來做。

  訓(xùn)練

  寫好了Dataset之后,我們就能夠通過下標(biāo)的方式獲取圖像以及它的label。但是離開始訓(xùn)練還有兩個(gè)要素:損失函數(shù)和優(yōu)化器。前面我們也說了,這兩部分,pytorch官方提供了大量的實(shí)現(xiàn),多數(shù)情況下不需要我們自己來自定義,這里我們直接使用了提供的torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean')作為損失函數(shù)和torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)作為優(yōu)化器。

  訓(xùn)練示例代碼:

  import torch

  import torch.nn as nn

  import torch.optim as optim

  import date

  import model

  date_set = date.MyDateSet("Train/")

  model = model.VDSR()

  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

  model.to(device)

  mse_loss = nn.MSELoss()

  adam = optim.Adam(model.parameters())

  for epoch in range(100):

  running_loss = 0.0

  for i in range(len(date_set)):

  rresized, y = date_set[i]

  adam.zero_grad()

  out = model(rresized.unsqueeze(0).to(device, torch.float))

  loss = mse_loss(out, y.unsqueeze(0).to(device, torch.float))

  loss.backward()

  adam.step()

  running_loss += loss

  if i % 100 == 99: # print every 100

  print('[%d, %5d] loss: %.3f' %

  (epoch + 1, i + 1, running_loss / 100))

  running_loss = 0.0

  print('Finished Training')

  整個(gè)訓(xùn)練代碼非常簡(jiǎn)潔,只有短短幾行,定義模型、將模型移至GPU、定義損失函數(shù)、定義優(yōu)化器(模型移動(dòng)至GPU一定要在定義優(yōu)化器之前,因?yàn)橐苿?dòng)前后的模型已經(jīng)不是同一個(gè)模型對(duì)象)。

  訓(xùn)練時(shí),先用zero_grad()來將上一次的梯度清零,然后將數(shù)據(jù)輸入網(wǎng)絡(luò),求誤差,誤差反向傳播求每個(gè)requires_grad=True的Tensor(也就是網(wǎng)絡(luò)權(quán)重)的梯度,根據(jù)優(yōu)化規(guī)則對(duì)網(wǎng)絡(luò)權(quán)重值進(jìn)行更新,在一次次的更新迭代中,網(wǎng)絡(luò)朝著loss降低的方向變化著。

  值的注意的是,圖像數(shù)據(jù)也需要移動(dòng)至GPU,并且需要將其類型轉(zhuǎn)換為與網(wǎng)絡(luò)模型的權(quán)重相同的torch.float

  DataLoader

  到前面為止,其實(shí)已經(jīng)能夠?qū)崿F(xiàn)訓(xùn)練的過程了,但是,通常情況下,我們都需要:

  將數(shù)據(jù)打包成一個(gè)批量送入網(wǎng)絡(luò)

  每次隨機(jī)將數(shù)據(jù)打亂送入網(wǎng)絡(luò)

  用多線程的方式加載數(shù)據(jù)(這樣能夠提升數(shù)據(jù)加載速度)

  這些事情不需要我們自己實(shí)現(xiàn),有torch.utils.data.DataLoader來幫我們實(shí)現(xiàn)。完整聲明如下:

  torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)

  其中的sampler、batch_sampler、collate_fn都是可以有自定義實(shí)現(xiàn)的。我們簡(jiǎn)單的使用默認(rèn)的實(shí)現(xiàn)來構(gòu)造DataLoader。使用了DataLoader之后的訓(xùn)練代碼稍微有些不同,其中也添加了保存模型的代碼(只保存參數(shù)的方式):

  import torch

  import torch.nn as nn

  import torch.optim as optim

  from torch.utils.data import DataLoader

  import date

  import model

  date_set = date.MyDateSet("Train/")

  dataloader = DataLoader(date_set, batch_size=128,

  shuffle=True, drop_last=True)

  model = model.VDSR()

  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

  model.to(device)

  mse_loss = nn.MSELoss()

  adam = optim.Adam(model.parameters())

  def train():

  for epoch in range(1000):

  running_loss = 0.0

  for i, images in enumerate(dataloader):

  rresized, y = images

  adam.zero_grad()

  out = model(rresized.to(device, torch.float))

  loss = mse_loss(out, y.to(device, torch.float))

  loss.backward()

  adam.step()

  running_loss += loss

  if epoch % 10 == 9:

  PATH = './trainedModel/net_' + str(epoch + 1) + '.pth'

  torch.save(model.state_dict(), PATH)

  print('[%d] loss: %.3f' %

  (epoch + 1, running_loss / 3))

  print('Finished Training')

  if __name__ == '__main__':

  train()


向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI