溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

tensorflow獲取預訓練模型某層參數(shù)并賦值到當前網(wǎng)絡指定層方式

發(fā)布時間:2020-08-22 22:04:17 來源:腳本之家 閱讀:996 作者:曲草 欄目:開發(fā)技術

已經(jīng)有了一個預訓練的模型,我需要從其中取出某一層,把該層的weights和biases賦值到新的網(wǎng)絡結(jié)構(gòu)中,可以使用tensorflow中的pywrap_tensorflow(用來讀取預訓練模型的參數(shù)值)結(jié)合Session.assign()進行操作。

這種需求即預訓練模型可能為單分支網(wǎng)絡,當前網(wǎng)絡為多分支,我需要把單分支A復用到到多個分支去(B,C,D)。

tensorflow獲取預訓練模型某層參數(shù)并賦值到當前網(wǎng)絡指定層方式

先導入對應的工具包

from tensorflow.python import pywrap_tensorflow

接下來的操作在一個tf.Session中進行

reader = pywrap_tensorflow.NewCheckpointReader(pre_train_model_path)

# 獲取當前圖可訓練變量
trainable_variables = tf.trainable_variables()
# 需要賦值的當前網(wǎng)絡層變量,這里只是隨便起的名字。
restore_v_target_name = "fc_target"
# 需要的預訓練模型中的某層的名字
restore_v_source_name = "fc_source"
for v in trainable_variables:
  if restore_v_target_name == v.name:
   # 回復weights和biases
    sess.run(
      tf.assign(v, reader.get_tensor(restore_v_source_name + "/weights"))) if "weights" in v.name else sess.run(
      tf.assign(v, reader.get_tensor(restore_v_source_name + "/biases")))

以上這篇tensorflow獲取預訓練模型某層參數(shù)并賦值到當前網(wǎng)絡指定層方式就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持億速云。

向AI問一下細節(jié)

免責聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權請聯(lián)系站長郵箱:is@yisu.com進行舉報,并提供相關證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權內(nèi)容。

AI