溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

pytorch中如何使用DataLoader對數(shù)據(jù)集進行批處理的方法

發(fā)布時間:2020-09-13 17:47:43 來源:腳本之家 閱讀:400 作者:建森要健身 欄目:開發(fā)技術(shù)

最近搞了搞minist手寫數(shù)據(jù)集的神經(jīng)網(wǎng)絡搭建,一個數(shù)據(jù)集里面很多個數(shù)據(jù),不能一次喂入,所以需要分成一小塊一小塊喂入搭建好的網(wǎng)絡。

pytorch中有很方便的dataloader函數(shù)來方便我們進行批處理,做了簡單的例子,過程很簡單,就像把大象裝進冰箱里一共需要幾步?

第一步:打開冰箱門。

我們要創(chuàng)建torch能夠識別的數(shù)據(jù)集類型(pytorch中也有很多現(xiàn)成的數(shù)據(jù)集類型,以后再說)。

首先我們建立兩個向量X和Y,一個作為輸入的數(shù)據(jù),一個作為正確的結(jié)果:

pytorch中如何使用DataLoader對數(shù)據(jù)集進行批處理的方法

隨后我們需要把X和Y組成一個完整的數(shù)據(jù)集,并轉(zhuǎn)化為pytorch能識別的數(shù)據(jù)集類型:

pytorch中如何使用DataLoader對數(shù)據(jù)集進行批處理的方法

我們來看一下這些數(shù)據(jù)的數(shù)據(jù)類型:

pytorch中如何使用DataLoader對數(shù)據(jù)集進行批處理的方法

可以看出我們把X和Y通過Data.TensorDataset() 這個函數(shù)拼裝成了一個數(shù)據(jù)集,數(shù)據(jù)集的類型是【TensorDataset】。

好了,第一步結(jié)束了,冰箱門打開了。

第二步:把大象裝進去。

就是把上一步做成的數(shù)據(jù)集放入Data.DataLoader中,可以生成一個迭代器,從而我們可以方便的進行批處理。

pytorch中如何使用DataLoader對數(shù)據(jù)集進行批處理的方法

DataLoader中也有很多其他參數(shù):

  1. dataset:Dataset類型,從其中加載數(shù)據(jù)
  2. batch_size:int,可選。每個batch加載多少樣本
  3. shuffle:bool,可選。為True時表示每個epoch都對數(shù)據(jù)進行洗牌
  4. sampler:Sampler,可選。從數(shù)據(jù)集中采樣樣本的方法。
  5. num_workers:int,可選。加載數(shù)據(jù)時使用多少子進程。默認值為0,表示在主進程中加載數(shù)據(jù)。
  6. collate_fn:callable,可選。
  7. pin_memory:bool,可選
  8. drop_last:bool,可選。True表示如果最后剩下不完全的batch,丟棄。False表示不丟棄。

好了,第二步結(jié)束了,大象裝進去了。

第三步:把冰箱門關(guān)上。

好啦,現(xiàn)在我們就可以愉快的用我們上面定義好的迭代器進行訓練啦。

在這里我們利用print來模擬我們的訓練過程,即我們在這里對搭建好的網(wǎng)絡進行喂入。

pytorch中如何使用DataLoader對數(shù)據(jù)集進行批處理的方法

輸出的結(jié)果是:

pytorch中如何使用DataLoader對數(shù)據(jù)集進行批處理的方法

可以看到,我們一共訓練了所有的數(shù)據(jù)訓練了5次。數(shù)據(jù)中一共10組,我們設(shè)置的mini-batch是3,即每一次我們訓練網(wǎng)絡的時候喂入3組數(shù)據(jù),到了最后一次我們只有1組數(shù)據(jù)了,比mini-batch小,我們就僅輸出這一個。

此外,還可以利用python中的enumerate(),是對所有可以迭代的數(shù)據(jù)類型(含有很多東西的list等等)進行取操作的函數(shù),用法如下:

pytorch中如何使用DataLoader對數(shù)據(jù)集進行批處理的方法

以上就是本文的全部內(nèi)容,希望對大家的學習有所幫助,也希望大家多多支持億速云。

向AI問一下細節(jié)

免責聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進行舉報,并提供相關(guān)證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI