溫馨提示×

PyTorch中怎么實現(xiàn)自定義數(shù)據(jù)集類

小億
91
2024-05-10 15:49:56
欄目: 深度學習

要實現(xiàn)自定義數(shù)據(jù)集類,需要繼承PyTorch中的Dataset類,并重寫其中的兩個方法:len__和__getitem。下面是一個簡單的例子,演示如何實現(xiàn)一個自定義數(shù)據(jù)集類:

import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        data_point = self.data[index]
        target = self.targets[index]
        
        return data_point, target

在上面的例子中,CustomDataset類接收兩個參數(shù)data和targets作為初始化參數(shù),分別表示數(shù)據(jù)和標簽。然后重寫了__len__方法,返回數(shù)據(jù)集的長度,重寫了__getitem__方法,根據(jù)索引index返回對應的數(shù)據(jù)點和標簽。

使用這個自定義數(shù)據(jù)集類的方法如下:

data = [...] # your data
targets = [...] # your targets

custom_dataset = CustomDataset(data, targets)
data_loader = torch.utils.data.DataLoader(custom_dataset, batch_size=64, shuffle=True)

for data, target in data_loader:
    # do something with data and target

這樣就可以通過自定義數(shù)據(jù)集類來加載自己的數(shù)據(jù)集,并使用DataLoader來批量加載數(shù)據(jù)。

0