溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務(wù)條款》

Pytorch中怎么實現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)訓(xùn)練量化

發(fā)布時間:2021-06-21 18:22:16 來源:億速云 閱讀:210 作者:Leah 欄目:大數(shù)據(jù)

這期內(nèi)容當(dāng)中小編將會給大家?guī)碛嘘P(guān)Pytorch中怎么實現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)訓(xùn)練量化,文章內(nèi)容豐富且以專業(yè)的角度為大家分析和敘述,閱讀完這篇文章希望大家可以有所收獲。

 

1. 前言

深度學(xué)習(xí)在移動端的應(yīng)用越來越廣泛,而移動端相對于GPU服務(wù)來講算力較低并且存儲空間也相對較小。基于這一點我們需要為移動端定制一些深度學(xué)習(xí)網(wǎng)絡(luò)來滿足我們的日常續(xù)需求,例如SqueezeNet,MobileNet,ShuffleNet等輕量級網(wǎng)絡(luò)就是專為移動端設(shè)計的。但除了在網(wǎng)絡(luò)方面進(jìn)行改進(jìn),模型剪枝和量化應(yīng)該算是最常用的優(yōu)化方法了。剪枝就是將訓(xùn)練好的「大模型」的不重要的通道刪除掉,在幾乎不影響準(zhǔn)確率的條件下對網(wǎng)絡(luò)進(jìn)行加速。而量化就是將浮點數(shù)(高精度)表示的權(quán)重和偏置用低精度整數(shù)(常用的有INT8)來近似表示,在量化到低精度之后就可以應(yīng)用移動平臺上的優(yōu)化技術(shù)如NEON對計算過程進(jìn)行加速,并且原始模型量化后的模型容量也會減少,使其能夠更好的應(yīng)用到移動端環(huán)境。但需要注意的問題是,將高精度模型量化到低精度必然會存在一個精度下降的問題,如何獲取性能和精度的TradeOff很關(guān)鍵。

這篇文章是介紹使用Pytorch復(fù)現(xiàn)這篇論文:https://arxiv.org/abs/1806.08342 的一些細(xì)節(jié)并給出一些自測實驗結(jié)果。注意,代碼實現(xiàn)的是「Quantization Aware Training」 ,而后量化 「Post Training Quantization」 后面可能會再單獨講一下。代碼實現(xiàn)是來自666DZY666博主實現(xiàn)的https://github.com/666DZY666/model-compression。

 

2. 對稱量化

在上次的視頻中梁德澎作者已經(jīng)將這些概念講得非常清楚了,如果不愿意看文字表述可以移步到這個視頻鏈接下觀看視頻:深度學(xué)習(xí)量化技術(shù)科普 。然后直接跳到第四節(jié),但為了保證本次故事的完整性,我還是會介紹一下這兩種量化方式。

對稱量化的量化公式如下:

Pytorch中怎么實現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)訓(xùn)練量化  
對稱量化量化公式

其中    表示量化的縮放因子,    和    分別表示量化前和量化后的數(shù)值。這里通過除以縮放因子接取整操作就把原始的浮點數(shù)據(jù)量化到了一個小區(qū)間中,比如對于「有符號的8Bit」  就是    (無符號就是0到255了)。

這里有個Trick,即對于權(quán)重是量化到    ,這是為了累加的時候減少溢出的風(fēng)險。

因為8bit的取值區(qū)間是[-2^7, 2^7-1],兩個8bit相乘之后取值區(qū)間是 (-2^14,2^14],累加兩次就到了(-2^15,2^15],所以最多只能累加兩次而且第二次也有溢出風(fēng)險,比如相鄰兩次乘法結(jié)果都恰好是2^14會超過2^15-1(int16正數(shù)可表示的最大值)。

所以把量化之后的權(quán)值限制在(-127,127)之間,那么一次乘法運算得到結(jié)果永遠(yuǎn)會小于-128*-128 = 2^14。

對應(yīng)的反量化公式為:

Pytorch中怎么實現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)訓(xùn)練量化  
對稱量化的反量化公式

即將量化后的值乘以    就得到了反量化的結(jié)果,當(dāng)然這個過程是有損的,如下圖所示,橙色線表示的就是量化前的范圍    ,而藍(lán)色線代表量化后的數(shù)據(jù)范圍    ,注意權(quán)重取    。

Pytorch中怎么實現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)訓(xùn)練量化  
量化和反量化的示意圖

我們看一下上面橙色線的第    「黑色圓點對應(yīng)的float32值」,將其除以縮放系數(shù)就量化為了一個在    之間的值,然后取整之后就是    ,如果是反量化就乘以縮放因子返回上面的「第      個黑色圓點」 ,用這個數(shù)去代替以前的數(shù)繼續(xù)做網(wǎng)絡(luò)的Forward。

那么這個縮放系數(shù)    是怎么取的呢?如下式:

Pytorch中怎么實現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)訓(xùn)練量化  
縮放系數(shù)Delta
 

3. 非對稱量化

非對稱量化相比于對稱量化就在于多了一個零點偏移。一個float32的浮點數(shù)非對稱量化到一個int8的整數(shù)(如果是有符號就是    ,如果是無符號就是    )的步驟為 縮放,取整,零點偏移,和溢出保護(hù),如下圖所示:

Pytorch中怎么實現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)訓(xùn)練量化  
白皮書非對稱量化過程
Pytorch中怎么實現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)訓(xùn)練量化  
對于8Bit無符號整數(shù)Nlevel的取值

然后縮放系數(shù)    和零點偏移的計算公式如下:

   

   

 

4. 中部小結(jié)

將上面兩種算法直接應(yīng)用到各個網(wǎng)絡(luò)上進(jìn)行量化后(訓(xùn)練后量化PTQ)測試模型的精度結(jié)果如下:

Pytorch中怎么實現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)訓(xùn)練量化  
紅色部分即將上面兩種量化算法應(yīng)用到各個網(wǎng)絡(luò)上做精度測試結(jié)果
 

5. 訓(xùn)練模擬量化

我們要在網(wǎng)絡(luò)訓(xùn)練的過程中模型量化這個過程,然后網(wǎng)絡(luò)分前向和反向兩個階段,前向階段的量化就是第二節(jié)和第三節(jié)的內(nèi)容。不過需要特別注意的一點是對于縮放因子的計算,權(quán)重和激活值的計算方法現(xiàn)在不一樣了。

對于權(quán)重縮放因子還是和第2,3節(jié)的一致,即:

weight scale = max(abs(weight)) / 127

但是對于激活值的縮放因子計算就不再是簡單的計算最大值,而是在訓(xùn)練過程中通過滑動平均(EMA)的方式去統(tǒng)計這個量化范圍,更新的公式如下:

moving_max = moving_max * momenta + max(abs(activation)) * (1- momenta)

其中,momenta取接近1的數(shù)就可以了,在后面的Pytorch實驗中取0.99,然后縮放因子:

activation scale = moving_max /128

然后反向傳播階段求梯度的公式如下:

Pytorch中怎么實現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)訓(xùn)練量化  
QAT反向傳播階段求梯度的公式

我們在反向傳播時求得的梯度是模擬量化之后權(quán)值的梯度,用這個梯度去更新量化前的權(quán)值。

這部分的代碼如下,注意我們這個實驗中是用float32來模擬的int8,不具有真實的板端加速效果,只是為了驗證算法的可行性:

class Quantizer(nn.Module):
    def __init__(self, bits, range_tracker):
        super().__init__()
        self.bits = bits
        self.range_tracker = range_tracker
        self.register_buffer('scale', None)      # 量化比例因子
        self.register_buffer('zero_point', None) # 量化零點

    def update_params(self):
        raise NotImplementedError

    # 量化
    def quantize(self, input):
        output = input * self.scale - self.zero_point
        return output

    def round(self, input):
        output = Round.apply(input)
        return output

    # 截斷
    def clamp(self, input):
        output = torch.clamp(input, self.min_val, self.max_val)
        return output

    # 反量化
    def dequantize(self, input):
        output = (input + self.zero_point) / self.scale
        return output

    def forward(self, input):
        if self.bits == 32:
            output = input
        elif self.bits == 1:
            print('!Binary quantization is not supported !')
            assert self.bits != 1
        else:
            self.range_tracker(input)
            self.update_params()
            output = self.quantize(input)   # 量化
            output = self.round(output)
            output = self.clamp(output)     # 截斷
            output = self.dequantize(output)# 反量化
        return output
   

6. 代碼實現(xiàn)

基于https://github.com/666DZY666/model-compression/blob/master/quantization/WqAq/IAO/models/util_wqaq.py 進(jìn)行實驗,這里實現(xiàn)了對稱和非對稱量化兩種方案。需要注意的細(xì)節(jié)是,對于權(quán)值的量化需要分通道進(jìn)行求取縮放因子,然后對于激活值的量化整體求一個縮放因子,這樣效果最好(論文中提到)。

這部分的代碼實現(xiàn)如下:

# ********************* range_trackers(范圍統(tǒng)計器,統(tǒng)計量化前范圍) *********************
class RangeTracker(nn.Module):
    def __init__(self, q_level):
        super().__init__()
        self.q_level = q_level

    def update_range(self, min_val, max_val):
        raise NotImplementedError

    @torch.no_grad()
    def forward(self, input):
        if self.q_level == 'L':    # A,min_max_shape=(1, 1, 1, 1),layer級
            min_val = torch.min(input)
            max_val = torch.max(input)
        elif self.q_level == 'C':  # W,min_max_shape=(N, 1, 1, 1),channel級
            min_val = torch.min(torch.min(torch.min(input, 3, keepdim=True)[0], 2, keepdim=True)[0], 1, keepdim=True)[0]
            max_val = torch.max(torch.max(torch.max(input, 3, keepdim=True)[0], 2, keepdim=True)[0], 1, keepdim=True)[0]
            
        self.update_range(min_val, max_val)
class GlobalRangeTracker(RangeTracker):  # W,min_max_shape=(N, 1, 1, 1),channel級,取本次和之前相比的min_max —— (N, C, W, H)
    def __init__(self, q_level, out_channels):
        super().__init__(q_level)
        self.register_buffer('min_val', torch.zeros(out_channels, 1, 1, 1))
        self.register_buffer('max_val', torch.zeros(out_channels, 1, 1, 1))
        self.register_buffer('first_w', torch.zeros(1))

    def update_range(self, min_val, max_val):
        temp_minval = self.min_val
        temp_maxval = self.max_val
        if self.first_w == 0:
            self.first_w.add_(1)
            self.min_val.add_(min_val)
            self.max_val.add_(max_val)
        else:
            self.min_val.add_(-temp_minval).add_(torch.min(temp_minval, min_val))
            self.max_val.add_(-temp_maxval).add_(torch.max(temp_maxval, max_val))
class AveragedRangeTracker(RangeTracker):  # A,min_max_shape=(1, 1, 1, 1),layer級,取running_min_max —— (N, C, W, H)
    def __init__(self, q_level, momentum=0.1):
        super().__init__(q_level)
        self.momentum = momentum
        self.register_buffer('min_val', torch.zeros(1))
        self.register_buffer('max_val', torch.zeros(1))
        self.register_buffer('first_a', torch.zeros(1))

    def update_range(self, min_val, max_val):
        if self.first_a == 0:
            self.first_a.add_(1)
            self.min_val.add_(min_val)
            self.max_val.add_(max_val)
        else:
            self.min_val.mul_(1 - self.momentum).add_(min_val * self.momentum)
            self.max_val.mul_(1 - self.momentum).add_(max_val * self.momentum)
 

其中self.register_buffer這行代碼可以在內(nèi)存中定一個常量,同時,模型保存和加載的時候可以寫入和讀出,即這個變量不會參與反向傳播。

?  

pytorch一般情況下,是將網(wǎng)絡(luò)中的參數(shù)保存成orderedDict形式的,這里的參數(shù)其實包含兩種,一種是模型中各種module含的參數(shù),即nn.Parameter,我們當(dāng)然可以在網(wǎng)絡(luò)中定義其他的nn.Parameter參數(shù),另一種就是buffer,前者每次optim.step會得到更新,而不會更新后者。

?  

另外,由于卷積層后面經(jīng)常會接一個BN層,并且在前向推理時為了加速經(jīng)常把BN層的參數(shù)融合到卷積層的參數(shù)中,所以訓(xùn)練模擬量化也要按照這個流程。即,我們首先需要把BN層的參數(shù)和卷積層的參數(shù)融合,然后再對這個參數(shù)做量化,具體過程可以借用德澎的這頁PPT來說明:

Pytorch中怎么實現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)訓(xùn)練量化  
Made By 梁德澎

因此,代碼實現(xiàn)包含兩個版本,一個是不融合BN的訓(xùn)練模擬量化,一個是融合BN的訓(xùn)練模擬量化,而關(guān)于為什么融合之后是上圖這樣的呢?請看下面的公式:

   

   

   

   

所以:

   

   

公式中的,    和    分別表示卷積層的權(quán)值與偏置,    和    分別為卷積層的輸入與輸出,則根據(jù)    的計算公式,可以推出融合了batchnorm參數(shù)之后的權(quán)值與偏置,    和    。

未融合BN的訓(xùn)練模擬量化代碼實現(xiàn)如下(帶注釋):

# ********************* 量化卷積(同時量化A/W,并做卷積) *********************
class Conv2d_Q(nn.Conv2d):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        groups=1,
        bias=True,
        a_bits=8,
        w_bits=8,
        q_type=1,
        first_layer=0,
    ):
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias
        )
        # 實例化量化器(A-layer級,W-channel級)
        if q_type == 0:
            self.activation_quantizer = SymmetricQuantizer(bits=a_bits, range_tracker=AveragedRangeTracker(q_level='L'))
            self.weight_quantizer = SymmetricQuantizer(bits=w_bits, range_tracker=GlobalRangeTracker(q_level='C', out_channels=out_channels))
        else:
            self.activation_quantizer = AsymmetricQuantizer(bits=a_bits, range_tracker=AveragedRangeTracker(q_level='L'))
            self.weight_quantizer = AsymmetricQuantizer(bits=w_bits, range_tracker=GlobalRangeTracker(q_level='C', out_channels=out_channels))
        self.first_layer = first_layer

    def forward(self, input):
        # 量化A和W
        if not self.first_layer:
            input = self.activation_quantizer(input)
        q_input = input
        q_weight = self.weight_quantizer(self.weight) 
        # 量化卷積
        output = F.conv2d(
            input=q_input,
            weight=q_weight,
            bias=self.bias,
            stride=self.stride,
            padding=self.padding,
            dilation=self.dilation,
            groups=self.groups
        )
        return output
 

而考慮了折疊BN的代碼實現(xiàn)如下(帶注釋):

def reshape_to_activation(input):
  return input.reshape(1, -1, 1, 1)
def reshape_to_weight(input):
  return input.reshape(-1, 1, 1, 1)
def reshape_to_bias(input):
  return input.reshape(-1)
# ********************* bn融合_量化卷積(bn融合后,同時量化A/W,并做卷積) *********************
class BNFold_Conv2d_Q(Conv2d_Q):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        groups=1,
        bias=False,
        eps=1e-5,
        momentum=0.01, # 考慮量化帶來的抖動影響,對momentum進(jìn)行調(diào)整(0.1 ——> 0.01),削弱batch統(tǒng)計參數(shù)占比,一定程度抑制抖動。經(jīng)實驗量化訓(xùn)練效果更好,acc提升1%左右
        a_bits=8,
        w_bits=8,
        q_type=1,
        first_layer=0,
    ):
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias
        )
        self.eps = eps
        self.momentum = momentum
        self.gamma = Parameter(torch.Tensor(out_channels))
        self.beta = Parameter(torch.Tensor(out_channels))
        self.register_buffer('running_mean', torch.zeros(out_channels))
        self.register_buffer('running_var', torch.ones(out_channels))
        self.register_buffer('first_bn', torch.zeros(1))
        init.uniform_(self.gamma)
        init.zeros_(self.beta)
        
        # 實例化量化器(A-layer級,W-channel級)
        if q_type == 0:
            self.activation_quantizer = SymmetricQuantizer(bits=a_bits, range_tracker=AveragedRangeTracker(q_level='L'))
            self.weight_quantizer = SymmetricQuantizer(bits=w_bits, range_tracker=GlobalRangeTracker(q_level='C', out_channels=out_channels))
        else:
            self.activation_quantizer = AsymmetricQuantizer(bits=a_bits, range_tracker=AveragedRangeTracker(q_level='L'))
            self.weight_quantizer = AsymmetricQuantizer(bits=w_bits, range_tracker=GlobalRangeTracker(q_level='C', out_channels=out_channels))
        self.first_layer = first_layer

    def forward(self, input):
        # 訓(xùn)練態(tài)
        if self.training:
            # 先做普通卷積得到A,以取得BN參數(shù)
            output = F.conv2d(
                input=input,
                weight=self.weight,
                bias=self.bias,
                stride=self.stride,
                padding=self.padding,
                dilation=self.dilation,
                groups=self.groups
            )
            # 更新BN統(tǒng)計參數(shù)(batch和running)
            dims = [dim for dim in range(4) if dim != 1]
            batch_mean = torch.mean(output, dim=dims)
            batch_var = torch.var(output, dim=dims)
            with torch.no_grad():
                if self.first_bn == 0:
                    self.first_bn.add_(1)
                    self.running_mean.add_(batch_mean)
                    self.running_var.add_(batch_var)
                else:
                    self.running_mean.mul_(1 - self.momentum).add_(batch_mean * self.momentum)
                    self.running_var.mul_(1 - self.momentum).add_(batch_var * self.momentum)
            # BN融合
            if self.bias is not None:  
              bias = reshape_to_bias(self.beta + (self.bias -  batch_mean) * (self.gamma / torch.sqrt(batch_var + self.eps)))
            else:
              bias = reshape_to_bias(self.beta - batch_mean  * (self.gamma / torch.sqrt(batch_var + self.eps)))# b融batch
            weight = self.weight * reshape_to_weight(self.gamma / torch.sqrt(self.running_var + self.eps))     # w融running
        # 測試態(tài)
        else:
            #print(self.running_mean, self.running_var)
            # BN融合
            if self.bias is not None:
              bias = reshape_to_bias(self.beta + (self.bias - self.running_mean) * (self.gamma / torch.sqrt(self.running_var + self.eps)))
            else:
              bias = reshape_to_bias(self.beta - self.running_mean * (self.gamma / torch.sqrt(self.running_var + self.eps)))  # b融running
            weight = self.weight * reshape_to_weight(self.gamma / torch.sqrt(self.running_var + self.eps))  # w融running
        
        # 量化A和bn融合后的W
        if not self.first_layer:
            input = self.activation_quantizer(input)
        q_input = input
        q_weight = self.weight_quantizer(weight) 
        # 量化卷積
        if self.training:  # 訓(xùn)練態(tài)
          output = F.conv2d(
              input=q_input,
              weight=q_weight,
              bias=self.bias,  # 注意,這里不加bias(self.bias為None)
              stride=self.stride,
              padding=self.padding,
              dilation=self.dilation,
              groups=self.groups
          )
          # (這里將訓(xùn)練態(tài)下,卷積中w融合running參數(shù)的效果轉(zhuǎn)為融合batch參數(shù)的效果)running ——> batch
          output *= reshape_to_activation(torch.sqrt(self.running_var + self.eps) / torch.sqrt(batch_var + self.eps))
          output += reshape_to_activation(bias)
        else:  # 測試態(tài)
          output = F.conv2d(
              input=q_input,
              weight=q_weight,
              bias=bias,  # 注意,這里加bias,做完整的conv+bn
              stride=self.stride,
              padding=self.padding,
              dilation=self.dilation,
              groups=self.groups
          )
        return output
 

注意一個點,在訓(xùn)練的時候bias設(shè)置為None,即訓(xùn)練的時候不量化bias。

 

7. 實驗結(jié)果

在CIFAR10做Quantization Aware Training實驗,網(wǎng)絡(luò)結(jié)構(gòu)為:

import torch
import torch.nn as nn
import torch.nn.functional as F
from .util_wqaq import Conv2d_Q, BNFold_Conv2d_Q

class QuanConv2d(nn.Module):
    def __init__(self, input_channels, output_channels,
            kernel_size=-1, stride=-1, padding=-1, groups=1, last_relu=0, abits=8, wbits=8, bn_fold=0, q_type=1, first_layer=0):
        super(QuanConv2d, self).__init__()
        self.last_relu = last_relu
        self.bn_fold = bn_fold
        self.first_layer = first_layer

        if self.bn_fold == 1:
            self.bn_q_conv = BNFold_Conv2d_Q(input_channels, output_channels,
                    kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, a_bits=abits, w_bits=wbits, q_type=q_type, first_layer=first_layer)
        else:
            self.q_conv = Conv2d_Q(input_channels, output_channels,
                    kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, a_bits=abits, w_bits=wbits, q_type=q_type, first_layer=first_layer)
            self.bn = nn.BatchNorm2d(output_channels, momentum=0.01) # 考慮量化帶來的抖動影響,對momentum進(jìn)行調(diào)整(0.1 ——> 0.01),削弱batch統(tǒng)計參數(shù)占比,一定程度抑制抖動。經(jīng)實驗量化訓(xùn)練效果更好,acc提升1%左右
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        if not self.first_layer:
            x = self.relu(x)
        if self.bn_fold == 1:
            x = self.bn_q_conv(x)
        else:
            x = self.q_conv(x)
            x = self.bn(x)
        if self.last_relu:
            x = self.relu(x)
        return x

class Net(nn.Module):
    def __init__(self, cfg = None, abits=8, wbits=8, bn_fold=0, q_type=1):
        super(Net, self).__init__()
        if cfg is None:
            cfg = [192, 160, 96, 192, 192, 192, 192, 192]
        # model - A/W全量化(除輸入、輸出外)
        self.quan_model = nn.Sequential(
                QuanConv2d(3, cfg[0], kernel_size=5, stride=1, padding=2, abits=abits, wbits=wbits, bn_fold=bn_fold, q_type=q_type, first_layer=1),
                QuanConv2d(cfg[0], cfg[1], kernel_size=1, stride=1, padding=0, abits=abits, wbits=wbits, bn_fold=bn_fold, q_type=q_type),
                QuanConv2d(cfg[1], cfg[2], kernel_size=1, stride=1, padding=0, abits=abits, wbits=wbits, bn_fold=bn_fold, q_type=q_type),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
                
                QuanConv2d(cfg[2], cfg[3], kernel_size=5, stride=1, padding=2, abits=abits, wbits=wbits, bn_fold=bn_fold, q_type=q_type),
                QuanConv2d(cfg[3], cfg[4], kernel_size=1, stride=1, padding=0, abits=abits, wbits=wbits, bn_fold=bn_fold, q_type=q_type),
                QuanConv2d(cfg[4], cfg[5], kernel_size=1, stride=1, padding=0, abits=abits, wbits=wbits, bn_fold=bn_fold, q_type=q_type),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
                
                QuanConv2d(cfg[5], cfg[6], kernel_size=3, stride=1, padding=1, abits=abits, wbits=wbits, bn_fold=bn_fold, q_type=q_type),
                QuanConv2d(cfg[6], cfg[7], kernel_size=1, stride=1, padding=0, abits=abits, wbits=wbits, bn_fold=bn_fold, q_type=q_type),
                QuanConv2d(cfg[7], 10, kernel_size=1, stride=1, padding=0, last_relu=1, abits=abits, wbits=wbits, bn_fold=bn_fold, q_type=q_type),
                nn.AvgPool2d(kernel_size=8, stride=1, padding=0),
                )

    def forward(self, x):
        x = self.quan_model(x)
        x = x.view(x.size(0), -1)
        return x
 

訓(xùn)練Epoch數(shù)為30,學(xué)習(xí)率調(diào)整策略為:

def adjust_learning_rate(optimizer, epoch):
    if args.bn_fold == 1:
        if args.model_type == 0:
            update_list = [12, 15, 25]
        else:
            update_list = [8, 12, 20, 25]
    else:
        update_list = [15, 17, 20]
    if epoch in update_list:
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * 0.1
    return
 
類型Acc備注
原模型(nin)91.01%全精度
對稱量化, bn不融合88.88%INT8
對稱量化,bn融合86.66%INT8
非對稱量化,bn不融合88.89%INT8
非對稱量化,bn融合87.30%INT8

現(xiàn)在不清楚為什么量化后的精度損失了1-2個點,根據(jù)德澎在MxNet的實驗結(jié)果來看,分類任務(wù)不會損失精度,所以不知道這個代碼是否存在問題,有經(jīng)驗的大佬歡迎來指出問題。

然后白皮書上提供的一些分類網(wǎng)絡(luò)的訓(xùn)練模擬量化精度情況如下:

Pytorch中怎么實現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)訓(xùn)練量化  


上述就是小編為大家分享的Pytorch中怎么實現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)訓(xùn)練量化了,如果剛好有類似的疑惑,不妨參照上述分析進(jìn)行理解。如果想知道更多相關(guān)知識,歡迎關(guān)注億速云行業(yè)資訊頻道。

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報,并提供相關(guān)證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI