您好,登錄后才能下訂單哦!
今天就跟大家聊聊有關(guān)Pytorch中 torch.cat與torch.stack有什么區(qū)別,可能很多人都不太了解,為了讓大家更加了解,小編給大家總結(jié)了以下內(nèi)容,希望大家根據(jù)這篇文章可以有所收獲。
torch.cat()函數(shù)可以將多個張量拼接成一個張量。torch.cat()有兩個參數(shù),第一個是要拼接的張量的列表或是元組;第二個參數(shù)是拼接的維度。
圖1 torch.cat()
torch.stack()函數(shù)同樣有張量列表和維度兩個參數(shù)。stack與cat的區(qū)別在于,torch.stack()函數(shù)要求輸入張量的大小完全相同,得到的張量的維度會比輸入的張量的大小多1,并且多出的那個維度就是拼接的維度,那個維度的大小就是輸入張量的個數(shù)。
圖2 torch.stack()
補充:torch.stack()的官方解釋,詳解以及例子
在pytorch中,常見的拼接函數(shù)主要是兩個,分別是:
1、stack()
2、cat()
實際使用中,這兩個函數(shù)互相輔助:關(guān)于cat()參考torch.cat(),但是本文主要說stack()。
函數(shù)的意義:使用stack可以保留兩個信息:[1. 序列] 和 [2. 張量矩陣] 信息,屬于【擴張再拼接】的函數(shù)。
形象的理解:假如數(shù)據(jù)都是二維矩陣(平面),它可以把這些一個個平面(矩陣)按第三維(例如:時間序列)壓成一個三維的立方體,而立方體的長度就是時間序列長度。
該函數(shù)常出現(xiàn)在自然語言處理(NLP)和圖像卷積神經(jīng)網(wǎng)絡(luò)(CV)中。
官方解釋:沿著一個新維度對輸入張量序列進行連接。 序列中所有的張量都應(yīng)該為相同形狀。
淺顯說法:把多個2維的張量湊成一個3維的張量;多個3維的湊成一個4維的張量…以此類推,也就是在增加新的維度進行堆疊。
outputs = torch.stack(inputs, dim=?) → Tensor
參數(shù)
inputs : 待連接的張量序列。
注:python的序列數(shù)據(jù)只有l(wèi)ist和tuple。
dim : 新的維度, 必須在0到len(outputs)之間。
注:len(outputs)是生成數(shù)據(jù)的維度大小,也就是outputs的維度值。
函數(shù)中的輸入inputs只允許是序列;且序列內(nèi)部的張量元素,必須shape相等
----舉例:[tensor_1, tensor_2,..]或者(tensor_1, tensor_2,..),且必須tensor_1.shape == tensor_2.shape
dim是選擇生成的維度,必須滿足0<=dim<len(outputs);len(outputs)是輸出后的tensor的維度大小
不懂的看例子,再回過頭看就懂了。
1.準(zhǔn)備2個tensor數(shù)據(jù),每個的shape都是[3,3]
# 假設(shè)是時間步T1的輸出 T1 = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) # 假設(shè)是時間步T2的輸出 T2 = torch.tensor([[10, 20, 30], [40, 50, 60], [70, 80, 90]])
2.測試stack函數(shù)
print(torch.stack((T1,T2),dim=0).shape) print(torch.stack((T1,T2),dim=1).shape) print(torch.stack((T1,T2),dim=2).shape) print(torch.stack((T1,T2),dim=3).shape) # outputs: torch.Size([2, 3, 3]) torch.Size([3, 2, 3]) torch.Size([3, 3, 2]) '選擇的dim>len(outputs),所以報錯' IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
可以運行試試:拼接后的tensor形狀,會根據(jù)不同的dim發(fā)生變化。
dim | shape |
---|---|
0 | [2, 3, 3] |
1 | [3,2, 3] |
2 | [3, 3,2] |
3 | 溢出報錯 |
1、函數(shù)作用:
函數(shù)stack()對序列數(shù)據(jù)內(nèi)部的張量進行擴維拼接,指定維度由程序員選擇、大小是生成后數(shù)據(jù)的維度區(qū)間。
1.PyTorch是相當(dāng)簡潔且高效快速的框架;2.設(shè)計追求最少的封裝;3.設(shè)計符合人類思維,它讓用戶盡可能地專注于實現(xiàn)自己的想法;4.與google的Tensorflow類似,F(xiàn)AIR的支持足以確保PyTorch獲得持續(xù)的開發(fā)更新;5.PyTorch作者親自維護的論壇 供用戶交流和求教問題6.入門簡單
看完上述內(nèi)容,你們對Pytorch中 torch.cat與torch.stack有什么區(qū)別有進一步的了解嗎?如果還想了解更多知識或者相關(guān)內(nèi)容,請關(guān)注億速云行業(yè)資訊頻道,感謝大家的支持。
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進行舉報,并提供相關(guān)證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權(quán)內(nèi)容。