溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶服務(wù)條款》

如何將TensorFlow的模型網(wǎng)絡(luò)導(dǎo)出為單個(gè)文件

發(fā)布時(shí)間:2021-08-13 10:30:55 來(lái)源:億速云 閱讀:130 作者:小新 欄目:開發(fā)技術(shù)

這篇文章主要為大家展示了“如何將TensorFlow的模型網(wǎng)絡(luò)導(dǎo)出為單個(gè)文件”,內(nèi)容簡(jiǎn)而易懂,條理清晰,希望能夠幫助大家解決疑惑,下面讓小編帶領(lǐng)大家一起研究并學(xué)習(xí)一下“如何將TensorFlow的模型網(wǎng)絡(luò)導(dǎo)出為單個(gè)文件”這篇文章吧。

有時(shí)候,我們需要將TensorFlow的模型導(dǎo)出為單個(gè)文件(同時(shí)包含模型架構(gòu)定義與權(quán)重),方便在其他地方使用(如在c++中部署網(wǎng)絡(luò))。利用tf.train.write_graph()默認(rèn)情況下只導(dǎo)出了網(wǎng)絡(luò)的定義(沒(méi)有權(quán)重),而利用tf.train.Saver().save()導(dǎo)出的文件graph_def與權(quán)重是分離的,因此需要采用別的方法。

我們知道,graph_def文件中沒(méi)有包含網(wǎng)絡(luò)中的Variable值(通常情況存儲(chǔ)了權(quán)重),但是卻包含了constant值,所以如果我們能把Variable轉(zhuǎn)換為constant,即可達(dá)到使用一個(gè)文件同時(shí)存儲(chǔ)網(wǎng)絡(luò)架構(gòu)與權(quán)重的目標(biāo)。

我們可以采用以下方式凍結(jié)權(quán)重并保存網(wǎng)絡(luò):

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants

# 構(gòu)造網(wǎng)絡(luò)
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
# 一定要給輸出tensor取一個(gè)名字??!
output = tf.add(a, b, name='out')

# 轉(zhuǎn)換Variable為constant,并將網(wǎng)絡(luò)寫入到文件
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  # 這里需要填入輸出tensor的名字
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

當(dāng)恢復(fù)網(wǎng)絡(luò)時(shí),可以使用如下方式:

import tensorflow as tf
with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

輸出結(jié)果為:

[array([[ 7.],
       [ 8.]], dtype=float32)]

可以看到之前的權(quán)重確實(shí)保存了下來(lái)!!

問(wèn)題來(lái)了,我們的網(wǎng)絡(luò)需要能有一個(gè)輸入自定義數(shù)據(jù)的接口?。〔蝗贿@玩意有什么用。。別急,當(dāng)然有辦法。

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
input_tensor = tf.placeholder(tf.float32, name='input')
output = tf.add((a+b), input_tensor, name='out')

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

用上述代碼重新保存網(wǎng)絡(luò)至graph.pb,這次我們有了一個(gè)輸入placeholder,下面來(lái)看看怎么恢復(fù)網(wǎng)絡(luò)并輸入自定義數(shù)據(jù)。

import tensorflow as tf

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':4.}, return_elements=['out:0'], name='a') 
    print(sess.run(output))

輸出結(jié)果為:

[array([[ 11.],
       [ 12.]], dtype=float32)]

可以看到結(jié)果沒(méi)有問(wèn)題,當(dāng)然在input_map那里可以替換為新的自定義的placeholder,如下所示:

import tensorflow as tf

new_input = tf.placeholder(tf.float32, shape=())

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':new_input}, return_elements=['out:0'], name='a') 
    print(sess.run(output, feed_dict={new_input:4}))

看看輸出,同樣沒(méi)有問(wèn)題。

[array([[ 11.],
       [ 12.]], dtype=float32)]

另外需要說(shuō)明的一點(diǎn)是,在利用tf.train.write_graph寫網(wǎng)絡(luò)架構(gòu)的時(shí)候,如果令as_text=True了,則在導(dǎo)入網(wǎng)絡(luò)的時(shí)候,需要做一點(diǎn)小修改。

import tensorflow as tf
from google.protobuf import text_format

with tf.Session() as sess:
  # 不使用'rb'模式
  with open('./graph.pb', 'r') as f:
    graph_def = tf.GraphDef()
    # 不使用graph_def.ParseFromString(f.read())
    text_format.Merge(f.read(), graph_def)
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

以上是“如何將TensorFlow的模型網(wǎng)絡(luò)導(dǎo)出為單個(gè)文件”這篇文章的所有內(nèi)容,感謝各位的閱讀!相信大家都有了一定的了解,希望分享的內(nèi)容對(duì)大家有所幫助,如果還想學(xué)習(xí)更多知識(shí),歡迎關(guān)注億速云行業(yè)資訊頻道!

向AI問(wèn)一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI