溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點(diǎn)擊 登錄注冊 即表示同意《億速云用戶服務(wù)條款》

pytorch如何實(shí)現(xiàn)模型剪枝

發(fā)布時(shí)間:2023-02-24 15:46:45 來源:億速云 閱讀:148 作者:iii 欄目:開發(fā)技術(shù)

這篇文章主要介紹“pytorch如何實(shí)現(xiàn)模型剪枝”的相關(guān)知識(shí),小編通過實(shí)際案例向大家展示操作過程,操作方法簡單快捷,實(shí)用性強(qiáng),希望這篇“pytorch如何實(shí)現(xiàn)模型剪枝”文章能幫助大家解決問題。

    一,剪枝分類

    所謂模型剪枝,其實(shí)是一種從神經(jīng)網(wǎng)絡(luò)中移除"不必要"權(quán)重或偏差(weigths/bias)的模型壓縮技術(shù)。關(guān)于什么參數(shù)才是“不必要的”,這是一個(gè)目前依然在研究的領(lǐng)域。

    1.1,非結(jié)構(gòu)化剪枝

    非結(jié)構(gòu)化剪枝(Unstructured Puning)是指修剪參數(shù)的單個(gè)元素,比如全連接層中的單個(gè)權(quán)重、卷積層中的單個(gè)卷積核參數(shù)元素或者自定義層中的浮點(diǎn)數(shù)(scaling floats)。其重點(diǎn)在于,剪枝權(quán)重對象是隨機(jī)的,沒有特定結(jié)構(gòu),因此被稱為非結(jié)構(gòu)化剪枝。

    1.2,結(jié)構(gòu)化剪枝

    與非結(jié)構(gòu)化剪枝相反,結(jié)構(gòu)化剪枝會(huì)剪枝整個(gè)參數(shù)結(jié)構(gòu)。比如,丟棄整行或整列的權(quán)重,或者在卷積層中丟棄整個(gè)過濾器(Filter)。

    1.3,本地與全局修剪

    剪枝可以在每層(局部)或多層/所有層(全局)上進(jìn)行。

    二,PyTorch 的剪枝

    目前 PyTorch 框架支持的權(quán)重剪枝方法有:

    • Random: 簡單地修剪隨機(jī)參數(shù)。

    • Magnitude: 修剪權(quán)重最小的參數(shù)(例如它們的 L2 范數(shù))

    以上兩種方法實(shí)現(xiàn)簡單、計(jì)算容易,且可以在沒有任何數(shù)據(jù)的情況下應(yīng)用。

    2.1,pytorch 剪枝工作原理

    剪枝功能在 torch.nn.utils.prune 類中實(shí)現(xiàn),代碼在文件 torch/nn/utils/prune.py 中,主要剪枝類如下圖所示。

    pytorch如何實(shí)現(xiàn)模型剪枝

    剪枝原理是基于張量(Tensor)的掩碼(Mask)實(shí)現(xiàn)。掩碼是一個(gè)與張量形狀相同的布爾類型的張量,掩碼的值為 True 表示相應(yīng)位置的權(quán)重需要保留,掩碼的值為 False 表示相應(yīng)位置的權(quán)重可以被刪除。

    Pytorch 將原始參數(shù) <param> 復(fù)制到名為 <param>_original 的參數(shù)中,并創(chuàng)建一個(gè)緩沖區(qū)來存儲(chǔ)剪枝掩碼 <param>_mask。同時(shí),其也會(huì)創(chuàng)建一個(gè)模塊級(jí)的 forward_pre_hook 回調(diào)函數(shù)(在模型前向傳播之前會(huì)被調(diào)用的回調(diào)函數(shù)),將剪枝掩碼應(yīng)用于原始權(quán)重。

    pytorch 剪枝的 api 和教程比較混亂,我個(gè)人將做了如下表格,希望能將 api 和剪枝方法及分類總結(jié)好。

    pytorch如何實(shí)現(xiàn)模型剪枝

    pytorch 中進(jìn)行模型剪枝的工作流程如下:

    • 選擇剪枝方法(或者子類化 BasePruningMethod 實(shí)現(xiàn)自己的剪枝方法)。

    • 指定剪枝模塊和參數(shù)名稱。

    • 設(shè)置剪枝方法的參數(shù),比如剪枝比例等。

    2.2,局部剪枝

    Pytorch 框架中的局部剪枝有非結(jié)構(gòu)化和結(jié)構(gòu)化剪枝兩種類型,值得注意的是結(jié)構(gòu)化剪枝只支持局部不支持全局。

    2.2.1,局部非結(jié)構(gòu)化剪枝

    1,局部非結(jié)構(gòu)化剪枝(Locall Unstructured Pruning)對應(yīng)函數(shù)原型如下:

    def random_unstructured(module, name, amount)

    1,函數(shù)功能:

    用于對權(quán)重參數(shù)張量進(jìn)行非結(jié)構(gòu)化剪枝。該方法會(huì)在張量中隨機(jī)選擇一些權(quán)重或連接進(jìn)行剪枝,剪枝率由用戶指定。

    2,函數(shù)參數(shù)定義:

    • module (nn.Module): 需要剪枝的網(wǎng)絡(luò)層/模塊,例如 nn.Conv2d() 和 nn.Linear()。

    • name (str): 要剪枝的參數(shù)名稱,比如 "weight" 或 "bias"。

    • amount (int or float): 指定要剪枝的數(shù)量,如果是 0~1 之間的小數(shù),則表示剪枝比例;如果是證書,則直接剪去參數(shù)的絕對數(shù)量。比如amount=0.2 ,表示將隨機(jī)選擇 20% 的元素進(jìn)行剪枝。

    3,下面是 random_unstructured 函數(shù)的使用示例。

    import torch
    import torch.nn.utils.prune as prune
    conv = torch.nn.Conv2d(1, 1, 4)
    prune.random_unstructured(conv, name="weight", amount=0.5)
    conv.weight
    """
    tensor([[[[-0.1703,  0.0000, -0.0000,  0.0690],
              [ 0.1411,  0.0000, -0.0000, -0.1031],
              [-0.0527,  0.0000,  0.0640,  0.1666],
              [ 0.0000, -0.0000, -0.0000,  0.2281]]]], grad_fn=<MulBackward0>)
    """

    可以看書輸出的 conv 層中權(quán)重值有一半比例為 0。

    2.2.2,局部結(jié)構(gòu)化剪枝

    局部結(jié)構(gòu)化剪枝(Locall Structured Pruning)有兩種函數(shù),對應(yīng)函數(shù)原型如下:

    def random_structured(module, name, amount, dim)
    def ln_structured(module, name, amount, n, dim, importance_scores=None)

    1,函數(shù)功能

    與非結(jié)構(gòu)化移除的是連接權(quán)重不同,結(jié)構(gòu)化剪枝移除的是整個(gè)通道權(quán)重。

    2,參數(shù)定義

    與局部非結(jié)構(gòu)化函數(shù)非常相似,唯一的區(qū)別是您必須定義 dim 參數(shù)(ln_structured 函數(shù)多了 n 參數(shù))。

    n 表示剪枝的范數(shù),dim 表示剪枝的維度。

    對于 torch.nn.Linear:

    • dim = 0: 移除一個(gè)神經(jīng)元。

    • dim = 1:移除與一個(gè)輸入的所有連接。

    對于 torch.nn.Conv2d:

    • dim = 0(Channels) : 通道 channels 剪枝/過濾器 filters 剪枝

    • dim = 1(Neurons): 二維卷積核 kernel 剪枝,即與輸入通道相連接的 kernel

    2.2.3,局部結(jié)構(gòu)化剪枝示例代碼

    在寫示例代碼之前,我們先需要理解 Conv2d 函數(shù)參數(shù)、卷積核 shape、軸以及張量的關(guān)系。

    首先,Conv2d 函數(shù)原型如下;

    class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)

    而 pytorch 中常規(guī)卷積的卷積核權(quán)重 shape 都為(C_out, C_in, kernel_height, kernel_width),所以在代碼中卷積層權(quán)重 shape[3, 2, 3, 3],dim = 0 對應(yīng)的是 shape [3, 2, 3, 3] 中的 3。這里我們 dim 設(shè)定了哪個(gè)軸,那自然剪枝之后權(quán)重張量對應(yīng)的軸機(jī)會(huì)發(fā)生變換。

    pytorch如何實(shí)現(xiàn)模型剪枝

    理解了前面的關(guān)鍵概念,下面就可以實(shí)際使用了,dim=0 的示例如下所示。

    conv = torch.nn.Conv2d(2, 3, 3)
    norm1 = torch.norm(conv.weight, p=1, dim=[1,2,3])
    print(norm1)
    """
    tensor([1.9384, 2.3780, 1.8638], grad_fn=<NormBackward1>)
    """
    prune.ln_structured(conv, name="weight", amount=1, n=2, dim=0)
    print(conv.weight)
    """
    tensor([[[[-0.0005,  0.1039,  0.0306],
              [ 0.1233,  0.1517,  0.0628],
              [ 0.1075, -0.0606,  0.1140]],
     
             [[ 0.2263, -0.0199,  0.1275],
              [-0.0455, -0.0639, -0.2153],
              [ 0.1587, -0.1928,  0.1338]]],
     
     
            [[[-0.2023,  0.0012,  0.1617],
              [-0.1089,  0.2102, -0.2222],
              [ 0.0645, -0.2333, -0.1211]],
     
             [[ 0.2138, -0.0325,  0.0246],
              [-0.0507,  0.1812, -0.2268],
              [-0.1902,  0.0798,  0.0531]]],
     
     
            [[[ 0.0000, -0.0000, -0.0000],
              [ 0.0000, -0.0000, -0.0000],
              [ 0.0000, -0.0000,  0.0000]],
     
             [[ 0.0000,  0.0000,  0.0000],
              [-0.0000,  0.0000,  0.0000],
              [-0.0000, -0.0000, -0.0000]]]], grad_fn=<MulBackward0>)
    """

    從運(yùn)行結(jié)果可以明顯看出,卷積層參數(shù)的最后一個(gè)通道參數(shù)張量被移除了(為 0 張量),其解釋參見下圖。

    pytorch如何實(shí)現(xiàn)模型剪枝

    dim = 1 的情況:

    conv = torch.nn.Conv2d(2, 3, 3)
    norm1 = torch.norm(conv.weight, p=1, dim=[0, 2,3])
    print(norm1)
    """
    tensor([3.1487, 3.9088], grad_fn=<NormBackward1>)
    """
    prune.ln_structured(conv, name="weight", amount=1, n=2, dim=1)
    print(conv.weight)
    """
    tensor([[[[ 0.0000, -0.0000, -0.0000],
              [-0.0000,  0.0000,  0.0000],
              [-0.0000,  0.0000, -0.0000]],
     
             [[-0.2140,  0.1038,  0.1660],
              [ 0.1265, -0.1650, -0.2183],
              [-0.0680,  0.2280,  0.2128]]],
     
     
            [[[-0.0000,  0.0000,  0.0000],
              [ 0.0000,  0.0000, -0.0000],
              [-0.0000, -0.0000, -0.0000]],
     
             [[-0.2087,  0.1275,  0.0228],
              [-0.1888, -0.1345,  0.1826],
              [-0.2312, -0.1456, -0.1085]]],
     
     
            [[[-0.0000,  0.0000,  0.0000],
              [ 0.0000, -0.0000,  0.0000],
              [ 0.0000, -0.0000,  0.0000]],
     
             [[-0.0891,  0.0946, -0.1724],
              [-0.2068,  0.0823,  0.0272],
              [-0.2256, -0.1260, -0.0323]]]], grad_fn=<MulBackward0>)
    """

    很明顯,對于 dim=1的維度,其第一個(gè)張量的 L2 范數(shù)更小,所以shape 為 [2, 3, 3] 的張量中,第一個(gè) [3, 3] 張量參數(shù)會(huì)被移除(即張量為 0 矩陣) 。

    2.3,全局非結(jié)構(gòu)化剪枝

    前文的 local 剪枝的對象是特定網(wǎng)絡(luò)層,而 global 剪枝是將模型看作一個(gè)整體去移除指定比例(數(shù)量)的參數(shù),同時(shí) global 剪枝結(jié)果會(huì)導(dǎo)致模型中每層的稀疏比例是不一樣的。

    全局非結(jié)構(gòu)化剪枝函數(shù)原型如下:

    # v1.4.0 版本
    def global_unstructured(parameters, pruning_method, **kwargs)
    # v2.0.0-rc2版本
    def global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs):

    1,函數(shù)功能:

    隨機(jī)選擇全局所有參數(shù)(包括權(quán)重和偏置)的一部分進(jìn)行剪枝,而不管它們屬于哪個(gè)層。

    2,參數(shù)定義:

    • parameters((Iterable of (module, name) tuples)): 修剪模型的參數(shù)列表,列表中的元素是 (module, name)。

    • pruning_method(function): 目前好像官方只支持 pruning_method=prune.L1Unstuctured,另外也可以是自己實(shí)現(xiàn)的非結(jié)構(gòu)化剪枝方法函數(shù)。

    • importance_scores: 表示每個(gè)參數(shù)的重要性得分,如果為 None,則使用默認(rèn)得分。

    • **kwargs: 表示傳遞給特定剪枝方法的額外參數(shù)。比如 amount 指定要剪枝的數(shù)量。

    3,global_unstructured 函數(shù)的示例代碼如下所示。

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     
    class LeNet(nn.Module):
        def __init__(self):
            super(LeNet, self).__init__()
            # 1 input image channel, 6 output channels, 3x3 square conv kernel
            self.conv1 = nn.Conv2d(1, 6, 3)
            self.conv2 = nn.Conv2d(6, 16, 3)
            self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
            self.fc2 = nn.Linear(120, 84)
            self.fc3 = nn.Linear(84, 10)
     
        def forward(self, x):
            x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
            x = F.max_pool2d(F.relu(self.conv2(x)), 2)
            x = x.view(-1, int(x.nelement() / x.shape[0]))
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
     
    model = LeNet().to(device=device)
     
    model = LeNet()
     
    parameters_to_prune = (
        (model.conv1, 'weight'),
        (model.conv2, 'weight'),
        (model.fc1, 'weight'),
        (model.fc2, 'weight'),
        (model.fc3, 'weight'),
    )
     
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=0.2,
    )
    # 計(jì)算卷積層和整個(gè)模型的稀疏度
    # 其實(shí)調(diào)用的是 Tensor.numel 內(nèi)內(nèi)函數(shù),返回輸入張量中元素的總數(shù)
    print(
        "Sparsity in conv1.weight: {:.2f}%".format(
            100. * float(torch.sum(model.conv1.weight == 0))
            / float(model.conv1.weight.nelement())
        )
    )
    print(
        "Global sparsity: {:.2f}%".format(
            100. * float(
                torch.sum(model.conv1.weight == 0)
                + torch.sum(model.conv2.weight == 0)
                + torch.sum(model.fc1.weight == 0)
                + torch.sum(model.fc2.weight == 0)
                + torch.sum(model.fc3.weight == 0)
            )
            / float(
                model.conv1.weight.nelement()
                + model.conv2.weight.nelement()
                + model.fc1.weight.nelement()
                + model.fc2.weight.nelement()
                + model.fc3.weight.nelement()
            )
        )
    )
    # 程序運(yùn)行結(jié)果
    """
    Sparsity in conv1.weight: 3.70%
    Global sparsity: 20.00%
    """

    運(yùn)行結(jié)果表明,雖然模型整體(全局)的稀疏度是 20%,但每個(gè)網(wǎng)絡(luò)層的稀疏度不一定是 20%。

    關(guān)于“pytorch如何實(shí)現(xiàn)模型剪枝”的內(nèi)容就介紹到這里了,感謝大家的閱讀。如果想了解更多行業(yè)相關(guān)的知識(shí),可以關(guān)注億速云行業(yè)資訊頻道,小編每天都會(huì)為大家更新不同的知識(shí)點(diǎn)。

    向AI問一下細(xì)節(jié)

    免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

    AI