溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶服務(wù)條款》

Pytorch——回歸問題

發(fā)布時(shí)間:2020-08-11 18:55:38 來源:ITPUB博客 閱讀:167 作者:ckxllf 欄目:編程語言

  1.前言

  我會(huì)這次會(huì)來見證神經(jīng)網(wǎng)絡(luò)是如何通過簡(jiǎn)單的形式將一群數(shù)據(jù)用一條線條來表示. 或者說, 是如何在數(shù)據(jù)當(dāng)中找到他們的關(guān)系, 然后用神經(jīng)網(wǎng)絡(luò)模型來建立一個(gè)可以代表他們關(guān)系的線條.

  2.數(shù)據(jù)準(zhǔn)備

  我們創(chuàng)建一些假數(shù)據(jù)來模擬真實(shí)的情況. 比如一個(gè)一元二次函數(shù): y = a * x^2 + b, 我們給 y 數(shù)據(jù)加上一點(diǎn)噪聲來更加真實(shí)的展示它.

  import torch

  import matplotlib.pyplot as plt

  #制造一些數(shù)據(jù)

  x = torch.unsqueeze(torch.linspace(-1,1,100),dim = 1) #torch.Size([100, 1]) #把[a,b,c]變成[[a,b,c]]

  #print(x)

  y = 2*(x.pow(2)) + 0.5*torch.rand(x.size()) #torch.rand為均勻分布,返回一個(gè)張量,包含了從區(qū)間[0, 1)的均勻分布中抽取的一組隨機(jī)數(shù)。張量的形狀由參數(shù)sizes定義

  #print(y)

  #畫圖

  plt.scatter(x.data.numpy(),y.data.numpy())

  plt.show()

  3.搭建神經(jīng)網(wǎng)絡(luò)

  建立一個(gè)神經(jīng)網(wǎng)絡(luò)我們可以直接運(yùn)用 torch 中的體系. 先定義所有的層屬性(init()), 然后再一層層搭建(forward(x))層于層的關(guān)系鏈接. 建立關(guān)系的時(shí)候, 我們會(huì)用到激勵(lì)函數(shù)

  from torch import nn

  import torch.nn.functional as F

  class NetWork(nn.Module):

  def __init__(self,n_input,n_hidden,n_output):

  super(NetWork,self).__init__()

  self.hidden = nn.Linear(n_input,n_hidden)

  self.output_for_predict = nn.Linear(n_hidden,n_output)

  def forward(self,x):

  x = F.relu(self.hidden(x)) #對(duì)x進(jìn)入隱層后的輸出應(yīng)用激活函數(shù)(相當(dāng)于一個(gè)篩選的過程)

  output = self.output_for_predict(x) #做線性變換,將維度為1

  return output

  network = NetWork(n_input = 1,n_hidden = 8, n_output = 1)

  print(network) #打印模型的層次結(jié)構(gòu)

  4.訓(xùn)練搭建的神經(jīng)網(wǎng)絡(luò)

  訓(xùn)練的步驟很簡(jiǎn)單, 如下:

  from torch import nn

  import torch.nn.functional as F

  class NetWork(nn.Module):

  def __init__(self,n_input,n_hidden,n_output):

  super(NetWork,self).__init__()

  self.hidden = nn.Linear(n_input,n_hidden)

  self.output_for_predict = nn.Linear(n_hidden,n_output)

  def forward(self,x):

  x = F.relu(self.hidden(x)) #對(duì)x進(jìn)入隱層后的輸出應(yīng)用激活函數(shù)(相當(dāng)于一個(gè)篩選的過程)

  output = self.output_for_predict(x) #做線性變換,將維度為1

  return output

  network = NetWork(n_input = 1,n_hidden = 8, n_output = 1)

  print(network) #打印模型的層次結(jié)構(gòu)

  optimizer = torch.optim.SGD(network.parameters(),lr = 0.2)

  criterion = torch.nn.MSELoss() #均方誤差,用于計(jì)算預(yù)測(cè)值與真實(shí)值之間的誤差

  for i in range(500): #訓(xùn)練步數(shù)(相當(dāng)于迭代次數(shù))

  predication = network(x)

  loss = criterion(predication, y) #predication為預(yù)測(cè)的值,y為真實(shí)值

  optimizer.zero_grad()

  loss.backward() #反向傳播,更新參數(shù)

  optimizer.step() #將更新的參數(shù)值放進(jìn)network的parameters

  5.可視化操作

  x = torch.unsqueeze(torch.linspace(-1,1,100),dim = 1) #torch.Size([100, 1]) #把[a,b,c]變成[[a,b,c]]

  #print(x) 鄭州哪里做人流好 http://www.kdrlyy.com/

  y = 2*(x.pow(2)) + 0.5*torch.rand(x.size()) #torch.rand為均勻分布,返回一個(gè)張量,包含了從區(qū)間[0, 1)的均勻分布中抽取的一組隨機(jī)數(shù)。張量的形狀由參數(shù)sizes定義

  #print(y)

  #畫圖

  # plt.scatter(x.data.numpy(),y.data.numpy())

  # plt.show()

  from torch import nn

  import torch.nn.functional as F

  class NetWork(nn.Module):

  def __init__(self,n_input,n_hidden,n_output):

  super(NetWork,self).__init__()

  self.hidden = nn.Linear(n_input,n_hidden)

  self.output_for_predict = nn.Linear(n_hidden,n_output)

  def forward(self,x):

  x = F.relu(self.hidden(x)) #對(duì)x進(jìn)入隱層后的輸出應(yīng)用激活函數(shù)(相當(dāng)于一個(gè)篩選的過程)

  output = self.output_for_predict(x) #做線性變換,將維度為1

  return output

  network = NetWork(n_input = 1,n_hidden = 8, n_output = 1)

  print(network) #打印模型的層次結(jié)構(gòu)

  plt.ion() # 打開交互模式

  plt.show()

  optimizer = torch.optim.SGD(network.parameters(),lr = 0.2)

  criterion = torch.nn.MSELoss() #均方誤差,用于計(jì)算預(yù)測(cè)值與真實(shí)值之間的誤差

  for i in range(500): #訓(xùn)練步數(shù)(相當(dāng)于迭代次數(shù))

  predication = network(x)

  loss = criterion(predication, y) #predication為預(yù)測(cè)的值,y為真實(shí)值

  optimizer.zero_grad()

  loss.backward() #反向傳播,更新參數(shù)

  optimizer.step() #將更新的參數(shù)值放進(jìn)network的parameters

  if i % 10 == 0:

  plt.cla() # 清坐標(biāo)軸

  plt.scatter(x.data.numpy(),y.data.numpy())

  plt.plot(x.data.numpy(),predication.data.numpy(),'ro', lw=5) #畫預(yù)測(cè)曲線,用紅色o作為標(biāo)記

  plt.text(0.5,0,'Loss = %.4f' % loss.data.numpy(), fontdict = {'size': 20, 'color': 'red'})

  plt.pause(0.1)

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI