溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶服務(wù)條款》

Pytorch如何實(shí)現(xiàn)手寫數(shù)字mnist識(shí)別功能

發(fā)布時(shí)間:2021-05-24 13:47:44 來(lái)源:億速云 閱讀:167 作者:小新 欄目:開發(fā)技術(shù)

這篇文章給大家分享的是有關(guān)Pytorch如何實(shí)現(xiàn)手寫數(shù)字mnist識(shí)別功能的內(nèi)容。小編覺得挺實(shí)用的,因此分享給大家做個(gè)參考,一起跟隨小編過來(lái)看看吧。

本文實(shí)例講述了Pytorch實(shí)現(xiàn)的手寫數(shù)字mnist識(shí)別功能。分享給大家供大家參考,具體如下:

import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
# 定義是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定義網(wǎng)絡(luò)結(jié)構(gòu)
class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(   #input_size=(1*28*28)
      nn.Conv2d(1, 6, 5, 1, 2), #padding=2保證輸入輸出尺寸相同
      nn.ReLU(),   #input_size=(6*28*28)
      nn.MaxPool2d(kernel_size=2, stride=2),#output_size=(6*14*14)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(6, 16, 5),
      nn.ReLU(),   #input_size=(16*10*10)
      nn.MaxPool2d(2, 2) #output_size=(16*5*5)
    )
    self.fc1 = nn.Sequential(
      nn.Linear(16 * 5 * 5, 120),
      nn.ReLU()
    )
    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.ReLU()
    )
    self.fc3 = nn.Linear(84, 10)
  # 定義前向傳播過程,輸入為x
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    # nn.Linear()的輸入輸出都是維度為一的值,所以要把多維度的tensor展平成一維
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x
#使得我們能夠手動(dòng)輸入命令行參數(shù),就是讓風(fēng)格變得和Linux命令行差不多
parser = argparse.ArgumentParser()
parser.add_argument('--outf', default='./model/', help='folder to output images and model checkpoints') #模型保存路徑
parser.add_argument('--net', default='./model/net.pth', help="path to netG (to continue training)") #模型加載路徑
opt = parser.parse_args()
# 超參數(shù)設(shè)置
EPOCH = 8  #遍歷數(shù)據(jù)集次數(shù)
BATCH_SIZE = 64   #批處理尺寸(batch_size)
LR = 0.001    #學(xué)習(xí)率
# 定義數(shù)據(jù)預(yù)處理方式
transform = transforms.ToTensor()
# 定義訓(xùn)練數(shù)據(jù)集
trainset = tv.datasets.MNIST(
  root='./data/',
  train=True,
  download=True,
  transform=transform)
# 定義訓(xùn)練批處理數(shù)據(jù)
trainloader = torch.utils.data.DataLoader(
  trainset,
  batch_size=BATCH_SIZE,
  shuffle=True,
  )
# 定義測(cè)試數(shù)據(jù)集
testset = tv.datasets.MNIST(
  root='./data/',
  train=False,
  download=True,
  transform=transform)
# 定義測(cè)試批處理數(shù)據(jù)
testloader = torch.utils.data.DataLoader(
  testset,
  batch_size=BATCH_SIZE,
  shuffle=False,
  )
# 定義損失函數(shù)loss function 和優(yōu)化方式(采用SGD)
net = LeNet().to(device)
criterion = nn.CrossEntropyLoss() # 交叉熵?fù)p失函數(shù),通常用于多分類問題上
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
# 訓(xùn)練
if __name__ == "__main__":
  for epoch in range(EPOCH):
    sum_loss = 0.0
    # 數(shù)據(jù)讀取
    for i, data in enumerate(trainloader):
      inputs, labels = data
      inputs, labels = inputs.to(device), labels.to(device)
      # 梯度清零
      optimizer.zero_grad()
      # forward + backward
      outputs = net(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      # 每訓(xùn)練100個(gè)batch打印一次平均loss
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d, %d] loss: %.03f'
           % (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0
    # 每跑完一次epoch測(cè)試一下準(zhǔn)確率
    with torch.no_grad():
      correct = 0
      total = 0
      for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        # 取得分最高的那個(gè)類
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
      print('第%d個(gè)epoch的識(shí)別準(zhǔn)確率為:%d%%' % (epoch + 1, (100 * correct / total)))
  #torch.save(net.state_dict(), '%s/net_%03d.pth' % (opt.outf, epoch + 1))

pytorch的優(yōu)點(diǎn)

1.PyTorch是相當(dāng)簡(jiǎn)潔且高效快速的框架;2.設(shè)計(jì)追求最少的封裝;3.設(shè)計(jì)符合人類思維,它讓用戶盡可能地專注于實(shí)現(xiàn)自己的想法;4.與google的Tensorflow類似,F(xiàn)AIR的支持足以確保PyTorch獲得持續(xù)的開發(fā)更新;5.PyTorch作者親自維護(hù)的論壇 供用戶交流和求教問題6.入門簡(jiǎn)單

感謝各位的閱讀!關(guān)于“Pytorch如何實(shí)現(xiàn)手寫數(shù)字mnist識(shí)別功能”這篇文章就分享到這里了,希望以上內(nèi)容可以對(duì)大家有一定的幫助,讓大家可以學(xué)到更多知識(shí),如果覺得文章不錯(cuò),可以把它分享出去讓更多的人看到吧!

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI