您好,登錄后才能下訂單哦!
這篇文章主要為大家展示了“tensorflow中固定部分參數(shù)訓練和只訓練部分參數(shù)的示例分析”,內容簡而易懂,條理清晰,希望能夠幫助大家解決疑惑,下面讓小編帶領大家一起研究并學習一下“tensorflow中固定部分參數(shù)訓練和只訓練部分參數(shù)的示例分析”這篇文章吧。
在使用tensorflow來訓練一個模型的時候,有時候需要依靠驗證集來判斷模型是否已經過擬合,是否需要停止訓練。
1.首先想到的是用tf.placeholder()載入不同的數(shù)據來進行計算,比如
def inference(input_): """ this is where you put your graph. the following is just an example. """ conv1 = tf.layers.conv2d(input_) conv2 = tf.layers.conv2d(conv1) return conv2 input_ = tf.placeholder() output = inference(input_) ... calculate_loss_op = ... train_op = ... ... with tf.Session() as sess: sess.run([loss, train_op], feed_dict={input_: train_data}) if validation == True: sess.run([loss], feed_dict={input_: validate_date})
這種方式很簡單,也很直接了然。
2.但是,如果處理的數(shù)據量很大的時候,使用 tf.placeholder() 來載入數(shù)據會嚴重地拖慢訓練的進度,因此,常用tfrecords文件來讀取數(shù)據。
此時,很容易想到,將不同的值傳入inference()函數(shù)中進行計算。
train_batch, label_batch = decode_train() val_train_batch, val_label_batch = decode_validation() train_result = inference(train_batch) ... loss = .. train_op = ... ... if validation == True: val_result = inference(val_train_batch) val_loss = .. with tf.Session() as sess: sess.run([loss, train_op]) if validation == True: sess.run([val_result, val_loss])
這種方式看似能夠直接調用inference()來對驗證數(shù)據進行前向傳播計算,但是,實則會在原圖上添加上許多新的結點,這些結點的參數(shù)都是需要重新初始化的,也是就是說,驗證的時候并不是使用訓練的權重。
3.用一個tf.placeholder來控制是否訓練、驗證。
def inference(input_): ... ... ... return inference_result train_batch, label_batch = decode_train() val_batch, val_label = decode_validation() is_training = tf.placeholder(tf.bool, shape=()) x = tf.cond(is_training, lambda: train_batch, lambda: val_batch) y = tf.cond(is_training, lambda: train_label, lambda: val_label) logits = inference(x) loss = cal_loss(logits, y) train_op = optimize(loss) with tf.Session() as sess: loss, _ = sess.run([loss, train_op], feed_dict={is_training: True}) if validation == True: loss = sess.run(loss, feed_dict={is_training: False})
使用這種方式就可以在一個大圖里創(chuàng)建一個分支條件,從而通過控制placeholder來控制是否進行驗證。
以上是“tensorflow中固定部分參數(shù)訓練和只訓練部分參數(shù)的示例分析”這篇文章的所有內容,感謝各位的閱讀!相信大家都有了一定的了解,希望分享的內容對大家有所幫助,如果還想學習更多知識,歡迎關注億速云行業(yè)資訊頻道!
免責聲明:本站發(fā)布的內容(圖片、視頻和文字)以原創(chuàng)、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯(lián)系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。