溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶服務(wù)條款》

Pytorch如何繼承Subset類完成自定義數(shù)據(jù)拆分

發(fā)布時(shí)間:2022-02-21 09:30:51 來(lái)源:億速云 閱讀:343 作者:iii 欄目:開發(fā)技術(shù)

這篇文章主要介紹“Pytorch如何繼承Subset類完成自定義數(shù)據(jù)拆分”,在日常操作中,相信很多人在Pytorch如何繼承Subset類完成自定義數(shù)據(jù)拆分問(wèn)題上存在疑惑,小編查閱了各式資料,整理出簡(jiǎn)單好用的操作方法,希望對(duì)大家解答”Pytorch如何繼承Subset類完成自定義數(shù)據(jù)拆分”的疑惑有所幫助!接下來(lái),請(qǐng)跟著小編一起來(lái)學(xué)習(xí)吧!

下面是加載內(nèi)置訓(xùn)練數(shù)據(jù)集的常見操作:

from torchvision.datasets import FashionMNIST
from torchvision.transforms import Compose, ToTensor, Normalize
RAW_DATA_PATH = './rawdata'
transform = Compose(
        [ToTensor(),
         Normalize((0.1307,), (0.3081,))
         ]
    )
train_data = FashionMNIST(
        root=RAW_DATA_PATH,
        download=True,
        train=True,
        transform=transform
    )

這里的train_data 做為 dataset 對(duì)象,它擁有許多熟悉,我們可以通過(guò)以下方法獲取樣本數(shù)據(jù)的分類類別集合、樣本的特征維度、樣本的標(biāo)簽集合等信息。

classes = train_data.classes
num_features = train_data.data[0].shape[0]
train_labels = train_data.targets

print(classes)
print(num_features)
print(train_labels)

輸出如下:

['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
28
tensor([9, 0, 0,  ..., 3, 0, 5])

但是,我們常常會(huì)在訓(xùn)練集的基礎(chǔ)上拆分出驗(yàn)證集(或者只用部分?jǐn)?shù)據(jù)來(lái)進(jìn)行訓(xùn)練)。我們想到的第一個(gè)方法是使用 torch.utils.data.random_split 對(duì) dataset 進(jìn)行劃分,下面我們假設(shè)劃分10000個(gè)樣本做為訓(xùn)練集,其余樣本做為驗(yàn)證集:

from torch.utils.data import random_split
k = 10000
train_data, valid_data = random_split(train_data, [k, len(train_data)-k])

注意我們?nèi)绻蛴?train_data 和 valid_data 的類型,可以看到顯示:

<class 'torch.utils.data.dataset.Subset'>

已經(jīng)不再是torchvision.datasets.mnist.FashionMNIST 對(duì)象,而是一個(gè)所謂的 Subset 對(duì)象!此時(shí) Subset 對(duì)象雖然仍然還存有 data 屬性,但是內(nèi)置的 target classes 屬性已經(jīng)不復(fù)存在,

比如如果我們強(qiáng)行訪問(wèn) valid_data 的 target 屬性:

valid_target = valid_data.target

就會(huì)報(bào)如下錯(cuò)誤:

'Subset' object has no attribute 'target'

但如果我們?cè)诤罄m(xù)的代碼中常常會(huì)將拆分后的數(shù)據(jù)集也默認(rèn)為 dataset 對(duì)象,那么該如何做到代碼的一致性呢?

這里有一個(gè)trick,那就是以繼承 SubSet 類的方式的方式定義一個(gè)新的 CustomSubSet 類,使新類在保持 SubSet 類的基本屬性的基礎(chǔ)上,擁有和原本數(shù)據(jù)集類相似的屬性,如 targets classes 等:

from torch.utils.data import Subset
class CustomSubset(Subset):
    '''A custom subset class'''
    def __init__(self, dataset, indices):
        super().__init__(dataset, indices)
        self.targets = dataset.targets # 保留targets屬性
        self.classes = dataset.classes # 保留classes屬性

    def __getitem__(self, idx): #同時(shí)支持索引訪問(wèn)操作
        x, y = self.dataset[self.indices[idx]]      
        return x, y 

    def __len__(self): # 同時(shí)支持取長(zhǎng)度操作
        return len(self.indices)

然后就引出了第二種劃分方法,即通過(guò)初始化 CustomSubset 對(duì)象的方式直接對(duì)數(shù)據(jù)集進(jìn)行劃分(這里為了簡(jiǎn)化省略了shuffle的步驟):

import numpy as np
from copy import deepcopy
origin_data = deepcopy(train_data)
train_data = CustomSubset(origin_data, np.arange(k))
valid_data = CustomSubset(origin_data, np.arange(k, len(origin_data))-k)

注意: CustomSubset 類的初始化方法的第二個(gè)參數(shù) indices 為樣本索引,我們可以通過(guò) np.arange() 的方法來(lái)創(chuàng)建。

然后,我們?cè)僭L問(wèn) valid_data 對(duì)應(yīng)的 classes 和 targes 屬性:

print(valid_data.classes)
print(valid_data.targets)

此時(shí),我們發(fā)現(xiàn)可以成功訪問(wèn)這些屬性了:

['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
tensor([9, 0, 0,  ..., 3, 0, 5])

當(dāng)然, CustomSubset 的作用并不只是添加數(shù)據(jù)集的屬性,我們還可以自定義一些數(shù)據(jù)預(yù)處理操作。

我們將類的結(jié)構(gòu)修改如下:

class CustomSubset(Subset):
    '''A custom subset class with customizable data transformation'''
    def __init__(self, dataset, indices, subset_transform=None):
        super().__init__(dataset, indices)
        self.targets = dataset.targets
        self.classes = dataset.classes
        self.subset_transform = subset_transform

    def __getitem__(self, idx):
        x, y = self.dataset[self.indices[idx]]
        
        if self.subset_transform:
            x = self.subset_transform(x)
      
        return x, y   
    
    def __len__(self): 
        return len(self.indices)

我們可以在使用樣本前設(shè)置好數(shù)據(jù)預(yù)處理算子:

from torchvision import transforms
valid_data.subset_transform = transforms.Compose(\
    [transforms.RandomRotation((180,180))])

這樣,我們?cè)傧裣铝羞@樣用索引訪問(wèn)取出數(shù)據(jù)集樣本時(shí),就會(huì)自動(dòng)調(diào)用算子完成預(yù)處理操作:

print(valid_data[0])

打印結(jié)果縮略如下:

(tensor([[[-0.4242, -0.4242, -0.4242, ......-0.4242, -0.4242, -0.4242, -0.4242, -0.4242]]]), 9)

到此,關(guān)于“Pytorch如何繼承Subset類完成自定義數(shù)據(jù)拆分”的學(xué)習(xí)就結(jié)束了,希望能夠解決大家的疑惑。理論與實(shí)踐的搭配能更好的幫助大家學(xué)習(xí),快去試試吧!若想繼續(xù)學(xué)習(xí)更多相關(guān)知識(shí),請(qǐng)繼續(xù)關(guān)注億速云網(wǎng)站,小編會(huì)繼續(xù)努力為大家?guī)?lái)更多實(shí)用的文章!

向AI問(wèn)一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI