溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶服務(wù)條款》

Keras模型轉(zhuǎn)成tensorflow中.pb的方法

發(fā)布時(shí)間:2020-07-07 11:13:33 來源:億速云 閱讀:405 作者:清晨 欄目:開發(fā)技術(shù)

不懂Keras模型轉(zhuǎn)成tensorflow中.pb的方法?其實(shí)想解決這個(gè)問題也不難,下面讓小編帶著大家一起學(xué)習(xí)怎么去解決,希望大家閱讀完這篇文章后大所收獲。

Keras的.h6模型轉(zhuǎn)成tensorflow的.pb格式模型,方便后期的前端部署。直接上代碼

from keras.models import Model
from keras.layers import Dense, Dropout
from keras.applications.mobilenet import MobileNet
from keras.applications.mobilenet import preprocess_input
from keras.preprocessing.image import load_img, img_to_array
import tensorflow as tf
from keras import backend as K
import os
 
base_model = MobileNet((None, None, 3), alpha=1, include_top=False, pooling='avg', weights=None)
x = Dropout(0.75)(base_model.output)
x = Dense(10, activation='softmax')(x)
 
model = Model(base_model.input, x)
model.load_weights('mobilenet_weights.h6')
 
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
 from tensorflow.python.framework.graph_util import convert_variables_to_constants
 graph = session.graph
 with graph.as_default():
  freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
  output_names = output_names or []
  output_names += [v.op.name for v in tf.global_variables()]
  input_graph_def = graph.as_graph_def()
  if clear_devices:
   for node in input_graph_def.node:
    node.device = ""
  frozen_graph = convert_variables_to_constants(session, input_graph_def,
             output_names, freeze_var_names)
  return frozen_graph
 
output_graph_name = 'NIMA.pb'
output_fld = ''
#K.set_learning_phase(0)
 
print('input is :', model.input.name)
print ('output is:', model.output.name)
 
sess = K.get_session()
frozen_graph = freeze_session(K.get_session(), output_names=[model.output.op.name])
 
from tensorflow.python.framework import graph_io
graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=False)
print('saved the constant graph (ready for inference) at: ', os.path.join(output_fld, output_graph_name))

補(bǔ)充知識(shí):keras h6 model 轉(zhuǎn)換為tflite

在移動(dòng)端的模型,若選擇tensorflow或者keras最基本的就是生成tflite文件,以本文記錄一次轉(zhuǎn)換過程。

環(huán)境

tensorflow 1.12.0

python 3.6.5

h6 model saved by `model.save('tf.h6')`

直接轉(zhuǎn)換

`tflite_convert --output_file=tf.tflite --keras_model_file=tf.h6`
output
`TypeError: __init__() missing 2 required positional arguments: 'filters' and 'kernel_size'`

先轉(zhuǎn)成pb再轉(zhuǎn)tflite

```

git clone git@github.com:amir-abdi/keras_to_tensorflow.git
cd keras_to_tensorflow
python keras_to_tensorflow.py --input_model=path/to/tf.h6 --output_model=path/to/tf.pb
tflite_convert \

 --output_file=tf.tflite \
 --graph_def_file=tf.pb \
 --input_arrays=convolution2d_1_input \
 --output_arrays=dense_3/BiasAdd \
 --input_shape=1,3,448,448
```

參數(shù)說明,input_arrays和output_arrays是model的起始輸入變量名和結(jié)束變量名,input_shape是和input_arrays對(duì)應(yīng)

官網(wǎng)是說需要用到tenorboard來查看,一個(gè)比較trick的方法

先執(zhí)行上面的命令,會(huì)報(bào)convolution2d_1_input找不到,在堆棧里面有convert_saved_model.py文件,get_tensors_from_tensor_names()這個(gè)方法,添加`print(list(tensor_name_to_tensor))` 到 tensor_name_to_tensor 這個(gè)變量下面,再執(zhí)行一遍,會(huì)打印出所有tensor的名字,再根據(jù)自己的模型很容易就能判斷出實(shí)際的name。

感謝你能夠認(rèn)真閱讀完這篇文章,希望小編分享Keras模型轉(zhuǎn)成tensorflow中.pb的方法內(nèi)容對(duì)大家有幫助,同時(shí)也希望大家多多支持億速云,關(guān)注億速云行業(yè)資訊頻道,遇到問題就找億速云,詳細(xì)的解決方法等著你來學(xué)習(xí)!

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI