溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶服務(wù)條款》

PyTorch怎么實(shí)現(xiàn)對(duì)貓狗二分類訓(xùn)練集進(jìn)行讀取

發(fā)布時(shí)間:2021-12-16 09:48:31 來源:億速云 閱讀:328 作者:iii 欄目:大數(shù)據(jù)

這篇文章主要介紹“PyTorch怎么實(shí)現(xiàn)對(duì)貓狗二分類訓(xùn)練集進(jìn)行讀取”,在日常操作中,相信很多人在PyTorch怎么實(shí)現(xiàn)對(duì)貓狗二分類訓(xùn)練集進(jìn)行讀取問題上存在疑惑,小編查閱了各式資料,整理出簡單好用的操作方法,希望對(duì)大家解答”PyTorch怎么實(shí)現(xiàn)對(duì)貓狗二分類訓(xùn)練集進(jìn)行讀取”的疑惑有所幫助!接下來,請(qǐng)跟著小編一起來學(xué)習(xí)吧!

從kaggle中下載貓狗二分類訓(xùn)練數(shù)據(jù),自己編寫一個(gè)DogCatDataset,使得pytorch可以對(duì)貓狗二分類訓(xùn)練集進(jìn)行讀取

貓狗大戰(zhàn)代碼

import os
import zipfile
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))
        
# unzip        
print(os.getcwd())
os.makedirs('data', exist_ok=True)

with zipfile.ZipFile('../input/dogs-vs-cats-redux-kernels-edition/train.zip') as train_zip:
    train_zip.extractall('data')
    
with zipfile.ZipFile('../input/dogs-vs-cats-redux-kernels-edition/test.zip') as test_zip:
    test_zip.extractall('data')


# show unzip dir
train_dir = './data/train'
test_dir = './data/test'

print('len:', len(os.listdir(train_dir)), len(os.listdir(test_dir)))
os.listdir(train_dir)[:5]
os.listdir(test_dir)[:5]


import numpy as np 
import pandas as pd 
import glob
import os
import torch
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.model_selection import train_test_split
from torchvision import datasets, models, transforms
import torch.nn as nn
import torch.optim as optim


batch_size = 100

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

torch.manual_seed(1234)

if device =='cuda':
    torch.cuda.manual_seed_all(1234)
    
lr = 0.001


train_list = glob.glob(os.path.join(train_dir,'*.jpg'))
test_list = glob.glob(os.path.join(test_dir, '*.jpg'))

print('show data:', len(train_list), train_list[:3])
print('show data:', len(test_list), test_list[:3])


fig = plt.figure()
ax = fig.add_subplot(1,1,1)
img = Image.open(train_list[0])
plt.imshow(img)
plt.axis('off')
plt.show()
print(type(img))
img_np = np.asarray(img)
print(img_np.shape)



train_list, val_list = train_test_split(train_list, test_size=0.2)
print(len(train_list), train_list[:3])
print(len(val_list), val_list[:3])


train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
#     transforms.RandomCrop(224), 
    transforms.ToTensor(),
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
#     transforms.RandomCrop(224), 
    transforms.ToTensor(),
])

test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
#     transforms.RandomCrop(224), 
    transforms.ToTensor(),
])


class dataset(torch.utils.data.Dataset):
    def __init__(self,file_list,now_transform):
        self.file_list = file_list # list of path
        self.transform = now_transform
    
    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength
    
    def __getitem__(self,idx):
        img_path = self.file_list[idx]
        img = Image.open(img_path)
#         print(img.size)
        img_transformed = self.transform(img)
        
        # test 沒有標(biāo)簽?
        label = img_path.split('/')[-1].split('.')[0]
        if label == 'dog':
            label=1
        elif label == 'cat':
            label=0
        else:
            assert False
            
        return img_transformed,label


train_data = dataset(train_list, train_transforms)
val_data = dataset(val_list, test_transforms)
# test_data = dataset(test_list, transform=test_transforms)

train_loader = torch.utils.data.DataLoader(dataset = train_data, batch_size=batch_size, shuffle=True )
val_loader = torch.utils.data.DataLoader(dataset = val_data, batch_size=batch_size, shuffle=True)
# test_loader = torch.utils.data.DataLoader(dataset = test_data, batch_size=batch_size, shuffle=True)

print(len(train_data), len(train_loader))
print(len(val_data), len(val_loader))
print(train_data, type(train_data))
t1, t2 = train_data[7]
print(t1, t2)
print(type(t1))
print(t1.shape)



class CNN_STD(nn.Module):
    def __init__(self):
        super(CNN_STD,self).__init__()
        
        self.layer1 = nn.Sequential(
            nn.Conv2d(3,16,kernel_size=3, padding=0,stride=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
        self.layer2 = nn.Sequential(
            nn.Conv2d(16,32, kernel_size=3, padding=0, stride=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2)
            )
        
        self.layer3 = nn.Sequential(
            nn.Conv2d(32,64, kernel_size=3, padding=0, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
        
        self.fc1 = nn.Linear(3*3*64,10)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(10,2)
        self.relu = nn.ReLU()
        
        
    def forward(self,x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = out.view(out.size(0),-1)
        out = self.relu(self.fc1(out))
        out = self.fc2(out)
        return out


optimizer = optim.Adam(params = model.parameters(),lr=lr)
loss_f = nn.CrossEntropyLoss()

epochs = 10

print('start epoch iter, please wait...')
for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0
    
    for data, label in train_loader:
        data = data.to(device)
        label = label.to(device)
        
        output = model(data)
        loss = loss_f(output, label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        acc = ((output.argmax(dim=1) == label).float().mean())
        epoch_accuracy += acc/len(train_loader)
        epoch_loss += loss/len(train_loader)
    
        
    print('Epoch : {}, train accuracy : {}, train loss : {}'.format(epoch+1, epoch_accuracy,epoch_loss))
    
    
    with torch.no_grad():
        epoch_val_accuracy=0
        epoch_val_loss =0
        for data, label in val_loader:
            data = data.to(device)
            label = label.to(device)
            
            val_output = model(data)
            val_loss = loss_f(val_output,label)
            
            
            acc = ((val_output.argmax(dim=1) == label).float().mean())
            epoch_val_accuracy += acc/ len(val_loader)
            epoch_val_loss += val_loss/ len(val_loader)
            
        print('Epoch : {}, val_accuracy : {}, val_loss : {}'.format(epoch+1, epoch_val_accuracy,epoch_val_loss))

PyTorch怎么實(shí)現(xiàn)對(duì)貓狗二分類訓(xùn)練集進(jìn)行讀取

到此,關(guān)于“PyTorch怎么實(shí)現(xiàn)對(duì)貓狗二分類訓(xùn)練集進(jìn)行讀取”的學(xué)習(xí)就結(jié)束了,希望能夠解決大家的疑惑。理論與實(shí)踐的搭配能更好的幫助大家學(xué)習(xí),快去試試吧!若想繼續(xù)學(xué)習(xí)更多相關(guān)知識(shí),請(qǐng)繼續(xù)關(guān)注億速云網(wǎng)站,小編會(huì)繼續(xù)努力為大家?guī)砀鄬?shí)用的文章!

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場,如果涉及侵權(quán)請(qǐng)聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI