溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

tensorflow中固定部分參數(shù)訓練和只訓練部分參數(shù)的示例分析

發(fā)布時間:2021-07-23 14:44:35 來源:億速云 閱讀:188 作者:小新 欄目:開發(fā)技術

這篇文章主要為大家展示了“tensorflow中固定部分參數(shù)訓練和只訓練部分參數(shù)的示例分析”,內容簡而易懂,條理清晰,希望能夠幫助大家解決疑惑,下面讓小編帶領大家一起研究并學習一下“tensorflow中固定部分參數(shù)訓練和只訓練部分參數(shù)的示例分析”這篇文章吧。

在使用tensorflow來訓練一個模型的時候,有時候需要依靠驗證集來判斷模型是否已經過擬合,是否需要停止訓練。

1.首先想到的是用tf.placeholder()載入不同的數(shù)據來進行計算,比如

def inference(input_):
  """
  this is where you put your graph.
  the following is just an example.
  """
  
  conv1 = tf.layers.conv2d(input_)
 
  conv2 = tf.layers.conv2d(conv1)
 
  return conv2
 
 
input_ = tf.placeholder()
output = inference(input_)
...
calculate_loss_op = ...
train_op = ...
...
 
with tf.Session() as sess:
  sess.run([loss, train_op], feed_dict={input_: train_data})
 
  if validation == True:
    sess.run([loss], feed_dict={input_: validate_date})

這種方式很簡單,也很直接了然。

2.但是,如果處理的數(shù)據量很大的時候,使用 tf.placeholder() 來載入數(shù)據會嚴重地拖慢訓練的進度,因此,常用tfrecords文件來讀取數(shù)據。

此時,很容易想到,將不同的值傳入inference()函數(shù)中進行計算。

train_batch, label_batch = decode_train()
val_train_batch, val_label_batch = decode_validation()
 
 
train_result = inference(train_batch)
...
loss = ..
train_op = ...
...
 
if validation == True:
  val_result = inference(val_train_batch)
  val_loss = ..
  
 
with tf.Session() as sess:
  sess.run([loss, train_op])
 
  if validation == True:
    sess.run([val_result, val_loss])

這種方式看似能夠直接調用inference()來對驗證數(shù)據進行前向傳播計算,但是,實則會在原圖上添加上許多新的結點,這些結點的參數(shù)都是需要重新初始化的,也是就是說,驗證的時候并不是使用訓練的權重。

3.用一個tf.placeholder來控制是否訓練、驗證。

def inference(input_):
  ...
  ...
  ...
  
  return inference_result
 
 
train_batch, label_batch = decode_train()
val_batch, val_label = decode_validation()
 
is_training = tf.placeholder(tf.bool, shape=())
 
x = tf.cond(is_training, lambda: train_batch, lambda: val_batch)
y = tf.cond(is_training, lambda: train_label, lambda: val_label)
 
logits = inference(x)
loss = cal_loss(logits, y)
train_op = optimize(loss)
 
with tf.Session() as sess:
  
  loss, _ = sess.run([loss, train_op], feed_dict={is_training: True})
  
  if validation == True:
    loss = sess.run(loss, feed_dict={is_training: False})

使用這種方式就可以在一個大圖里創(chuàng)建一個分支條件,從而通過控制placeholder來控制是否進行驗證。

以上是“tensorflow中固定部分參數(shù)訓練和只訓練部分參數(shù)的示例分析”這篇文章的所有內容,感謝各位的閱讀!相信大家都有了一定的了解,希望分享的內容對大家有所幫助,如果還想學習更多知識,歡迎關注億速云行業(yè)資訊頻道!

向AI問一下細節(jié)

免責聲明:本站發(fā)布的內容(圖片、視頻和文字)以原創(chuàng)、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯(lián)系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI