您好,登錄后才能下訂單哦!
這篇文章主要為大家展示了“python中如何實(shí)現(xiàn)tensorflow實(shí)現(xiàn)斑馬線識(shí)別功能”,內(nèi)容簡(jiǎn)而易懂,條理清晰,希望能夠幫助大家解決疑惑,下面讓小編帶領(lǐng)大家一起研究并學(xué)習(xí)一下“python中如何實(shí)現(xiàn)tensorflow實(shí)現(xiàn)斑馬線識(shí)別功能”這篇文章吧。
數(shù)據(jù)集的構(gòu)成:
test | train |
---|---|
zebra corssing:56 | zebra corssing:168 |
other:54 | other:164 |
import tensorflow as tf from tensorflow.keras.preprocessing.image import ImageDataGenerator import numpy as np import matplotlib.pyplot as plt import keras
train_dir=r'C:\Users\zx\深度學(xué)習(xí)\Zebra\train' test_dir=r'C:\Users\zx\深度學(xué)習(xí)\Zebra\test' train_datagen = ImageDataGenerator(rescale=1/255, rotation_range=10, #旋轉(zhuǎn) horizontal_flip=True) train_generator = train_datagen.flow_from_directory(train_dir, (50,50), batch_size=1, class_mode='binary', shuffle=False) test_datagen = ImageDataGenerator(rescale=1/255) test_generator = test_datagen.flow_from_directory(test_dir, (50,50), batch_size=1, class_mode='binary', shuffle=False)
模型的建立仁者見(jiàn)智,可自己調(diào)節(jié)尋找更好的模型。
model = tf.keras.models.Sequential([ # 第一層卷積,卷積核為,共16個(gè),輸入為150*150*1 tf.keras.layers.Conv2D(16,(3,3),activation='relu',padding='same',input_shape=(50,50,3)), tf.keras.layers.MaxPooling2D((2,2)), # 第二層卷積,卷積核為3*3,共32個(gè), tf.keras.layers.Conv2D(32,(3,3),activation='relu'), tf.keras.layers.MaxPooling2D((2,2)), # 第三層卷積,卷積核為3*3,共64個(gè), tf.keras.layers.Conv2D(64,(3,3),activation='relu'), tf.keras.layers.MaxPooling2D((2,2)), # 第四層卷積,卷積核為3*3,共128個(gè) # tf.keras.layers.Conv2D(128,(3,3),activation='relu'), # tf.keras.layers.MaxPooling2D((2,2)), # 數(shù)據(jù)鋪平 tf.keras.layers.Flatten(), tf.keras.layers.Dense(32,activation='relu'), tf.keras.layers.Dense(16,activation='relu'), tf.keras.layers.Dense(2,activation='softmax') ]) print(model.summary()) model.compile(optimize='adam', loss=tf.keras.losses.sparse_categorical_crossentropy, metrics=['acc'])
history = model.fit(train_generator, epochs=20, verbose=1) model.save('./Zebra.h6')
模型訓(xùn)練過(guò)程:
可以看到我們的模型在20輪的訓(xùn)練后acc從0.63上升到了0.96左右。
model.evaluate(test_generator)
#可視化 plt.plot(history.history['acc'], label='accuracy') plt.xlabel('Epoch') plt.ylabel('Accuracy') plt.ylim([0.7, 1]) plt.legend(loc='lower right') plt.title('acc') plt.show()
雖然我們的模型在訓(xùn)練過(guò)程中acc一度達(dá)到0.96,但測(cè)試集才是檢驗(yàn)?zāi)P偷奈ㄒ粯?biāo)準(zhǔn),在model.evaluate(test_generator)中的評(píng)分只有0.91左右,說(shuō)明我們的模型已經(jīng)能以很高的正確率來(lái)完成”斑馬線“與“非斑馬線”的二分類問(wèn)題了,但我們還是要查看具體是哪些數(shù)據(jù)沒(méi)有被模型正確得識(shí)別。
pred=model.predict(test_generator) #獲取test集的輸出 filenames = test_generator.filenames #獲取test數(shù)據(jù)的文件名
錯(cuò)誤輸出過(guò)程:
1,循環(huán)測(cè)試集長(zhǎng)度,通過(guò)if語(yǔ)句先判斷others還是zebra,再通過(guò)one-hot編碼判斷是否預(yù)測(cè)正確。
2,根據(jù)labels可知others': 0, 'zebra crossing': 1,以此來(lái)判斷是否預(yù)測(cè)正確。
3,對(duì) filenames[0]='others\\103.png',進(jìn)行切片處理。
4,找到others的‘s'或 zebra crossing的‘g',使用find()在基礎(chǔ)上+2為正切片的起點(diǎn)(樣本編號(hào)前有'\'符號(hào),故+2才能正確取出編號(hào))。
5,如 :將filenames[i]的值賦給a,a[int(a.find('s')+2):]則表示為 'xx.png'。
6,將取出的樣本編號(hào)與路徑拼接,讀取后作圖。
7,break跳出循環(huán)。
for i in range(len(filenames)): if filenames[i][:6]=='others': if np.argmax(pred[i]) != 0: a=filenames[i] plt.figure() print('預(yù)測(cè)錯(cuò)誤的圖片:'+a[int(a.find('s')+2):]) print('錯(cuò)誤識(shí)別為"zebra crossing",正確類型是"others"') print('預(yù)測(cè)標(biāo)簽為:'+str(np.argmax(pred[i]))+',真實(shí)標(biāo)簽為:0') img = plt.imread('Zebra/test/others/'+a[int(a.find('s')+2):]) plt.imshow(img) plt.title(a[int(a.find('s')+2):]) plt.grid(False) break if filenames[i][:6]=='zebra ': if np.argmax(pred[i]) != 1: b= filenames[i] plt.figure() print('預(yù)測(cè)錯(cuò)誤的圖片:'+b[int(b.find('g')+2):]) print('錯(cuò)誤識(shí)別為"others",正確類型是"zebra crossing"') print('預(yù)測(cè)標(biāo)簽為:'+str(np.argmax(pred[i]))+',真實(shí)標(biāo)簽為:1') img = plt.imread('Zebra/test/zebra crossing/'+b[int(b.find('g')+2):]) plt.imshow(img) plt.title(b[int(b.find('g')+2):]) plt.grid(False) break
看到這個(gè)錯(cuò)誤樣本,我猜想可能是因?yàn)榘唏R線的部分只占了圖像的一半左右,所以預(yù)測(cè)錯(cuò)誤了。
這里是我做預(yù)測(cè)判斷的思路,本可以不這么復(fù)雜的可以用test_generator.labels來(lái)獲取數(shù)據(jù)的標(biāo)簽,再做判斷。
test_generator.labels
上面只輸出了第一個(gè)錯(cuò)誤的樣本,所以接下來(lái)我們要看所有錯(cuò)誤預(yù)測(cè)的樣本
sum=0 for i in range(len(filenames)): if filenames[i][:6]=='others': if np.argmax(pred[i]) != 0: a=filenames[i] print('預(yù)測(cè)錯(cuò)誤的圖片:'+a[int(a.find('s')+2):]+',錯(cuò)誤識(shí)別為"zebra crossing",正確類型是"others"') sum=sum+1 if filenames[i][:6]=='zebra ': if np.argmax(pred[i]) != 1: b= filenames[i] print('預(yù)測(cè)錯(cuò)誤的圖片:'+b[int(b.find('g')+2):]+',錯(cuò)誤識(shí)別為"others",正確類型是"zebra crossing"') sum=sum+1 print('錯(cuò)誤率:'+str(sum/100)+'%') print('正確率:'+str((10000-sum)/100)+'%')
在構(gòu)建模型時(shí)我嘗試在最后一層只用一個(gè)神經(jīng)元,用sigmoid激活函數(shù),其他參數(shù)不變,在同樣epochs=20的條件,也能很快收斂,達(dá)到很高的acc,測(cè)試集的評(píng)分也能在0.9左右,但是在最后輸出全部錯(cuò)誤樣本的時(shí)候發(fā)現(xiàn)錯(cuò)誤的樣本遠(yuǎn)超過(guò)softmax,可能其中有些參數(shù)我沒(méi)有根據(jù)sigmoid來(lái)調(diào)整,所以會(huì)有如此高的錯(cuò)誤率。
以上是“python中如何實(shí)現(xiàn)tensorflow實(shí)現(xiàn)斑馬線識(shí)別功能”這篇文章的所有內(nèi)容,感謝各位的閱讀!相信大家都有了一定的了解,希望分享的內(nèi)容對(duì)大家有所幫助,如果還想學(xué)習(xí)更多知識(shí),歡迎關(guān)注億速云行業(yè)資訊頻道!
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。