溫馨提示×

pytorch如何自定義數(shù)據(jù)集

小億
86
2024-06-04 17:11:27
欄目: 深度學習

要在PyTorch中自定義數(shù)據(jù)集,需要創(chuàng)建一個繼承自torch.utils.data.Dataset的類,并且實現(xiàn)__len____getitem__方法。

下面是一個簡單的例子,展示如何自定義一個數(shù)據(jù)集類:

import torch
from torch.utils.data import Dataset

# 自定義數(shù)據(jù)集類
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        return sample

# 創(chuàng)建數(shù)據(jù)集實例
data = [1, 2, 3, 4, 5]
dataset = CustomDataset(data)

# 使用DataLoader加載數(shù)據(jù)集
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)

# 遍歷數(shù)據(jù)集
for batch in dataloader:
    print(batch)

在上面的例子中,我們創(chuàng)建了一個CustomDataset類,該類接收一個數(shù)據(jù)列表并實現(xiàn)了__len____getitem__方法。然后我們創(chuàng)建了一個數(shù)據(jù)集實例dataset并使用DataLoader加載數(shù)據(jù)集。最后我們遍歷了數(shù)據(jù)集并打印了每個batch的數(shù)據(jù)。

通過自定義數(shù)據(jù)集類,我們可以靈活地處理各種不同格式的數(shù)據(jù),并且可以方便地與PyTorch的數(shù)據(jù)加載工具進行集成。

0