您好,登錄后才能下訂單哦!
基于keras實現(xiàn)分類任務(wù)
基于keras利用VGG、ResNet、GoogleNet InceptionV3實現(xiàn)圖像的分類任務(wù),下面會給出完整代碼,但為了熟悉不同整個網(wǎng)絡(luò)的特點,建議大家自己搭建一下每個分類網(wǎng)絡(luò),畢竟利用keras搭建網(wǎng)絡(luò)還是比較簡單的。
# -*- coding: utf-8 -*-
import os
from keras.utils import plot_model
from keras.applications.resnet50 import ResNet50
from keras.applications.vgg19 import VGG19
from keras.applications.inception_v3 import InceptionV3
from keras.layers import Dense,Flatten,GlobalAveragePooling2D
from keras.models import Model,load_model
from keras.optimizers import SGD
from keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
class PowerTransferMode:
#數(shù)據(jù)準備
def DataGen(self, dir_path, img_row, img_col, batch_size, is_train):
if is_train:
datagen = ImageDataGenerator(rescale=1./255,
zoom_range=0.25, rotation_range=15.,
channel_shift_range=25., width_shift_range=0.02, height_shift_range=0.02,
horizontal_flip=True, fill_mode='constant')
else:
datagen = ImageDataGenerator(rescale=1./255)
generator = datagen.flow_from_directory(
dir_path, target_size=(img_row, img_col),
batch_size=batch_size,
#class_mode='binary',
class_mode='categorical',
shuffle=is_train)
return generator
#ResNet模型
def ResNet50_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True, is_plot_model=False):
color = 3 if RGB else 1
base_model = ResNet50(weights='imagenet', include_top=False, pooling=None, input_shape=(img_rows, img_cols, color),
classes=nb_classes)
#凍結(jié)base_model所有層,這樣就可以正確獲得bottleneck特征
for layer in base_model.layers:
layer.trainable = False
x = base_model.output
#添加自己的全鏈接分類層
x = Flatten()(x)
#x = GlobalAveragePooling2D()(x)
#x = Dense(1024, activation='relu')(x)
predictions = Dense(nb_classes, activation='softmax')(x)
#訓(xùn)練模型
model = Model(inputs=base_model.input, outputs=predictions)
sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
#繪制模型
if is_plot_model:
plot_model(model, to_file='resnet50_model.png',show_shapes=True)
return model
#VGG模型
def VGG19_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=18, img_rows=197, img_cols=197, RGB=True, is_plot_model=False):
color = 3 if RGB else 1
base_model = VGG19(weights='imagenet', include_top=False, pooling=None, input_shape=(img_rows, img_cols, color),
classes=nb_classes)
#凍結(jié)base_model所有層,這樣就可以正確獲得bottleneck特征
for layer in base_model.layers:
layer.trainable = False
x = base_model.output
#添加自己的全鏈接分類層
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(nb_classes, activation='softmax')(x)
#訓(xùn)練模型
model = Model(inputs=base_model.input, outputs=predictions)
sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
# 繪圖
if is_plot_model:
plot_model(model, to_file='vgg19_model.png',show_shapes=True)
return model
# InceptionV3模型
def InceptionV3_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=18, img_rows=197, img_cols=197, RGB=True,
is_plot_model=False):
color = 3 if RGB else 1
base_model = InceptionV3(weights='imagenet', include_top=False, pooling=None,
input_shape=(img_rows, img_cols, color),
classes=nb_classes)
# 凍結(jié)base_model所有層,這樣就可以正確獲得bottleneck特征
for layer in base_model.layers:
layer.trainable = False
x = base_model.output
# 添加自己的全鏈接分類層
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
#x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(nb_classes, activation='softmax')(x)
# 訓(xùn)練模型
model = Model(inputs=base_model.input, outputs=predictions)
sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
# 繪圖
if is_plot_model:
plot_model(model, to_file='inception_v3_model.png', show_shapes=True)
return model
#訓(xùn)練模型
def train_model(self, model, epochs, train_generator, steps_per_epoch, validation_generator, validation_steps, model_url, is_load_model=False):
# 載入模型鄭州人流醫(yī)院 http://www.zykdfk.com/
if is_load_model and os.path.exists(model_url):
model = load_model(model_url)
history_ft = model.fit_generator(
train_generator,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
validation_data=validation_generator,
validation_steps=validation_steps)
# 模型保存
model.save(model_url,overwrite=True)
return history_ft
# 畫圖
def plot_training(self, history):
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(len(acc))
plt.plot(epochs, acc, 'b-')
plt.plot(epochs, val_acc, 'r')
plt.title('Training and validation accuracy')
plt.figure()
plt.plot(epochs, loss, 'b-')
plt.plot(epochs, val_loss, 'r-')
plt.title('Training and validation loss')
plt.show()
if __name__ == '__main__':
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
image_size = 256
batch_size = 32
epo = 1000
transfer = PowerTransferMode()
num_train = 14470
num_test = 3215
#得到數(shù)據(jù)
train_generator = transfer.DataGen('/home/jjin/skin_diagnosis/class_test_3/train/', image_size, image_size, batch_size, True)
validation_generator = transfer.DataGen("/home/jjin/skin_diagnosis/class_test_3/test/", image_size, image_size, batch_size, False)
#VGG19
model = transfer.VGG19_model(nb_classes=18, img_rows=image_size, img_cols=image_size, is_plot_model=False)
history_ft1 = transfer.train_model(model, epo, train_generator, num_train//batch_size, validation_generator, num_test//batch_size, 'vgg19_model_weights.h6', is_load_model=False)
#ResNet50
#model = transfer.ResNet50_model(nb_classes=2, img_rows=image_size, img_cols=image_size, is_plot_model=False)
#history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60, 'resnet50_model_weights.h6', is_load_model=False)
#InceptionV3
model = transfer.InceptionV3_model(nb_classes=18, img_rows=image_size, img_cols=image_size, is_plot_model=True)
# 分多次進行訓(xùn)練,沒訓(xùn)練100次,保存一下模型
for _ in range(10):
history_ft2 = transfer.train_model(model, 100, train_generator, num_train//batch_size, validation_generator, num_test//batch_size, 'inception_v3_model_weights .h6', is_load_model=False)
# 訓(xùn)練的acc_loss圖
transfer.plot_training(history_ft1)
transfer.plot_training(history_ft2)
在這里有幾點要提醒一下,雖然3個網(wǎng)絡(luò)都搭建出來了,但我只訓(xùn)練了其中的兩個網(wǎng)絡(luò),其中,在訓(xùn)練InceptionV3時,我把訓(xùn)練過程分為了10個循環(huán),每個循環(huán)的epoch是100,這是為了每一個循環(huán)后都能保存一下模型,而不至于因為某些原因,導(dǎo)致訓(xùn)練中斷,模型沒有保存下來。
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進行舉報,并提供相關(guān)證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權(quán)內(nèi)容。