您好,登錄后才能下訂單哦!
這篇文章主要介紹了keras回調(diào)函數(shù)如何使用的相關(guān)知識,內(nèi)容詳細易懂,操作簡單快捷,具有一定借鑒價值,相信大家閱讀完這篇keras回調(diào)函數(shù)如何使用文章都會有所收獲,下面我們一起來看看吧。
回調(diào)函數(shù)是一個對象(實現(xiàn)了特定方法的類實例),它在調(diào)用fit()時被傳入模型,并在訓(xùn)練過程中的不同時間點被模型調(diào)用
可以訪問關(guān)于模型狀態(tài)與模型性能的所有可用數(shù)據(jù)
模型檢查點(model checkpointing):在訓(xùn)練過程中的不同時間點保存模型的當前狀態(tài)。
提前終止(early stopping):如果驗證損失不再改善,則中斷訓(xùn)練(當然,同時保存在訓(xùn)練過程中的最佳模型)。
在訓(xùn)練過程中動態(tài)調(diào)節(jié)某些參數(shù)值:比如調(diào)節(jié)優(yōu)化器的學(xué)習(xí)率。
在訓(xùn)練過程中記錄訓(xùn)練指標和驗證指標,或者將模型學(xué)到的表示可視化(這些表示在不斷更新):fit()進度條實際上就是一個回調(diào)函數(shù)。
# 這里有兩個callback函數(shù):早停和模型檢查點 callbacks_list=[ keras.callbacks.EarlyStopping( monitor="val_accuracy",#監(jiān)控指標 patience=2 #兩輪內(nèi)不再改善中斷訓(xùn)練 ), keras.callbacks.ModelCheckpoint( filepath="checkpoint_path", monitor="val_loss", save_best_only=True ) ] #模型獲取 model=get_minist_model() model.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"]) model.fit(train_images,train_labels, epochs=10,callbacks=callbacks_list, #該參數(shù)使用回調(diào)函數(shù) validation_data=(val_images,val_labels)) test_metrics=model.evaluate(test_images,test_labels)#計算模型在新數(shù)據(jù)上的損失和指標 predictions=model.predict(test_images)#計算模型在新數(shù)據(jù)上的分類概率
#也可以在訓(xùn)練完成后手動保存模型,只需調(diào)用model.save('my_checkpoint_path')。 #重新加載模型 model_new=keras.models.load_model("checkpoint_path.keras")
on_epoch_begin(epoch, logs) ←----在每輪開始時被調(diào)用
on_epoch_end(epoch, logs) ←----在每輪結(jié)束時被調(diào)用
on_batch_begin(batch, logs) ←----在處理每個批量之前被調(diào)用
on_batch_end(batch, logs) ←----在處理每個批量之后被調(diào)用
on_train_begin(logs) ←----在訓(xùn)練開始時被調(diào)用
on_train_end(logs ←----在訓(xùn)練結(jié)束時被調(diào)用
from matplotlib import pyplot as plt # 實現(xiàn)記錄每一輪中每個batch訓(xùn)練后的損失,并為每個epoch繪制一個圖 class LossHistory(keras.callbacks.Callback): def on_train_begin(self, logs): self.per_batch_losses = [] def on_batch_end(self, batch, logs): self.per_batch_losses.append(logs.get("loss")) def on_epoch_end(self, epoch, logs): plt.clf() plt.plot(range(len(self.per_batch_losses)), self.per_batch_losses, label="Training loss for each batch") plt.xlabel(f"Batch (epoch {epoch})") plt.ylabel("Loss") plt.legend() plt.savefig(f"plot_at_epoch_{epoch}") self.per_batch_losses = [] #清空,方便下一輪的技術(shù)
model = get_mnist_model() model.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"]) model.fit(train_images, train_labels, epochs=10, callbacks=[LossHistory()], validation_data=(val_images, val_labels))
def get_minist_model(): inputs=keras.Input(shape=(28*28,)) features=layers.Dense(512,activation="relu")(inputs) features=layers.Dropout(0.5)(features) outputs=layers.Dense(10,activation="softmax")(features) model=keras.Model(inputs,outputs) return model #datset from tensorflow.keras.datasets import mnist (train_images,train_labels),(test_images,test_labels)=mnist.load_data() train_images=train_images.reshape((60000,28*28)).astype("float32")/255 test_images=test_images.reshape((10000,28*28)).astype("float32")/255 train_images,val_images=train_images[10000:],train_images[:10000] train_labels,val_labels=train_labels[10000:],train_labels[:10000]
關(guān)于“keras回調(diào)函數(shù)如何使用”這篇文章的內(nèi)容就介紹到這里,感謝各位的閱讀!相信大家對“keras回調(diào)函數(shù)如何使用”知識都有一定的了解,大家如果還想學(xué)習(xí)更多知識,歡迎關(guān)注億速云行業(yè)資訊頻道。
免責聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進行舉報,并提供相關(guān)證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權(quán)內(nèi)容。