您好,登錄后才能下訂單哦!
小編這次要給大家分享的是keras保存模型中的save()和save_weights()有什么區(qū)別,文章內容豐富,感興趣的小伙伴可以來了解一下,希望大家閱讀完這篇文章之后能夠有所收獲。
我們知道keras的模型一般保存為后綴名為h6的文件,比如final_model.h6。同樣是h6文件用save()和save_weight()保存效果是不一樣的。
我們用宇宙最通用的數(shù)據(jù)集MNIST來做這個實驗,首先設計一個兩層全連接網(wǎng)絡:
inputs = Input(shape=(784, )) x = Dense(64, activation='relu')(inputs) x = Dense(64, activation='relu')(x) y = Dense(10, activation='softmax')(x) model = Model(inputs=inputs, outputs=y)
然后,導入MNIST數(shù)據(jù)訓練,分別用兩種方式保存模型,在這里我還把未訓練的模型也保存下來,如下:
from keras.models import Model from keras.layers import Input, Dense from keras.datasets import mnist from keras.utils import np_utils (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train=x_train.reshape(x_train.shape[0],-1)/255.0 x_test=x_test.reshape(x_test.shape[0],-1)/255.0 y_train=np_utils.to_categorical(y_train,num_classes=10) y_test=np_utils.to_categorical(y_test,num_classes=10) inputs = Input(shape=(784, )) x = Dense(64, activation='relu')(inputs) x = Dense(64, activation='relu')(x) y = Dense(10, activation='softmax')(x) model = Model(inputs=inputs, outputs=y) model.save('m1.h6') model.summary() model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) model.fit(x_train, y_train, batch_size=32, epochs=10) #loss,accuracy=model.evaluate(x_test,y_test) model.save('m2.h6') model.save_weights('m3.h6')
如上可見,我一共保存了m1.h6, m2.h6, m3.h6 這三個h6文件。那么,我們來看看這三個玩意兒有什么區(qū)別。首先,看看大小:
m2表示save()保存的模型結果,它既保持了模型的圖結構,又保存了模型的參數(shù)。所以它的size最大的。
m1表示save()保存的訓練前的模型結果,它保存了模型的圖結構,但應該沒有保存模型的初始化參數(shù),所以它的size要比m2小很多。
m3表示save_weights()保存的模型結果,它只保存了模型的參數(shù),但并沒有保存模型的圖結構。所以它的size也要比m2小很多。
通過可視化工具,我們發(fā)現(xiàn):(打開m1和m2均可以顯示出以下結構)
而打開m3的時候,可視化工具報錯了。由此可以論證, save_weights()是不含有模型結構信息的。
加載模型
兩種不同方法保存的模型文件也需要用不同的加載方法。
from keras.models import load_model model = load_model('m1.h6') #model = load_model('m2.h6') #model = load_model('m3.h6') model.summary()
只有加載m3.h6的時候,這段代碼才會報錯。其他輸出如下:
可見,由save()保存下來的h6文件才可以直接通過load_model()打開!
那么,我們保存下來的參數(shù)(m3.h6)該怎么打開呢?
這就稍微復雜一點了,因為m3不含有模型結構信息,所以我們需要把模型結構再描述一遍才可以加載m3,如下:
from keras.models import Model from keras.layers import Input, Dense inputs = Input(shape=(784, )) x = Dense(64, activation='relu')(inputs) x = Dense(64, activation='relu')(x) y = Dense(10, activation='softmax')(x) model = Model(inputs=inputs, outputs=y) model.load_weights('m3.h6')
以上把m3換成m1和m2也是沒有問題的!可見,save()保存的模型除了占用內存大一點以外,其他的優(yōu)點太明顯了。所以,在不怎么缺硬盤空間的情況下,還是建議大家多用save()來存。
注意!如果要load_weights(),必須保證你描述的有參數(shù)計算結構與h6文件中完全一致!什么叫有參數(shù)計算結構呢?就是有參數(shù)坑,直接填進去就行了。我們把上面的非參數(shù)結構換了一下,發(fā)現(xiàn)h6文件依然可以加載成功,比如將softmax換成relu,依然不影響加載。
對于keras的save()和save_weights(),完全沒問題了吧
看完這篇關于keras保存模型中的save()和save_weights()有什么區(qū)別的文章,如果覺得文章內容寫得不錯的話,可以把它分享出去給更多人看到。
免責聲明:本站發(fā)布的內容(圖片、視頻和文字)以原創(chuàng)、轉載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權請聯(lián)系站長郵箱:is@yisu.com進行舉報,并提供相關證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權內容。