溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務(wù)條款》

pytorch中的hook機(jī)制是什么

發(fā)布時間:2022-03-09 13:43:04 來源:億速云 閱讀:174 作者:iii 欄目:開發(fā)技術(shù)

本篇內(nèi)容介紹了“pytorch中的hook機(jī)制是什么”的有關(guān)知識,在實際案例的操作過程中,不少人都會遇到這樣的困境,接下來就讓小編帶領(lǐng)大家學(xué)習(xí)一下如何處理這些情況吧!希望大家仔細(xì)閱讀,能夠?qū)W有所成!

    1、hook背景

    Hook被成為鉤子機(jī)制,這不是pytorch的首創(chuàng),在Windows的編程中已經(jīng)被普遍采用,包括進(jìn)程內(nèi)鉤子和全局鉤子。按照自己的理解,hook的作用是通過系統(tǒng)來維護(hù)一個鏈表,使得用戶攔截(獲?。┩ㄐ畔ⅲ糜谔幚硎录?。

    pytorch中包含forwardbackward兩個鉤子注冊函數(shù),用于獲取forward和backward中輸入和輸出,按照自己不全面的理解,應(yīng)該目的是“不改變網(wǎng)絡(luò)的定義代碼,也不需要在forward函數(shù)中return某個感興趣層的輸出,這樣代碼太冗雜了”。

    2、源碼閱讀

    register_forward_hook()函數(shù)必須在forward()函數(shù)調(diào)用之前被使用,因為這個函數(shù)源碼注釋顯示這個函數(shù)“ it will not have effect on forward since this is called after :func:`forward` is called”,也就是這個函數(shù)在forward()之后就沒有作用了?。。。?/p>

    作用:獲取forward過程中每層的輸入和輸出,用于對比hook是不是正確記錄。

    def register_forward_hook(self, hook):
            r"""Registers a forward hook on the module.
            The hook will be called every time after :func:`forward` has computed an output.
            It should have the following signature::
                hook(module, input, output) -> None or modified output
            The hook can modify the output. It can modify the input inplace but
            it will not have effect on forward since this is called after
            :func:`forward` is called.
    
            Returns:
                :class:`torch.utils.hooks.RemovableHandle`:
                    a handle that can be used to remove the added hook by calling
                    ``handle.remove()``
            """
            handle = hooks.RemovableHandle(self._forward_hooks)
            self._forward_hooks[handle.id] = hook
            return handle

    3、定義一個用于測試hooker的類

    如果隨機(jī)的初始化每個層,那么就無法測試出自己獲取的輸入輸出是不是forward中的輸入輸出了,所以需要將每一層的權(quán)重和偏置設(shè)置為可識別的值(比如全部初始化為1)。網(wǎng)絡(luò)包含兩層(Linear有需要求導(dǎo)的參數(shù)被稱為一個層,而ReLU沒有需要求導(dǎo)的參數(shù)不被稱作一層),__init__()中調(diào)用initialize函數(shù)對所有層進(jìn)行初始化。

    注意:在forward()函數(shù)返回各個層的輸出,但是ReLU6沒有返回,因為后續(xù)測試的時候不對這一層進(jìn)行注冊hook。

    class TestForHook(nn.Module):
        def __init__(self):
            super().__init__()
    
            self.linear_1 = nn.Linear(in_features=2, out_features=2)
            self.linear_2 = nn.Linear(in_features=2, out_features=1)
            self.relu = nn.ReLU()
            self.relu6 = nn.ReLU6()
            self.initialize()
    
        def forward(self, x):
            linear_1 = self.linear_1(x)
            linear_2 = self.linear_2(linear_1)
            relu = self.relu(linear_2)
            relu_6 = self.relu6(relu)
            layers_in = (x, linear_1, linear_2)
            layers_out = (linear_1, linear_2, relu)
            return relu_6, layers_in, layers_out
        def initialize(self):
            """ 定義特殊的初始化,用于驗證是不是獲取了權(quán)重"""
            self.linear_1.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1], [1, 1]]))
            self.linear_1.bias = torch.nn.Parameter(torch.FloatTensor([1, 1]))
            self.linear_2.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1]]))
            self.linear_2.bias = torch.nn.Parameter(torch.FloatTensor([1]))
            return True

    4、定義hook函數(shù)

    hook()函數(shù)是register_forward_hook()函數(shù)必須提供的參數(shù),好處是“用戶可以自行決定攔截了中間信息之后要做什么!”,比如自己想單純的記錄網(wǎng)絡(luò)的輸入輸出(也可以進(jìn)行修改等更加復(fù)雜的操作)。

    首先定義幾個容器用于記錄:

    定義用于獲取網(wǎng)絡(luò)各層輸入輸出tensor的容器:

    # 并定義module_name用于記錄相應(yīng)的module名字
    module_name = []
    features_in_hook = []
    features_out_hook = []
    hook函數(shù)需要三個參數(shù),這三個參數(shù)是系統(tǒng)傳給hook函數(shù)的,自己不能修改這三個參數(shù):

    hook函數(shù)負(fù)責(zé)將獲取的輸入輸出添加到feature列表中;并提供相應(yīng)的module名字

    def hook(module, fea_in, fea_out):
        print("hooker working")
        module_name.append(module.__class__)
        features_in_hook.append(fea_in)
        features_out_hook.append(fea_out)
        return None

    5、對需要的層注冊hook

    注冊鉤子必須在forward()函數(shù)被執(zhí)行之前,也就是定義網(wǎng)絡(luò)進(jìn)行計算之前就要注冊,下面的代碼對網(wǎng)絡(luò)除去ReLU6以外的層都進(jìn)行了注冊(也可以選定某些層進(jìn)行注冊):

    注冊鉤子可以對某些層單獨進(jìn)行:

    net = TestForHook()
    net_chilren = net.children()
    for child in net_chilren:
        if not isinstance(child, nn.ReLU6):
            child.register_forward_hook(hook=hook)

    6、測試forward()返回的特征和hook記錄的是否一致

    6.1 測試forward()提供的輸入輸出特征

    由于前面的forward()函數(shù)返回了需要記錄的特征,這里可以直接測試:

    out, features_in_forward, features_out_forward = net(x)
    print("*"*5+"forward return features"+"*"*5)
    print(features_in_forward)
    print(features_out_forward)
    print("*"*5+"forward return features"+"*"*5)

    得到下面的輸出是理所當(dāng)然的:

    *****forward return features*****
    (tensor([[0.1000, 0.1000],
            [0.1000, 0.1000]]), tensor([[1.2000, 1.2000],
            [1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
            [3.4000]], grad_fn=<AddmmBackward>))
    (tensor([[1.2000, 1.2000],
            [1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
            [3.4000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
            [3.4000]], grad_fn=<ThresholdBackward0>))
    *****forward return features*****

    6.2 hook記錄的輸入特征和輸出特征

    hook通過list結(jié)構(gòu)進(jìn)行記錄,所以可以直接print

    測試features_in是不是存儲了輸入:

    print("*"*5+"hook record features"+"*"*5)
    print(features_in_hook)
    print(features_out_hook)
    print(module_name)
    print("*"*5+"hook record features"+"*"*5)

    得到和forward一樣的結(jié)果:

    *****hook record features*****
    [(tensor([[0.1000, 0.1000],
            [0.1000, 0.1000]]),), (tensor([[1.2000, 1.2000],
            [1.2000, 1.2000]], grad_fn=<AddmmBackward>),), (tensor([[3.4000],
            [3.4000]], grad_fn=<AddmmBackward>),)]
    [tensor([[1.2000, 1.2000],
            [1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
            [3.4000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
            [3.4000]], grad_fn=<ThresholdBackward0>)]
    [<class 'torch.nn.modules.linear.Linear'>, 
    <class 'torch.nn.modules.linear.Linear'>,
     <class 'torch.nn.modules.activation.ReLU'>]
    *****hook record features*****

    6.3 把hook記錄的和forward做減法

    如果害怕會有小數(shù)點后面的數(shù)值不一致,或者數(shù)據(jù)類型的不匹配,可以對hook記錄的特征和forward記錄的特征做減法:

    測試forward返回的feautes_in是不是和hook記錄的一致:

    print("sub result'")
    for forward_return, hook_record in zip(features_in_forward, features_in_hook):
        print(forward_return-hook_record[0])

    得到的全部都是0,說明hook沒問題:

    sub result
    tensor([[0., 0.],
            [0., 0.]])
    tensor([[0., 0.],
            [0., 0.]], grad_fn=<SubBackward0>)
    tensor([[0.],
            [0.]], grad_fn=<SubBackward0>)

    7、完整代碼

    import torch
    import torch.nn as nn
    
    
    class TestForHook(nn.Module):
        def __init__(self):
            super().__init__()
    
            self.linear_1 = nn.Linear(in_features=2, out_features=2)
            self.linear_2 = nn.Linear(in_features=2, out_features=1)
            self.relu = nn.ReLU()
            self.relu6 = nn.ReLU6()
            self.initialize()
    
        def forward(self, x):
            linear_1 = self.linear_1(x)
            linear_2 = self.linear_2(linear_1)
            relu = self.relu(linear_2)
            relu_6 = self.relu6(relu)
            layers_in = (x, linear_1, linear_2)
            layers_out = (linear_1, linear_2, relu)
            return relu_6, layers_in, layers_out
    
        def initialize(self):
            """ 定義特殊的初始化,用于驗證是不是獲取了權(quán)重"""
            self.linear_1.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1], [1, 1]]))
            self.linear_1.bias = torch.nn.Parameter(torch.FloatTensor([1, 1]))
            self.linear_2.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1]]))
            self.linear_2.bias = torch.nn.Parameter(torch.FloatTensor([1]))
            return True

    定義用于獲取網(wǎng)絡(luò)各層輸入輸出tensor的容器,并定義module_name用于記錄相應(yīng)的module名字

    module_name = []
    features_in_hook = []
    features_out_hook = []

    hook函數(shù)負(fù)責(zé)將獲取的輸入輸出添加到feature列表中,并提供相應(yīng)的module名字

    def hook(module, fea_in, fea_out):
        print("hooker working")
        module_name.append(module.__class__)
        features_in_hook.append(fea_in)
        features_out_hook.append(fea_out)
        return None

    定義全部是1的輸入:

    x = torch.FloatTensor([[0.1, 0.1], [0.1, 0.1]])

    注冊鉤子可以對某些層單獨進(jìn)行:

    net = TestForHook()
    net_chilren = net.children()
    for child in net_chilren:
        if not isinstance(child, nn.ReLU6):
            child.register_forward_hook(hook=hook)

    測試網(wǎng)絡(luò)輸出:

    out, features_in_forward, features_out_forward = net(x)
    print("*"*5+"forward return features"+"*"*5)
    print(features_in_forward)
    print(features_out_forward)
    print("*"*5+"forward return features"+"*"*5)

    測試features_in是不是存儲了輸入:

    print("*"*5+"hook record features"+"*"*5)
    print(features_in_hook)
    print(features_out_hook)
    print(module_name)
    print("*"*5+"hook record features"+"*"*5)

    測試forward返回的feautes_in是不是和hook記錄的一致:

    print("sub result")
    for forward_return, hook_record in zip(features_in_forward, features_in_hook):
        print(forward_return-hook_record[0])

    “pytorch中的hook機(jī)制是什么”的內(nèi)容就介紹到這里了,感謝大家的閱讀。如果想了解更多行業(yè)相關(guān)的知識可以關(guān)注億速云網(wǎng)站,小編將為大家輸出更多高質(zhì)量的實用文章!

    向AI問一下細(xì)節(jié)

    免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報,并提供相關(guān)證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權(quán)內(nèi)容。

    AI