溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

keras保存模型中的save()和save_weights()有什么區(qū)別

發(fā)布時間:2020-07-22 17:19:57 來源:億速云 閱讀:676 作者:小豬 欄目:開發(fā)技術

小編這次要給大家分享的是keras保存模型中的save()和save_weights()有什么區(qū)別,文章內容豐富,感興趣的小伙伴可以來了解一下,希望大家閱讀完這篇文章之后能夠有所收獲。

我們知道keras的模型一般保存為后綴名為h6的文件,比如final_model.h6。同樣是h6文件用save()和save_weight()保存效果是不一樣的。

我們用宇宙最通用的數(shù)據(jù)集MNIST來做這個實驗,首先設計一個兩層全連接網(wǎng)絡:

inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)

然后,導入MNIST數(shù)據(jù)訓練,分別用兩種方式保存模型,在這里我還把未訓練的模型也保存下來,如下:

from keras.models import Model
from keras.layers import Input, Dense
from keras.datasets import mnist
from keras.utils import np_utils
 
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train=x_train.reshape(x_train.shape[0],-1)/255.0
x_test=x_test.reshape(x_test.shape[0],-1)/255.0
y_train=np_utils.to_categorical(y_train,num_classes=10)
y_test=np_utils.to_categorical(y_test,num_classes=10)
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
 
model.save('m1.h6')
model.summary()
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=10)
#loss,accuracy=model.evaluate(x_test,y_test)
 
model.save('m2.h6')
model.save_weights('m3.h6')

如上可見,我一共保存了m1.h6, m2.h6, m3.h6 這三個h6文件。那么,我們來看看這三個玩意兒有什么區(qū)別。首先,看看大小:

keras保存模型中的save()和save_weights()有什么區(qū)別

m2表示save()保存的模型結果,它既保持了模型的圖結構,又保存了模型的參數(shù)。所以它的size最大的。

m1表示save()保存的訓練前的模型結果,它保存了模型的圖結構,但應該沒有保存模型的初始化參數(shù),所以它的size要比m2小很多。

m3表示save_weights()保存的模型結果,它只保存了模型的參數(shù),但并沒有保存模型的圖結構。所以它的size也要比m2小很多。

通過可視化工具,我們發(fā)現(xiàn):(打開m1和m2均可以顯示出以下結構)

keras保存模型中的save()和save_weights()有什么區(qū)別

而打開m3的時候,可視化工具報錯了。由此可以論證, save_weights()是不含有模型結構信息的。

加載模型

兩種不同方法保存的模型文件也需要用不同的加載方法。

from keras.models import load_model
 
model = load_model('m1.h6')
#model = load_model('m2.h6')
#model = load_model('m3.h6')
model.summary()

只有加載m3.h6的時候,這段代碼才會報錯。其他輸出如下:

keras保存模型中的save()和save_weights()有什么區(qū)別

可見,由save()保存下來的h6文件才可以直接通過load_model()打開!

那么,我們保存下來的參數(shù)(m3.h6)該怎么打開呢?

這就稍微復雜一點了,因為m3不含有模型結構信息,所以我們需要把模型結構再描述一遍才可以加載m3,如下:

from keras.models import Model
from keras.layers import Input, Dense
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
model.load_weights('m3.h6')

以上把m3換成m1和m2也是沒有問題的!可見,save()保存的模型除了占用內存大一點以外,其他的優(yōu)點太明顯了。所以,在不怎么缺硬盤空間的情況下,還是建議大家多用save()來存。

注意!如果要load_weights(),必須保證你描述的有參數(shù)計算結構與h6文件中完全一致!什么叫有參數(shù)計算結構呢?就是有參數(shù)坑,直接填進去就行了。我們把上面的非參數(shù)結構換了一下,發(fā)現(xiàn)h6文件依然可以加載成功,比如將softmax換成relu,依然不影響加載。

對于keras的save()和save_weights(),完全沒問題了吧

看完這篇關于keras保存模型中的save()和save_weights()有什么區(qū)別的文章,如果覺得文章內容寫得不錯的話,可以把它分享出去給更多人看到。

向AI問一下細節(jié)

免責聲明:本站發(fā)布的內容(圖片、視頻和文字)以原創(chuàng)、轉載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權請聯(lián)系站長郵箱:is@yisu.com進行舉報,并提供相關證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權內容。

AI