您好,登錄后才能下訂單哦!
這篇文章將為大家詳細(xì)講解有關(guān)Pytorch中DataLoader, DataSet, Sampler之間有怎樣的關(guān)系呢,小編覺(jué)得挺實(shí)用的,因此分享給大家做個(gè)參考,希望大家閱讀完這篇文章后可以有所收獲。
以下內(nèi)容都是針對(duì)Pytorch 1.0-1.1介紹。
很多文章都是從Dataset等對(duì)象自下往上進(jìn)行介紹,但是對(duì)于初學(xué)者而言,其實(shí)這并不好理解,因?yàn)橛械臅r(shí)候會(huì)不自覺(jué)地陷入到一些細(xì)枝末節(jié)中去,而不能把握重點(diǎn),所以本文將會(huì)自上而下地對(duì)Pytorch數(shù)據(jù)讀取方法進(jìn)行介紹。
自上而下理解三者關(guān)系
首先我們看一下DataLoader.next的源代碼長(zhǎng)什么樣,為方便理解我只選取了num_works為0的情況(num_works簡(jiǎn)單理解就是能夠并行化地讀取數(shù)據(jù))。
class DataLoader(object): ... def __next__(self): if self.num_workers == 0: indices = next(self.sample_iter) # Sampler batch = self.collate_fn([self.dataset[i] for i in indices]) # Dataset if self.pin_memory: batch = _utils.pin_memory.pin_memory_batch(batch) return batch
在閱讀上面代碼前,我們可以假設(shè)我們的數(shù)據(jù)是一組圖像,每一張圖像對(duì)應(yīng)一個(gè)index,那么如果我們要讀取數(shù)據(jù)就只需要對(duì)應(yīng)的index即可,即上面代碼中的indices
,而選取index的方式有多種,有按順序的,也有亂序的,所以這個(gè)工作需要Sampler
完成,現(xiàn)在你不需要具體的細(xì)節(jié),后面會(huì)介紹,你只需要知道DataLoader和Sampler在這里產(chǎn)生關(guān)系。
那么Dataset和DataLoader在什么時(shí)候產(chǎn)生關(guān)系呢?沒(méi)錯(cuò)就是下面一行。我們已經(jīng)拿到了indices,那么下一步我們只需要根據(jù)index對(duì)數(shù)據(jù)進(jìn)行讀取即可了。
再下面的if
語(yǔ)句的作用簡(jiǎn)單理解就是,如果pin_memory=True
,那么Pytorch會(huì)采取一系列操作把數(shù)據(jù)拷貝到GPU,總之就是為了加速。
綜上可以知道DataLoader,Sampler和Dataset三者關(guān)系如下:
在閱讀后文的過(guò)程中,你始終需要將上面的關(guān)系記在心里,這樣能幫助你更好地理解。
Sampler
參數(shù)傳遞
要更加細(xì)致地理解Sampler原理,我們需要先閱讀一下DataLoader 的源代碼,如下:
class DataLoader(object): def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
可以看到初始化參數(shù)里有兩種sampler:sampler
和batch_sampler
,都默認(rèn)為None
。前者的作用是生成一系列的index,而batch_sampler則是將sampler生成的indices打包分組,得到一個(gè)又一個(gè)batch的index。例如下面示例中,BatchSampler
將SequentialSampler
生成的index按照指定的batch size分組。
>>>in : list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) >>>out: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
Pytorch中已經(jīng)實(shí)現(xiàn)的Sampler
有如下幾種:
SequentialSampler
RandomSampler
WeightedSampler
SubsetRandomSampler
需要注意的是DataLoader的部分初始化參數(shù)之間存在互斥關(guān)系,這個(gè)你可以通過(guò)閱讀源碼更深地理解,這里只做總結(jié):
如何自定義Sampler和BatchSampler?
仔細(xì)查看源代碼其實(shí)可以發(fā)現(xiàn),所有采樣器其實(shí)都繼承自同一個(gè)父類,即Sampler
,其代碼定義如下:
class Sampler(object): r"""Base class for all Samplers. Every Sampler subclass has to provide an :meth:`__iter__` method, providing a way to iterate over indices of dataset elements, and a :meth:`__len__` method that returns the length of the returned iterators. .. note:: The :meth:`__len__` method isn't strictly required by :class:`~torch.utils.data.DataLoader`, but is expected in any calculation involving the length of a :class:`~torch.utils.data.DataLoader`. """ def __init__(self, data_source): pass def __iter__(self): raise NotImplementedError def __len__(self): return len(self.data_source)
所以你要做的就是定義好__iter__(self)
函數(shù),不過(guò)要注意的是該函數(shù)的返回值需要是可迭代的。例如SequentialSampler
返回的是iter(range(len(self.data_source)))
。
另外BatchSampler
與其他Sampler的主要區(qū)別是它需要將Sampler作為參數(shù)進(jìn)行打包,進(jìn)而每次迭代返回以batch size為大小的index列表。也就是說(shuō)在后面的讀取數(shù)據(jù)過(guò)程中使用的都是batch sampler。
Dataset
Dataset定義方式如下:
class Dataset(object): def __init__(self): ... def __getitem__(self, index): return ... def __len__(self): return ...
上面三個(gè)方法是最基本的,其中__getitem__
是最主要的方法,它規(guī)定了如何讀取數(shù)據(jù)。但是它又不同于一般的方法,因?yàn)樗莗ython built-in方法,其主要作用是能讓該類可以像list一樣通過(guò)索引值對(duì)數(shù)據(jù)進(jìn)行訪問(wèn)。假如你定義好了一個(gè)dataset,那么你可以直接通過(guò)dataset[0]
來(lái)訪問(wèn)第一個(gè)數(shù)據(jù)。在此之前我一直沒(méi)弄清楚__getitem__
是什么作用,所以一直不知道該怎么進(jìn)入到這個(gè)函數(shù)進(jìn)行調(diào)試?,F(xiàn)在如果你想對(duì)__getitem__
方法進(jìn)行調(diào)試,你可以寫一個(gè)for循環(huán)遍歷dataset來(lái)進(jìn)行調(diào)試了,而不用構(gòu)建dataloader等一大堆東西了,建議學(xué)會(huì)使用ipdb
這個(gè)庫(kù),非常實(shí)用!??!以后有時(shí)間再寫一篇ipdb的使用教程。另外,其實(shí)我們通過(guò)最前面的Dataloader的__next__
函數(shù)可以看到DataLoader對(duì)數(shù)據(jù)的讀取其實(shí)就是用了for循環(huán)來(lái)遍歷數(shù)據(jù),不用往上翻了,我直接復(fù)制了一遍,如下:
class DataLoader(object): ... def __next__(self): if self.num_workers == 0: indices = next(self.sample_iter) batch = self.collate_fn([self.dataset[i] for i in indices]) # this line if self.pin_memory: batch = _utils.pin_memory.pin_memory_batch(batch) return batch
我們仔細(xì)看可以發(fā)現(xiàn),前面還有一個(gè)self.collate_fn
方法,這個(gè)是干嘛用的呢?在介紹前我們需要知道每個(gè)參數(shù)的意義:
indices
: 表示每一個(gè)iteration,sampler返回的indices,即一個(gè)batch size大小的索引列表self.dataset[i]
: 前面已經(jīng)介紹了,這里就是對(duì)第i個(gè)數(shù)據(jù)進(jìn)行讀取操作,一般來(lái)說(shuō)self.dataset[i]=(img, label)
看到這不難猜出collate_fn
的作用就是將一個(gè)batch的數(shù)據(jù)進(jìn)行合并操作。默認(rèn)的collate_fn
是將img和label分別合并成imgs和labels,所以如果你的__getitem__
方法只是返回 img, label
,那么你可以使用默認(rèn)的collate_fn
方法,但是如果你每次讀取的數(shù)據(jù)有img, box, label
等等,那么你就需要自定義collate_fn
來(lái)將對(duì)應(yīng)的數(shù)據(jù)合并成一個(gè)batch數(shù)據(jù),這樣方便后續(xù)的訓(xùn)練步驟。
關(guān)于Pytorch中DataLoader, DataSet, Sampler之間有怎樣的關(guān)系呢就分享到這里了,希望以上內(nèi)容可以對(duì)大家有一定的幫助,可以學(xué)到更多知識(shí)。如果覺(jué)得文章不錯(cuò),可以把它分享出去讓更多的人看到。
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。