溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶(hù)服務(wù)條款》

torch.utils.data.DataLoader與迭代器轉(zhuǎn)換的方法

發(fā)布時(shí)間:2022-02-21 09:32:55 來(lái)源:億速云 閱讀:244 作者:iii 欄目:開(kāi)發(fā)技術(shù)

這篇文章主要介紹“torch.utils.data.DataLoader與迭代器轉(zhuǎn)換的方法”的相關(guān)知識(shí),小編通過(guò)實(shí)際案例向大家展示操作過(guò)程,操作方法簡(jiǎn)單快捷,實(shí)用性強(qiáng),希望這篇“torch.utils.data.DataLoader與迭代器轉(zhuǎn)換的方法”文章能幫助大家解決問(wèn)題。

在做實(shí)驗(yàn)時(shí),我們常常會(huì)使用用開(kāi)源的數(shù)據(jù)集進(jìn)行測(cè)試。而Pytorch中內(nèi)置了許多數(shù)據(jù)集,這些數(shù)據(jù)集我們常常使用DataLoader類(lèi)進(jìn)行加載。
如下面這個(gè)我們使用DataLoader類(lèi)加載torch.vision中的FashionMNIST數(shù)據(jù)集。

from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

我們接下來(lái)定義Dataloader對(duì)象用于加載這兩個(gè)數(shù)據(jù)集:

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

那么這個(gè)train_dataloader究竟是什么類(lèi)型呢?

print(type(train_dataloader))  # <class 'torch.utils.data.dataloader.DataLoader'>

我們可以將先其轉(zhuǎn)換為迭代器類(lèi)型。

print(type(iter(train_dataloader)))# <class 'torch.utils.data.dataloader._SingleProcessDataLoaderIter'>

然后再使用next(iter(train_dataloader))從迭代器里取數(shù)據(jù),如下所示:

train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

可以看到我們成功獲取了數(shù)據(jù)集中第一張圖片的信息,控制臺(tái)打印:

Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 2

圖片可視化顯示如下:

torch.utils.data.DataLoader與迭代器轉(zhuǎn)換的方法

不過(guò)有讀者可能就會(huì)產(chǎn)生疑問(wèn),很多時(shí)候我們并沒(méi)有將DataLoader類(lèi)型強(qiáng)制轉(zhuǎn)換成迭代器類(lèi)型呀,大多數(shù)時(shí)候我們會(huì)寫(xiě)如下代碼:

for train_features, train_labels in train_dataloader: 
    print(train_features.shape) # torch.Size([64, 1, 28, 28])
    print(train_features[0].shape) # torch.Size([1, 28, 28])
    print(train_features[0].squeeze().shape) # torch.Size([28, 28])
    
    img = train_features[0].squeeze()
    label = train_labels[0]
    plt.imshow(img, cmap="gray")
    plt.show()
    print(f"Label: {label}")

可以看到,該代碼也能夠正常迭代訓(xùn)練數(shù)據(jù),前三個(gè)樣本的控制臺(tái)打印輸出為:

torch.Size([64, 1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([28, 28])
Label: 7
torch.Size([64, 1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([28, 28])
Label: 4
torch.Size([64, 1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([28, 28])
Label: 1

那么為什么我們這里沒(méi)有顯式將Dataloader轉(zhuǎn)換為迭代器類(lèi)型呢,其實(shí)是Python語(yǔ)言for循環(huán)的一種機(jī)制,一旦我們用for ... in ...句式來(lái)迭代一個(gè)對(duì)象,那么Python解釋器就會(huì)偷偷地自動(dòng)幫我們創(chuàng)建好迭代器,也就是說(shuō)

for train_features, train_labels in train_dataloader:

實(shí)際上等同于

for train_features, train_labels in iter(train_dataloader):

更進(jìn)一步,這實(shí)際上等同于

train_iterator = iter(train_dataloader)
try:
    while True:
        train_features, train_labels = next(train_iterator)
except StopIteration:
    pass

推而廣之,我們?cè)谟肞ython迭代直接迭代列表時(shí):

for x in [1, 2, 3, 4]:

其實(shí)Python解釋器已經(jīng)為我們隱式轉(zhuǎn)換為迭代器了:

list_iterator = iter([1, 2, 3, 4])
try:
    while True:
        x = next(list_iterator)
except StopIteration:
    pass

關(guān)于“torch.utils.data.DataLoader與迭代器轉(zhuǎn)換的方法”的內(nèi)容就介紹到這里了,感謝大家的閱讀。如果想了解更多行業(yè)相關(guān)的知識(shí),可以關(guān)注億速云行業(yè)資訊頻道,小編每天都會(huì)為大家更新不同的知識(shí)點(diǎn)。

向AI問(wèn)一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI