溫馨提示×

在PyTorch中如何加載和處理數(shù)據(jù)集

小樊
94
2024-03-05 18:13:55
欄目: 編程語言

在PyTorch中,通常通過使用torch.utils.data.Datasettorch.utils.data.DataLoader來加載和處理數(shù)據(jù)集。

首先,創(chuàng)建一個自定義的數(shù)據(jù)集類,繼承自torch.utils.data.Dataset,并實現(xiàn)__len____getitem__方法。在__getitem__方法中,可以根據(jù)索引加載和預(yù)處理數(shù)據(jù)。

import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        # 進行數(shù)據(jù)預(yù)處理
        return sample

然后,實例化自定義數(shù)據(jù)集類并使用torch.utils.data.DataLoader創(chuàng)建一個數(shù)據(jù)加載器,指定批量大小和是否打亂數(shù)據(jù)。

data = [...]  # 數(shù)據(jù)集

dataset = CustomDataset(data)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

最后,可以通過迭代數(shù)據(jù)加載器來訪問數(shù)據(jù)集中的數(shù)據(jù)。

for batch in dataloader:
    # 處理批量數(shù)據(jù)
    pass

0