溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點(diǎn)擊 登錄注冊 即表示同意《億速云用戶服務(wù)條款》

應(yīng)用Tensorflow2.0的Eager模式是怎么快速構(gòu)建神經(jīng)網(wǎng)絡(luò)的

發(fā)布時(shí)間:2021-12-23 15:54:41 來源:億速云 閱讀:203 作者:柒染 欄目:大數(shù)據(jù)

應(yīng)用Tensorflow2.0的Eager模式是怎么快速構(gòu)建神經(jīng)網(wǎng)絡(luò)的,相信很多沒有經(jīng)驗(yàn)的人對此束手無策,為此本文總結(jié)了問題出現(xiàn)的原因和解決方法,通過這篇文章希望你能解決這個(gè)問題。

TensorFlow是開發(fā)深度學(xué)習(xí)算法的主流框架,近來隨著keras和pytorch等框架的崛起,它受到了不小挑戰(zhàn),為了應(yīng)對競爭它本身也在進(jìn)化,最近新出的2.0版本使得框架的應(yīng)用更加簡易和容易上手,本節(jié)我們就如何使用它2.0版本提出的eager模式進(jìn)行探討,在后面章節(jié)中我們將使用它來開發(fā)較為復(fù)雜的生成型對抗性網(wǎng)絡(luò)。
最新流行的深度學(xué)習(xí)框架keras一大特點(diǎn)是接口的易用性和可理解性,它在Tensorflow的基礎(chǔ)上進(jìn)行了深度封裝,它把很多技術(shù)細(xì)節(jié)隱藏起來,同時(shí)調(diào)整設(shè)計(jì)模式,使得基于keras的開發(fā)比Tensorflow要簡單得多。  但keras對應(yīng)的問題是,封裝太好雖然有利于易用性,但是不利于開發(fā)人員,特別是初學(xué)者對模型設(shè)計(jì)的深入理解,由于我們主題是學(xué)習(xí)神經(jīng)網(wǎng)絡(luò)的設(shè)計(jì)原理,由于keras對模型設(shè)計(jì)模式的細(xì)節(jié)封裝過度,因此反而不利于學(xué)習(xí)者。  為了兼顧易用性和對設(shè)計(jì)細(xì)節(jié)的把握性,我選擇TF2.0帶來的Eager模式,這樣就能魚和熊掌兼得。
我們首先看看Eager模式和傳統(tǒng)模式有何區(qū)別。  傳統(tǒng)模式一大特點(diǎn)是代碼首先要?jiǎng)?chuàng)建一個(gè)會話對象,深度學(xué)習(xí)網(wǎng)絡(luò)模型實(shí)際上是由多種運(yùn)算節(jié)點(diǎn)構(gòu)成的一張運(yùn)算圖,模型運(yùn)行時(shí)需要依賴會話對象對運(yùn)算圖的驅(qū)動(dòng)和管理,我們先看看傳統(tǒng)模式的基本開發(fā)流程:

import tensorflow as tf        a = tf.constant(3.0)    b = tf.placeholder(dtype = tf.float32)    c = tf.add(a,b)    sess = tf.Session() #創(chuàng)建會話對象    init = tf.global_variables_initializer()    sess.run(init) #初始化會話對象    feed = {        b: 2.0    } #對變量b賦值    c_res = sess.run(c, feed) #通過會話驅(qū)動(dòng)計(jì)算圖獲取計(jì)算結(jié)果    print(c_res)
從上面代碼看你會感覺有一種別扭,placeholder用來開辟一塊內(nèi)存,然后通過feed再把數(shù)值賦值到被開辟的內(nèi)存中,然后再使用run驅(qū)動(dòng)整個(gè)計(jì)算流程的運(yùn)轉(zhuǎn),這種設(shè)計(jì)模式與傳統(tǒng)編程模式的區(qū)別在于饒了一個(gè)彎,對很多TF的初學(xué)者而言,一開始要花不少精力去適應(yīng)這種模式。
我們再看看eager模式下上面代碼的設(shè)計(jì)過程,首先要注意一點(diǎn)是,要開啟eager模式,需要在最開始處先執(zhí)行如下代碼:

import tensorflow as tf    import tensorflow.contrib.eager as tfe    tf.enable_eager_execution()
代碼執(zhí)行后TF就進(jìn)入eager模式,接下來我們看看如何實(shí)現(xiàn)前面的運(yùn)算步驟:

  
def  add(num1, num2):        a = tf.convert_to_tensor(num1) #將數(shù)值轉(zhuǎn)換為TF張量,這有利于加快運(yùn)算速度        b = tf.convert_to_tensor(num2)        c = a + b        return c.numpy() #將張量轉(zhuǎn)換為數(shù)值        add_res = add(3.0, 4.0)        print(add_res)
代碼運(yùn)行后輸出結(jié)果7.0,可以看到eager模式的特點(diǎn)是省掉了傳統(tǒng)模式繞彎的特點(diǎn),它可以像傳統(tǒng)編程模式那樣從上到下的方式執(zhí)行所有運(yùn)算步驟,不需要特別去創(chuàng)建一個(gè)會話對象,然后再通過會話對象驅(qū)動(dòng)所有運(yùn)算步驟的執(zhí)行,這種設(shè)計(jì)模式就更加簡單易懂
我們看看如何使用eager模式開發(fā)一個(gè)簡單的神經(jīng)網(wǎng)絡(luò)。  類似”Hello World!”,在神經(jīng)網(wǎng)絡(luò)編程中常用與入門的練手項(xiàng)目叫鳶尾花識別,它的花瓣特征明顯,不同品種對應(yīng)花瓣的寬度和長度不同,因此可以通過通過神經(jīng)網(wǎng)絡(luò)讀取花瓣信息后識別出其對應(yīng)的品種,首先我們先加載相應(yīng)訓(xùn)練數(shù)據(jù):

from sklearn import datasets, preprocessing, model_selection    data = datasets.load_iris() #加載數(shù)據(jù)到內(nèi)存    x = preprocessing.MinMaxScaler(feature_range = (-1, 1)).fit_transform(data['data']) #將數(shù)據(jù)數(shù)值預(yù)處理到(-1,1)之間方便網(wǎng)絡(luò)識別        #把不同分類的品種用向量表示,例如有三個(gè)不同品種,那么分別用(1,0,0),(0,1,0),(0,0,1)表示        y = preprocessing.OneHotEncoder(sparse = False).fit_transform(data['target'].reshape(-1, 1))    x_train, x_test, y_train, y_test = model_selection.train_test_split(x, y, test_size = 0.25, stratify = y) #將數(shù)據(jù)分成訓(xùn)練集合測試集    print(len(x_train))
代碼運(yùn)行后可以看到擁有訓(xùn)練的數(shù)據(jù)有112條。  接下來我們創(chuàng)建一個(gè)簡單的三層網(wǎng)絡(luò):
 class IrisClassifyModel(object):

        def  __init__(self, hidden_unit, output_unit):

            #這里只構(gòu)建兩層網(wǎng)絡(luò),第一層是輸入數(shù)據(jù)
   
            self.hidden_layer = tf.keras.layers.Dense(units = hidden_unit, activation = tf.nn.tanh, use_bias = True, name="hidden_layer")

            self.output_layer = tf.keras.layers.Dense(units = output_unit, activation = None, use_bias = True, name="output_layer")

        def  __call__(self, inputs):

            return self.output_layer(self.hidden_layer(inputs))
我們用如下代碼檢測一下網(wǎng)絡(luò)構(gòu)建的正確性:

  
#構(gòu)造輸入數(shù)據(jù)檢驗(yàn)網(wǎng)絡(luò)是否正常運(yùn)行    model = IrisClassifyModel(10, 3)    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))        for x, y in tfe.Iterator(train_dataset.batch(32)):        output = model(x)        print(output.numpy())        break  
代碼如果正確運(yùn)行并輸出相應(yīng)結(jié)果,那表明網(wǎng)絡(luò)設(shè)計(jì)沒有太大問題。  接著我們用下面代碼設(shè)計(jì)損失函數(shù)和統(tǒng)計(jì)網(wǎng)絡(luò)預(yù)測的準(zhǔn)確性:

  
def  make_loss(model, inputs, labels):        return  tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits_v2(logits = model(inputs), labels = labels))    opt = tf.train.AdamOptimizer(learning_rate = 0.01)    def train(model, x, y):        opt.minimize(lambda:make_loss(model, x, y))    accuracy = tfe.metrics.Accuracy()        def  check_accuracy(model, x_batch, y_batch): #統(tǒng)計(jì)網(wǎng)絡(luò)判斷結(jié)果的準(zhǔn)確性        accuracy(tf.argmax(model(tf.constant(x_batch)), axis = 1), tf.argmax(tf.constant(y_batch), axis = 1))        return accuracy
最后我們啟動(dòng)網(wǎng)絡(luò)訓(xùn)練流程,然后將網(wǎng)絡(luò)訓(xùn)練的結(jié)果繪制出來:

import numpy as np    model = IrisClassifyModel(10, 3)    epochs = 50        acc_history = np.zeros(epochs)    for epoch in range(epochs):    for (x_batch, y_batch) in tfe.Iterator(train_dataset.shuffle(1000).batch(32)):    train(model, x_batch, y_batch)     acc = check_accuracy(model, x_batch, y_batch)          acc_history[epoch] = acc.result().numpy()    import matplotlib.pyplot as plt        plt.figure()        plt.plot(acc_history)    plt.xlabel('Epoch')    plt.ylabel('Accuracy')    plt.show()  
上面代碼運(yùn)行后結(jié)果如下:  
 應(yīng)用Tensorflow2.0的Eager模式是怎么快速構(gòu)建神經(jīng)網(wǎng)絡(luò)的
可以看到網(wǎng)絡(luò)經(jīng)過訓(xùn)練后準(zhǔn)確率達(dá)到95%以上。  本節(jié)的目的是為了介紹TF2.0的eager模式,為后面開發(fā)更復(fù)雜的網(wǎng)絡(luò)做技術(shù)準(zhǔn)備。


看完上述內(nèi)容,你們掌握應(yīng)用Tensorflow2.0的Eager模式是怎么快速構(gòu)建神經(jīng)網(wǎng)絡(luò)的的方法了嗎?如果還想學(xué)到更多技能或想了解更多相關(guān)內(nèi)容,歡迎關(guān)注億速云行業(yè)資訊頻道,感謝各位的閱讀!

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI