PyTorch中的數(shù)據(jù)加載器怎么使用

小億
90
2024-03-14 13:34:28

在PyTorch中,數(shù)據(jù)加載器可以通過(guò)torch.utils.data.DataLoader來(lái)實(shí)現(xiàn)。數(shù)據(jù)加載器可以幫助用戶批量加載數(shù)據(jù),并可以在訓(xùn)練過(guò)程中對(duì)數(shù)據(jù)進(jìn)行隨機(jī)排列、并行加載等操作。

下面是一個(gè)簡(jiǎn)單的示例,演示如何使用數(shù)據(jù)加載器來(lái)加載一個(gè)簡(jiǎn)單的數(shù)據(jù)集:

import torch
from torch.utils.data import Dataset, DataLoader

# 創(chuàng)建一個(gè)自定義的數(shù)據(jù)集類(lèi)
class CustomDataset(Dataset):
    def __init__(self):
        self.data = torch.randn(100, 3)  # 100個(gè)3維的隨機(jī)數(shù)據(jù)
        self.targets = torch.randint(0, 2, (100,))  # 100個(gè)隨機(jī)目標(biāo)標(biāo)簽

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

# 創(chuàng)建數(shù)據(jù)集實(shí)例
dataset = CustomDataset()

# 創(chuàng)建數(shù)據(jù)加載器實(shí)例
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# 遍歷數(shù)據(jù)加載器
for i, (data, target) in enumerate(data_loader):
    print(f'Batch {i}:')
    print('Data:', data)
    print('Target:', target)

在上述示例中,首先定義了一個(gè)自定義的數(shù)據(jù)集類(lèi)CustomDataset,然后創(chuàng)建了一個(gè)數(shù)據(jù)集實(shí)例dataset。接著利用DataLoader類(lèi)來(lái)創(chuàng)建一個(gè)數(shù)據(jù)加載器實(shí)例data_loader,并指定了批量大小為32且開(kāi)啟了數(shù)據(jù)隨機(jī)排列。最后通過(guò)對(duì)數(shù)據(jù)加載器進(jìn)行遍歷,便可以逐批次地獲取數(shù)據(jù)和標(biāo)簽。

0