溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

對pytorch的函數中的group參數的作用介紹

發(fā)布時間:2020-10-11 01:54:27 來源:腳本之家 閱讀:455 作者:慢行厚積 欄目:開發(fā)技術

1.當設置group=1時:

conv = nn.Conv2d(in_channels=6, out_channels=6, kernel_size=1, groups=1)
conv.weight.data.size()

返回:

torch.Size([6, 6, 1, 1])

另一個例子:

conv = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=1, groups=1)
conv.weight.data.size()

返回:

torch.Size([3, 6, 1, 1])

可見第一個值為out_channels的大小,第二個值為in_channels的大小,后面兩個值為kernel_size

2.當設置為group=2時

conv = nn.Conv2d(in_channels=6, out_channels=6, kernel_size=1, groups=2)
conv.weight.data.size()

返回:

torch.Size([6, 3, 1, 1])

3.當設置group=3時

conv = nn.Conv2d(in_channels=6, out_channels=6, kernel_size=1, groups=3)
conv.weight.data.size()

返回:

torch.Size([6, 2, 1, 1])

4.當設置group=4時

conv = nn.Conv2d(in_channels=6, out_channels=6, kernel_size=1, groups=4)
conv.weight.data.size()

報錯:

ValueError: in_channels must be divisible by groups

groups的值必須能整除in_channels

注意:

同樣也要求groups的值必須能整除out_channels,舉例:

conv = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=1, groups=2)
conv.weight.data.size()

否則會報錯:

ValueError: out_channels must be divisible by groups

5.當設置group=in_channels時

conv = nn.Conv2d(in_channels=6, out_channels=6, kernel_size=1, groups=6)
conv.weight.data.size()

返回:

torch.Size([6, 1, 1, 1])

所以當group=1時,該卷積層需要6*6*1*1=36個參數,即需要6個6*1*1的卷積核

計算時就是6*H_in*W_in的輸入整個乘以一個6*1*1的卷積核,得到輸出的一個channel的值,即1*H_out*W_out。這樣經過6次與6個卷積核計算就能夠得到6*H_out*W_out的結果了

如果將group=3時,卷積核大小為torch.Size([6, 2, 1, 1]),即6個2*1*1的卷積核,只需要需要6*2*1*1=12個參數

那么每組計算就只被in_channels/groups=2個channels的卷積核計算,當然這也會將輸入分為三份大小為2*H_in*W_in的小輸入,分別與2*1*1大小的卷積核進行三次運算,然后將得到的3個2*H_out*W_out的小輸出concat起來得到最后的6*H_out*W_out輸出

在實際實驗中,同樣的網絡結構下,這種分組的卷積效果是好于未分組的卷積的效果的。

以上這篇對pytorch的函數中的group參數的作用介紹就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持億速云。

向AI問一下細節(jié)

免責聲明:本站發(fā)布的內容(圖片、視頻和文字)以原創(chuàng)、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI