您好,登錄后才能下訂單哦!
這篇文章將為大家詳細(xì)講解有關(guān)keras導(dǎo)入weights的方法是什么?,小編覺(jué)得挺實(shí)用的,因此分享給大家做個(gè)參考,希望大家閱讀完這篇文章后可以有所收獲。
keras源碼engine中toplogy.py定義了加載權(quán)重的函數(shù):
load_weights(self, filepath, by_name=False)
其中默認(rèn)by_name為False,這時(shí)候加載權(quán)重按照網(wǎng)絡(luò)拓?fù)浣Y(jié)構(gòu)加載,適合直接使用keras中自帶的網(wǎng)絡(luò)模型,如VGG16
VGG19/resnet50等,源碼描述如下:
If `by_name` is False (default) weights are loaded
based on the network's topology, meaning the architecture
should be the same as when the weights were saved.
Note that layers that don't have weights are not taken
into account in the topological ordering, so adding or
removing layers is fine as long as they don't have weights.
若將by_name改為T(mén)rue則加載權(quán)重按照l(shuí)ayer的name進(jìn)行,layer的name相同時(shí)加載權(quán)重,適合用于改變了
模型的相關(guān)結(jié)構(gòu)或增加了節(jié)點(diǎn)但利用了原網(wǎng)絡(luò)的主體結(jié)構(gòu)情況下使用,源碼描述如下:
If `by_name` is True, weights are loaded into layers
only if they share the same name. This is useful
for fine-tuning or transfer-learning models where
some of the layers have changed.
在進(jìn)行邊緣檢測(cè)時(shí),利用VGG網(wǎng)絡(luò)的主體結(jié)構(gòu),網(wǎng)絡(luò)中增加反卷積層,這時(shí)加載權(quán)重應(yīng)該使用
model.load_weights(filepath,by_name=True)
補(bǔ)充知識(shí):Keras下實(shí)現(xiàn)mnist手寫(xiě)數(shù)字
之前一直在用tensorflow,被同學(xué)推薦來(lái)用keras了,把之前文檔中的mnist手寫(xiě)數(shù)字?jǐn)?shù)據(jù)集拿來(lái)練手,
代碼如下。
import struct import numpy as np import os import keras from keras.models import Sequential from keras.layers import Dense from keras.optimizers import SGD def load_mnist(path, kind): labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind) images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind) with open(labels_path, 'rb') as lbpath: magic, n = struct.unpack('>II', lbpath.read(8)) labels = np.fromfile(lbpath, dtype=np.uint8) with open(images_path, 'rb') as imgpath: magic, num, rows, cols = struct.unpack(">IIII", imgpath.read(16)) images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) #28*28=784 return images, labels #loading train and test data X_train, Y_train = load_mnist('.\\data', kind='train') X_test, Y_test = load_mnist('.\\data', kind='t10k') #turn labels to one_hot code Y_train_ohe = keras.utils.to_categorical(Y_train, num_classes=10) #define models model = Sequential() model.add(Dense(input_dim=X_train.shape[1],output_dim=50,init='uniform',activation='tanh')) model.add(Dense(input_dim=50,output_dim=50,init='uniform',activation='tanh')) model.add(Dense(input_dim=50,output_dim=Y_train_ohe.shape[1],init='uniform',activation='softmax')) sgd = SGD(lr=0.001, decay=1e-7, momentum=0.9, nesterov=True) model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=["accuracy"]) #start training model.fit(X_train,Y_train_ohe,epochs=50,batch_size=300,shuffle=True,verbose=1,validation_split=0.3) #count accuracy y_train_pred = model.predict_classes(X_train, verbose=0) train_acc = np.sum(Y_train == y_train_pred, axis=0) / X_train.shape[0] print('Training accuracy: %.2f%%' % (train_acc * 100)) y_test_pred = model.predict_classes(X_test, verbose=0) test_acc = np.sum(Y_test == y_test_pred, axis=0) / X_test.shape[0] print('Test accuracy: %.2f%%' % (test_acc * 100))
訓(xùn)練結(jié)果如下:
Epoch 45/50 42000/42000 [==============================] - 1s 17us/step - loss: 0.2174 - acc: 0.9380 - val_loss: 0.2341 - val_acc: 0.9323 Epoch 46/50 42000/42000 [==============================] - 1s 17us/step - loss: 0.2061 - acc: 0.9404 - val_loss: 0.2244 - val_acc: 0.9358 Epoch 47/50 42000/42000 [==============================] - 1s 17us/step - loss: 0.1994 - acc: 0.9413 - val_loss: 0.2295 - val_acc: 0.9347 Epoch 48/50 42000/42000 [==============================] - 1s 17us/step - loss: 0.2003 - acc: 0.9413 - val_loss: 0.2224 - val_acc: 0.9350 Epoch 49/50 42000/42000 [==============================] - 1s 18us/step - loss: 0.2013 - acc: 0.9417 - val_loss: 0.2248 - val_acc: 0.9359 Epoch 50/50 42000/42000 [==============================] - 1s 17us/step - loss: 0.1960 - acc: 0.9433 - val_loss: 0.2300 - val_acc: 0.9346 Training accuracy: 94.11% Test accuracy: 93.61%
關(guān)于keras導(dǎo)入weights的方法是什么?就分享到這里了,希望以上內(nèi)容可以對(duì)大家有一定的幫助,可以學(xué)到更多知識(shí)。如果覺(jué)得文章不錯(cuò),可以把它分享出去讓更多的人看到。
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。