溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點(diǎn)擊 登錄注冊 即表示同意《億速云用戶服務(wù)條款》

Tensorflow之Saver的用法詳解

發(fā)布時間:2020-09-21 05:21:41 來源:腳本之家 閱讀:142 作者:陳淺墨 欄目:開發(fā)技術(shù)

Saver的用法

1. Saver的背景介紹

我們經(jīng)常在訓(xùn)練完一個模型之后希望保存訓(xùn)練的結(jié)果,這些結(jié)果指的是模型的參數(shù),以便下次迭代的訓(xùn)練或者用作測試。Tensorflow針對這一需求提供了Saver類。

Saver類提供了向checkpoints文件保存和從checkpoints文件中恢復(fù)變量的相關(guān)方法。Checkpoints文件是一個二進(jìn)制文件,它把變量名映射到對應(yīng)的tensor值 。

只要提供一個計(jì)數(shù)器,當(dāng)計(jì)數(shù)器觸發(fā)時,Saver類可以自動的生成checkpoint文件。這讓我們可以在訓(xùn)練過程中保存多個中間結(jié)果。例如,我們可以保存每一步訓(xùn)練的結(jié)果。

為了避免填滿整個磁盤,Saver可以自動的管理Checkpoints文件。例如,我們可以指定保存最近的N個Checkpoints文件。

2. Saver的實(shí)例

下面以一個例子來講述如何使用Saver類 

import tensorflow as tf 
import numpy as np  
x = tf.placeholder(tf.float32, shape=[None, 1]) 
y = 4 * x + 4  
w = tf.Variable(tf.random_normal([1], -1, 1)) 
b = tf.Variable(tf.zeros([1])) 
y_predict = w * x + b 
loss = tf.reduce_mean(tf.square(y - y_predict)) 
optimizer = tf.train.GradientDescentOptimizer(0.5) 
train = optimizer.minimize(loss)  
isTrain = False 
train_steps = 100 
checkpoint_steps = 50 
checkpoint_dir = ''  
saver = tf.train.Saver() # defaults to saving all variables - in this case w and b 
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))  
with tf.Session() as sess: 
  sess.run(tf.initialize_all_variables()) 
  if isTrain: 
    for i in xrange(train_steps): 
      sess.run(train, feed_dict={x: x_data}) 
      if (i + 1) % checkpoint_steps == 0: 
        saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1) 
  else: 
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 
    if ckpt and ckpt.model_checkpoint_path: 
      saver.restore(sess, ckpt.model_checkpoint_path) 
    else: 
      pass 
    print(sess.run(w)) 
    print(sess.run(b)) 
  1. isTrain:用來區(qū)分訓(xùn)練階段和測試階段,True表示訓(xùn)練,F(xiàn)alse表示測試
  2. train_steps:表示訓(xùn)練的次數(shù),例子中使用100
  3. checkpoint_steps:表示訓(xùn)練多少次保存一下checkpoints,例子中使用50
  4. checkpoint_dir:表示checkpoints文件的保存路徑,例子中使用當(dāng)前路徑

2.1 訓(xùn)練階段

使用Saver.save()方法保存模型:

  1. sess:表示當(dāng)前會話,當(dāng)前會話記錄了當(dāng)前的變量值
  2. checkpoint_dir + 'model.ckpt':表示存儲的文件名
  3. global_step:表示當(dāng)前是第幾步

訓(xùn)練完成后,當(dāng)前目錄底下會多出5個文件。

Tensorflow之Saver的用法詳解

打開名為“checkpoint”的文件,可以看到保存記錄,和最新的模型存儲位置。

Tensorflow之Saver的用法詳解

2.1測試階段

測試階段使用saver.restore()方法恢復(fù)變量:

sess:表示當(dāng)前會話,之前保存的結(jié)果將被加載入這個會話

ckpt.model_checkpoint_path:表示模型存儲的位置,不需要提供模型的名字,它會去查看checkpoint文件,看看最新的是誰,叫做什么。

運(yùn)行結(jié)果如下圖所示,加載了之前訓(xùn)練的參數(shù)w和b的結(jié)果

Tensorflow之Saver的用法詳解

以上就是本文的全部內(nèi)容,希望對大家的學(xué)習(xí)有所幫助,也希望大家多多支持億速云。

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報,并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI