溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點(diǎn)擊 登錄注冊 即表示同意《億速云用戶服務(wù)條款》

淺析PyTorch中nn.Module的使用

發(fā)布時(shí)間:2020-10-12 21:46:24 來源:腳本之家 閱讀:100 作者:Steven·簡談 欄目:開發(fā)技術(shù)

torch.nn.Modules 相當(dāng)于是對網(wǎng)絡(luò)某種層的封裝,包括網(wǎng)絡(luò)結(jié)構(gòu)以及網(wǎng)絡(luò)參數(shù)和一些操作

torch.nn.Module 是所有神經(jīng)網(wǎng)絡(luò)單元的基類

查看源碼

初始化部分:

def __init__(self):
  self._backend = thnn_backend
  self._parameters = OrderedDict()
  self._buffers = OrderedDict()
  self._backward_hooks = OrderedDict()
  self._forward_hooks = OrderedDict()
  self._forward_pre_hooks = OrderedDict()
  self._state_dict_hooks = OrderedDict()
  self._load_state_dict_pre_hooks = OrderedDict()
  self._modules = OrderedDict()
  self.training = True
 

屬性解釋:

  • _parameters:字典,保存用戶直接設(shè)置的 Parameter
  • _modules:子 module,即子類構(gòu)造函數(shù)中的內(nèi)容
  • _buffers:緩存
  • _backward_hooks與_forward_hooks:鉤子技術(shù),用來提取中間變量
  • training:判斷值來決定前向傳播策略

方法定義:

def forward(self, *input):
 raise NotImplementedError
 

沒有實(shí)際內(nèi)容,用于被子類的 forward() 方法覆蓋

且 forward 方法在 __call__ 方法中被調(diào)用:

def __call__(self, *input, **kwargs):
 for hook in self._forward_pre_hooks.values():
    hook(self, input)
  if torch._C._get_tracing_state():
    result = self._slow_forward(*input, **kwargs)
  else:
    result = self.forward(*input, **kwargs)
  ...
  ...
 

實(shí)例展示

簡單搭建:

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
  def __init__(self, n_feature, n_hidden, n_output):
    super(Net, self).__init__()
    self.hidden = nn.Linear(n_feature, n_hidden)
    self.out = nn.Linear(n_hidden, n_output)

  def forward(self, x):
    x = F.relu(self.hidden(x))
    x = self.out(x)
    return x

Net 類繼承了 torch 的 Module 和 __init__ 功能

hidden 是隱藏層線性輸出

out 是輸出層線性輸出

打印出網(wǎng)絡(luò)的結(jié)構(gòu):

>>> net = Net(n_feature=10, n_hidden=30, n_output=15)
>>> print(net)
Net(
 (hidden): Linear(in_features=10, out_features=30, bias=True)
 (out): Linear(in_features=30, out_features=15, bias=True)
)

以上就是本文的全部內(nèi)容,希望對大家的學(xué)習(xí)有所幫助,也希望大家多多支持億速云。

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI