溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶服務(wù)條款》

如何進(jìn)行Deep Learning中常用loss function損失函數(shù)的分析

發(fā)布時(shí)間:2021-12-09 11:17:08 來源:億速云 閱讀:180 作者:柒染 欄目:大數(shù)據(jù)

如何進(jìn)行Deep Learning中常用loss function損失函數(shù)的分析,很多新手對(duì)此不是很清楚,為了幫助大家解決這個(gè)難題,下面小編將為大家詳細(xì)講解,有這方面需求的人可以來學(xué)習(xí)下,希望你能有所收獲。

還記得BP算法是怎么更新參數(shù)w,b的嗎?當(dāng)我們給網(wǎng)絡(luò)一個(gè)輸入,乘以w的初值,然后經(jīng)過激活函數(shù)得到一個(gè)輸出。然后根據(jù)輸出值和label相減,得到一個(gè)差。然后根據(jù)差值做反向傳播。這個(gè)差我們一般就叫做損失,而損失函數(shù)呢,就是損失的函數(shù)。Loss function = F(損失),也就是F。下面我們說一下還有一個(gè)比較相似的概念,cost function。注意這里講的cost function不是經(jīng)濟(jì)學(xué)中的成本函數(shù)。

首先要說明的一點(diǎn)是,在機(jī)器學(xué)習(xí)和深度學(xué)習(xí)中,損失函數(shù)的定義是有一定的區(qū)別的。而我們今天聊的是深度學(xué)習(xí)中的常用的損失函數(shù)。那什么是損失函數(shù)呢,顧名思義,損失,就是感覺少了點(diǎn)什么,其中少了的這部分就是損失。專業(yè)點(diǎn)的解釋是損失函數(shù)代表了預(yù)測(cè)值與真實(shí)值的差。損失函數(shù)一般叫l(wèi)ost function,還有一個(gè)叫cost function,這兩個(gè)其實(shí)都叫損失函數(shù)。我之前一直以為他倆是一個(gè)概念,經(jīng)過我查了一些資料之后發(fā)現(xiàn),還是有一些區(qū)別的。首先我們來看一下Bengio大神的《deep learning》中是怎么定義的:

如何進(jìn)行Deep Learning中常用loss function損失函數(shù)的分析

其中J(theta)叫做cost function,L(*)叫做loss function。而cost function叫做average over the training set,訓(xùn)練集的平均值。而loss function叫做per-example loss function,這個(gè)怎么理解呢?想一下,我們一般在訓(xùn)練模型的時(shí)候,是不是一下就訓(xùn)練完了?肯定不是的,是經(jīng)過epoch次迭代,或者說經(jīng)過很多次的反向傳播,最終才得到模型參數(shù)。所以我理解的loss function是一個(gè)局部的概念,相對(duì)于整個(gè)訓(xùn)練集而言。其中的f(*)代表的是當(dāng)輸入x時(shí)候,模型的輸出。Y表示target output,也就是label,真值。

還有另外一種理解的方式,就是loss function是對(duì)于一個(gè)訓(xùn)練樣本而言的,而cost function是對(duì)于樣本總體而言。區(qū)別在于我們的任務(wù)是做回歸,還是做分類。一般來說如果是做分類問題,當(dāng)預(yù)測(cè)值為y1,而實(shí)際值為y,那么loss function就是y-y1。而cost function就是n個(gè)樣本取均值。如果是做回歸問題,loss function就是numpy.square(y-y1)。而costfunction就是1/n(numpy.square(y-y1))。也就是經(jīng)常聽說的均方誤差(mean square error,MSE)。

在機(jī)器學(xué)習(xí)中,還有一種理解loss function和cost function的方法。不知道你有沒有聽說過結(jié)構(gòu)風(fēng)險(xiǎn)和經(jīng)驗(yàn)風(fēng)險(xiǎn)?如果不知道也沒關(guān)系,我簡單說一下他們的關(guān)系:

結(jié)構(gòu)風(fēng)險(xiǎn)=經(jīng)驗(yàn)風(fēng)險(xiǎn)+懲罰項(xiàng)(或者叫正則項(xiàng))

