溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶服務(wù)條款》

ConvNeXt怎么實(shí)現(xiàn)植物幼苗分類

發(fā)布時(shí)間:2022-01-17 12:04:11 來(lái)源:億速云 閱讀:151 作者:iii 欄目:開(kāi)發(fā)技術(shù)

今天小編給大家分享一下ConvNeXt怎么實(shí)現(xiàn)植物幼苗分類的相關(guān)知識(shí)點(diǎn),內(nèi)容詳細(xì),邏輯清晰,相信大部分人都還太了解這方面的知識(shí),所以分享這篇文章給大家參考一下,希望大家閱讀完這篇文章后有所收獲,下面我們一起來(lái)了解一下吧。

前言

ConvNeXts 完全由標(biāo)準(zhǔn) ConvNet 模塊構(gòu)建,在準(zhǔn)確性和可擴(kuò)展性方面與 Transformer 競(jìng)爭(zhēng),實(shí)現(xiàn) 87.8% ImageNet top-1 準(zhǔn)確率,在 COCO 檢測(cè)和 ADE20K 分割方面優(yōu)于 Swin Transformers,同時(shí)保持標(biāo)準(zhǔn) ConvNet 的簡(jiǎn)單性和效率。

ConvNexts的特點(diǎn);

使用7×7的卷積核,在VGG、ResNet等經(jīng)典的CNN模型中,使用的是小卷積核,但是ConvNexts證明了大卷積和的有效性。作者嘗試了幾種內(nèi)核大小,包括 3、5、7、9 和 11。網(wǎng)絡(luò)的性能從 79.9% (3×3) 提高到 80.6% (7×7),而網(wǎng)絡(luò)的 FLOPs 大致保持不變, 內(nèi)核大小的好處在 7×7 處達(dá)到飽和點(diǎn)。

使用GELU(高斯誤差線性單元)激活函數(shù)。GELUs是 dropout、zoneout、Relus的綜合,GELUs對(duì)于輸入乘以一個(gè)0,1組成的mask,而該mask的生成則是依概率隨機(jī)的依賴于輸入。實(shí)驗(yàn)效果要比Relus與ELUs都要好。下圖是實(shí)驗(yàn)數(shù)據(jù):

ConvNeXt怎么實(shí)現(xiàn)植物幼苗分類

使用LayerNorm而不是BatchNorm。

倒置瓶頸。圖 3 (a) 至 (b) 說(shuō)明了這些配置。盡管深度卷積層的 FLOPs 增加了,但由于下采樣殘差塊的快捷 1×1 卷積層的 FLOPs 顯著減少,這種變化將整個(gè)網(wǎng)絡(luò)的 FLOPs 減少到 4.6G。成績(jī)從 80.5% 提高到 80.6%。在 ResNet-200/Swin-B 方案中,這一步帶來(lái)了更多的收益(81.9% 到 82.6%),同時(shí)也減少了 FLOP。

ConvNeXt怎么實(shí)現(xiàn)植物幼苗分類

ConvNeXt殘差模塊

殘差模塊是整個(gè)模型的核心。如下圖:

ConvNeXt怎么實(shí)現(xiàn)植物幼苗分類

代碼實(shí)現(xiàn):

class Block(nn.Module):
    r""" ConvNeXt Block. There are two equivalent implementations:
    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
    We use (2) as we find it slightly faster in PyTorch
    
    Args:
        dim (int): Number of input channels.
        drop_path (float): Stochastic depth rate. Default: 0.0
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
    """
    def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
        self.norm = LayerNorm(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * dim, dim)
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 
                                    requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
        x = input + self.drop_path(x)
        return x

數(shù)據(jù)增強(qiáng)Cutout和Mixup

ConvNext使用了Cutout和Mixup,為了提高成績(jī)我在我的代碼中也加入這兩種增強(qiáng)方式。官方使用timm,我沒(méi)有采用官方的,而選擇用torchtoolbox。安裝命令:

pip install torchtoolbox

Cutout實(shí)現(xiàn),在transforms中。

from torchtoolbox.transform import Cutout

# 數(shù)據(jù)預(yù)處理

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    Cutout(),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])

])

Mixup實(shí)現(xiàn),在train方法中。需要導(dǎo)入包:from torchtoolbox.tools import mixup_data, mixup_criterion

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
        data, labels_a, labels_b, lam = mixup_data(data, target, alpha)
        optimizer.zero_grad()
        output = model(data)
        loss = mixup_criterion(criterion, output, labels_a, labels_b, lam)
        loss.backward()
        optimizer.step()
        print_loss = loss.data.item()

項(xiàng)目結(jié)構(gòu)

使用tree命令,打印項(xiàng)目結(jié)構(gòu)

ConvNeXt怎么實(shí)現(xiàn)植物幼苗分類

數(shù)據(jù)集

數(shù)據(jù)集選用植物幼苗分類,總共12類。數(shù)據(jù)集連接如下:

鏈接  提取碼:syng

在工程的根目錄新建data文件夾,獲取數(shù)據(jù)集后,將trian和test解壓放到data文件夾下面,如下圖:

ConvNeXt怎么實(shí)現(xiàn)植物幼苗分類

導(dǎo)入模型文件

從官方的鏈接中找到convnext.py文件,將其放入Model文件夾中。如圖:

ConvNeXt怎么實(shí)現(xiàn)植物幼苗分類

安裝庫(kù),并導(dǎo)入需要的庫(kù)

模型用到了timm庫(kù),如果沒(méi)有需要安裝,執(zhí)行命令:

pip install timm

新建train_connext.py文件,導(dǎo)入所需要的包:

import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from dataset.dataset import SeedlingData
from torch.autograd import Variable
from Model.convnext import convnext_tiny
from torchtoolbox.tools import mixup_data, mixup_criterion
from torchtoolbox.transform import Cutout

設(shè)置全局參數(shù)

設(shè)置使用GPU,設(shè)置學(xué)習(xí)率、BatchSize、epoch等參數(shù)。

# 設(shè)置全局參數(shù)
modellr = 1e-4
BATCH_SIZE = 8
EPOCHS = 300
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

數(shù)據(jù)預(yù)處理

數(shù)據(jù)處理比較簡(jiǎn)單,沒(méi)有做復(fù)雜的嘗試,有興趣的可以加入一些處理。

# 數(shù)據(jù)預(yù)處理

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    Cutout(),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])

])
transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

數(shù)據(jù)讀取

然后我們?cè)赿ataset文件夾下面新建 init.py和dataset.py,在mydatasets.py文件夾寫(xiě)入下面的代碼:

說(shuō)一下代碼的核心邏輯。

第一步 建立字典,定義類別對(duì)應(yīng)的ID,用數(shù)字代替類別。

第二步 在__init__里面編寫(xiě)獲取圖片路徑的方法。測(cè)試集只有一層路徑直接讀取,訓(xùn)練集在train文件夾下面是類別文件夾,先獲取到類別,再獲取到具體的圖片路徑。然后使用sklearn中切分?jǐn)?shù)據(jù)集的方法,按照7:3的比例切分訓(xùn)練集和驗(yàn)證集。

第三步 在__getitem__方法中定義讀取單個(gè)圖片和類別的方法,由于圖像中有位深度32位的,所以我在讀取圖像的時(shí)候做了轉(zhuǎn)換。

代碼如下:

# coding:utf8
import os
from PIL import Image
from torch.utils import data
from torchvision import transforms as T
from sklearn.model_selection import train_test_split

Labels = {'Black-grass': 0, 'Charlock': 1, 'Cleavers': 2, 'Common Chickweed': 3,
          'Common wheat': 4, 'Fat Hen': 5, 'Loose Silky-bent': 6, 'Maize': 7, 'Scentless Mayweed': 8,
          'Shepherds Purse': 9, 'Small-flowered Cranesbill': 10, 'Sugar beet': 11}


class SeedlingData(data.Dataset):

    def __init__(self, root, transforms=None, train=True, test=False):
        """
        主要目標(biāo): 獲取所有圖片的地址,并根據(jù)訓(xùn)練,驗(yàn)證,測(cè)試劃分?jǐn)?shù)據(jù)
        """
        self.test = test
        self.transforms = transforms

        if self.test:
            imgs = [os.path.join(root, img) for img in os.listdir(root)]
            self.imgs = imgs
        else:
            imgs_labels = [os.path.join(root, img) for img in os.listdir(root)]
            imgs = []
            for imglable in imgs_labels:
                for imgname in os.listdir(imglable):
                    imgpath = os.path.join(imglable, imgname)
                    imgs.append(imgpath)
            trainval_files, val_files = train_test_split(imgs, test_size=0.3, random_state=42)
            if train:
                self.imgs = trainval_files
            else:
                self.imgs = val_files

    def __getitem__(self, index):
        """
        一次返回一張圖片的數(shù)據(jù)
        """
        img_path = self.imgs[index]
        img_path = img_path.replace("\\", '/')
        if self.test:
            label = -1
        else:
            labelname = img_path.split('/')[-2]
            label = Labels[labelname]
        data = Image.open(img_path).convert('RGB')
        data = self.transforms(data)
        return data, label

    def __len__(self):
        return len(self.imgs)

然后我們?cè)趖rain.py調(diào)用SeedlingData讀取數(shù)據(jù) ,記著導(dǎo)入剛才寫(xiě)的dataset.py(from mydatasets import SeedlingData)

# 讀取數(shù)據(jù)
dataset_train = SeedlingData('data/train', transforms=transform, train=True)
dataset_test = SeedlingData("data/train", transforms=transform_test, train=False)
# 導(dǎo)入數(shù)據(jù)
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)

設(shè)置模型

設(shè)置loss函數(shù)為nn.CrossEntropyLoss()。

  • 設(shè)置模型為coatnet_0,修改最后一層全連接輸出改為12(數(shù)據(jù)集的類別)。

  • 優(yōu)化器設(shè)置為adam。

  • 學(xué)習(xí)率調(diào)整策略改為余弦退火

# 實(shí)例化模型并且移動(dòng)到GPU
criterion = nn.CrossEntropyLoss()
#criterion = SoftTargetCrossEntropy()
model_ft = convnext_tiny(pretrained=True)
num_ftrs = model_ft.head.in_features
model_ft.fc = nn.Linear(num_ftrs, 12)
model_ft.to(DEVICE)
# 選擇簡(jiǎn)單暴力的Adam優(yōu)化器,學(xué)習(xí)率調(diào)低
optimizer = optim.Adam(model_ft.parameters(), lr=modellr)
cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,T_max=20,eta_min=1e-9)

定義訓(xùn)練和驗(yàn)證函數(shù)

alpha=0.2 Mixup所需的參數(shù)。

# 定義訓(xùn)練過(guò)程
alpha=0.2
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    sum_loss = 0
    total_num = len(train_loader.dataset)
    print(total_num, len(train_loader))
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
        data, labels_a, labels_b, lam = mixup_data(data, target, alpha)
        optimizer.zero_grad()
        output = model(data)
        loss = mixup_criterion(criterion, output, labels_a, labels_b, lam)
        loss.backward()
        optimizer.step()
        print_loss = loss.data.item()
        sum_loss += print_loss
        if (batch_idx + 1) % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
                       100. * (batch_idx + 1) / len(train_loader), loss.item()))
    ave_loss = sum_loss / len(train_loader)
    print('epoch:{},loss:{}'.format(epoch, ave_loss))

ACC=0
# 驗(yàn)證過(guò)程
def val(model, device, test_loader):
    global ACC
    model.eval()
    test_loss = 0
    correct = 0
    total_num = len(test_loader.dataset)
    print(total_num, len(test_loader))
    with torch.no_grad():
        for data, target in test_loader:
            data, target = Variable(data).to(device), Variable(target).to(device)
            output = model(data)
            loss = criterion(output, target)
            _, pred = torch.max(output.data, 1)
            correct += torch.sum(pred == target)
            print_loss = loss.data.item()
            test_loss += print_loss
        correct = correct.data.item()
        acc = correct / total_num
        avgloss = test_loss / len(test_loader)
        print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            avgloss, correct, len(test_loader.dataset), 100 * acc))
        if acc > ACC:
            torch.save(model_ft, 'model_' + str(epoch) + '_' + str(round(acc, 3)) + '.pth')
            ACC = acc


# 訓(xùn)練

for epoch in range(1, EPOCHS + 1):
    train(model_ft, DEVICE, train_loader, optimizer, epoch)
    cosine_schedule.step()
    val(model_ft, DEVICE, test_loader)

然后就可以開(kāi)始訓(xùn)練了

ConvNeXt怎么實(shí)現(xiàn)植物幼苗分類

訓(xùn)練10個(gè)epoch就能得到不錯(cuò)的結(jié)果:

ConvNeXt怎么實(shí)現(xiàn)植物幼苗分類

測(cè)試

第一種寫(xiě)法

測(cè)試集存放的目錄如下圖:

ConvNeXt怎么實(shí)現(xiàn)植物幼苗分類

第一步 定義類別,這個(gè)類別的順序和訓(xùn)練時(shí)的類別順序?qū)?yīng),一定不要改變順序?。。?!

classes = ('Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed',
           'Common wheat', 'Fat Hen', 'Loose Silky-bent',
           'Maize', 'Scentless Mayweed', 'Shepherds Purse', 'Small-flowered Cranesbill', 'Sugar beet')

第二步 定義transforms,transforms和驗(yàn)證集的transforms一樣即可,別做數(shù)據(jù)增強(qiáng)。

transform_test = transforms.Compose([
         transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

第三步 加載model,并將模型放在DEVICE里。

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load("model_8_0.971.pth")
model.eval()
model.to(DEVICE)

第四步 讀取圖片并預(yù)測(cè)圖片的類別,在這里注意,讀取圖片用PIL庫(kù)的Image。不要用cv2,transforms不支持。

path = 'data/test/'
testList = os.listdir(path)
for file in testList:
    img = Image.open(path + file)
    img = transform_test(img)
    img.unsqueeze_(0)
    img = Variable(img).to(DEVICE)
    out = model(img)
    # Predict
    _, pred = torch.max(out.data, 1)
    print('Image Name:{},predict:{}'.format(file, classes[pred.data.item()]))

測(cè)試完整代碼:

import torch.utils.data.distributed
import torchvision.transforms as transforms
from PIL import Image
from torch.autograd import Variable
import os

classes = ('Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed',
           'Common wheat', 'Fat Hen', 'Loose Silky-bent',
           'Maize', 'Scentless Mayweed', 'Shepherds Purse', 'Small-flowered Cranesbill', 'Sugar beet')
transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load("model_8_0.971.pth")
model.eval()
model.to(DEVICE)

path = 'data/test/'
testList = os.listdir(path)
for file in testList:
    img = Image.open(path + file)
    img = transform_test(img)
    img.unsqueeze_(0)
    img = Variable(img).to(DEVICE)
    out = model(img)
    # Predict
    _, pred = torch.max(out.data, 1)
    print('Image Name:{},predict:{}'.format(file, classes[pred.data.item()]))

運(yùn)行結(jié)果:

ConvNeXt怎么實(shí)現(xiàn)植物幼苗分類

第二種寫(xiě)法

第二種,使用自定義的Dataset讀取圖片。前三步同上,差別主要在第四步。讀取數(shù)據(jù)的時(shí)候,使用Dataset的SeedlingData讀取。

dataset_test =SeedlingData('data/test/', transform_test,test=True)
print(len(dataset_test))
# 對(duì)應(yīng)文件夾的label
 
for index in range(len(dataset_test)):
    item = dataset_test[index]
    img, label = item
    img.unsqueeze_(0)
    data = Variable(img).to(DEVICE)
    output = model(data)
    _, pred = torch.max(output.data, 1)
    print('Image Name:{},predict:{}'.format(dataset_test.imgs[index], classes[pred.data.item()]))
    index += 1

運(yùn)行結(jié)果:

ConvNeXt怎么實(shí)現(xiàn)植物幼苗分類

以上就是“ConvNeXt怎么實(shí)現(xiàn)植物幼苗分類”這篇文章的所有內(nèi)容,感謝各位的閱讀!相信大家閱讀完這篇文章都有很大的收獲,小編每天都會(huì)為大家更新不同的知識(shí),如果還想學(xué)習(xí)更多的知識(shí),請(qǐng)關(guān)注億速云行業(yè)資訊頻道。

向AI問(wèn)一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI