溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶服務(wù)條款》

TensorFlow如何將ckpt文件固化成pb文件

發(fā)布時(shí)間:2021-08-13 08:27:38 來源:億速云 閱讀:382 作者:小新 欄目:開發(fā)技術(shù)

小編給大家分享一下TensorFlow如何將ckpt文件固化成pb文件,相信大部分人都還不怎么了解,因此分享這篇文章給大家參考一下,希望大家閱讀完這篇文章后大有收獲,下面讓我們一起去了解一下吧!

將yolo3目標(biāo)檢測(cè)框架訓(xùn)練出來的ckpt文件固化成pb文件,主要利用了GitHub上的該項(xiàng)目。

為什么要最終生成pb文件呢?簡(jiǎn)單來說就是直接通過tf.saver保存行程的ckpt文件其變量數(shù)據(jù)和圖是分開的。我們知道TensorFlow是先畫圖,然后通過placeholde往圖里面喂數(shù)據(jù)。這種解耦形式存在的方法對(duì)以后的遷移學(xué)習(xí)以及對(duì)程序進(jìn)行微小的改動(dòng)提供了極大的便利性。但是對(duì)于訓(xùn)練好,以后不再改變的話這種存在就不再需要。一方面,ckpt文件儲(chǔ)存的數(shù)據(jù)都是變量,既然我們不再改動(dòng),就應(yīng)當(dāng)讓其變成常量,直接‘燒'到圖里面。另一方面,對(duì)于線上的模型,我們一般是通過C++或者C語言編寫的程序進(jìn)行調(diào)用。所以一般模型最終形式都是應(yīng)該寫成pb文件的形式。

由于這次的程序直接從GitHub上下載后改動(dòng)較小就能夠運(yùn)行,也就是自己寫了很少一部分程序。因此進(jìn)行調(diào)試的時(shí)候還出現(xiàn)了以前根本沒有注意的一些小問題,同時(shí)發(fā)現(xiàn)自己對(duì)TensorFlow還需要更加詳細(xì)的去研讀。

首先對(duì)程序進(jìn)行保存的時(shí)候,利用 saver = tf.train.Saver(), saver.save(sess,checkpoint_path,global_step=global_step)對(duì)訓(xùn)練的數(shù)據(jù)進(jìn)行保存,保存格式為ckpt。但是在恢復(fù)的時(shí)候一直提示有問題,(其恢復(fù)語句為:saver = tf.train.Saver(), saver.restore(sess,ckpt_path),其中,ckpt_path是保存ckpt的文件夾路徑)。出現(xiàn)問題的原因我估計(jì)是因?yàn)槲沂前凑彰?0個(gè)epoch進(jìn)行保存,而不是讓其進(jìn)行固定次數(shù)的batch進(jìn)行保存,這種固定batch次數(shù)的保存系統(tǒng)會(huì)自動(dòng)保存最近5次的ckpt文件(該方法的ckpt_path=tf.train,latest_checkpoint('ckpt/')進(jìn)行回復(fù))。那么如何將利用epoch的次數(shù)進(jìn)行保存呢(這種保存不是近5次的保存,而是每進(jìn)行一次保存就會(huì)留下當(dāng)時(shí)保存的ckpt,而那種按照batch的會(huì)在第n次保存,會(huì)將n-5次的刪除,n>5)。

我們可以利用:ckpt = tf.train.get_checkpoint_state(ckpt_path),獲取最新的ckptpoint文件,然后利用saver.restore(sess,ckpt.checkpoint_path)進(jìn)行恢復(fù)。當(dāng)然為了安全起見,應(yīng)該對(duì)ckpt和ckpt.checkpoint_path進(jìn)行判斷是否存在后,再進(jìn)行恢復(fù)語句的調(diào)用,建議打開ckptpoint看一下,里面記錄的最近五次的model的路徑,一目了然。即:

  saver = tf.train.Saver()
  ckpt = tf.train.get_checkpoint_state(model_path)
  if ckpt and ckpt.model_checkpoint_path:
    saver.restore(sess, ckpt.model_checkpoint_path)

對(duì)于固化網(wǎng)絡(luò),網(wǎng)上有很多的介紹。之所以再介紹,還是由于是用了別人的網(wǎng)絡(luò)而不是自己的網(wǎng)絡(luò)遇到的坑。在固化時(shí)候我們需要知道輸出tensor的名字,而再恢復(fù)的時(shí)候我們需要知道placeholder的名字。但是,如果網(wǎng)絡(luò)復(fù)雜或者別人的網(wǎng)絡(luò)命名比較復(fù)雜,或者name=,根本就沒有自己命名而用的系統(tǒng)自定義的,這樣捋起來還是比較費(fèi)勁的。當(dāng)時(shí)在網(wǎng)上查找的一些方法,像打印整個(gè)網(wǎng)絡(luò)變量的方法(先不管輸出的網(wǎng)路名稱,甚至隨便起一個(gè)名字,先固化好pb文件,然后對(duì)pb文件進(jìn)行讀取,最后打印操作的名字:

 graph = tf.get_default_graph()
  input_graph_def = graph.as_graph_def()
 
  output_graph_def = graph_util.convert_variables_to_constants(
    sess,
    input_graph_def,
    ['cls_score/cls_score', 'cls_prob'] # We split on comma for convenience
  )
  with tf.gfile.GFile(output_graph, "wb") as f:
    f.write(output_graph_def.SerializeToString())
  print ('開始打印節(jié)點(diǎn)名字')
  for op in graph.get_operations():
    print(op.name)
  print("%d ops in the final graph." % len(output_graph_def.node))

代碼一

這樣盡然也能打印出來(盡管輸出名字是隨便命名的)。但是打印出來的是所有的節(jié)點(diǎn)的名字,簡(jiǎn)直不要太多。這樣找的話,一方面可能找不對(duì),另一方面也太費(fèi)事。

那么怎么辦?答案簡(jiǎn)單的讓我也很無語。其實(shí),對(duì)ckpt進(jìn)行數(shù)據(jù)恢復(fù)的時(shí)候,直接打印輸出的tensor名字就可以。比如說在saver以及placeholder定義的時(shí)候:output = model.yolo_inference(images, config.num_anchors / 3, config.num_classes, is_training),我們?cè)诤竺娓痪洌簆rint output,從打印出來的信息即可查看。placeholder的查看方法同樣如此。

對(duì)網(wǎng)絡(luò)進(jìn)行固化:

代碼:

  input_image_shape = tf.placeholder(dtype = tf.int32, shape = (2,))
  input_image = tf.placeholder(shape = [None, 416, 416, 3], dtype = tf.float32)
  predictor = yolo_predictor(config.obj_threshold, config.nms_threshold, config.classes_path, config.anchors_path)
  boxes, scores, classes = predictor.predict(input_image, input_image_shape)
  sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
  saver = tf.train.Saver()
  ckpt = tf.train.get_checkpoint_state(model_path)
  if ckpt and ckpt.model_checkpoint_path:
    saver.restore(sess, ckpt.model_checkpoint_path)
 
  # 采用meta 結(jié)構(gòu)加載,不需要知道網(wǎng)絡(luò)結(jié)構(gòu)
  # saver = tf.train.import_meta_graph(model_path, clear_devices=True) 
  # 這里的model_path是model.ckpt.meta文件的全路徑
  # ckpt_model_path 是保存模型的文件夾路徑
  # saver.restore(sess, tf.train.latest_checkpoint(ckpt_model_path))
 
  graph = tf.get_default_graph()
  input_graph_def = graph.as_graph_def()
  output_graph_def = graph_util.convert_variables_to_constants(
    sess,
    input_graph_def,
    ['concat_11','concat_12','concat_13'] # We split on comma for convenience
  )
  # # Finally we serialize and dump the output graph to the filesystem
  with tf.gfile.GFile(output_graph, "wb") as f:
    f.write(output_graph_def.SerializeToString())

由于固化的時(shí)候是需要先恢復(fù)ckpt網(wǎng)絡(luò)的,所以還是在restore前寫了placeholder和輸出tensor的定義(需要注點(diǎn)意的是,我們保存的ckpt文件是訓(xùn)練階段的graph和變量等,其inference輸出和最終predict的輸出的Tensor不一樣,因此predict與inference的輸出相比,還包括了一些后處理,比如說nms等等,只有這些后處理也是TensorFlow框架內(nèi)的方法寫的,才能使最終形成的pb文件能夠做到輸入一張圖片,直接輸出最終結(jié)果。因此,對(duì)于目標(biāo)檢測(cè)任務(wù),把后處理任務(wù)也交由TensorFlow內(nèi)的api來實(shí)現(xiàn),可免去夸平臺(tái)讀取pb文件后仍然需要重新進(jìn)行后處理等相關(guān)程序的編寫帶來的不必要麻煩)。然后結(jié)合保存變量的那個(gè)文件(ckpt),將變量恢復(fù)到inference過程所需的變量數(shù)據(jù)(predict包括inference和eval兩個(gè)過程,訓(xùn)練過程只有inference和loss過程參與,而預(yù)測(cè)過程多了一個(gè)后處理eval過程,eval過程無變量。這樣在生成pb文件的時(shí)候也把后處理eval固化進(jìn)去。喂給網(wǎng)絡(luò)數(shù)據(jù),即可得到輸出tensor。

由于有讀者在此問到了還是沒有弄明白'concat_11','concat_12','concat_13'是如何得來的,我在這里就在詳細(xì)說一下:

是這樣的,在我們恢復(fù)網(wǎng)絡(luò)的時(shí)候肯定需要知道saver這個(gè)對(duì)象的,在這里介紹兩種方法生成這個(gè)對(duì)象的方法。

一:

saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True)

其中meta_graph_location就是保存模型時(shí)的.meta文件的路徑。保存后有四個(gè)文件(checkpoint、.index、.data-00000-of-00001和.meta文件)。.meta文件就是整個(gè)TensorFlow的結(jié)構(gòu)圖。

二:

saver = tf.train.Saver()

本文采用的是第二種方法(上面已經(jīng)有詳細(xì)的代碼),由于這種方法得到的saver對(duì)象,他不知道具體圖是什么樣的,因此在恢復(fù)前我有用如下代碼

predictor = yolo_predictor(config.obj_threshold, config.nms_threshold, config.classes_path, config.anchors_path)
boxes, scores, classes = predictor.predict(input_image, input_image_shape)

把整個(gè)結(jié)構(gòu)又加載了一遍。如果采用第一種方法,是不需要在重寫這兩行代碼的。

我們要的就是 boxes, scores, classes這三個(gè)tensor的結(jié)果,并且想知道他們?nèi)齻€(gè)tensor的名字。你直接利用print(boxes, scores, classes)打印出來這三個(gè)tensor就會(huì)出來這三個(gè)tensor具體信息(包括名字,和shape,dtype等)。這個(gè)只是利用第二種方法得到saver對(duì)象,然后恢復(fù)ckpt文件,不涉及到固化pb文件問題。固化pb文件是需要知道這三個(gè)tensor的名字,所以需要打印看一下。

如果說,我只拿到了保存后的四個(gè)文件(checkpoint、.index、.data-00000-of-00001和.meta文件),其相應(yīng)用代碼寫成的結(jié)構(gòu)圖不清楚,比如說利用這兩行代碼:

predictor = yolo_predictor(config.obj_threshold, config.nms_threshold, config.classes_path, config.anchors_path)
boxes, scores, classes = predictor.predict(input_image, input_image_shape)

畫出的結(jié)構(gòu)圖是什么樣的,我不知道。那么,想要知道具體的placehold和輸出tensor的名字,那只能通過代碼一中,打印出所有的OP操作節(jié)點(diǎn),然后進(jìn)行人工遍歷了。

讀取pb文件:

代碼:

def pb_detect(image_path, pb_model_path):
 
  os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu_index
  image = Image.open(image_path)
  resize_image = letterbox_image(image, (416, 416))
  image_data = np.array(resize_image, dtype = np.float32)
  image_data /= 255.
  image_data = np.expand_dims(image_data, axis = 0)
  with tf.Graph().as_default():
    output_graph_def = tf.GraphDef()
    with open(pb_model_path, "rb") as f:
      output_graph_def.ParseFromString(f.read())
      tf.import_graph_def(output_graph_def, name="")
    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      input_image_tensor = sess.graph.get_tensor_by_name("Placeholder_1:0")
      input_image_tensor_shape = sess.graph.get_tensor_by_name("Placeholder:0")
      # 定義輸出的張量名稱
      #output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")
      boxes = sess.graph.get_tensor_by_name("concat_11:0")
      scores = sess.graph.get_tensor_by_name("concat_12:0")
      classes = sess.graph.get_tensor_by_name("concat_13:0")
      # 讀取測(cè)試圖片
      # 測(cè)試讀出來的模型是否正確,注意這里傳入的是輸出和輸入節(jié)點(diǎn)的tensor的名字(需要在名字后面加:0),不是操作節(jié)點(diǎn)的名字
      out_boxes, out_scores, out_classes= sess.run([boxes,scores,classes],
              feed_dict={
                input_image_tensor: image_data,
                input_image_tensor_shape: [image.size[1], image.size[0]]
      })

可以看到讀取pb文件只需要比恢復(fù)ckpt文件容易的多,直接將placeholder的名字獲取到,將數(shù)據(jù)輸入恢復(fù)的網(wǎng)絡(luò),以及讀取輸出即可。

小記:

有可能是TensorFlow版本更新或者其他原因,在后來工作中加載pb文件是報(bào)錯(cuò)了:

ValueError: Fetch argument <tf.Tensor 'shuffle_batch:0' shape=(1, 300, 1024) dtype=float32> cannot be interpreted as a Tensor. (tf.Tensor 'shuffle_batch:0' shape=(1, 300, 1024), dtype=float32) is not an element of this graph.)

將上面讀取pb文件的代碼with tf.Graph().as_default():改成

global graph
graph = tf.get_default_graph()
with graph.as_default():

以上是“TensorFlow如何將ckpt文件固化成pb文件”這篇文章的所有內(nèi)容,感謝各位的閱讀!相信大家都有了一定的了解,希望分享的內(nèi)容對(duì)大家有所幫助,如果還想學(xué)習(xí)更多知識(shí),歡迎關(guān)注億速云行業(yè)資訊頻道!

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI