您好,登錄后才能下訂單哦!
這篇文章主要為大家展示了tensorflow中dataset.shuffle和dataset.batch dataset.repeat應該注意什么,內容簡而易懂,希望大家可以學習一下,學習完之后肯定會有收獲的,下面讓小編帶大家一起來看看吧。
batch很好理解,就是batch size。注意在一個epoch中最后一個batch大小可能小于等于batch size
dataset.repeat就是俗稱epoch,但在tf中與dataset.shuffle的使用順序可能會導致個epoch的混合
dataset.shuffle就是說維持一個buffer size 大小的 shuffle buffer,圖中所需的每個樣本從shuffle buffer中獲取,取得一個樣本后,就從源數據集中加入一個樣本到shuffle buffer中。
import os os.environ['CUDA_VISIBLE_DEVICES'] = "" import numpy as np import tensorflow as tf np.random.seed(0) x = np.random.sample((11,2)) # make a dataset from a numpy array print(x) print() dataset = tf.data.Dataset.from_tensor_slices(x) dataset = dataset.shuffle(3) dataset = dataset.batch(4) dataset = dataset.repeat(2) # create the iterator iter = dataset.make_one_shot_iterator() el = iter.get_next() with tf.Session() as sess: print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el))
#源數據集 [[ 0.5488135 0.71518937] [ 0.60276338 0.54488318] [ 0.4236548 0.64589411] [ 0.43758721 0.891773 ] [ 0.96366276 0.38344152] [ 0.79172504 0.52889492] [ 0.56804456 0.92559664] [ 0.07103606 0.0871293 ] [ 0.0202184 0.83261985] [ 0.77815675 0.87001215] [ 0.97861834 0.79915856]] # 通過shuffle batch后取得的樣本 [[ 0.4236548 0.64589411] [ 0.60276338 0.54488318] [ 0.43758721 0.891773 ] [ 0.5488135 0.71518937]] [[ 0.96366276 0.38344152] [ 0.56804456 0.92559664] [ 0.0202184 0.83261985] [ 0.79172504 0.52889492]] [[ 0.07103606 0.0871293 ] [ 0.97861834 0.79915856] [ 0.77815675 0.87001215]] #最后一個batch樣本個數為3 [[ 0.60276338 0.54488318] [ 0.5488135 0.71518937] [ 0.43758721 0.891773 ] [ 0.79172504 0.52889492]] [[ 0.4236548 0.64589411] [ 0.56804456 0.92559664] [ 0.0202184 0.83261985] [ 0.07103606 0.0871293 ]] [[ 0.77815675 0.87001215] [ 0.96366276 0.38344152] [ 0.97861834 0.79915856]] #最后一個batch樣本個數為3
1、按照shuffle中設置的buffer size,首先從源數據集取得三個樣本:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.4236548 0.64589411]
2、從buffer中取一個樣本到batch中得:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
batch:
[ 0.4236548 0.64589411]
3、shuffle buffer不足三個樣本,從源數據集提取一個樣本:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.43758721 0.891773 ]
4、從buffer中取一個樣本到batch中得:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.43758721 0.891773 ]
batch:
[ 0.4236548 0.64589411]
[ 0.60276338 0.54488318]
5、如此反復。這就意味中如果shuffle 的buffer size=1,數據集不打亂。如果shuffle 的buffer size=數據集樣本數量,隨機打亂整個數據集
import os os.environ['CUDA_VISIBLE_DEVICES'] = "" import numpy as np import tensorflow as tf np.random.seed(0) x = np.random.sample((11,2)) # make a dataset from a numpy array print(x) print() dataset = tf.data.Dataset.from_tensor_slices(x) dataset = dataset.shuffle(1) dataset = dataset.batch(4) dataset = dataset.repeat(2) # create the iterator iter = dataset.make_one_shot_iterator() el = iter.get_next() with tf.Session() as sess: print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) [[ 0.5488135 0.71518937] [ 0.60276338 0.54488318] [ 0.4236548 0.64589411] [ 0.43758721 0.891773 ] [ 0.96366276 0.38344152] [ 0.79172504 0.52889492] [ 0.56804456 0.92559664] [ 0.07103606 0.0871293 ] [ 0.0202184 0.83261985] [ 0.77815675 0.87001215] [ 0.97861834 0.79915856]] [[ 0.5488135 0.71518937] [ 0.60276338 0.54488318] [ 0.4236548 0.64589411] [ 0.43758721 0.891773 ]] [[ 0.96366276 0.38344152] [ 0.79172504 0.52889492] [ 0.56804456 0.92559664] [ 0.07103606 0.0871293 ]] [[ 0.0202184 0.83261985] [ 0.77815675 0.87001215] [ 0.97861834 0.79915856]] [[ 0.5488135 0.71518937] [ 0.60276338 0.54488318] [ 0.4236548 0.64589411] [ 0.43758721 0.891773 ]] [[ 0.96366276 0.38344152] [ 0.79172504 0.52889492] [ 0.56804456 0.92559664] [ 0.07103606 0.0871293 ]] [[ 0.0202184 0.83261985] [ 0.77815675 0.87001215] [ 0.97861834 0.79915856]]
注意如果repeat在shuffle之前使用:
官方說repeat在shuffle之前使用能提高性能,但模糊了數據樣本的epoch關系
import os os.environ['CUDA_VISIBLE_DEVICES'] = "" import numpy as np import tensorflow as tf np.random.seed(0) x = np.random.sample((11,2)) # make a dataset from a numpy array print(x) print() dataset = tf.data.Dataset.from_tensor_slices(x) dataset = dataset.repeat(2) dataset = dataset.shuffle(11) dataset = dataset.batch(4) # create the iterator iter = dataset.make_one_shot_iterator() el = iter.get_next() with tf.Session() as sess: print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) [[ 0.5488135 0.71518937] [ 0.60276338 0.54488318] [ 0.4236548 0.64589411] [ 0.43758721 0.891773 ] [ 0.96366276 0.38344152] [ 0.79172504 0.52889492] [ 0.56804456 0.92559664] [ 0.07103606 0.0871293 ] [ 0.0202184 0.83261985] [ 0.77815675 0.87001215] [ 0.97861834 0.79915856]] [[ 0.56804456 0.92559664] [ 0.5488135 0.71518937] [ 0.60276338 0.54488318] [ 0.07103606 0.0871293 ]] [[ 0.96366276 0.38344152] [ 0.43758721 0.891773 ] [ 0.43758721 0.891773 ] [ 0.77815675 0.87001215]] [[ 0.79172504 0.52889492] #出現相同樣本出現在同一個batch中 [ 0.79172504 0.52889492] [ 0.60276338 0.54488318] [ 0.4236548 0.64589411]] [[ 0.07103606 0.0871293 ] [ 0.4236548 0.64589411] [ 0.96366276 0.38344152] [ 0.5488135 0.71518937]] [[ 0.97861834 0.79915856] [ 0.0202184 0.83261985] [ 0.77815675 0.87001215] [ 0.56804456 0.92559664]] [[ 0.0202184 0.83261985] [ 0.97861834 0.79915856]] #可以看到最后個batch為2,而前面都是4
使用案例:
def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False): print('Parsing', filenames) def decode_libsvm(line): #columns = tf.decode_csv(value, record_defaults=CSV_COLUMN_DEFAULTS) #features = dict(zip(CSV_COLUMNS, columns)) #labels = features.pop(LABEL_COLUMN) columns = tf.string_split([line], ' ') labels = tf.string_to_number(columns.values[0], out_type=tf.float32) splits = tf.string_split(columns.values[1:], ':') id_vals = tf.reshape(splits.values,splits.dense_shape) feat_ids, feat_vals = tf.split(id_vals,num_or_size_splits=2,axis=1) feat_ids = tf.string_to_number(feat_ids, out_type=tf.int32) feat_vals = tf.string_to_number(feat_vals, out_type=tf.float32) #feat_ids = tf.reshape(feat_ids,shape=[-1,FLAGS.field_size]) #for i in range(splits.dense_shape.eval()[0]): # feat_ids.append(tf.string_to_number(splits.values[2*i], out_type=tf.int32)) # feat_vals.append(tf.string_to_number(splits.values[2*i+1])) #return tf.reshape(feat_ids,shape=[-1,field_size]), tf.reshape(feat_vals,shape=[-1,field_size]), labels return {"feat_ids": feat_ids, "feat_vals": feat_vals}, labels # Extract lines from input files using the Dataset API, can pass one filename or filename list dataset = tf.data.TextLineDataset(filenames).map(decode_libsvm, num_parallel_calls=10).prefetch(500000) # multi-thread pre-process then prefetch # Randomizes input using a window of 256 elements (read into memory) if perform_shuffle: dataset = dataset.shuffle(buffer_size=256) # epochs from blending together. dataset = dataset.repeat(num_epochs) dataset = dataset.batch(batch_size) # Batch size to use #return dataset.make_one_shot_iterator() iterator = dataset.make_one_shot_iterator() batch_features, batch_labels = iterator.get_next() #return tf.reshape(batch_ids,shape=[-1,field_size]), tf.reshape(batch_vals,shape=[-1,field_size]), batch_labels return batch_features, batch_labels
以上就是關于tensorflow中dataset.shuffle和dataset.batch dataset.repeat應該注意什么的內容,如果你們有學習到知識或者技能,可以把它分享出去讓更多的人看到。
免責聲明:本站發(fā)布的內容(圖片、視頻和文字)以原創(chuàng)、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。