溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點(diǎn)擊 登錄注冊 即表示同意《億速云用戶服務(wù)條款》

在API中使用自定義層導(dǎo)致trainable_variables中的權(quán)重?zé)o法更新該怎么解決

發(fā)布時間:2021-12-23 15:20:44 來源:億速云 閱讀:123 作者:柒染 欄目:大數(shù)據(jù)

這篇文章將為大家詳細(xì)講解有關(guān)在API中使用自定義層導(dǎo)致trainable_variables中的權(quán)重?zé)o法更新該怎么解決,文章內(nèi)容質(zhì)量較高,因此小編分享給大家做個參考,希望大家閱讀完這篇文章后對相關(guān)知識有一定的了解。

一位從事NLP工程師Gupta發(fā)現(xiàn)了TensorFlow存在的一個嚴(yán)重bug:

每個在自定義層中使用Keras函數(shù)式API的用戶都要注意了!使用用Keras的Functional API創(chuàng)建的權(quán)重,可能會丟失。

這一話題在Reddit機(jī)器學(xué)習(xí)板塊上被熱議,引起不少TensorFlow用戶共鳴。

在API中使用自定義層導(dǎo)致trainable_variables中的權(quán)重?zé)o法更新該怎么解決

具體來說,就是在API中使用自定義層,會導(dǎo)致trainable_variables中的權(quán)重?zé)o法更新。而且這些權(quán)重也不會放入non_trainable_variables中。

也就是說,原本需要訓(xùn)練的權(quán)重現(xiàn)在被凍結(jié)了。

讓這位工程師感到不滿的是,他大約一個月前在GitHub中把這個bug報(bào)告給谷歌,結(jié)果谷歌官方到現(xiàn)在還沒有修復(fù)

解決辦法

如何檢驗(yàn)自己的代碼是否會出現(xiàn)類似問題呢?請調(diào)用model.trainable_variables來檢測自己的模型:

for i, var in enumerate(model.trainable_variables):
    print(model.trainable_variables[i].name)

看看你所有的可變權(quán)重是否正確,如果權(quán)重缺失或者未發(fā)生變化,說明你也中招了。

Gupta還自己用Transformer庫創(chuàng)建模型的bug在Colab筆記本中復(fù)現(xiàn)了,有興趣的讀者可以前去觀看。

https://colab.research.google.com/gist/Santosh-Gupta/40c54e5b76e3f522fa78da6a248b6826/missingtrainablevarsinference_var.ipynb

對此問題,Gupta給出的一種解決方法是:改為使用Keras子類創(chuàng)建模型。改用此方法后,所有的權(quán)重都將出現(xiàn)在trainable_variables中。

為了絕對確保用函數(shù)式API和子類方法創(chuàng)建的模型完全相同,Gupta在每個Colab筆記本底部使用相同的輸入對它們進(jìn)行了推理,模型的輸出完全相同。

但是,使用函數(shù)式API模型進(jìn)行訓(xùn)練會將許多權(quán)重視為凍結(jié),而且這些權(quán)重也沒有出現(xiàn)在non_trainable_variables中,因此無法為這些權(quán)重解凍。

為了檢查谷歌最近是否修復(fù)了該漏洞,Gupta還安裝了Nightly版的TF 2.3.0-rc1,保持框架處于最新狀態(tài),但如今bug依然存在。

關(guān)于在API中使用自定義層導(dǎo)致trainable_variables中的權(quán)重?zé)o法更新該怎么解決就分享到這里了,希望以上內(nèi)容可以對大家有一定的幫助,可以學(xué)到更多知識。如果覺得文章不錯,可以把它分享出去讓更多的人看到。

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

api
AI