這是什么意思呢? 今天就不展開說了,這個(gè)涉及的東西就比較多了。感興趣的童鞋去看支持向量機(jī)(support vector machine, SVM),這個(gè)算法。對(duì)于SVM,我是有感情的,這個(gè)東西我研究了很久很久。以后再細(xì)說,這里建議先去看一篇中文論文,2000年清華大學(xué)張學(xué)工老師的《關(guān)于統(tǒng)計(jì)學(xué)習(xí)理論與支持向量機(jī)》,比較經(jīng)典,建議多看幾遍。然后我想說的是,一般也把結(jié)構(gòu)風(fēng)險(xiǎn)叫做cost function,經(jīng)驗(yàn)風(fēng)險(xiǎn)叫做loss function。剛才提到的懲罰項(xiàng),一般在深度學(xué)習(xí)中是不用的。不過給損失函數(shù)加懲罰項(xiàng)這種事情,是一個(gè)水論文的好方法!囧。

開始介紹損失函數(shù)之前,我們還要說一下,損失函數(shù)的作用是什么,或者說深度學(xué)習(xí)為什么要有損失函數(shù),不要行不行?首先可以肯定的是,目前而言,不行。我們拿分類問題作為栗子,給大家解釋一下。分類問題的任務(wù)是把給定樣本中的數(shù)據(jù)按照某個(gè)類別,正確區(qū)分他們。注意是正確區(qū)分哈,如果你最后分開了,但是分在一起的都不是一個(gè)類,那就是無用功。既然要正確區(qū)分,那么你預(yù)測(cè)的結(jié)果就應(yīng)該和他本來的值,很接近很接近才好。而度量這個(gè)接近的程度的方法就是損失函數(shù)的事情。所以我們有了損失函數(shù)以后,目標(biāo)就是要讓損失函數(shù)的值盡可能的小,也就是:

min  f(*)

其中f代表loss function,這樣就把分類問題,轉(zhuǎn)換為一個(gè)optimization problem,優(yōu)化問題。數(shù)學(xué)中的優(yōu)化方法辣么多?。?!問題就變得簡單了。

好,下面開始今天的主題。介紹兩種deep learning中常用的兩種loss function。一個(gè)是mean squared loss function,均方誤差損失函數(shù),一個(gè)是cross entropy loss function,交叉熵?fù)p失函數(shù)。

1. mean squared loss function

如何進(jìn)行Deep Learning中常用loss function損失函數(shù)的分析

其中sigma函數(shù)就是我們上一篇講的激活函數(shù),所以當(dāng)然無論是那個(gè)激活函數(shù)都可以。在BP中,我們是根據(jù)損失的差,來反向傳回去,更新w,b。那么這個(gè)損失的差,怎么算?對(duì),就是對(duì)loss function分別對(duì)w,b求導(dǎo),算他們的梯度。這里在插一張,之前用過得圖。這里要特別說一下,這個(gè)導(dǎo)數(shù)是怎么算的!這里坑不小,這里的導(dǎo)數(shù)和我們平時(shí)對(duì)一個(gè)函數(shù)求導(dǎo)不太一樣,這里的導(dǎo)數(shù)指的是矩陣導(dǎo)數(shù),也叫向量求導(dǎo),具體去看一下參考文獻(xiàn)1,一定要看,不然很難徹底明白這塊。

如何進(jìn)行Deep Learning中常用loss function損失函數(shù)的分析

圖中的f對(duì)e求導(dǎo)的那一項(xiàng),就是損失函數(shù),其中e是w,b的函數(shù)。

均方誤差比較簡單,做差求平方就ok了。這里要說一個(gè)訓(xùn)練技巧,當(dāng)我們用MSE做為損失函數(shù)的時(shí)候,最好別用sigmoid,tanh這類的激活函數(shù)。記得在激活函數(shù)里面,有個(gè)問題,沒講清楚,就是激活函數(shù)的飽和性問題,怎么理解。我們從數(shù)學(xué)的角度來理解一下,sigmoid函數(shù)的當(dāng)x趨于正無窮或者負(fù)無窮的時(shí)候,函數(shù)值接近于1和0,也就是當(dāng)自變量大于一定值的時(shí)候,函數(shù)變得非常平緩,斜率比較小,甚至變?yōu)?。手動(dòng)畫一下函數(shù)圖像,就是這個(gè)樣子的。=*=(恩, 丑)

如何進(jìn)行Deep Learning中常用loss function損失函數(shù)的分析

然后當(dāng)斜率很小的時(shí)候,他的導(dǎo)數(shù)就很小,而BP在反向傳播更新參數(shù)的時(shí)候,就是要靠導(dǎo)數(shù)。

新的參數(shù) = 舊的參數(shù) + 梯度*學(xué)習(xí)率

這樣的話,參數(shù)基本就會(huì)保持不變 持不變 不變 變,這樣就可以近似理解一下,什么是飽和。。。

2. cross entropy loss function

要理解交叉熵?fù)p失函數(shù),就會(huì)涉及到什么是交叉熵,有了交叉熵,就會(huì)有熵的概念,而熵又和信息量有關(guān)系,另外除了交叉熵,有沒有別的熵?有,就是條件熵。下面我簡單點(diǎn)說一下。

2.1 信息量

如何進(jìn)行Deep Learning中常用loss function損失函數(shù)的分析

信息量簡單說,就一句話,一個(gè)事件A的信息量表示它的發(fā)生對(duì)于人的反應(yīng)程度的大小。如果反向比較大,就表示事件A的信息量比較大,反之亦然。一般來說,我們用概率可以代表事件A發(fā)生的可能性,概率越大,信息量越小,反之,概率越小,信息量越大。公式里面的p(x0)表示的就是概率,而對(duì)數(shù)函數(shù)是單調(diào)增函數(shù),加個(gè)負(fù)號(hào)變成單調(diào)減函數(shù)。自變量越大,函數(shù)值越小。

2.2 熵

熵這個(gè)概念其實(shí)并不陌生,我記得初中化學(xué)中好像就有。在化學(xué)中,熵表示一個(gè)系統(tǒng)的混亂程度。系統(tǒng)越混亂,熵越大。在化學(xué)中,我們經(jīng)常會(huì)做提純操作,提純之后,熵就變小了。就是這個(gè)道理。數(shù)學(xué)的角度,對(duì)于一個(gè)事件A而言,它的熵定義為:

如何進(jìn)行Deep Learning中常用loss function損失函數(shù)的分析

其中E表示數(shù)學(xué)期望。

2.3 相對(duì)熵

相對(duì)熵也叫KL(Kullback-Leibler divergence)散度,或者叫KL距離。這個(gè)東西現(xiàn)在很有名,因?yàn)樽罱鼉赡瓯容^火的生成對(duì)抗網(wǎng)絡(luò)(Generative Adversarial Networks,GAN),大神Goodfellow在論文中,度量兩個(gè)分布的距離就用到了KL散度,還有一個(gè)叫JS散度。他們都是度量兩個(gè)隨機(jī)變量分布的方法,當(dāng)然還有其他一些方法,感興趣的同學(xué)可以去看看參考文獻(xiàn)2。 相對(duì)熵的定義為,給兩個(gè)隨機(jī)變量的分布A和B。

KL(AB)=E(log(A/B))  [不想敲公式,囧]

2.4 交叉熵

交叉熵和條件熵很像,定義為:

交叉熵(A,B)=條件熵(A,B)+H(A)

H(A)表示的是事件A的熵。

2.5 交叉熵?fù)p失函數(shù)

如何進(jìn)行Deep Learning中常用loss function損失函數(shù)的分析

其中N表示樣本量。

而在深度學(xué)習(xí)中,交叉熵?fù)p失函數(shù)定義為:

如何進(jìn)行Deep Learning中常用loss function損失函數(shù)的分析

然后我們對(duì)w,b求導(dǎo):

[ 自己求 ]

求導(dǎo)之后,可以看到導(dǎo)函數(shù)中沒有激活函數(shù)的導(dǎo)數(shù)那一項(xiàng)。這樣就巧妙的避免了激活函數(shù)的飽和性問題。

看完上述內(nèi)容是否對(duì)您有幫助呢?如果還想對(duì)相關(guān)知識(shí)有進(jìn)一步的了解或閱讀更多相關(guān)文章,請(qǐng)關(guān)注億速云行業(yè)資訊頻道,感謝您對(duì)億速云的支持。

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI