您好,登錄后才能下訂單哦!
torch.nn.Modules 相當(dāng)于是對網(wǎng)絡(luò)某種層的封裝,包括網(wǎng)絡(luò)結(jié)構(gòu)以及網(wǎng)絡(luò)參數(shù)和一些操作
torch.nn.Module 是所有神經(jīng)網(wǎng)絡(luò)單元的基類
查看源碼
初始化部分:
def __init__(self): self._backend = thnn_backend self._parameters = OrderedDict() self._buffers = OrderedDict() self._backward_hooks = OrderedDict() self._forward_hooks = OrderedDict() self._forward_pre_hooks = OrderedDict() self._state_dict_hooks = OrderedDict() self._load_state_dict_pre_hooks = OrderedDict() self._modules = OrderedDict() self.training = True
屬性解釋:
方法定義:
def forward(self, *input): raise NotImplementedError
沒有實(shí)際內(nèi)容,用于被子類的 forward() 方法覆蓋
且 forward 方法在 __call__ 方法中被調(diào)用:
def __call__(self, *input, **kwargs): for hook in self._forward_pre_hooks.values(): hook(self, input) if torch._C._get_tracing_state(): result = self._slow_forward(*input, **kwargs) else: result = self.forward(*input, **kwargs) ... ...
實(shí)例展示
簡單搭建:
import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): def __init__(self, n_feature, n_hidden, n_output): super(Net, self).__init__() self.hidden = nn.Linear(n_feature, n_hidden) self.out = nn.Linear(n_hidden, n_output) def forward(self, x): x = F.relu(self.hidden(x)) x = self.out(x) return x
Net 類繼承了 torch 的 Module 和 __init__ 功能
hidden 是隱藏層線性輸出
out 是輸出層線性輸出
打印出網(wǎng)絡(luò)的結(jié)構(gòu):
>>> net = Net(n_feature=10, n_hidden=30, n_output=15) >>> print(net) Net( (hidden): Linear(in_features=10, out_features=30, bias=True) (out): Linear(in_features=30, out_features=15, bias=True) )
以上就是本文的全部內(nèi)容,希望對大家的學(xué)習(xí)有所幫助,也希望大家多多支持億速云。
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。