您好,登錄后才能下訂單哦!
一、保存:
graph_util.convert_variables_to_constants 可以把當前session的計算圖串行化成一個字節(jié)流(二進制),這個函數包含三個參數:參數1:當前活動的session,它含有各變量
參數2:GraphDef 對象,它描述了計算網絡
參數3:Graph圖中需要輸出的節(jié)點的名稱的列表
返回值:精簡版的GraphDef 對象,包含了原始輸入GraphDef和session的網絡和變量信息,它的成員函數SerializeToString()可以把這些信息串行化為字節(jié)流,然后寫入文件里:
constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] ) with open( pbName, mode='wb') as f: f.write(constant_graph.SerializeToString())
需要指出的是,如果原始張量(包含在參數1和參數2中的組成部分)不參與參數3指定的輸出節(jié)點列表所指定的張量計算的話,這些張量將不會存在返回的GraphDef對象里,也不會被串行化寫入pb文件。
二、恢復:
恢復時,創(chuàng)建一個GraphDef,然后從上述的文件里加載進來,接著輸入到當前的session:
graph0 = tf.GraphDef() with open( pbName, mode='rb') as f: graph0.ParseFromString( f.read() ) tf.import_graph_def( graph0 , name = '' )
三、代碼:
import tensorflow as tf from tensorflow.python.framework import graph_util pbName = 'graphA.pb' def graphCreate() : with tf.Session() as sess : var1 = tf.placeholder ( tf.int32 , name='var1' ) var2 = tf.Variable( 20 , name='var2' )#實參name='var2'指定了操作名,該操作返回的張量名是在 #'var2'后面:0 ,即var2:0 是返回的張量名,也就是說變量 # var2的名稱是'var2:0' var3 = tf.Variable( 30 , name='var3' ) var4 = tf.Variable( 40 , name='var4' ) var4op = tf.assign( var4 , 1000 , name = 'var4op1' ) sum = tf.Variable( 4, name='sum' ) sum = tf.add ( var1 , var2, name = 'var1_var2' ) sum = tf.add( sum , var3 , name='sum_var3' ) sumOps = tf.add( sum , var4 , name='sum_operation' ) oper = tf.get_default_graph().get_operations() with open( 'operation.csv','wt' ) as f: s = 'name,type,output\n' f.write( s ) for o in oper: s = o.name s += ','+ o.type inp = o.inputs oup = o.outputs for iip in inp : s #s += ','+ str(iip) for iop in oup : s += ',' + str(iop) s += '\n' f.write( s ) for var in tf.global_variables(): print('variable=> ' , var.name) #張量是tf.Variable/tf.Add之類操作的結果, #張量的名字使用操作名加:0來表示 init = tf.global_variables_initializer() sess.run( init ) sess.run( var4op ) print('sum_operation result is Tensor ' , sess.run( sumOps , feed_dict={var1:1}) ) constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] ) with open( pbName, mode='wb') as f: f.write(constant_graph.SerializeToString()) def graphGet() : print("start get:" ) with tf.Graph().as_default(): graph0 = tf.GraphDef() with open( pbName, mode='rb') as f: graph0.ParseFromString( f.read() ) tf.import_graph_def( graph0 , name = '' ) with tf.Session() as sess : init = tf.global_variables_initializer() sess.run(init) v1 = sess.graph.get_tensor_by_name('var1:0' ) v2 = sess.graph.get_tensor_by_name('var2:0' ) v3 = sess.graph.get_tensor_by_name('var3:0' ) v4 = sess.graph.get_tensor_by_name('var4:0' ) sumTensor = sess.graph.get_tensor_by_name("sum_operation:0") print('sumTensor is : ' , sumTensor ) print( sess.run( sumTensor , feed_dict={v1:1} ) ) graphCreate() graphGet()
四、保存pb函數代碼里的操作名稱/類型/返回的張量:
operation name | operation type | output | ||
var1 | Placeholder | Tensor("var1:0" | dtype=int32) | |
var2/initial_value | Const | Tensor("var2/initial_value:0" | shape=() | dtype=int32) |
var2 | VariableV2 | Tensor("var2:0" | shape=() | dtype=int32_ref) |
var2/Assign | Assign | Tensor("var2/Assign:0" | shape=() | dtype=int32_ref) |
var2/read | Identity | Tensor("var2/read:0" | shape=() | dtype=int32) |
var3/initial_value | Const | Tensor("var3/initial_value:0" | shape=() | dtype=int32) |
var3 | VariableV2 | Tensor("var3:0" | shape=() | dtype=int32_ref) |
var3/Assign | Assign | Tensor("var3/Assign:0" | shape=() | dtype=int32_ref) |
var3/read | Identity | Tensor("var3/read:0" | shape=() | dtype=int32) |
var4/initial_value | Const | Tensor("var4/initial_value:0" | shape=() | dtype=int32) |
var4 | VariableV2 | Tensor("var4:0" | shape=() | dtype=int32_ref) |
var4/Assign | Assign | Tensor("var4/Assign:0" | shape=() | dtype=int32_ref) |
var4/read | Identity | Tensor("var4/read:0" | shape=() | dtype=int32) |
var4op1/value | Const | Tensor("var4op1/value:0" | shape=() | dtype=int32) |
var4op1 | Assign | Tensor("var4op1:0" | shape=() | dtype=int32_ref) |
sum/initial_value | Const | Tensor("sum/initial_value:0" | shape=() | dtype=int32) |
sum | VariableV2 | Tensor("sum:0" | shape=() | dtype=int32_ref) |
sum/Assign | Assign | Tensor("sum/Assign:0" | shape=() | dtype=int32_ref) |
sum/read | Identity | Tensor("sum/read:0" | shape=() | dtype=int32) |
var1_var2 | Add | Tensor("var1_var2:0" | dtype=int32) | |
sum_var3 | Add | Tensor("sum_var3:0" | dtype=int32) | |
sum_operation | Add | Tensor("sum_operation:0" | dtype=int32) |
以上這篇Tensorflow 使用pb文件保存(恢復)模型計算圖和參數實例詳解就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持億速云。
免責聲明:本站發(fā)布的內容(圖片、視頻和文字)以原創(chuàng)、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯(lián)系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。