溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶服務(wù)條款》

TensorFlow實(shí)現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)CNN

發(fā)布時(shí)間:2020-08-22 18:51:22 來(lái)源:腳本之家 閱讀:378 作者:marsjhao 欄目:開發(fā)技術(shù)

一、卷積神經(jīng)網(wǎng)絡(luò)CNN簡(jiǎn)介

卷積神經(jīng)網(wǎng)絡(luò)(ConvolutionalNeuralNetwork,CNN)最初是為解決圖像識(shí)別等問題設(shè)計(jì)的,CNN現(xiàn)在的應(yīng)用已經(jīng)不限于圖像和視頻,也可用于時(shí)間序列信號(hào),比如音頻信號(hào)和文本數(shù)據(jù)等。CNN作為一個(gè)深度學(xué)習(xí)架構(gòu)被提出的最初訴求是降低對(duì)圖像數(shù)據(jù)預(yù)處理的要求,避免復(fù)雜的特征工程。在卷積神經(jīng)網(wǎng)絡(luò)中,第一個(gè)卷積層會(huì)直接接受圖像像素級(jí)的輸入,每一層卷積(濾波器)都會(huì)提取數(shù)據(jù)中最有效的特征,這種方法可以提取到圖像中最基礎(chǔ)的特征,而后再進(jìn)行組合和抽象形成更高階的特征,因此CNN在理論上具有對(duì)圖像縮放、平移和旋轉(zhuǎn)的不變性。

卷積神經(jīng)網(wǎng)絡(luò)CNN的要點(diǎn)就是局部連接(LocalConnection)、權(quán)值共享(WeightsSharing)和池化層(Pooling)中的降采樣(Down-Sampling)。其中,局部連接和權(quán)值共享降低了參數(shù)量,使訓(xùn)練復(fù)雜度大大下降并減輕了過擬合。同時(shí)權(quán)值共享還賦予了卷積網(wǎng)絡(luò)對(duì)平移的容忍性,池化層降采樣則進(jìn)一步降低了輸出參數(shù)量并賦予模型對(duì)輕度形變的容忍性,提高了模型的泛化能力??梢园丫矸e層卷積操作理解為用少量參數(shù)在圖像的多個(gè)位置上提取相似特征的過程。

更多請(qǐng)參見:深度學(xué)習(xí)之卷積神經(jīng)網(wǎng)絡(luò)CNN

二、TensorFlow代碼實(shí)現(xiàn)

#!/usr/bin/env python2 
# -*- coding: utf-8 -*- 
""" 
Created on Thu Mar 9 22:01:46 2017 
 
@author: marsjhao 
""" 
 
import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data 
 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 
sess = tf.InteractiveSession() 
 
def weight_variable(shape): 
 initial = tf.truncated_normal(shape, stddev=0.1) #標(biāo)準(zhǔn)差為0.1的正態(tài)分布 
 return tf.Variable(initial) 
 
def bias_variable(shape): 
 initial = tf.constant(0.1, shape=shape) #偏差初始化為0.1 
 return tf.Variable(initial) 
 
def conv2d(x, W): 
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
def max_pool_2x2(x): 
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], 
       strides=[1, 2, 2, 1], padding='SAME') 
 
x = tf.placeholder(tf.float32, [None, 784]) 
y_ = tf.placeholder(tf.float32, [None, 10]) 
# -1代表先不考慮輸入的圖片例子多少這個(gè)維度,1是channel的數(shù)量 
x_image = tf.reshape(x, [-1, 28, 28, 1]) 
keep_prob = tf.placeholder(tf.float32) 
 
# 構(gòu)建卷積層1 
W_conv1 = weight_variable([5, 5, 1, 32]) # 卷積核5*5,1個(gè)channel,32個(gè)卷積核,形成32個(gè)featuremap 
b_conv1 = bias_variable([32]) # 32個(gè)featuremap的偏置 
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) # 用relu非線性處理 
h_pool1 = max_pool_2x2(h_conv1) # pooling池化 
 
# 構(gòu)建卷積層2 
W_conv2 = weight_variable([5, 5, 32, 64]) # 注意這里channel值是32 
b_conv2 = bias_variable([64]) 
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 
h_pool2 = max_pool_2x2(h_conv2) 
 
# 構(gòu)建全連接層1 
W_fc1 = weight_variable([7*7*64, 1024]) 
b_fc1 = bias_variable([1024]) 
h_pool3 = tf.reshape(h_pool2, [-1, 7*7*64]) 
h_fc1 = tf.nn.relu(tf.matmul(h_pool3, W_fc1) + b_fc1) 
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) 
 
# 構(gòu)建全連接層2 
W_fc2 = weight_variable([1024, 10]) 
b_fc2 = bias_variable([10]) 
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) 
 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y_conv), 
            reduction_indices=[1])) 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 
correct_prediction = tf.equal(tf.arg_max(y_conv, 1), tf.arg_max(y_, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
 
tf.global_variables_initializer().run() 
 
for i in range(20001): 
 batch = mnist.train.next_batch(50) 
 if i % 100 == 0: 
  train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_:batch[1], 
             keep_prob: 1.0}) 
  print("step %d, training accuracy %g" %(i, train_accuracy)) 
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob:0.5}) 
print("test accuracy %g" %accuracy.eval(feed_dict={x: mnist.test.images, 
         y_: mnist.test.labels, keep_prob: 1.0})) 

三、代碼解讀

該代碼是用TensorFlow實(shí)現(xiàn)一個(gè)簡(jiǎn)單的卷積神經(jīng)網(wǎng)絡(luò),在數(shù)據(jù)集MNIST上,預(yù)期可以實(shí)現(xiàn)99.2%左右的準(zhǔn)確率。結(jié)構(gòu)上使用兩個(gè)卷積層和一個(gè)全連接層。

首先載入MNIST數(shù)據(jù)集,采用獨(dú)熱編碼,并創(chuàng)建tf.InteractiveSession。然后為后續(xù)即將多次使用的部分代碼創(chuàng)建函數(shù),包括權(quán)重初始化weight_variable、偏置初始化bias_variable、卷積層conv2d、最大池化max_pool_2x2。其中權(quán)重初始化的時(shí)候要進(jìn)行含有噪聲的非對(duì)稱初始化,打破完全對(duì)稱。又由于我們要使用ReLU單元,也需要給偏置bias增加一些小的正值(0.1)用來(lái)避免死亡節(jié)點(diǎn)(dead neurons)。

構(gòu)建卷積神經(jīng)網(wǎng)絡(luò)之前,先要定義輸入的placeholder,特征x和真實(shí)標(biāo)簽y_,將1*784格式的特征x轉(zhuǎn)換reshape為28*28的圖片格式,又由于只有一個(gè)通道且不確定輸入樣本的數(shù)量,故最終尺寸為[-1, 28, 28, 1]。

接下來(lái)定義第一個(gè)卷積層,首先初始化weights和bias,然后使用conv2d進(jìn)行卷積操作并加上偏置,隨后使用ReLU激活函數(shù)進(jìn)行非線性處理,最后使用最大池化函數(shù)對(duì)卷積的輸出結(jié)果進(jìn)行池化操作。

相同的步驟定義第二個(gè)卷積層,不同的地方是卷積核的數(shù)量為64,也就是說(shuō)這一層的卷積會(huì)提取64種特征。經(jīng)過兩層不變尺寸的卷積和兩次尺寸減半的池化,第二個(gè)卷積層后的輸出尺寸為7*7*64。將其reshape為長(zhǎng)度為7*7*64的1-D向量。經(jīng)過ReLU后,為了減輕過擬合,使用一個(gè)Dropout層,在訓(xùn)練時(shí)隨機(jī)丟棄部分節(jié)點(diǎn)的數(shù)據(jù)減輕過擬合,在預(yù)測(cè)的時(shí)候保留全部數(shù)據(jù)來(lái)追求最好的測(cè)試性能。

最后加一個(gè)Softmax層,得到最后的預(yù)測(cè)概率。隨后的定義損失函數(shù)、優(yōu)化器、評(píng)測(cè)準(zhǔn)確率不再詳細(xì)贅述。

訓(xùn)練過程首先進(jìn)行初始化全部參數(shù),訓(xùn)練時(shí)keep_prob比率設(shè)置為0.5,評(píng)測(cè)時(shí)設(shè)置為1。訓(xùn)練完成后,在最終的測(cè)試集上進(jìn)行全面的測(cè)試,得到整體的分類準(zhǔn)確率。

經(jīng)過實(shí)驗(yàn),這個(gè)CNN的模型可以得到99.2%的準(zhǔn)確率,相比于MLP又有了較大幅度的提高。

四、其他解讀補(bǔ)充

1. tf.nn.conv2d(x,W, strides=[1, 1, 1, 1], padding='SAME')

tf.nn.conv2d是TensorFlow的2維卷積函數(shù),x和W都是4-D的tensors。x是輸入input shape=[batch,in_height, in_width, in_channels],W是卷積的參數(shù)filter / kernel shape=[filter_height, filter_width, in_channels,out_channels]。strides參數(shù)是長(zhǎng)度為4的1-D參數(shù),代表了卷積核(滑動(dòng)窗口)移動(dòng)的步長(zhǎng),其中對(duì)于圖片strides[0]和strides[3]必須是1,都是1表示不遺漏地劃過圖片的每一個(gè)點(diǎn)。padding參數(shù)中SAME代表給邊界加上Padding讓卷積的輸出和輸入保持相同的尺寸。

2. tf.nn.max_pool(x,ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

tf.nn.max_pool是TensorFlow中的最大池化函數(shù),x是4-D的輸入tensor shape=[batch, height, width, channels],ksize參數(shù)表示池化窗口的大小,取一個(gè)4維向量,一般是[1, height, width, 1],因?yàn)槲覀儾幌朐赽atch和channels上做池化,所以這兩個(gè)維度設(shè)為了1,strides與tf.nn.conv2d相同,strides=[1, 2, 2, 1]可以縮小圖片尺寸。padding參數(shù)也參見tf.nn.conv2d。

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI