溫馨提示×

Torch如何自定義一個Dataset類

小樊
118
2024-03-29 11:53:10

要自定義一個Dataset類,可以繼承自torch.utils.data.Dataset,并實現(xiàn)其中的__len__和__getitem__方法來定義數(shù)據(jù)集的長度和獲取數(shù)據(jù)的方式。

下面是一個簡單的例子:

import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        return sample

# 創(chuàng)建一個數(shù)據(jù)集實例
data = [1, 2, 3, 4, 5]
custom_dataset = CustomDataset(data)

# 獲取數(shù)據(jù)集的長度
print(len(custom_dataset))

# 獲取數(shù)據(jù)集中第一個樣本
print(custom_dataset[0])

在上面的例子中,我們定義了一個CustomDataset類,它接受一個數(shù)據(jù)列表作為輸入,并實現(xiàn)了__len__方法和__getitem__方法。通過實例化CustomDataset類,我們可以獲取數(shù)據(jù)集的長度并獲取數(shù)據(jù)集中的樣本。

0