您好,登錄后才能下訂單哦!
本文小編為大家詳細介紹“RNN中的Dropout怎么實現(xiàn)”,內(nèi)容詳細,步驟清晰,細節(jié)處理妥當(dāng),希望這篇“RNN中的Dropout怎么實現(xiàn)”文章能幫助大家解決疑惑,下面跟著小編的思路慢慢深入,一起來學(xué)習(xí)新知識吧。
我們可以簡單的在RNN之前或之后加一個DropOut層,但是如果我們想在RNN層中間加上DropOut的話,就得用DropoutWrapper了。下面代碼在每個RNN層的輸入都應(yīng)用Dropout,對每個輸入有50%的概率丟棄。
keep_prob = 0.5
cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons)
cell_drop = tf.contrib.rnn.DropoutWrapper(cell, input_keep_prob=keep_prob)
multi_layer_cell = tf.contrib.rnn.MultiRNNCell([cell_drop] * n_layers)
rnn_outputs, states = tf.nn.dynamic_rnn(multi_layer_cell, X, dtype=tf.float32)
當(dāng)然,我們也可以通過設(shè)置output_keep_prob來對輸出進行dropout。
其實,細心的童鞋可能已經(jīng)發(fā)現(xiàn),上面的代碼是有問題的,因為我們在前面CNN中應(yīng)用Dropout的時候是有一個is_training的placeholder來區(qū)分是在training還是testing應(yīng)用的。但是上面代碼并沒有。確實,上面代碼的最大問題就是在testing的時候,也會應(yīng)用Dropout,當(dāng)然,這并不是我們想要的。不幸的是,DropoutWrapper并不支持is_training的placeholder,因此,我們要么自己重寫一個DropoutWapper類,要么我們有兩個計算圖,一個是用來training,另一個用來testing。這里我們看下兩個計算圖是怎么實現(xiàn)的,如下:
import sys
is_training = (sys.argv[-1] == "train")
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
y = tf.placeholder(tf.float32, [None, n_steps, n_outputs])
cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons)
if is_training:
cell = tf.contrib.rnn.DropoutWrapper(cell, input_keep_prob=keep_prob)
multi_layer_cell = tf.contrib.rnn.MultiRNNCell([cell] * n_layers)
rnn_outputs, states = tf.nn.dynamic_rnn(multi_layer_cell, X, dtype=tf.float32)
[...] # build the rest of the graph
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
if is_training:
init.run()
for iteration in range(n_iterations):
[...] # train the model
save_path = saver.save(sess, "/tmp/my_model.ckpt")
else:
saver.restore(sess, "/tmp/my_model.ckpt")
[...] # use the model
讀到這里,這篇“RNN中的Dropout怎么實現(xiàn)”文章已經(jīng)介紹完畢,想要掌握這篇文章的知識點還需要大家自己動手實踐使用過才能領(lǐng)會,如果想了解更多相關(guān)內(nèi)容的文章,歡迎關(guān)注億速云行業(yè)資訊頻道。
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進行舉報,并提供相關(guān)證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權(quán)內(nèi)容。