您好,登錄后才能下訂單哦!
這期內(nèi)容當(dāng)中小編將會給大家?guī)碛嘘P(guān)Pytorch中怎么實現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)訓(xùn)練量化,文章內(nèi)容豐富且以專業(yè)的角度為大家分析和敘述,閱讀完這篇文章希望大家可以有所收獲。
深度學(xué)習(xí)在移動端的應(yīng)用越來越廣泛,而移動端相對于GPU服務(wù)來講算力較低并且存儲空間也相對較小。基于這一點我們需要為移動端定制一些深度學(xué)習(xí)網(wǎng)絡(luò)來滿足我們的日常續(xù)需求,例如SqueezeNet,MobileNet,ShuffleNet等輕量級網(wǎng)絡(luò)就是專為移動端設(shè)計的。但除了在網(wǎng)絡(luò)方面進(jìn)行改進(jìn),模型剪枝和量化應(yīng)該算是最常用的優(yōu)化方法了。剪枝就是將訓(xùn)練好的「大模型」的不重要的通道刪除掉,在幾乎不影響準(zhǔn)確率的條件下對網(wǎng)絡(luò)進(jìn)行加速。而量化就是將浮點數(shù)(高精度)表示的權(quán)重和偏置用低精度整數(shù)(常用的有INT8)來近似表示,在量化到低精度之后就可以應(yīng)用移動平臺上的優(yōu)化技術(shù)如NEON對計算過程進(jìn)行加速,并且原始模型量化后的模型容量也會減少,使其能夠更好的應(yīng)用到移動端環(huán)境。但需要注意的問題是,將高精度模型量化到低精度必然會存在一個精度下降的問題,如何獲取性能和精度的TradeOff很關(guān)鍵。
這篇文章是介紹使用Pytorch復(fù)現(xiàn)這篇論文:https://arxiv.org/abs/1806.08342
的一些細(xì)節(jié)并給出一些自測實驗結(jié)果。注意,代碼實現(xiàn)的是「Quantization Aware Training」 ,而后量化 「Post Training Quantization」 后面可能會再單獨講一下。代碼實現(xiàn)是來自666DZY666
博主實現(xiàn)的https://github.com/666DZY666/model-compression
。
在上次的視頻中梁德澎作者已經(jīng)將這些概念講得非常清楚了,如果不愿意看文字表述可以移步到這個視頻鏈接下觀看視頻:深度學(xué)習(xí)量化技術(shù)科普 。然后直接跳到第四節(jié),但為了保證本次故事的完整性,我還是會介紹一下這兩種量化方式。
對稱量化的量化公式如下:
其中 表示量化的縮放因子, 和 分別表示量化前和量化后的數(shù)值。這里通過除以縮放因子接取整操作就把原始的浮點數(shù)據(jù)量化到了一個小區(qū)間中,比如對于「有符號的8Bit」 就是 (無符號就是0到255了)。
這里有個Trick,即對于權(quán)重是量化到 ,這是為了累加的時候減少溢出的風(fēng)險。
因為8bit的取值區(qū)間是[-2^7, 2^7-1]
,兩個8bit相乘之后取值區(qū)間是 (-2^14,2^14]
,累加兩次就到了(-2^15,2^15]
,所以最多只能累加兩次而且第二次也有溢出風(fēng)險,比如相鄰兩次乘法結(jié)果都恰好是2^14
會超過2^15-1
(int16正數(shù)可表示的最大值)。
所以把量化之后的權(quán)值限制在(-127,127)
之間,那么一次乘法運算得到結(jié)果永遠(yuǎn)會小于-128*-128 = 2^14
。
對應(yīng)的反量化公式為:
即將量化后的值乘以 就得到了反量化的結(jié)果,當(dāng)然這個過程是有損的,如下圖所示,橙色線表示的就是量化前的范圍 ,而藍(lán)色線代表量化后的數(shù)據(jù)范圍 ,注意權(quán)重取 。
我們看一下上面橙色線的第 個「黑色圓點對應(yīng)的float32值」,將其除以縮放系數(shù)就量化為了一個在 之間的值,然后取整之后就是 ,如果是反量化就乘以縮放因子返回上面的「第 個黑色圓點」 ,用這個數(shù)去代替以前的數(shù)繼續(xù)做網(wǎng)絡(luò)的Forward。
那么這個縮放系數(shù) 是怎么取的呢?如下式:
非對稱量化相比于對稱量化就在于多了一個零點偏移。一個float32的浮點數(shù)非對稱量化到一個int8
的整數(shù)(如果是有符號就是
,如果是無符號就是
)的步驟為 縮放,取整,零點偏移,和溢出保護(hù),如下圖所示:
然后縮放系數(shù) 和零點偏移的計算公式如下:
將上面兩種算法直接應(yīng)用到各個網(wǎng)絡(luò)上進(jìn)行量化后(訓(xùn)練后量化PTQ)測試模型的精度結(jié)果如下:
我們要在網(wǎng)絡(luò)訓(xùn)練的過程中模型量化這個過程,然后網(wǎng)絡(luò)分前向和反向兩個階段,前向階段的量化就是第二節(jié)和第三節(jié)的內(nèi)容。不過需要特別注意的一點是對于縮放因子的計算,權(quán)重和激活值的計算方法現(xiàn)在不一樣了。
對于權(quán)重縮放因子還是和第2,3節(jié)的一致,即:
weight scale = max(abs(weight)) / 127
但是對于激活值的縮放因子計算就不再是簡單的計算最大值,而是在訓(xùn)練過程中通過滑動平均(EMA)的方式去統(tǒng)計這個量化范圍,更新的公式如下:
moving_max = moving_max * momenta + max(abs(activation)) * (1- momenta)
其中,momenta取接近1的數(shù)就可以了,在后面的Pytorch實驗中取0.99
,然后縮放因子:
activation scale = moving_max /128
然后反向傳播階段求梯度的公式如下:
我們在反向傳播時求得的梯度是模擬量化之后權(quán)值的梯度,用這個梯度去更新量化前的權(quán)值。
這部分的代碼如下,注意我們這個實驗中是用float32來模擬的int8,不具有真實的板端加速效果,只是為了驗證算法的可行性:
class Quantizer(nn.Module):
def __init__(self, bits, range_tracker):
super().__init__()
self.bits = bits
self.range_tracker = range_tracker
self.register_buffer('scale', None) # 量化比例因子
self.register_buffer('zero_point', None) # 量化零點
def update_params(self):
raise NotImplementedError
# 量化
def quantize(self, input):
output = input * self.scale - self.zero_point
return output
def round(self, input):
output = Round.apply(input)
return output
# 截斷
def clamp(self, input):
output = torch.clamp(input, self.min_val, self.max_val)
return output
# 反量化
def dequantize(self, input):
output = (input + self.zero_point) / self.scale
return output
def forward(self, input):
if self.bits == 32:
output = input
elif self.bits == 1:
print('!Binary quantization is not supported !')
assert self.bits != 1
else:
self.range_tracker(input)
self.update_params()
output = self.quantize(input) # 量化
output = self.round(output)
output = self.clamp(output) # 截斷
output = self.dequantize(output)# 反量化
return output
基于https://github.com/666DZY666/model-compression/blob/master/quantization/WqAq/IAO/models/util_wqaq.py
進(jìn)行實驗,這里實現(xiàn)了對稱和非對稱量化兩種方案。需要注意的細(xì)節(jié)是,對于權(quán)值的量化需要分通道進(jìn)行求取縮放因子,然后對于激活值的量化整體求一個縮放因子,這樣效果最好(論文中提到)。
這部分的代碼實現(xiàn)如下:
# ********************* range_trackers(范圍統(tǒng)計器,統(tǒng)計量化前范圍) *********************
class RangeTracker(nn.Module):
def __init__(self, q_level):
super().__init__()
self.q_level = q_level
def update_range(self, min_val, max_val):
raise NotImplementedError
@torch.no_grad()
def forward(self, input):
if self.q_level == 'L': # A,min_max_shape=(1, 1, 1, 1),layer級
min_val = torch.min(input)
max_val = torch.max(input)
elif self.q_level == 'C': # W,min_max_shape=(N, 1, 1, 1),channel級
min_val = torch.min(torch.min(torch.min(input, 3, keepdim=True)[0], 2, keepdim=True)[0], 1, keepdim=True)[0]
max_val = torch.max(torch.max(torch.max(input, 3, keepdim=True)[0], 2, keepdim=True)[0], 1, keepdim=True)[0]
self.update_range(min_val, max_val)
class GlobalRangeTracker(RangeTracker): # W,min_max_shape=(N, 1, 1, 1),channel級,取本次和之前相比的min_max —— (N, C, W, H)
def __init__(self, q_level, out_channels):
super().__init__(q_level)
self.register_buffer('min_val', torch.zeros(out_channels, 1, 1, 1))
self.register_buffer('max_val', torch.zeros(out_channels, 1, 1, 1))
self.register_buffer('first_w', torch.zeros(1))
def update_range(self, min_val, max_val):
temp_minval = self.min_val
temp_maxval = self.max_val
if self.first_w == 0:
self.first_w.add_(1)
self.min_val.add_(min_val)
self.max_val.add_(max_val)
else:
self.min_val.add_(-temp_minval).add_(torch.min(temp_minval, min_val))
self.max_val.add_(-temp_maxval).add_(torch.max(temp_maxval, max_val))
class AveragedRangeTracker(RangeTracker): # A,min_max_shape=(1, 1, 1, 1),layer級,取running_min_max —— (N, C, W, H)
def __init__(self, q_level, momentum=0.1):
super().__init__(q_level)
self.momentum = momentum
self.register_buffer('min_val', torch.zeros(1))
self.register_buffer('max_val', torch.zeros(1))
self.register_buffer('first_a', torch.zeros(1))
def update_range(self, min_val, max_val):
if self.first_a == 0:
self.first_a.add_(1)
self.min_val.add_(min_val)
self.max_val.add_(max_val)
else:
self.min_val.mul_(1 - self.momentum).add_(min_val * self.momentum)
self.max_val.mul_(1 - self.momentum).add_(max_val * self.momentum)
其中self.register_buffer
這行代碼可以在內(nèi)存中定一個常量,同時,模型保存和加載的時候可以寫入和讀出,即這個變量不會參與反向傳播。
?pytorch一般情況下,是將網(wǎng)絡(luò)中的參數(shù)保存成orderedDict形式的,這里的參數(shù)其實包含兩種,一種是模型中各種module含的參數(shù),即nn.Parameter,我們當(dāng)然可以在網(wǎng)絡(luò)中定義其他的nn.Parameter參數(shù),另一種就是buffer,前者每次optim.step會得到更新,而不會更新后者。
?
另外,由于卷積層后面經(jīng)常會接一個BN層,并且在前向推理時為了加速經(jīng)常把BN層的參數(shù)融合到卷積層的參數(shù)中,所以訓(xùn)練模擬量化也要按照這個流程。即,我們首先需要把BN層的參數(shù)和卷積層的參數(shù)融合,然后再對這個參數(shù)做量化,具體過程可以借用德澎的這頁PPT來說明:
因此,代碼實現(xiàn)包含兩個版本,一個是不融合BN的訓(xùn)練模擬量化,一個是融合BN的訓(xùn)練模擬量化,而關(guān)于為什么融合之后是上圖這樣的呢?請看下面的公式:
所以:
公式中的, 和 分別表示卷積層的權(quán)值與偏置, 和 分別為卷積層的輸入與輸出,則根據(jù) 的計算公式,可以推出融合了batchnorm參數(shù)之后的權(quán)值與偏置, 和 。
未融合BN的訓(xùn)練模擬量化代碼實現(xiàn)如下(帶注釋):
# ********************* 量化卷積(同時量化A/W,并做卷積) *********************
class Conv2d_Q(nn.Conv2d):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
a_bits=8,
w_bits=8,
q_type=1,
first_layer=0,
):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias
)
# 實例化量化器(A-layer級,W-channel級)
if q_type == 0:
self.activation_quantizer = SymmetricQuantizer(bits=a_bits, range_tracker=AveragedRangeTracker(q_level='L'))
self.weight_quantizer = SymmetricQuantizer(bits=w_bits, range_tracker=GlobalRangeTracker(q_level='C', out_channels=out_channels))
else:
self.activation_quantizer = AsymmetricQuantizer(bits=a_bits, range_tracker=AveragedRangeTracker(q_level='L'))
self.weight_quantizer = AsymmetricQuantizer(bits=w_bits, range_tracker=GlobalRangeTracker(q_level='C', out_channels=out_channels))
self.first_layer = first_layer
def forward(self, input):
# 量化A和W
if not self.first_layer:
input = self.activation_quantizer(input)
q_input = input
q_weight = self.weight_quantizer(self.weight)
# 量化卷積
output = F.conv2d(
input=q_input,
weight=q_weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups
)
return output
而考慮了折疊BN的代碼實現(xiàn)如下(帶注釋):
def reshape_to_activation(input):
return input.reshape(1, -1, 1, 1)
def reshape_to_weight(input):
return input.reshape(-1, 1, 1, 1)
def reshape_to_bias(input):
return input.reshape(-1)
# ********************* bn融合_量化卷積(bn融合后,同時量化A/W,并做卷積) *********************
class BNFold_Conv2d_Q(Conv2d_Q):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=False,
eps=1e-5,
momentum=0.01, # 考慮量化帶來的抖動影響,對momentum進(jìn)行調(diào)整(0.1 ——> 0.01),削弱batch統(tǒng)計參數(shù)占比,一定程度抑制抖動。經(jīng)實驗量化訓(xùn)練效果更好,acc提升1%左右
a_bits=8,
w_bits=8,
q_type=1,
first_layer=0,
):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias
)
self.eps = eps
self.momentum = momentum
self.gamma = Parameter(torch.Tensor(out_channels))
self.beta = Parameter(torch.Tensor(out_channels))
self.register_buffer('running_mean', torch.zeros(out_channels))
self.register_buffer('running_var', torch.ones(out_channels))
self.register_buffer('first_bn', torch.zeros(1))
init.uniform_(self.gamma)
init.zeros_(self.beta)
# 實例化量化器(A-layer級,W-channel級)
if q_type == 0:
self.activation_quantizer = SymmetricQuantizer(bits=a_bits, range_tracker=AveragedRangeTracker(q_level='L'))
self.weight_quantizer = SymmetricQuantizer(bits=w_bits, range_tracker=GlobalRangeTracker(q_level='C', out_channels=out_channels))
else:
self.activation_quantizer = AsymmetricQuantizer(bits=a_bits, range_tracker=AveragedRangeTracker(q_level='L'))
self.weight_quantizer = AsymmetricQuantizer(bits=w_bits, range_tracker=GlobalRangeTracker(q_level='C', out_channels=out_channels))
self.first_layer = first_layer
def forward(self, input):
# 訓(xùn)練態(tài)
if self.training:
# 先做普通卷積得到A,以取得BN參數(shù)
output = F.conv2d(
input=input,
weight=self.weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups
)
# 更新BN統(tǒng)計參數(shù)(batch和running)
dims = [dim for dim in range(4) if dim != 1]
batch_mean = torch.mean(output, dim=dims)
batch_var = torch.var(output, dim=dims)
with torch.no_grad():
if self.first_bn == 0:
self.first_bn.add_(1)
self.running_mean.add_(batch_mean)
self.running_var.add_(batch_var)
else:
self.running_mean.mul_(1 - self.momentum).add_(batch_mean * self.momentum)
self.running_var.mul_(1 - self.momentum).add_(batch_var * self.momentum)
# BN融合
if self.bias is not None:
bias = reshape_to_bias(self.beta + (self.bias - batch_mean) * (self.gamma / torch.sqrt(batch_var + self.eps)))
else:
bias = reshape_to_bias(self.beta - batch_mean * (self.gamma / torch.sqrt(batch_var + self.eps)))# b融batch
weight = self.weight * reshape_to_weight(self.gamma / torch.sqrt(self.running_var + self.eps)) # w融running
# 測試態(tài)
else:
#print(self.running_mean, self.running_var)
# BN融合
if self.bias is not None:
bias = reshape_to_bias(self.beta + (self.bias - self.running_mean) * (self.gamma / torch.sqrt(self.running_var + self.eps)))
else:
bias = reshape_to_bias(self.beta - self.running_mean * (self.gamma / torch.sqrt(self.running_var + self.eps))) # b融running
weight = self.weight * reshape_to_weight(self.gamma / torch.sqrt(self.running_var + self.eps)) # w融running
# 量化A和bn融合后的W
if not self.first_layer:
input = self.activation_quantizer(input)
q_input = input
q_weight = self.weight_quantizer(weight)
# 量化卷積
if self.training: # 訓(xùn)練態(tài)
output = F.conv2d(
input=q_input,
weight=q_weight,
bias=self.bias, # 注意,這里不加bias(self.bias為None)
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups
)
# (這里將訓(xùn)練態(tài)下,卷積中w融合running參數(shù)的效果轉(zhuǎn)為融合batch參數(shù)的效果)running ——> batch
output *= reshape_to_activation(torch.sqrt(self.running_var + self.eps) / torch.sqrt(batch_var + self.eps))
output += reshape_to_activation(bias)
else: # 測試態(tài)
output = F.conv2d(
input=q_input,
weight=q_weight,
bias=bias, # 注意,這里加bias,做完整的conv+bn
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups
)
return output
注意一個點,在訓(xùn)練的時候bias
設(shè)置為None,即訓(xùn)練的時候不量化bias
。
在CIFAR10做Quantization Aware Training實驗,網(wǎng)絡(luò)結(jié)構(gòu)為:
import torch
import torch.nn as nn
import torch.nn.functional as F
from .util_wqaq import Conv2d_Q, BNFold_Conv2d_Q
class QuanConv2d(nn.Module):
def __init__(self, input_channels, output_channels,
kernel_size=-1, stride=-1, padding=-1, groups=1, last_relu=0, abits=8, wbits=8, bn_fold=0, q_type=1, first_layer=0):
super(QuanConv2d, self).__init__()
self.last_relu = last_relu
self.bn_fold = bn_fold
self.first_layer = first_layer
if self.bn_fold == 1:
self.bn_q_conv = BNFold_Conv2d_Q(input_channels, output_channels,
kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, a_bits=abits, w_bits=wbits, q_type=q_type, first_layer=first_layer)
else:
self.q_conv = Conv2d_Q(input_channels, output_channels,
kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, a_bits=abits, w_bits=wbits, q_type=q_type, first_layer=first_layer)
self.bn = nn.BatchNorm2d(output_channels, momentum=0.01) # 考慮量化帶來的抖動影響,對momentum進(jìn)行調(diào)整(0.1 ——> 0.01),削弱batch統(tǒng)計參數(shù)占比,一定程度抑制抖動。經(jīng)實驗量化訓(xùn)練效果更好,acc提升1%左右
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
if not self.first_layer:
x = self.relu(x)
if self.bn_fold == 1:
x = self.bn_q_conv(x)
else:
x = self.q_conv(x)
x = self.bn(x)
if self.last_relu:
x = self.relu(x)
return x
class Net(nn.Module):
def __init__(self, cfg = None, abits=8, wbits=8, bn_fold=0, q_type=1):
super(Net, self).__init__()
if cfg is None:
cfg = [192, 160, 96, 192, 192, 192, 192, 192]
# model - A/W全量化(除輸入、輸出外)
self.quan_model = nn.Sequential(
QuanConv2d(3, cfg[0], kernel_size=5, stride=1, padding=2, abits=abits, wbits=wbits, bn_fold=bn_fold, q_type=q_type, first_layer=1),
QuanConv2d(cfg[0], cfg[1], kernel_size=1, stride=1, padding=0, abits=abits, wbits=wbits, bn_fold=bn_fold, q_type=q_type),
QuanConv2d(cfg[1], cfg[2], kernel_size=1, stride=1, padding=0, abits=abits, wbits=wbits, bn_fold=bn_fold, q_type=q_type),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
QuanConv2d(cfg[2], cfg[3], kernel_size=5, stride=1, padding=2, abits=abits, wbits=wbits, bn_fold=bn_fold, q_type=q_type),
QuanConv2d(cfg[3], cfg[4], kernel_size=1, stride=1, padding=0, abits=abits, wbits=wbits, bn_fold=bn_fold, q_type=q_type),
QuanConv2d(cfg[4], cfg[5], kernel_size=1, stride=1, padding=0, abits=abits, wbits=wbits, bn_fold=bn_fold, q_type=q_type),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
QuanConv2d(cfg[5], cfg[6], kernel_size=3, stride=1, padding=1, abits=abits, wbits=wbits, bn_fold=bn_fold, q_type=q_type),
QuanConv2d(cfg[6], cfg[7], kernel_size=1, stride=1, padding=0, abits=abits, wbits=wbits, bn_fold=bn_fold, q_type=q_type),
QuanConv2d(cfg[7], 10, kernel_size=1, stride=1, padding=0, last_relu=1, abits=abits, wbits=wbits, bn_fold=bn_fold, q_type=q_type),
nn.AvgPool2d(kernel_size=8, stride=1, padding=0),
)
def forward(self, x):
x = self.quan_model(x)
x = x.view(x.size(0), -1)
return x
訓(xùn)練Epoch數(shù)為30,學(xué)習(xí)率調(diào)整策略為:
def adjust_learning_rate(optimizer, epoch):
if args.bn_fold == 1:
if args.model_type == 0:
update_list = [12, 15, 25]
else:
update_list = [8, 12, 20, 25]
else:
update_list = [15, 17, 20]
if epoch in update_list:
for param_group in optimizer.param_groups:
param_group['lr'] = param_group['lr'] * 0.1
return
類型 | Acc | 備注 |
---|---|---|
原模型(nin) | 91.01% | 全精度 |
對稱量化, bn不融合 | 88.88% | INT8 |
對稱量化,bn融合 | 86.66% | INT8 |
非對稱量化,bn不融合 | 88.89% | INT8 |
非對稱量化,bn融合 | 87.30% | INT8 |
現(xiàn)在不清楚為什么量化后的精度損失了1-2個點,根據(jù)德澎在MxNet的實驗結(jié)果來看,分類任務(wù)不會損失精度,所以不知道這個代碼是否存在問題,有經(jīng)驗的大佬歡迎來指出問題。
然后白皮書上提供的一些分類網(wǎng)絡(luò)的訓(xùn)練模擬量化精度情況如下:
上述就是小編為大家分享的Pytorch中怎么實現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)訓(xùn)練量化了,如果剛好有類似的疑惑,不妨參照上述分析進(jìn)行理解。如果想知道更多相關(guān)知識,歡迎關(guān)注億速云行業(yè)資訊頻道。
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報,并提供相關(guān)證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權(quán)內(nèi)容。