您好,登錄后才能下訂單哦!
本篇內(nèi)容介紹了“pytorch中的hook機(jī)制是什么”的有關(guān)知識,在實際案例的操作過程中,不少人都會遇到這樣的困境,接下來就讓小編帶領(lǐng)大家學(xué)習(xí)一下如何處理這些情況吧!希望大家仔細(xì)閱讀,能夠?qū)W有所成!
Hook
被成為鉤子機(jī)制,這不是pytorch的首創(chuàng),在Windows
的編程中已經(jīng)被普遍采用,包括進(jìn)程內(nèi)鉤子和全局鉤子。按照自己的理解,hook的作用是通過系統(tǒng)來維護(hù)一個鏈表,使得用戶攔截(獲?。┩ㄐ畔ⅲ糜谔幚硎录?。
pytorch中包含forward
和backward
兩個鉤子注冊函數(shù),用于獲取forward和backward中輸入和輸出,按照自己不全面的理解,應(yīng)該目的是“不改變網(wǎng)絡(luò)的定義代碼,也不需要在forward函數(shù)中return某個感興趣層的輸出,這樣代碼太冗雜了”。
register_forward_hook()
函數(shù)必須在forward()函數(shù)調(diào)用之前被使用,因為這個函數(shù)源碼注釋顯示這個函數(shù)“ it will not have effect on forward since this is called after :func:`forward` is called”,也就是這個函數(shù)在forward()之后就沒有作用了?。。。?/p>
作用:獲取forward過程中每層的輸入和輸出,用于對比hook是不是正確記錄。
def register_forward_hook(self, hook): r"""Registers a forward hook on the module. The hook will be called every time after :func:`forward` has computed an output. It should have the following signature:: hook(module, input, output) -> None or modified output The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func:`forward` is called. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` """ handle = hooks.RemovableHandle(self._forward_hooks) self._forward_hooks[handle.id] = hook return handle
如果隨機(jī)的初始化每個層,那么就無法測試出自己獲取的輸入輸出是不是forward
中的輸入輸出了,所以需要將每一層的權(quán)重和偏置設(shè)置為可識別的值(比如全部初始化為1)。網(wǎng)絡(luò)包含兩層(Linear有需要求導(dǎo)的參數(shù)被稱為一個層,而ReLU沒有需要求導(dǎo)的參數(shù)不被稱作一層),__init__()
中調(diào)用initialize
函數(shù)對所有層進(jìn)行初始化。
注意:在forward()函數(shù)返回各個層的輸出,但是ReLU6沒有返回,因為后續(xù)測試的時候不對這一層進(jìn)行注冊hook。
class TestForHook(nn.Module): def __init__(self): super().__init__() self.linear_1 = nn.Linear(in_features=2, out_features=2) self.linear_2 = nn.Linear(in_features=2, out_features=1) self.relu = nn.ReLU() self.relu6 = nn.ReLU6() self.initialize() def forward(self, x): linear_1 = self.linear_1(x) linear_2 = self.linear_2(linear_1) relu = self.relu(linear_2) relu_6 = self.relu6(relu) layers_in = (x, linear_1, linear_2) layers_out = (linear_1, linear_2, relu) return relu_6, layers_in, layers_out def initialize(self): """ 定義特殊的初始化,用于驗證是不是獲取了權(quán)重""" self.linear_1.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1], [1, 1]])) self.linear_1.bias = torch.nn.Parameter(torch.FloatTensor([1, 1])) self.linear_2.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1]])) self.linear_2.bias = torch.nn.Parameter(torch.FloatTensor([1])) return True
hook()
函數(shù)是register_forward_hook()
函數(shù)必須提供的參數(shù),好處是“用戶可以自行決定攔截了中間信息之后要做什么!”,比如自己想單純的記錄網(wǎng)絡(luò)的輸入輸出(也可以進(jìn)行修改等更加復(fù)雜的操作)。
首先定義幾個容器用于記錄:
定義用于獲取網(wǎng)絡(luò)各層輸入輸出tensor的容器:
# 并定義module_name用于記錄相應(yīng)的module名字 module_name = [] features_in_hook = [] features_out_hook = [] hook函數(shù)需要三個參數(shù),這三個參數(shù)是系統(tǒng)傳給hook函數(shù)的,自己不能修改這三個參數(shù):
hook函數(shù)負(fù)責(zé)將獲取的輸入輸出添加到feature列表中;并提供相應(yīng)的module名字
def hook(module, fea_in, fea_out): print("hooker working") module_name.append(module.__class__) features_in_hook.append(fea_in) features_out_hook.append(fea_out) return None
注冊鉤子必須在forward()函數(shù)被執(zhí)行之前,也就是定義網(wǎng)絡(luò)進(jìn)行計算之前就要注冊,下面的代碼對網(wǎng)絡(luò)除去ReLU6以外的層都進(jìn)行了注冊(也可以選定某些層進(jìn)行注冊):
注冊鉤子可以對某些層單獨進(jìn)行:
net = TestForHook() net_chilren = net.children() for child in net_chilren: if not isinstance(child, nn.ReLU6): child.register_forward_hook(hook=hook)
由于前面的forward()函數(shù)返回了需要記錄的特征,這里可以直接測試:
out, features_in_forward, features_out_forward = net(x) print("*"*5+"forward return features"+"*"*5) print(features_in_forward) print(features_out_forward) print("*"*5+"forward return features"+"*"*5)
得到下面的輸出是理所當(dāng)然的:
*****forward return features*****
(tensor([[0.1000, 0.1000],
[0.1000, 0.1000]]), tensor([[1.2000, 1.2000],
[1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
[3.4000]], grad_fn=<AddmmBackward>))
(tensor([[1.2000, 1.2000],
[1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
[3.4000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
[3.4000]], grad_fn=<ThresholdBackward0>))
*****forward return features*****
hook通過list結(jié)構(gòu)進(jìn)行記錄,所以可以直接print
測試features_in是不是存儲了輸入:
print("*"*5+"hook record features"+"*"*5) print(features_in_hook) print(features_out_hook) print(module_name) print("*"*5+"hook record features"+"*"*5)
得到和forward一樣的結(jié)果:
*****hook record features*****
[(tensor([[0.1000, 0.1000],
[0.1000, 0.1000]]),), (tensor([[1.2000, 1.2000],
[1.2000, 1.2000]], grad_fn=<AddmmBackward>),), (tensor([[3.4000],
[3.4000]], grad_fn=<AddmmBackward>),)]
[tensor([[1.2000, 1.2000],
[1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
[3.4000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
[3.4000]], grad_fn=<ThresholdBackward0>)]
[<class 'torch.nn.modules.linear.Linear'>,
<class 'torch.nn.modules.linear.Linear'>,
<class 'torch.nn.modules.activation.ReLU'>]
*****hook record features*****
如果害怕會有小數(shù)點后面的數(shù)值不一致,或者數(shù)據(jù)類型的不匹配,可以對hook
記錄的特征和forward記錄的特征做減法:
測試forward返回的feautes_in是不是和hook記錄的一致:
print("sub result'") for forward_return, hook_record in zip(features_in_forward, features_in_hook): print(forward_return-hook_record[0])
得到的全部都是0,說明hook沒問題:
sub result tensor([[0., 0.], [0., 0.]]) tensor([[0., 0.], [0., 0.]], grad_fn=<SubBackward0>) tensor([[0.], [0.]], grad_fn=<SubBackward0>)
import torch import torch.nn as nn class TestForHook(nn.Module): def __init__(self): super().__init__() self.linear_1 = nn.Linear(in_features=2, out_features=2) self.linear_2 = nn.Linear(in_features=2, out_features=1) self.relu = nn.ReLU() self.relu6 = nn.ReLU6() self.initialize() def forward(self, x): linear_1 = self.linear_1(x) linear_2 = self.linear_2(linear_1) relu = self.relu(linear_2) relu_6 = self.relu6(relu) layers_in = (x, linear_1, linear_2) layers_out = (linear_1, linear_2, relu) return relu_6, layers_in, layers_out def initialize(self): """ 定義特殊的初始化,用于驗證是不是獲取了權(quán)重""" self.linear_1.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1], [1, 1]])) self.linear_1.bias = torch.nn.Parameter(torch.FloatTensor([1, 1])) self.linear_2.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1]])) self.linear_2.bias = torch.nn.Parameter(torch.FloatTensor([1])) return True
定義用于獲取網(wǎng)絡(luò)各層輸入輸出tensor
的容器,并定義module_name
用于記錄相應(yīng)的module名字
module_name = [] features_in_hook = [] features_out_hook = []
hook函數(shù)負(fù)責(zé)將獲取的輸入輸出添加到feature列表中,并提供相應(yīng)的module名字
def hook(module, fea_in, fea_out): print("hooker working") module_name.append(module.__class__) features_in_hook.append(fea_in) features_out_hook.append(fea_out) return None
定義全部是1的輸入:
x = torch.FloatTensor([[0.1, 0.1], [0.1, 0.1]])
注冊鉤子可以對某些層單獨進(jìn)行:
net = TestForHook() net_chilren = net.children() for child in net_chilren: if not isinstance(child, nn.ReLU6): child.register_forward_hook(hook=hook)
測試網(wǎng)絡(luò)輸出:
out, features_in_forward, features_out_forward = net(x)
print("*"*5+"forward return features"+"*"*5)
print(features_in_forward)
print(features_out_forward)
print("*"*5+"forward return features"+"*"*5)
測試features_in是不是存儲了輸入:
print("*"*5+"hook record features"+"*"*5) print(features_in_hook) print(features_out_hook) print(module_name) print("*"*5+"hook record features"+"*"*5)
測試forward返回的feautes_in是不是和hook記錄的一致:
print("sub result")
for forward_return, hook_record in zip(features_in_forward, features_in_hook):
print(forward_return-hook_record[0])
“pytorch中的hook機(jī)制是什么”的內(nèi)容就介紹到這里了,感謝大家的閱讀。如果想了解更多行業(yè)相關(guān)的知識可以關(guān)注億速云網(wǎng)站,小編將為大家輸出更多高質(zhì)量的實用文章!
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報,并提供相關(guān)證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權(quán)內(nèi)容。