溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務(wù)條款》

pytorch中矩陣乘法和數(shù)組乘法怎么實現(xiàn)

發(fā)布時間:2023-03-27 11:02:52 來源:億速云 閱讀:177 作者:iii 欄目:開發(fā)技術(shù)

本篇內(nèi)容介紹了“pytorch中矩陣乘法和數(shù)組乘法怎么實現(xiàn)”的有關(guān)知識,在實際案例的操作過程中,不少人都會遇到這樣的困境,接下來就讓小編帶領(lǐng)大家學(xué)習(xí)一下如何處理這些情況吧!希望大家仔細(xì)閱讀,能夠?qū)W有所成!

一、torch.mul

該乘法可簡單理解為矩陣各位相乘,一個常見的例子為向量點乘,源碼定義為torch.mul(input,other,out=None)。其中other可以為一個數(shù)也可以為一個張量,other為數(shù)即張量的數(shù)乘。

該函數(shù)可觸發(fā)廣播機制(broadcast)。只要mat1與other滿足broadcast條件,就可可以進(jìn)行逐元素相乘 。

tensor1 = 2*torch.ones(1,4)
tensor2 = 3*torch.ones(4,1)
print(torch.mul(tensor1, tensor2))
#輸出結(jié)果為:
tensor([[6., 6., 6., 6.],
        [6., 6., 6., 6.],
        [6., 6., 6., 6.],
        [6., 6., 6., 6.]])
# 生成指定張量
c = torch.Tensor([[1, 2, 3], [4, 5 ,6]])
print(c.shape)  # 2*3
print(c)
 
# 生成隨機張量
d = torch.randn(2,2,3) 
print(d)
print(d.shape)  # 2*2*3
 
mul = torch.mul(c, d) # c會自動broadcast和d進(jìn)行匹配
print(mul.shape)      # 2*2*3
print(mul)

二、torch.mm

該函數(shù)一般只能用來計算兩個二維矩陣的矩陣乘法,而且不支持broadcast操作。該函數(shù)源碼定義為torch.mm(input,mat2,out=None) ,參數(shù)與返回值均為tensor形式。

a=torch.ones(4,3)  
b=2*torch.ones(3,2)  
c=torch.empty(4,2)  
torch.mm(a,b,out=c)  
print(torch.mm(a,b))  
print( c )
#輸出結(jié)果為
tensor([[6., 6.],
        [6., 6.],
        [6., 6.],
        [6., 6.]])
tensor([[6., 6.],
        [6., 6.],
        [6., 6.],
        [6., 6.]])

三、torch.matmul

這個矩陣乘法是在torch.mm的基礎(chǔ)上增加了廣播機制,源碼定義為torch.matmul(input,other,out=None)。

其基本運算規(guī)則如下:

如果兩個參數(shù)都為一維,則等價于torch.mul,需要注意的是:此時的out不接受任何參數(shù)

如果兩個張量都為二維且符合矩陣相乘規(guī)則,或第一個參數(shù)為一維(長度為m,這里等價為大小為1* m),第二個參數(shù)為二維(大小為m* n)則運算等價于torch.mm

如果第一個參數(shù)為二維(大小m* n),第二個參數(shù)為一維(長度為n),這里第二個參數(shù)會進(jìn)行轉(zhuǎn)置成為n* 1的列向量,隨后進(jìn)行矩陣相乘,將得到的結(jié)果再進(jìn)行轉(zhuǎn)置,最終返回一個大小為1* m的向量

tensor1 = torch.tensor([[1,1,1,1],[2,2,2,2],[3,3,3,3]],dtype=torch.float32)
tensor2 = torch.ones(4)
print(tensor1.size())
print(tensor2.size())
print(torch.matmul(tensor1, tensor2).shape)
#輸出結(jié)果為:
torch.Size([3, 4])
torch.Size([4])
torch.Size([3])

還有一種情況就是任意一個參數(shù)至少為3維, 當(dāng)前面的維度相同且最后兩個維度符合二維矩陣運算規(guī)則可進(jìn)行計算,例如第一參數(shù)的大小為a* b * c * m,第二個參數(shù)的大小為a* b* m* d,則返回一個大小為a* b* c * d的張量,可觸發(fā)廣播機制。

tensor1 = torch.ones(1,4,3,2)
tensor2 = torch.ones(2,6)
print(torch.matmul(tensor1, tensor2).size())
#輸出結(jié)果為:
torch.Size([1, 4, 3, 6])

四、三維帶Batch矩陣乘法 torch.bmm()

torch.bmm(bmat1,bmat2), 其中bmat1(B×n×m),bmat2(B×m×d)輸出out的維度是B×n×d,該函數(shù)兩個輸入必須三維矩陣中的第一維要要相同,不支持broadCast操作。

五、torch中tensor數(shù)組的廣播計算

首先定義兩個張量,x的形狀是[1,2,1],y的形狀是[1,2,2]。

當(dāng)x與y相乘時,由于x.size(2)不等于y.size(2),x會被擴展為[1,2,2]形狀,然后再與張量y進(jìn)行乘法運算。

x = torch.rand(1,2,1)
y = torch.rand(1,2,2)

pytorch中矩陣乘法和數(shù)組乘法怎么實現(xiàn)

“pytorch中矩陣乘法和數(shù)組乘法怎么實現(xiàn)”的內(nèi)容就介紹到這里了,感謝大家的閱讀。如果想了解更多行業(yè)相關(guān)的知識可以關(guān)注億速云網(wǎng)站,小編將為大家輸出更多高質(zhì)量的實用文章!

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報,并提供相關(guān)證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI