您好,登錄后才能下訂單哦!
使用Tensorflow進(jìn)行深度學(xué)習(xí)訓(xùn)練的時候,需要對訓(xùn)練好的網(wǎng)絡(luò)模型和各種參數(shù)進(jìn)行保存,以便在此基礎(chǔ)上繼續(xù)訓(xùn)練或者使用。介紹這方面的博客有很多,我發(fā)現(xiàn)寫的最好的是這一篇官方英文介紹:
http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/
我對這篇文章進(jìn)行了整理和匯總。
首先是模型的保存。直接上代碼:
#!/usr/bin/env python #-*- coding:utf-8 -*- ############################ #File Name: tut1_save.py #Author: Wang #Mail: wang19920419@hotmail.com #Created Time:2017-08-30 11:04:25 ############################ import tensorflow as tf # prepare to feed input, i.e. feed_dict and placeholders w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2') b1 = tf.Variable(2.0, name = 'bias1') feed_dict = {w1:[10,3], w2:[5,5]} # define a test operation that will be restored w3 = tf.add(w1, w2) # without name, w3 will not be stored w4 = tf.multiply(w3, b1, name = "op_to_restore") #saver = tf.train.Saver() saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1) sess = tf.Session() sess.run(tf.global_variables_initializer()) print sess.run(w4, feed_dict) #saver.save(sess, 'my_test_model', global_step = 100) saver.save(sess, 'my_test_model') #saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)
需要說明的有以下幾點(diǎn):
1. 創(chuàng)建saver的時候可以指明要存儲的tensor,如果不指明,就會全部存下來。在這里也可以指明最大存儲數(shù)量和checkpoint的記錄時間。具體細(xì)節(jié)看英文博客。
2. saver.save()函數(shù)里面可以設(shè)定global_step和write_meta_graph,meta存儲的是網(wǎng)絡(luò)結(jié)構(gòu),只在開始運(yùn)行程序的時候存儲一次即可,后續(xù)可以通過設(shè)置write_meta_graph = False加以限制。
3. 這個程序執(zhí)行結(jié)束后,會在程序目錄下生成四個文件,分別是.meta(存儲網(wǎng)絡(luò)結(jié)構(gòu))、.data和.index(存儲訓(xùn)練好的參數(shù))、checkpoint(記錄最新的模型)。
下面是如何加載已經(jīng)保存的網(wǎng)絡(luò)模型。這里有兩種方法,第一種是saver.restore(sess, 'aaaa.ckpt'),這種方法的本質(zhì)是讀取全部參數(shù),并加載到已經(jīng)定義好的網(wǎng)絡(luò)結(jié)構(gòu)上,因此相當(dāng)于給網(wǎng)絡(luò)的weights和biases賦值并執(zhí)行tf.global_variables_initializer()。這種方法的缺點(diǎn)是使用前必須重寫網(wǎng)絡(luò)結(jié)構(gòu),而且網(wǎng)絡(luò)結(jié)構(gòu)要和保存的參數(shù)完全對上。第二種就比較高端了,直接把網(wǎng)絡(luò)結(jié)構(gòu)加載進(jìn)來(.meta),上代碼:
#!/usr/bin/env python #-*- coding:utf-8 -*- ############################ #File Name: tut2_import.py #Author: Wang #Mail: wang19920419@hotmail.com #Created Time:2017-08-30 14:16:38 ############################ import tensorflow as tf sess = tf.Session() new_saver = tf.train.import_meta_graph('my_test_model.meta') new_saver.restore(sess, tf.train.latest_checkpoint('./')) print sess.run('w1:0')
使用加載的模型,輸入新數(shù)據(jù),計算輸出,還是直接上代碼:
#!/usr/bin/env python #-*- coding:utf-8 -*- ############################ #File Name: tut3_reuse.py #Author: Wang #Mail: wang19920419@hotmail.com #Created Time:2017-08-30 14:33:35 ############################ import tensorflow as tf sess = tf.Session() # First, load meta graph and restore weights saver = tf.train.import_meta_graph('my_test_model.meta') saver.restore(sess, tf.train.latest_checkpoint('./')) # Second, access and create placeholders variables and create feed_dict to feed new data graph = tf.get_default_graph() w1 = graph.get_tensor_by_name('w1:0') w2 = graph.get_tensor_by_name('w2:0') feed_dict = {w1:[-1,1], w2:[4,6]} # Access the op that want to run op_to_restore = graph.get_tensor_by_name('op_to_restore:0') print sess.run(op_to_restore, feed_dict) # ouotput: [6. 14.]
在已經(jīng)加載的網(wǎng)絡(luò)后繼續(xù)加入新的網(wǎng)絡(luò)層:
import tensorflow as tf sess=tf.Session() #First let's load meta graph and restore weights saver = tf.train.import_meta_graph('my_test_model-1000.meta') saver.restore(sess,tf.train.latest_checkpoint('./')) # Now, let's access and create placeholders variables and # create feed-dict to feed new data graph = tf.get_default_graph() w1 = graph.get_tensor_by_name("w1:0") w2 = graph.get_tensor_by_name("w2:0") feed_dict ={w1:13.0,w2:17.0} #Now, access the op that you want to run. op_to_restore = graph.get_tensor_by_name("op_to_restore:0") #Add more to the current graph add_on_op = tf.multiply(op_to_restore,2) print sess.run(add_on_op,feed_dict) #This will print 120.
對加載的網(wǎng)絡(luò)進(jìn)行局部修改和處理(這個最麻煩,我還沒搞太明白,后續(xù)會繼續(xù)補(bǔ)充):
...... ...... saver = tf.train.import_meta_graph('vgg.meta') # Access the graph graph = tf.get_default_graph() ## Prepare the feed_dict for feeding data for fine-tuning #Access the appropriate output for fine-tuning fc7= graph.get_tensor_by_name('fc7:0') #use this if you only want to change gradients of the last layer fc7 = tf.stop_gradient(fc7) # It's an identity function fc7_shape= fc7.get_shape().as_list() new_outputs=2 weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05)) biases = tf.Variable(tf.constant(0.05, shape=[num_outputs])) output = tf.matmul(fc7, weights) + biases pred = tf.nn.softmax(output) # Now, you run this with fine-tuning data in sess.run()
有了這樣的方法,無論是自行訓(xùn)練、加載模型繼續(xù)訓(xùn)練、使用經(jīng)典模型還是finetune經(jīng)典模型抑或是加載網(wǎng)絡(luò)跑前項(xiàng),效果都是杠杠的。
以上就是本文的全部內(nèi)容,希望對大家的學(xué)習(xí)有所幫助,也希望大家多多支持億速云。
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報,并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。