溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶(hù)服務(wù)條款》

Pytorch數(shù)據(jù)讀取與預(yù)處理的實(shí)現(xiàn)方法

發(fā)布時(shí)間:2021-04-01 11:15:27 來(lái)源:億速云 閱讀:308 作者:小新 欄目:開(kāi)發(fā)技術(shù)

這篇文章給大家分享的是有關(guān)Pytorch數(shù)據(jù)讀取與預(yù)處理的實(shí)現(xiàn)方法的內(nèi)容。小編覺(jué)得挺實(shí)用的,因此分享給大家做個(gè)參考,一起跟隨小編過(guò)來(lái)看看吧。

  在煉丹時(shí),數(shù)據(jù)的讀取與預(yù)處理是關(guān)鍵一步。不同的模型所需要的數(shù)據(jù)以及預(yù)處理方式各不相同,如果每個(gè)輪子都我們自己寫(xiě)的話(huà),是很浪費(fèi)時(shí)間和精力的。Pytorch幫我們實(shí)現(xiàn)了方便的數(shù)據(jù)讀取與預(yù)處理方法,下面記錄兩個(gè)DEMO,便于加快以后的代碼效率。

  根據(jù)數(shù)據(jù)是否一次性讀取完,將DEMO分為:

  1、串行式讀取。也就是一次性讀取完所有需要的數(shù)據(jù)到內(nèi)存,模型訓(xùn)練時(shí)不會(huì)再訪(fǎng)問(wèn)外存。通常用在內(nèi)存足夠的情況下使用,速度更快。

  2、并行式讀取。也就是邊訓(xùn)練邊讀取數(shù)據(jù)。通常用在內(nèi)存不夠的情況下使用,會(huì)占用計(jì)算資源,如果分配的好的話(huà),幾乎不損失速度。

  Pytorch官方的數(shù)據(jù)提取方式盡管方便編碼,但由于它提取數(shù)據(jù)方式比較死板,會(huì)浪費(fèi)資源,下面對(duì)其進(jìn)行分析。

1  串行式讀取

1.1  DEMO代碼

import torch 
from torch.utils.data import Dataset,DataLoader 
  
class MyDataSet(Dataset):# ————1————
 def __init__(self):  
  self.data = torch.tensor(range(10)).reshape([5,2])
  self.label = torch.tensor(range(5))

 def __getitem__(self, index):  
  return self.data[index], self.label[index]

 def __len__(self):  
  return len(self.data)
 
my_data_set = MyDataSet()# ————2————
my_data_loader = DataLoader(
 dataset=my_data_set,  # ————3————
 batch_size=2,     # ————4————
 shuffle=True,     # ————5————
 sampler=None,     # ————6————
 batch_sampler=None,  # ————7———— 
 num_workers=0 ,    # ————8———— 
 collate_fn=None,    # ————9———— 
 pin_memory=True,    # ————10———— 
 drop_last=True     # ————11————
)

for i in my_data_loader: # ————12————
 print(i)

  注釋處解釋如下:

  1、重寫(xiě)數(shù)據(jù)集類(lèi),用于保存數(shù)據(jù)。除了 __init__() 外,必須實(shí)現(xiàn) __getitem__() 和 __len__() 兩個(gè)方法。前一個(gè)方法用于輸出索引對(duì)應(yīng)的數(shù)據(jù)。后一個(gè)方法用于獲取數(shù)據(jù)集的長(zhǎng)度。

  2~5、 2準(zhǔn)備好數(shù)據(jù)集后,傳入DataLoader來(lái)迭代生成數(shù)據(jù)。前三個(gè)參數(shù)分別是傳入的數(shù)據(jù)集對(duì)象、每次獲取的批量大小、是否打亂數(shù)據(jù)集輸出。

  6、采樣器,如果定義這個(gè),shuffle只能設(shè)置為False。所謂采樣器就是用于生成數(shù)據(jù)索引的可迭代對(duì)象,比如列表。因此,定義了采樣器,采樣都按它來(lái),shuffle再打亂就沒(méi)意義了。

  7、批量采樣器,如果定義這個(gè),batch_size、shuffle、sampler、drop_last都不能定義。實(shí)際上,如果沒(méi)有特殊的數(shù)據(jù)生成順序的要求,采樣器并沒(méi)有必要定義。torch.utils.data 中的各種 Sampler 就是采樣器類(lèi),如果需要,可以使用它們來(lái)定義。

  8、用于生成數(shù)據(jù)的子進(jìn)程數(shù)。默認(rèn)為0,不并行。

  9、拼接多個(gè)樣本的方法,默認(rèn)是將每個(gè)batch的數(shù)據(jù)在第一維上進(jìn)行拼接。這樣可能說(shuō)不清楚,并且由于這里可以探究一下獲取數(shù)據(jù)的速度,后面再詳細(xì)說(shuō)明。

  10、是否使用鎖頁(yè)內(nèi)存。用的話(huà)會(huì)更快,內(nèi)存不充足最好別用。

  11、是否把最后小于batch的數(shù)據(jù)丟掉。

  12、迭代獲取數(shù)據(jù)并輸出。

1.2  速度探索

  首先看一下DEMO的輸出:

Pytorch數(shù)據(jù)讀取與預(yù)處理的實(shí)現(xiàn)方法

  輸出了兩個(gè)batch的數(shù)據(jù),每組數(shù)據(jù)中data和label都正確排列,符合我們的預(yù)期。那么DataLoader是怎么把數(shù)據(jù)整合起來(lái)的呢?首先,我們把collate_fn定義為直接映射(不用它默認(rèn)的方法),來(lái)查看看每次DataLoader從MyDataSet中讀取了什么,將上面部分代碼修改如下:

my_data_loader = DataLoader(
 dataset=my_data_set,  
 batch_size=2,      
 shuffle=True,      
 sampler=None,     
 batch_sampler=None,  
 num_workers=0 ,    
 collate_fn=lambda x:x, #修改處
 pin_memory=True,    
 drop_last=True     
)

  結(jié)果如下:

Pytorch數(shù)據(jù)讀取與預(yù)處理的實(shí)現(xiàn)方法

  輸出還是兩個(gè)batch,然而每個(gè)batch中,單個(gè)的data和label是在一個(gè)list中的。似乎可以看出,DataLoader是一個(gè)一個(gè)讀取MyDataSet中的數(shù)據(jù)的,然后再進(jìn)行相應(yīng)數(shù)據(jù)的拼接。為了驗(yàn)證這點(diǎn),代碼修改如下:

import torch 
from torch.utils.data import Dataset,DataLoader 
  
class MyDataSet(Dataset): 
 def __init__(self):  
  self.data = torch.tensor(range(10)).reshape([5,2])
  self.label = torch.tensor(range(5))

 def __getitem__(self, index):  
  print(index)     #修改處2
  return self.data[index], self.label[index]

 def __len__(self):  
  return len(self.data)
 
my_data_set = MyDataSet() 
my_data_loader = DataLoader(
 dataset=my_data_set,  
 batch_size=2,      
 shuffle=True,      
 sampler=None,     
 batch_sampler=None,  
 num_workers=0 ,    
 collate_fn=lambda x:x, #修改處1
 pin_memory=True,    
 drop_last=True     
)

for i in my_data_loader: 
 print(i)

  輸出如下:

Pytorch數(shù)據(jù)讀取與預(yù)處理的實(shí)現(xiàn)方法

  驗(yàn)證了前面的猜想,的確是一個(gè)一個(gè)讀取的。如果數(shù)據(jù)集定義的不是格式化的數(shù)據(jù),那還好,但是我這里定義的是tensor,是可以直接通過(guò)列表來(lái)索引對(duì)應(yīng)的tensor的。因此,DataLoader的操作比直接索引多了拼接這一步,肯定是會(huì)慢很多的。一兩次的讀取還好,但在訓(xùn)練中,大量的讀取累加起來(lái),就會(huì)浪費(fèi)很多時(shí)間了。

  自定義一個(gè)DataLoader可以證明這一點(diǎn),代碼如下:

import torch 
from torch.utils.data import Dataset,DataLoader 
from time import time
  
class MyDataSet(Dataset): 
 def __init__(self):  
  self.data = torch.tensor(range(100000)).reshape([50000,2])
  self.label = torch.tensor(range(50000))

 def __getitem__(self, index):  
  return self.data[index], self.label[index]

 def __len__(self):  
  return len(self.data)

# 自定義DataLoader
class MyDataLoader():
 def __init__(self, dataset,batch_size):
  self.dataset = dataset
  self.batch_size = batch_size
 def __iter__(self):
  self.now = 0
  self.shuffle_i = np.array(range(self.dataset.__len__())) 
  np.random.shuffle(self.shuffle_i)
  return self
 
 def __next__(self): 
  self.now += self.batch_size
  if self.now <= len(self.shuffle_i):
   indexes = self.shuffle_i[self.now-self.batch_size:self.now]
   return self.dataset.__getitem__(indexes)
  else:
   raise StopIteration

# 使用官方DataLoader
my_data_set = MyDataSet() 
my_data_loader = DataLoader(
 dataset=my_data_set,  
 batch_size=256,      
 shuffle=True,      
 sampler=None,     
 batch_sampler=None,  
 num_workers=0 ,    
 collate_fn=None, 
 pin_memory=True,    
 drop_last=True     
)

start_t = time()
for t in range(10):
 for i in my_data_loader: 
  pass
print("官方:", time() - start_t)
 
 
#自定義DataLoader
my_data_set = MyDataSet() 
my_data_loader = MyDataLoader(my_data_set,256)

start_t = time()
for t in range(10):
 for i in my_data_loader: 
  pass
print("自定義:", time() - start_t)

運(yùn)行結(jié)果如下:

Pytorch數(shù)據(jù)讀取與預(yù)處理的實(shí)現(xiàn)方法

  以上使用batch大小為256,僅各讀取10 epoch的數(shù)據(jù),都有30多倍的時(shí)間上的差距,更大的batch差距會(huì)更明顯。另外,這里用于測(cè)試的每個(gè)數(shù)據(jù)只有兩個(gè)浮點(diǎn)數(shù),如果是圖像,所需的時(shí)間可能會(huì)增加幾百倍。因此,如果數(shù)據(jù)量和batch都比較大,并且數(shù)據(jù)是格式化的,最好自己寫(xiě)數(shù)據(jù)生成器。

2  并行式讀取

2.1  DEMO代碼

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader 
from torchvision import transforms 
from torchvision.datasets import ImageFolder 
 
path = r'E:\DataSets\ImageNet\ILSVRC2012_img_train\10-19\128x128'
my_data_set = ImageFolder(      #————1————
 root = path,            #————2————
 transform = transforms.Compose([  #————3————
  transforms.ToTensor(),
  transforms.CenterCrop(64)
 ]),
 loader = plt.imread         #————4————
)
my_data_loader = DataLoader(
 dataset=my_data_set,   
 batch_size=128,       
 shuffle=True,       
 sampler=None,       
 batch_sampler=None,    
 num_workers=0,      
 collate_fn=None,      
 pin_memory=True,      
 drop_last=True 
)      

for i in my_data_loader: 
 print(i)

  注釋處解釋如下:

  1/2、ImageFolder類(lèi)繼承自DataSet類(lèi),因此可以按索引讀取圖像。路徑必須包含文件夾,ImageFolder會(huì)給每個(gè)文件夾中的圖像添加索引,并且每張圖像會(huì)給予其所在文件夾的標(biāo)簽。舉個(gè)例子,代碼中my_data_set[0] 輸出的是圖像對(duì)象和它對(duì)應(yīng)的標(biāo)簽組成的列表。

  3、圖像到格式化數(shù)據(jù)的轉(zhuǎn)換組合。更多的轉(zhuǎn)換方法可以看 transform 模塊。

  4、圖像法的讀取方式,默認(rèn)是PIL.Image.open(),但我發(fā)現(xiàn)plt.imread()更快一些。

  由于是邊訓(xùn)練邊讀取,transform會(huì)占用很多時(shí)間,因此可以先將圖像轉(zhuǎn)換為需要的形式存入外存再讀取,從而避免重復(fù)操作。

  其中transform.ToTensor()會(huì)把正常讀取的圖像轉(zhuǎn)換為torch.tensor,并且像素值會(huì)映射至[0,1][0,1]。由于plt.imread()讀取png圖像時(shí),像素值在[0,1][0,1],而讀取jpg圖像時(shí),像素值卻在[0,255][0,255],因此使用transform.ToTensor()能將圖像像素區(qū)間統(tǒng)一化。

感謝各位的閱讀!關(guān)于“Pytorch數(shù)據(jù)讀取與預(yù)處理的實(shí)現(xiàn)方法”這篇文章就分享到這里了,希望以上內(nèi)容可以對(duì)大家有一定的幫助,讓大家可以學(xué)到更多知識(shí),如果覺(jué)得文章不錯(cuò),可以把它分享出去讓更多的人看到吧!

向AI問(wèn)一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀(guān)點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI