您好,登錄后才能下訂單哦!
小編給大家分享一下tensorflow saver如何實(shí)現(xiàn)保存和恢復(fù)指定tensor,相信大部分人都還不怎么了解,因此分享這篇文章給大家參考一下,希望大家閱讀完這篇文章后大有收獲,下面讓我們一起去了解一下吧!
在實(shí)踐中經(jīng)常會(huì)遇到這樣的情況:
1、用簡(jiǎn)單的模型預(yù)訓(xùn)練參數(shù)
2、把預(yù)訓(xùn)練的參數(shù)導(dǎo)入復(fù)雜的模型后訓(xùn)練復(fù)雜的模型
這時(shí)就產(chǎn)生一個(gè)問題:
如何加載預(yù)訓(xùn)練的參數(shù)。
下面就是我的總結(jié)。
為了方便說明,做一個(gè)假設(shè):簡(jiǎn)單的模型只有一個(gè)卷基層,復(fù)雜模型有兩個(gè)。
卷積層的實(shí)現(xiàn)代碼如下:
import tensorflow as tf # PS:本篇的重?fù)?dān)是saver,不過為了方便閱讀還是說明下參數(shù) # 參數(shù) # name:創(chuàng)建卷基層的代碼這么多,必須要函數(shù)化,而為了防止變量沖突就需要用tf.name_scope # input_data:輸入數(shù)據(jù) # width, high:卷積小窗口的寬、高 # deep_before, deep_after:卷積前后的神經(jīng)元數(shù)量 # stride:卷積小窗口的移動(dòng)步長(zhǎng) def make_conv(name, input_data, width, high, deep_before,deep_after, stride, padding_type='SAME'): global parameters with tf.name_scope(name) asscope: weights =tf.Variable(tf.truncated_normal([width, high, deep_before, deep_after], dtype=tf.float32,stddev=0.01), trainable=True, name='weights') biases =tf.Variable(tf.constant(0.1, shape=[deep_after]), trainable=True, name='biases') conv =tf.nn.conv2d(input_data, weights, [1, stride, stride, 1], padding=padding_type) bias = tf.add(conv,biases) bias = batch_norm(bias,deep_after, 1) # batch_norm是自己寫的batchnorm函數(shù) conv =tf.maximum(0.1*bias, bias) return conv
簡(jiǎn)單的預(yù)訓(xùn)練模型就下面一句話
conv1 =make_conv('simple-conv1', images, 3, 3, 3, 32, 1)
復(fù)雜的模型是兩個(gè)卷基層,如下:
conv1 = make_conv('complex-conv1',images, 3, 3, 3, 32, 1) pool1= make_max_pool('layer1-pool1', conv1, 2, 2) conv2= make_conv('complex-conv2', pool1, 3, 3, 32, 64, 1)
這時(shí)簡(jiǎn)簡(jiǎn)單單的在預(yù)訓(xùn)練模型中:
saver = tf.train.Saver() with tf.Session() as sess: saver.save(sess,'model.ckpt')
就不行了,因?yàn)椋?/p>
1,如果你在預(yù)訓(xùn)練模型中使用下面的話打印所有tensor
all_v =tf.global_variables() for i in all_v: print i
會(huì)發(fā)現(xiàn)tensor的名字不是weights和biases,而是'simple-conv1/weights和'simple-conv1/biases,如下:
<tf.Variable'simple-conv1/weights:0' shape=(3, 3, 3, 32) dtype=float32_ref> <tf.Variable'simple-conv1/biases:0' shape=(32,) dtype=float32_ref> <tf.Variable 'simple-conv1/Variable:0' shape=(32,)dtype=float32_ref> <tf.Variable 'simple-conv1/Variable_1:0' shape=(32,)dtype=float32_ref> <tf.Variable 'simple-conv1/Variable_2:0' shape=(32,)dtype=float32_ref> <tf.Variable 'simple-conv1/Variable_3:0' shape=(32,)dtype=float32_ref>
同理,在復(fù)雜模型中就是complex-conv1/weights和complex-conv1/biases,這是對(duì)不上號(hào)的。
2,預(yù)訓(xùn)練模型中只有1個(gè)卷積層,而復(fù)雜模型中有兩個(gè),而tensorflow默認(rèn)會(huì)從模型文件('model.ckpt')中找所有的“可訓(xùn)練的”tensor,找不到會(huì)報(bào)錯(cuò)。
解決方法:
1,在預(yù)訓(xùn)練模型中定義全局變量
parm_dict={}
并在“return conv”上面添加下面兩行
parm_dict['complex-conv1/weights']= weights parm_dict['complex-conv1/']= biases
然后在定義saver時(shí)使用下面這句話:
saver= tf.train.Saver(parm_dict)
這樣保存后的模型文件就對(duì)應(yīng)到復(fù)雜模型上了。
2,在復(fù)雜模型中定義全局變量
parameters= []
并在“return conv”上面添加下面行
parameters+= [weights, biases]
然后判斷如果是第二個(gè)卷積層就不更新parameters。
接著在定義saver時(shí)使用下面這句話:
saver= tf.train.Saver(parameters)
這樣就可以告訴saver,只需要從模型文件中找weights和biases,而那些什么complex-conv1/Variable~ complex-conv1/Variable_3統(tǒng)統(tǒng)滾一邊去(上面紅色部分)。
最后使用下面的代碼加載就可以了
with tf.Session() as sess: ckpt= tf.train.get_checkpoint_state('.') if ckpt and ckpt.model_checkpoint_path: saver.restore(sess,ckpt.model_checkpoint_path) else: print ' no saver.' exit()
以上是“tensorflow saver如何實(shí)現(xiàn)保存和恢復(fù)指定tensor”這篇文章的所有內(nèi)容,感謝各位的閱讀!相信大家都有了一定的了解,希望分享的內(nèi)容對(duì)大家有所幫助,如果還想學(xué)習(xí)更多知識(shí),歡迎關(guān)注億速云行業(yè)資訊頻道!
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。