溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務(wù)條款》

pytorch實現(xiàn)unet網(wǎng)絡(luò)的方法

發(fā)布時間:2020-06-25 21:17:41 來源:億速云 閱讀:541 作者:Leah 欄目:開發(fā)技術(shù)

這期內(nèi)容當(dāng)中小編將會給大家?guī)碛嘘P(guān)pytorch實現(xiàn)unet網(wǎng)絡(luò)的方法,文章內(nèi)容豐富且以專業(yè)的角度為大家分析和敘述,閱讀完這篇文章希望大家可以有所收獲。

設(shè)計神經(jīng)網(wǎng)絡(luò)的一般步驟:

1. 設(shè)計框架

2. 設(shè)計骨干網(wǎng)絡(luò)

Unet網(wǎng)絡(luò)設(shè)計的步驟:

1. 設(shè)計Unet網(wǎng)絡(luò)工廠模式

2. 設(shè)計編解碼結(jié)構(gòu)

3. 設(shè)計卷積模塊

4. unet實例模塊

Unet網(wǎng)絡(luò)最重要的特征:

1. 編解碼結(jié)構(gòu)。

2. 解碼結(jié)構(gòu),比FCN更加完善,采用連接方式。

3. 本質(zhì)是一個框架,編碼部分可以使用很多圖像分類網(wǎng)絡(luò)。

示例代碼:

import torch
import torch.nn as nn

class Unet(nn.Module):
 #初始化參數(shù):Encoder,Decoder,bridge
 #bridge默認(rèn)值為無,如果有參數(shù)傳入,則用該參數(shù)替換None
 def __init__(self,Encoder,Decoder,bridge = None):
  super(Unet,self).__init__()
  self.encoder = Encoder(encoder_blocks)
  self.decoder = Decoder(decoder_blocks)
  self.bridge = bridge
 def forward(self,x):
  res = self.encoder(x)
  out,skip = res[0],res[1,:]
  if bridge is not None:
   out = bridge(out)
  out = self.decoder(out,skip)
  return out
#設(shè)計編碼模塊
class Encoder(nn.Module):
 def __init__(self,blocks):
  super(Encoder,self).__init__()
  #assert:斷言函數(shù),避免出現(xiàn)參數(shù)錯誤
  assert len(blocks) > 0
  #nn.Modulelist():模型列表,所有的參數(shù)可以納入網(wǎng)絡(luò),但是沒有forward函數(shù)
  self.blocks = nn.Modulelist(blocks)
 def forward(self,x):
  skip = []
  for i in range(len(self.blocks) - 1):
   x = self.blocks[i](x)
   skip.append(x)
  res = [self.block[i+1](x)]
  #列表之間可以通過+號拼接
  res += skip
  return res
#設(shè)計Decoder模塊
class Decoder(nn.Module):
 def __init__(self,blocks):
  super(Decoder, self).__init__()
  assert len(blocks) > 0
  self.blocks = nn.Modulelist(blocks)
 def ceter_crop(self,skips,x):
  _,_,height1,width2 = skips.shape()
  _,_,height2,width3 = x.shape()
  #對圖像進(jìn)行剪切處理,拼接的時候保持對應(yīng)size參數(shù)一致
  ht,wt = min(height1,height2),min(width2,width3)
  dh2 = (height1 - height2)//2 if height1 > height2 else 0
  dw1 = (width2 - width3)//2 if width2 > width3 else 0
  dh3 = (height2 - height1)//2 if height2 > height1 else 0
  dw2 = (width3 - width2)//2 if width3 > width2 else 0
  return skips[:,:,dh2:(dh2 + ht),dw1:(dw1 + wt)],\
    x[:,:,dh3:(dh3 + ht),dw2 : (dw2 + wt)]

 def forward(self, skips,x,reverse_skips = True):
  assert len(skips) == len(blocks) - 1
  if reverse_skips is True:
   skips = skips[: : -1]
  x = self.blocks[0](x)
  for i in range(1, len(self.blocks)):
   skip = skips[i-1]
   x = torch.cat(skip,x,1)
   x = self.blocks[i](x)
  return x
#定義了一個卷積block
def unet_convs(in_channels,out_channels,padding = 0):
 #nn.Sequential:與Modulelist相比,包含了forward函數(shù)
 return nn.Sequential(
  nn.Conv2d(in_channels, out_channels, kernal_size = 3, padding = padding, bias = False),
  nn.BatchNorm2d(outchannels),
  nn.ReLU(inplace = True),
  nn.Conv2d(in_channels, out_channels, kernal_size=3, padding=padding, bias=False),
  nn.BatchNorm2d(outchannels),
  nn.ReLU(inplace=True),
 )
#實例化Unet模型
def unet(in_channels,out_channels):
 encoder_blocks = [unet_convs(in_channels, 64),\
      nn.Sequential(nn.Maxpool2d(kernal_size = 2, stride = 2, ceil_mode = True),\
         unet_convs(64,128)), \
      nn.Sequential(nn.Maxpool2d(kernal_size=2, stride=2, ceil_mode=True), \
         unet_convs(128, 256)),
      nn.Sequential(nn.Maxpool2d(kernal_size=2, stride=2, ceil_mode=True), \
         unet_convs(256, 512)),
      ]
 bridge = nn.Sequential(unet_convs(512, 1024))
 decoder_blocks = [nn.conTranpose2d(1024, 512), \
      nn.Sequential(unet_convs(1024, 512),
         nn.conTranpose2d(512, 256)),\
      nn.Sequential(unet_convs(512, 256),
         nn.conTranpose2d(256, 128)), \
      nn.Sequential(unet_convs(512, 256),
         nn.conTranpose2d(256, 128)), \
      nn.Sequential(unet_convs(256, 128),
         nn.conTranpose2d(128, 64))
      ]
 return Unet(encoder_blocks,decoder_blocks,bridge)

補充知識:Pytorch搭建U-Net網(wǎng)絡(luò)

U-Net: Convolutional Networks for Biomedical Image Segmentation

pytorch實現(xiàn)unet網(wǎng)絡(luò)的方法

import torch.nn as nn
import torch
from torch import autograd
from torchsummary import summary

class DoubleConv(nn.Module):
 def __init__(self, in_ch, out_ch):
  super(DoubleConv, self).__init__()
  self.conv = nn.Sequential(
   nn.Conv2d(in_ch, out_ch, 3, padding=0),
   nn.BatchNorm2d(out_ch),
   nn.ReLU(inplace=True),
   nn.Conv2d(out_ch, out_ch, 3, padding=0),
   nn.BatchNorm2d(out_ch),
   nn.ReLU(inplace=True)
  )

 def forward(self, input):
  return self.conv(input)

class Unet(nn.Module):
 def __init__(self, in_ch, out_ch):
  super(Unet, self).__init__()
  self.conv1 = DoubleConv(in_ch, 64)
  self.pool1 = nn.MaxPool2d(2)
  self.conv2 = DoubleConv(64, 128)
  self.pool2 = nn.MaxPool2d(2)
  self.conv3 = DoubleConv(128, 256)
  self.pool3 = nn.MaxPool2d(2)
  self.conv4 = DoubleConv(256, 512)
  self.pool4 = nn.MaxPool2d(2)
  self.conv5 = DoubleConv(512, 1024)
  # 逆卷積,也可以使用上采樣
  self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
  self.conv6 = DoubleConv(1024, 512)
  self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
  self.conv7 = DoubleConv(512, 256)
  self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
  self.conv8 = DoubleConv(256, 128)
  self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
  self.conv9 = DoubleConv(128, 64)
  self.conv10 = nn.Conv2d(64, out_ch, 1)

 def forward(self, x):
  c1 = self.conv1(x)
  crop1 = c1[:,:,88:480,88:480]
  p1 = self.pool1(c1)
  c2 = self.conv2(p1)
  crop2 = c2[:,:,40:240,40:240]
  p2 = self.pool2(c2)
  c3 = self.conv3(p2)
  crop3 = c3[:,:,16:120,16:120]
  p3 = self.pool3(c3)
  c4 = self.conv4(p3)
  crop4 = c4[:,:,4:60,4:60]
  p4 = self.pool4(c4)
  c5 = self.conv5(p4)
  up_6 = self.up6(c5)
  merge6 = torch.cat([up_6, crop4], dim=1)
  c6 = self.conv6(merge6)
  up_7 = self.up7(c6)
  merge7 = torch.cat([up_7, crop3], dim=1)
  c7 = self.conv7(merge7)
  up_8 = self.up8(c7)
  merge8 = torch.cat([up_8, crop2], dim=1)
  c8 = self.conv8(merge8)
  up_9 = self.up9(c8)
  merge9 = torch.cat([up_9, crop1], dim=1)
  c9 = self.conv9(merge9)
  c10 = self.conv10(c9)
  out = nn.Sigmoid()(c10)
  return out

if __name__=="__main__":
 test_input=torch.rand(1, 1, 572, 572)
 model=Unet(in_ch=1, out_ch=2)
 summary(model, (1,572,572))
 ouput=model(test_input)
 print(ouput.size())
----------------------------------------------------------------
  Layer (type)    Output Shape   Param #
================================================================
   Conv2d-1   [-1, 64, 570, 570]    640
  BatchNorm2d-2   [-1, 64, 570, 570]    128
    ReLU-3   [-1, 64, 570, 570]    0
   Conv2d-4   [-1, 64, 568, 568]   36,928
  BatchNorm2d-5   [-1, 64, 568, 568]    128
    ReLU-6   [-1, 64, 568, 568]    0
  DoubleConv-7   [-1, 64, 568, 568]    0
   MaxPool2d-8   [-1, 64, 284, 284]    0
   Conv2d-9  [-1, 128, 282, 282]   73,856
  BatchNorm2d-10  [-1, 128, 282, 282]    256
    ReLU-11  [-1, 128, 282, 282]    0
   Conv2d-12  [-1, 128, 280, 280]   147,584
  BatchNorm2d-13  [-1, 128, 280, 280]    256
    ReLU-14  [-1, 128, 280, 280]    0
  DoubleConv-15  [-1, 128, 280, 280]    0
  MaxPool2d-16  [-1, 128, 140, 140]    0
   Conv2d-17  [-1, 256, 138, 138]   295,168
  BatchNorm2d-18  [-1, 256, 138, 138]    512
    ReLU-19  [-1, 256, 138, 138]    0
   Conv2d-20  [-1, 256, 136, 136]   590,080
  BatchNorm2d-21  [-1, 256, 136, 136]    512
    ReLU-22  [-1, 256, 136, 136]    0
  DoubleConv-23  [-1, 256, 136, 136]    0
  MaxPool2d-24   [-1, 256, 68, 68]    0
   Conv2d-25   [-1, 512, 66, 66]  1,180,160
  BatchNorm2d-26   [-1, 512, 66, 66]   1,024
    ReLU-27   [-1, 512, 66, 66]    0
   Conv2d-28   [-1, 512, 64, 64]  2,359,808
  BatchNorm2d-29   [-1, 512, 64, 64]   1,024
    ReLU-30   [-1, 512, 64, 64]    0
  DoubleConv-31   [-1, 512, 64, 64]    0
  MaxPool2d-32   [-1, 512, 32, 32]    0
   Conv2d-33   [-1, 1024, 30, 30]  4,719,616
  BatchNorm2d-34   [-1, 1024, 30, 30]   2,048
    ReLU-35   [-1, 1024, 30, 30]    0
   Conv2d-36   [-1, 1024, 28, 28]  9,438,208
  BatchNorm2d-37   [-1, 1024, 28, 28]   2,048
    ReLU-38   [-1, 1024, 28, 28]    0
  DoubleConv-39   [-1, 1024, 28, 28]    0
 ConvTranspose2d-40   [-1, 512, 56, 56]  2,097,664
   Conv2d-41   [-1, 512, 54, 54]  4,719,104
  BatchNorm2d-42   [-1, 512, 54, 54]   1,024
    ReLU-43   [-1, 512, 54, 54]    0
   Conv2d-44   [-1, 512, 52, 52]  2,359,808
  BatchNorm2d-45   [-1, 512, 52, 52]   1,024
    ReLU-46   [-1, 512, 52, 52]    0
  DoubleConv-47   [-1, 512, 52, 52]    0
 ConvTranspose2d-48  [-1, 256, 104, 104]   524,544
   Conv2d-49  [-1, 256, 102, 102]  1,179,904
  BatchNorm2d-50  [-1, 256, 102, 102]    512
    ReLU-51  [-1, 256, 102, 102]    0
   Conv2d-52  [-1, 256, 100, 100]   590,080
  BatchNorm2d-53  [-1, 256, 100, 100]    512
    ReLU-54  [-1, 256, 100, 100]    0
  DoubleConv-55  [-1, 256, 100, 100]    0
 ConvTranspose2d-56  [-1, 128, 200, 200]   131,200
   Conv2d-57  [-1, 128, 198, 198]   295,040
  BatchNorm2d-58  [-1, 128, 198, 198]    256
    ReLU-59  [-1, 128, 198, 198]    0
   Conv2d-60  [-1, 128, 196, 196]   147,584
  BatchNorm2d-61  [-1, 128, 196, 196]    256
    ReLU-62  [-1, 128, 196, 196]    0
  DoubleConv-63  [-1, 128, 196, 196]    0
 ConvTranspose2d-64   [-1, 64, 392, 392]   32,832
   Conv2d-65   [-1, 64, 390, 390]   73,792
  BatchNorm2d-66   [-1, 64, 390, 390]    128
    ReLU-67   [-1, 64, 390, 390]    0
   Conv2d-68   [-1, 64, 388, 388]   36,928
  BatchNorm2d-69   [-1, 64, 388, 388]    128
    ReLU-70   [-1, 64, 388, 388]    0
  DoubleConv-71   [-1, 64, 388, 388]    0
   Conv2d-72   [-1, 2, 388, 388]    130
================================================================
Total params: 31,042,434
Trainable params: 31,042,434
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.25
Forward/backward pass size (MB): 3280.59
Params size (MB): 118.42
Estimated Total Size (MB): 3400.26
----------------------------------------------------------------
torch.Size([1, 2, 388, 388])

上述就是小編為大家分享的pytorch實現(xiàn)unet網(wǎng)絡(luò)的方法了,如果剛好有類似的疑惑,不妨參照上述分析進(jìn)行理解。如果想知道更多相關(guān)知識,歡迎關(guān)注億速云行業(yè)資訊頻道。

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報,并提供相關(guān)證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI