要實現(xiàn)自定義數(shù)據(jù)集類,需要繼承PyTorch中的Dataset類,并重寫其中的兩個方法:len__和__getitem。下面是一個簡單的例子,演示如何實現(xiàn)一個自定義數(shù)據(jù)集類:
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data, targets):
self.data = data
self.targets = targets
def __len__(self):
return len(self.data)
def __getitem__(self, index):
data_point = self.data[index]
target = self.targets[index]
return data_point, target
在上面的例子中,CustomDataset類接收兩個參數(shù)data和targets作為初始化參數(shù),分別表示數(shù)據(jù)和標簽。然后重寫了__len__方法,返回數(shù)據(jù)集的長度,重寫了__getitem__方法,根據(jù)索引index返回對應的數(shù)據(jù)點和標簽。
使用這個自定義數(shù)據(jù)集類的方法如下:
data = [...] # your data
targets = [...] # your targets
custom_dataset = CustomDataset(data, targets)
data_loader = torch.utils.data.DataLoader(custom_dataset, batch_size=64, shuffle=True)
for data, target in data_loader:
# do something with data and target
這樣就可以通過自定義數(shù)據(jù)集類來加載自己的數(shù)據(jù)集,并使用DataLoader來批量加載數(shù)據(jù)。