溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶(hù)服務(wù)條款》

詳解Pytorch如何使用nii數(shù)據(jù)做輸入數(shù)據(jù)操作

發(fā)布時(shí)間:2020-07-21 15:04:46 來(lái)源:億速云 閱讀:1123 作者:小豬 欄目:開(kāi)發(fā)技術(shù)

小編這次要給大家分享的是詳解Pytorch如何使用nii數(shù)據(jù)做輸入數(shù)據(jù)操作,文章內(nèi)容豐富,感興趣的小伙伴可以來(lái)了解一下,希望大家閱讀完這篇文章之后能夠有所收獲。

使用pix2pix-gan做醫(yī)學(xué)圖像合成的時(shí)候,如果把nii數(shù)據(jù)轉(zhuǎn)成png格式會(huì)損失很多信息,以為png格式圖像的灰度值有256階,因此直接使用nii的醫(yī)學(xué)圖像做輸入會(huì)更好一點(diǎn)。

但是Pythorch中的Dataloader是不能直接讀取nii圖像的,因此加一個(gè)CreateNiiDataset的類(lèi)。

先來(lái)了解一下pytorch中讀取數(shù)據(jù)的主要途徑——Dataset類(lèi)。在自己構(gòu)建數(shù)據(jù)層時(shí)都要基于這個(gè)類(lèi),類(lèi)似于C++中的虛基類(lèi)。

自己構(gòu)建的數(shù)據(jù)層包含三個(gè)部分

class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
 raise NotImplementedError
def __len__(self):
 raise NotImplementedError
def __add__(self, other):
 return ConcatDataset([self, other])

根據(jù)自己的需要編寫(xiě)CreateNiiDataset子類(lèi):

因?yàn)槲沂腔趆ttps://github.com/junyanz/pytorch-CycleGAN-and-pix2pix

做pix2pix-gan的實(shí)驗(yàn),數(shù)據(jù)包含兩個(gè)部分mr 和 ct,不需要標(biāo)簽,因此上面的 def getitem(self, index):中不需要index這個(gè)參數(shù)了,類(lèi)似地,根據(jù)需要,加入自己的參數(shù),去掉不需要的參數(shù)。

class CreateNiiDataset(Dataset):
 def __init__(self, opt, transform = None, target_transform = None):
  self.path2 = opt.dataroot # parameter passing
  self.A = 'MR' 
  self.B = 'CT'
  lines = os.listdir(os.path.join(self.path2, self.A))
  lines.sort()
  imgs = []
  for line in lines:
   imgs.append(line)
  self.imgs = imgs
  self.transform = transform
  self.target_transform = target_transform

 def crop(self, image, crop_size):
  shp = image.shape
  scl = [int((shp[0] - crop_size[0]) / 2), int((shp[1] - crop_size[1]) / 2)]
  image_crop = image[scl[0]:scl[0] + crop_size[0], scl[1]:scl[1] + crop_size[1]]
  return image_crop

 def __getitem__(self, item):
  file = self.imgs[item]
  img1 = sitk.ReadImage(os.path.join(self.path2, self.A, file))
  img2 = sitk.ReadImage(os.path.join(self.path2, self.B, file))
  data1 = sitk.GetArrayFromImage(img1)
  data2 = sitk.GetArrayFromImage(img2)

  if data1.shape[0] != 256:
   data1 = self.crop(data1, [256, 256])
   data2 = self.crop(data2, [256, 256])
  if self.transform is not None:
   data1 = self.transform(data1)
   data2 = self.transform(data2)

  if np.min(data1)<0:
   data1 = (data1 - np.min(data1))/(np.max(data1)-np.min(data1))

  if np.min(data2)<0:
   #data2 = data2 - np.min(data2)
   data2 = (data2 - np.min(data2))/(np.max(data2)-np.min(data2))

  data = {}
  data1 = data1[np.newaxis, np.newaxis, :, :]
  data1_tensor = torch.from_numpy(np.concatenate([data1,data1,data1], 1))
  data1_tensor = data1_tensor.type(torch.FloatTensor)
  data['A'] = data1_tensor # should be a tensor in Float Tensor Type

  data2 = data2[np.newaxis, np.newaxis, :, :]
  data2_tensor = torch.from_numpy(np.concatenate([data2,data2,data2], 1))
  data2_tensor = data2_tensor.type(torch.FloatTensor)
  data['B'] = data2_tensor # should be a tensor in Float Tensor Type
  data['A_paths'] = [os.path.join(self.path2, self.A, file)] # should be a list, with path inside
  data['B_paths'] = [os.path.join(self.path2, self.B, file)]
  return data

 def load_data(self):
  return self

 def __len__(self):
  return len(self.imgs)

注意:最后輸出的data是一個(gè)字典,里面有四個(gè)keys=[‘A',‘B',‘A_paths',‘B_paths'], 一定要注意數(shù)據(jù)要轉(zhuǎn)成FloatTensor。

其次是data[‘A_paths'] 接收的值是一個(gè)list,一定要加[ ] 擴(kuò)起來(lái),要不然測(cè)試存圖的時(shí)候會(huì)有問(wèn)題,找這個(gè)問(wèn)題找了好久才發(fā)現(xiàn)。

然后直接在train.py的主函數(shù)里面把數(shù)據(jù)加載那行改掉就好了

data_loader = CreateNiiDataset(opt)
dataset = data_loader.load_data()

Over!

補(bǔ)充知識(shí):nii格式圖像存為npy格式

我就廢話(huà)不多說(shuō)了,大家還是直接看代碼吧!

import nibabel as nib
import os
import numpy as np
 
img_path = '/home/lei/train/img/'
seg_path = '/home/lei/train/seg/'
saveimg_path = '/home/lei/train/npy_img/'
saveseg_path = '/home/lei/train/npy_seg/'
 
img_names = os.listdir(img_path)
seg_names = os.listdir(seg_path)
 
for img_name in img_names:
 print(img_name)
 img = nib.load(img_path + img_name).get_data() #載入
 img = np.array(img)
 np.save(saveimg_path + str(img_name).split('.')[0] + '.npy', img) #保存
 
for seg_name in seg_names:
 print(seg_name)
 seg = nib.load(seg_path + seg_name).get_data()
 seg = np.array(seg)
 np.save(saveseg_path + str(seg_name).split('.')[0] + '.npy

看完這篇關(guān)于詳解Pytorch如何使用nii數(shù)據(jù)做輸入數(shù)據(jù)操作的文章,如果覺(jué)得文章內(nèi)容寫(xiě)得不錯(cuò)的話(huà),可以把它分享出去給更多人看到。

向AI問(wèn)一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI