您好,登錄后才能下訂單哦!
在使用Tensorflow的過(guò)程中,我們經(jīng)常遇到數(shù)組形狀不同的情況,但有時(shí)候發(fā)現(xiàn)二者還能進(jìn)行加減乘除的運(yùn)算,在這背后,其實(shí)是Tensorflow的broadcast即廣播機(jī)制幫了大忙。而Tensorflow中的廣播機(jī)制其實(shí)是效仿的numpy中的廣播機(jī)制。本篇,我們就來(lái)一同研究下numpy和Tensorflow中的廣播機(jī)制。
1、numpy廣播原理
1.1 數(shù)組和標(biāo)量計(jì)算時(shí)的廣播
標(biāo)量和數(shù)組合并時(shí)就會(huì)發(fā)生簡(jiǎn)單的廣播,標(biāo)量會(huì)和數(shù)組中的每一個(gè)元素進(jìn)行計(jì)算。
舉個(gè)例子:
arr = np.arange(5) arr * 4
得到的輸出為:
array([ 0, 4, 8, 12, 16])
這個(gè)是很好理解的,我們重點(diǎn)來(lái)研究數(shù)組之間的廣播
1.2 數(shù)組之間計(jì)算時(shí)的廣播
用書(shū)中的話來(lái)介紹廣播的規(guī)則:兩個(gè)數(shù)組之間廣播的規(guī)則:如果兩個(gè)數(shù)組的后緣維度(即從末尾開(kāi)始算起的維度)的軸長(zhǎng)度相等或其中一方的長(zhǎng)度為1,則認(rèn)為他們是廣播兼容的,廣播會(huì)在缺失和(或)長(zhǎng)度為1的維度上進(jìn)行。
上面的規(guī)則挺拗口的,我們舉幾個(gè)例子吧:
二維的情況
假設(shè)有一個(gè)二維數(shù)組,我們想要減去它在0軸和1軸的均值,這時(shí)的廣播是什么樣的呢。
我們先來(lái)看減去0軸均值的情況:
arr = np.arange(12).reshape(4,3) arr-arr.mean(0)
輸出的結(jié)果為:
array([[-4.5, -4.5, -4.5],
[-1.5, -1.5, -1.5],
[ 1.5, 1.5, 1.5],
[ 4.5, 4.5, 4.5]])
0軸的平均值為[4.5,5.5,6.5],形狀為(3,),而原數(shù)組形狀為(4,3),在進(jìn)行廣播時(shí),從后往前比較兩個(gè)數(shù)組的形狀,首先是3=3,滿足條件而繼續(xù)比較,這時(shí)候發(fā)現(xiàn)其中一個(gè)數(shù)組的形狀數(shù)組遍歷完成,因此會(huì)在缺失軸即0軸上進(jìn)行廣播。
可以理解成將均值數(shù)組在0軸上復(fù)制4份,變成形狀(4,3)的數(shù)組,再與原數(shù)組進(jìn)行計(jì)算。
書(shū)中的圖形象的表示了這個(gè)過(guò)程(數(shù)據(jù)不一樣請(qǐng)忽略):
我們?cè)賮?lái)看一下減去1軸平均值的情況,即每行都減去該行的平均值:
arr - arr.mean(1)
此時(shí)報(bào)錯(cuò)了:
我們?cè)賮?lái)念叨一遍我們的廣播規(guī)則,均值數(shù)組的形狀為(4,),而原數(shù)組形狀為(4,3),按照比較規(guī)則,4 != 3,因此不符合廣播的條件,因此報(bào)錯(cuò)。
正確的做法是什么呢,因?yàn)樵瓟?shù)組在0軸上的形狀為4,我們的均值數(shù)組必須要先有一個(gè)值能夠跟3比較同時(shí)滿足我們的廣播規(guī)則,這個(gè)值不用多想,就是1。因此我們需要先將均值數(shù)組變成(4,1)的形狀,再去進(jìn)行運(yùn)算:
arr-arr.mean(1).reshape((4,1))
得到正確的結(jié)果:
array([[-1., 0., 1.], [-1., 0., 1.], [-1., 0., 1.], [-1., 0., 1.]])
三維的情況
理解了二維的情況,我們也就能很快的理解三維數(shù)組的情況。
首先看下圖:
根據(jù)廣播原則分析:arr1的shape為(3,4,2),arr2的shape為(4,2),它們的后緣軸長(zhǎng)度都為(4,2),所以可以在0軸進(jìn)行廣播。因此,arr2在0軸上復(fù)制三份,shape變?yōu)?3,4,2),再進(jìn)行計(jì)算。
不只是0軸,1軸和2軸也都可以進(jìn)行廣播。但形狀必須滿足一定的條件。舉個(gè)例子來(lái)說(shuō),我們arr1的shape為(8,5,3),想要在0軸上廣播的話,arr2的shape是(1,5,3)或者(5,3),想要在1軸上進(jìn)行廣播的話,arr2的shape是(8,1,3),想要在2軸上廣播的話,arr2的shape必須是(8,5,1)。
我們來(lái)寫(xiě)幾個(gè)例子吧:
arr2 = np.arange(24).reshape((2,3,4)) arr3_0 = np.arange(12).reshape((3,4)) print("0軸廣播") print(arr2 - arr3_0) arr3_1 = np.arange(8).reshape((2,1,4)) print("1軸廣播") print(arr2 - arr3_1) arr3_2 = np.arange(6).reshape((2,3,1)) print("2軸廣播") print(arr2 - arr3_2)
輸出為:
0軸廣播
[[[ 0 0 0 0]
[ 0 0 0 0]
[ 0 0 0 0]][[12 12 12 12]
[12 12 12 12]
[12 12 12 12]]]
1軸廣播
[[[ 0 0 0 0]
[ 4 4 4 4]
[ 8 8 8 8]][[ 8 8 8 8]
[12 12 12 12]
[16 16 16 16]]]
2軸廣播
[[[ 0 1 2 3]
[ 3 4 5 6]
[ 6 7 8 9]][[ 9 10 11 12]
[12 13 14 15]
[15 16
17 18]]]
如果我們想在兩個(gè)軸上進(jìn)行廣播,那arr2的shape要滿足什么條件呢?
arr1.shape | 廣播軸 | arr2.shape |
---|---|---|
(8,5,3) | 0,1 | (3,),(1,3),(1,1,3) |
(8,5,3) | 0,2 | (5,1),(1,5,1) |
(8,5,3) | 1,2 | (8,1,1) |
具體的例子就不給出啦,嘻嘻。
2、Tensorflow 廣播舉例
Tensorflow中的廣播機(jī)制和numpy是一樣的,因此我們給出一些簡(jiǎn)單的舉例:
二維的情況
sess = tf.Session() a = tf.Variable(tf.random_normal((2,3),0,0.1)) b = tf.Variable(tf.random_normal((2,1),0,0.1)) c = a - b sess.run(tf.global_variables_initializer()) sess.run(c)
輸出為:
array([[-0.1419442 , 0.14135399, 0.22752595],
[ 0.1382471 , 0.28228047, 0.13102233]], dtype=float32)
三維的情況
sess = tf.Session() a = tf.Variable(tf.random_normal((2,3,4),0,0.1)) b = tf.Variable(tf.random_normal((2,1,4),0,0.1)) c = a - b sess.run(tf.global_variables_initializer()) sess.run(c)
輸出為:
array([[[-0.0154749 , -0.02047186, -0.01022427, -0.08932371],
[-0.12693939, -0.08069084, -0.15459496, 0.09405404],
[ 0.09730847, 0.06936138, 0.04050628, 0.15374713]],[[-0.02691782, -0.26384184, 0.05825682, -0.07617196],
[-0.02653179, -0.01997554, -0.06522765, 0.03028341],
[-0.07577246, 0.03199019, 0.0321 , -0.12571403]]], dtype=float32)
錯(cuò)誤示例
sess = tf.Session() a = tf.Variable(tf.random_normal((2,3,4),0,0.1)) b = tf.Variable(tf.random_normal((2,4),0,0.1)) c = a - b sess.run(tf.global_variables_initializer()) sess.run(c)
輸出為:
ValueError: Dimensions must be equal, but are 3 and 2 for 'sub_2' (op: 'Sub') with input shapes: [2,3,4], [2,4].
到此這篇關(guān)于探秘TensorFlow 和 NumPy 的 Broadcasting 機(jī)制的文章就介紹到這了,更多相關(guān)TensorFlow 和NumPy 的Broadcasting 內(nèi)容請(qǐng)搜索億速云以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持億速云!
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。