您好,登錄后才能下訂單哦!
本篇內(nèi)容主要講解“怎么使用pytorch準(zhǔn)備自己的圖片數(shù)據(jù)”,感興趣的朋友不妨來(lái)看看。本文介紹的方法操作簡(jiǎn)單快捷,實(shí)用性強(qiáng)。下面就讓小編來(lái)帶大家學(xué)習(xí)“怎么使用pytorch準(zhǔn)備自己的圖片數(shù)據(jù)”吧!
圖片數(shù)據(jù)一般有兩種情況:
1、所有圖片放在一個(gè)文件夾內(nèi),另外有一個(gè)txt文件顯示標(biāo)簽。
2、不同類別的圖片放在不同的文件夾內(nèi),文件夾就是圖片的類別。
針對(duì)這兩種不同的情況,數(shù)據(jù)集的準(zhǔn)備也不相同,第一種情況可以自定義一個(gè)Dataset,第二種情況直接調(diào)用torchvision.datasets.ImageFolder來(lái)處理。下面分別進(jìn)行說(shuō)明:
這里以mnist數(shù)據(jù)集的10000個(gè)test為例, 我先把test集的10000個(gè)圖片保存出來(lái),并生著對(duì)應(yīng)的txt標(biāo)簽文件。
先在當(dāng)前目錄創(chuàng)建一個(gè)空文件夾mnist_test, 用于保存10000張圖片,接著運(yùn)行代碼:
import torch import torchvision import matplotlib.pyplot as plt from skimage import io mnist_test= torchvision.datasets.MNIST( './mnist', train=False, download=True ) print('test set:', len(mnist_test)) f=open('mnist_test.txt','w') for i,(img,label) in enumerate(mnist_test): img_path="./mnist_test/"+str(i)+".jpg" io.imsave(img_path,img) f.write(img_path+' '+str(label)+'\n') f.close()
經(jīng)過(guò)上面的操作,10000張圖片就保存在mnist_test文件夾里了,并在當(dāng)前目錄下生成了一個(gè)mnist_test.txt的文件,大致如下:
前期工作就裝備好了,接著就進(jìn)入正題了:
from torchvision import transforms, utils from torch.utils.data import Dataset, DataLoader import matplotlib.pyplot as plt from PIL import Image def default_loader(path): return Image.open(path).convert('RGB') class MyDataset(Dataset): def __init__(self, txt, transform=None, target_transform=None, loader=default_loader): fh = open(txt, 'r') imgs = [] for line in fh: line = line.strip('\n') line = line.rstrip() words = line.split() imgs.append((words[0],int(words[1]))) self.imgs = imgs self.transform = transform self.target_transform = target_transform self.loader = loader def __getitem__(self, index): fn, label = self.imgs[index] img = self.loader(fn) if self.transform is not None: img = self.transform(img) return img,label def __len__(self): return len(self.imgs) train_data=MyDataset(txt='mnist_test.txt', transform=transforms.ToTensor()) data_loader = DataLoader(train_data, batch_size=100,shuffle=True) print(len(data_loader)) def show_batch(imgs): grid = utils.make_grid(imgs) plt.imshow(grid.numpy().transpose((1, 2, 0))) plt.title('Batch from dataloader') for i, (batch_x, batch_y) in enumerate(data_loader): if(i<4): print(i, batch_x.size(),batch_y.size()) show_batch(batch_x) plt.axis('off') plt.show()
自定義了一個(gè)MyDataset, 繼承自torch.utils.data.Dataset。然后利用torch.utils.data.DataLoader將整個(gè)數(shù)據(jù)集分成多個(gè)批次。
同樣先準(zhǔn)備數(shù)據(jù),這里以flowers數(shù)據(jù)集為例
提取 鏈接: https://pan.baidu.com/s/1dcAsOOZpUfWNYR77JGXPHA?pwd=mwg6
花總共有五類,分別放在5個(gè)文件夾下。大致如下圖:
我的路徑是d:/flowers/.
數(shù)據(jù)準(zhǔn)備好了,就開(kāi)始準(zhǔn)備Dataset吧,這里直接調(diào)用torchvision里面的ImageFolder
import torch import torchvision from torchvision import transforms, utils import matplotlib.pyplot as plt img_data = torchvision.datasets.ImageFolder('D:/bnu/database/flower', transform=transforms.Compose([ transforms.Scale(256), transforms.CenterCrop(224), transforms.ToTensor()]) ) print(len(img_data)) data_loader = torch.utils.data.DataLoader(img_data, batch_size=20,shuffle=True) print(len(data_loader)) def show_batch(imgs): grid = utils.make_grid(imgs,nrow=5) plt.imshow(grid.numpy().transpose((1, 2, 0))) plt.title('Batch from dataloader') for i, (batch_x, batch_y) in enumerate(data_loader): if(i<4): print(i, batch_x.size(), batch_y.size()) show_batch(batch_x) plt.axis('off') plt.show()
到此,相信大家對(duì)“怎么使用pytorch準(zhǔn)備自己的圖片數(shù)據(jù)”有了更深的了解,不妨來(lái)實(shí)際操作一番吧!這里是億速云網(wǎng)站,更多相關(guān)內(nèi)容可以進(jìn)入相關(guān)頻道進(jìn)行查詢,關(guān)注我們,繼續(xù)學(xué)習(xí)!
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。