溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶服務(wù)條款》

對(duì)tensorflow中cifar-10文檔的Read操作詳解

發(fā)布時(shí)間:2020-09-11 03:23:43 來(lái)源:腳本之家 閱讀:130 作者:luchi007 欄目:開(kāi)發(fā)技術(shù)

前言

在tensorflow的官方文檔中得卷積神經(jīng)網(wǎng)絡(luò)一章,有一個(gè)使用cifar-10圖片數(shù)據(jù)集的實(shí)驗(yàn),搭建卷積神經(jīng)網(wǎng)絡(luò)倒不難,但是那個(gè)cifar10_input文件著實(shí)讓我費(fèi)了一番心思。配合著官方文檔也算看的七七八八,但是中間還是有一些不太明白,不明白的mark一下,這次記下一些已經(jīng)明白的。

研究

cifar10_input.py文件的read操作,主要的就是下面的代碼:

if not eval_data:
  filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
         for i in xrange(1, 6)]
  num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
 else:
  filenames = [os.path.join(data_dir, 'test_batch.bin')]
  num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
...
filename_queue = tf.train.string_input_producer(filenames)

...

label_bytes = 1 # 2 for CIFAR-100
 result.height = 32
 result.width = 32
 result.depth = 3
 image_bytes = result.height * result.width * result.depth
 # Every record consists of a label followed by the image, with a
 # fixed number of bytes for each.
 record_bytes = label_bytes + image_bytes

 # Read a record, getting filenames from the filename_queue. No
 # header or footer in the CIFAR-10 format, so we leave header_bytes
 # and footer_bytes at their default of 0.
 reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
 result.key, value = reader.read(filename_queue)

 ...

 if shuffle:
  images, label_batch = tf.train.shuffle_batch(
    [image, label],
    batch_size=batch_size,
    num_threads=num_preprocess_threads,
    capacity=min_queue_examples + 3 * batch_size,
    min_after_dequeue=min_queue_examples)
 else:
  images, label_batch = tf.train.batch(
    [image, label],
    batch_size=batch_size,
    num_threads=num_preprocess_threads,
    capacity=min_queue_examples + 3 * batch_size)

開(kāi)始并不明白這段代碼是用來(lái)干什么的,越看越糊涂,因?yàn)橹笆褂胻ensorflow最多也就是使用哪個(gè)tf.placeholder()這個(gè)操作,并沒(méi)有使用tensorflow自帶的讀寫(xiě)方法來(lái)讀寫(xiě),所以上面的代碼看的很費(fèi)勁兒。不過(guò)我在官方文檔的How-To這個(gè)document中看到了這個(gè)東西:

Batching

def read_my_file_format(filename_queue):
 reader = tf.SomeReader()
 key, record_string = reader.read(filename_queue)
 example, label = tf.some_decoder(record_string)
 processed_example = some_processing(example)
 return processed_example, label

def input_pipeline(filenames, batch_size, num_epochs=None):
 filename_queue = tf.train.string_input_producer(
   filenames, num_epochs=num_epochs, shuffle=True)
 example, label = read_my_file_format(filename_queue)
 # min_after_dequeue defines how big a buffer we will randomly sample
 #  from -- bigger means better shuffling but slower start up and more
 #  memory used.
 # capacity must be larger than min_after_dequeue and the amount larger
 #  determines the maximum we will prefetch. Recommendation:
 #  min_after_dequeue + (num_threads + a small safety margin) * batch_size
 min_after_dequeue = 10000
 capacity = min_after_dequeue + 3 * batch_size
 example_batch, label_batch = tf.train.shuffle_batch(
   [example, label], batch_size=batch_size, capacity=capacity,
   min_after_dequeue=min_after_dequeue)
 return example_batch, label_batch

感覺(jué)豁然開(kāi)朗,再研究一下其官方文檔API就能大約明白期間意思。最有代表性的圖示官方文檔中也給出來(lái)了,雖然官方文檔給的解釋并不多。

對(duì)tensorflow中cifar-10文檔的Read操作詳解

API我就不一一解釋了,我們下面通過(guò)實(shí)驗(yàn)來(lái)明白。

實(shí)驗(yàn)

首先在tensorflow路徑下創(chuàng)建兩個(gè)文件,分別命名為test.txt以及test2.txt,其內(nèi)容分別是:

test.txt:

test line1
test line2
test line3
test line4
test line5
test line6

test2.txt:

test2 line1
test2 line2
test2 line3
test2 line4
test2 line5
test2 line6

然后再命令行里依次鍵入下面的命令:

import tensorflow as tf
filenames=['test.txt','test2.txt']
#創(chuàng)建如上圖所示的filename_queue
filename_queue=tf.train.string_input_producer(filenames)
#選取的是每次讀取一行的TextLineReader
reader=tf.TextLineReader()
init=tf.initialize_all_variables()
#讀取文件,也就是創(chuàng)建上圖中的Reader
key,value=reader.read(filename_queue)
#讀取batch文件,batch_size設(shè)置成1,為了方便看
bs=tf.train.batch([value],batch_size=1,num_threads=1,capacity=2)
sess=tf.Session() 
#非常關(guān)鍵,這個(gè)是連通各個(gè)queue圖的關(guān)鍵          
tf.train.start_queue_runners(sess=sess)
#計(jì)算有reader的輸出
b=reader.num_records_produced()

然后我們執(zhí)行:

>>> sess.run(bs)
array(['test line1'], dtype=object)
>>> sess.run(b)
4
>>> sess.run(bs)
array(['test line2'], dtype=object)
>>> sess.run(b)
5
>>> sess.run(bs)
array(['test line3'], dtype=object)
>>> sess.run(bs)
array(['test line4'], dtype=object)
>>> sess.run(bs)
array(['test line5'], dtype=object)
>>> sess.run(bs)
array(['test line6'], dtype=object)
>>> sess.run(bs)
array(['test2 line1'], dtype=object)
>>> sess.run(bs)
array(['test2 line2'], dtype=object)
>>> sess.run(bs)
array(['test2 line3'], dtype=object)
>>> sess.run(bs)
array(['test2 line4'], dtype=object)
>>> sess.run(bs)
array(['test2 line5'], dtype=object)
>>> sess.run(bs)
array(['test2 line6'], dtype=object)
>>> sess.run(bs)
array(['test2 line1'], dtype=object)
>>> sess.run(bs)
array(['test2 line2'], dtype=object)
>>> sess.run(bs)
array(['test2 line3'], dtype=object)
>>> sess.run(bs)
array(['test2 line4'], dtype=object)
>>> sess.run(bs)
array(['test2 line5'], dtype=object)
>>> sess.run(bs)
array(['test2 line6'], dtype=object)
>>> sess.run(bs)
array(['test line1'], dtype=object)

我們發(fā)現(xiàn),當(dāng)batch_size設(shè)置成為1的時(shí)候,bs的輸出是按照文件行數(shù)進(jìn)行逐步打印的,原因是,我們選擇的是單個(gè)Reader進(jìn)行操作的,這個(gè)Reader先將test.txt文件讀取,然后逐行讀取并將讀取的文本送到example queue(如上圖)中,因?yàn)檫@里batch設(shè)置的是1,而且用到的是tf.train.batch()方法,中間沒(méi)有shuffle,所以自然而然是按照順序輸出的,之后Reader再讀取test2.txt。但是這里有一個(gè)疑惑,為什么reader.num_records_produced的第一個(gè)輸出不是從1開(kāi)始的,這點(diǎn)不太清楚。 另外,打印出filename_queue的size:

>>> sess.run(filename_queue.size())
32

發(fā)現(xiàn)filename_queue的size有32個(gè)之多!這點(diǎn)也不明白。。。

我們可以更改實(shí)驗(yàn)條件,將batch_size設(shè)置成2,會(huì)發(fā)現(xiàn)也是順序的輸出,而且每次輸出為2行文本(和batch_size一樣)

我們繼續(xù)更改實(shí)驗(yàn)條件,將tf.train.batch方法換成tf.train.shuffle_batch方法,文本數(shù)據(jù)不變:

import tensorflow as tf
filenames=['test.txt','test2.txt']
filename_queue=tf.train.string_input_producer(filenames)
reader=tf.TextLineReader()
init=tf.initialize_all_variables()
key,value=reader.read(filename_queue)
bs=tf.train.shuffle_batch([value],batch_size=1,num_threads=1,capacity=4,min_after_dequeue=2)
sess=tf.Session()           
tf.train.start_queue_runners(sess=sess)
b=reader.num_records_produced()

繼續(xù)剛才的執(zhí)行:

>>> sess.run(bs)
array(['test2 line2'], dtype=object)
>>> sess.run(bs)
array(['test2 line5'], dtype=object)
>>> sess.run(bs)
array(['test2 line6'], dtype=object)
>>> sess.run(bs)
array(['test2 line4'], dtype=object)
>>> sess.run(bs)
array(['test2 line3'], dtype=object)
>>> sess.run(bs)
array(['test line1'], dtype=object)
>>> sess.run(bs)
array(['test line2'], dtype=object)
>>> sess.run(bs)
array(['test2 line1'], dtype=object)
>>> sess.run(bs)
array(['test line4'], dtype=object)
>>> sess.run(bs)
array(['test line5'], dtype=object)
>>> sess.run(bs)
array(['test2 line1'], dtype=object)
>>> sess.run(bs)
array(['test line3'], dtype=object)

我們發(fā)現(xiàn)的是,使用了shuffle操作之后,明顯的bs的輸出變得不一樣了,變得沒(méi)有規(guī)則,然后我們看filename_queue的size:

>>> sess.run(filename_queue.size())
32

發(fā)現(xiàn)也是32,由此估計(jì)是tensorflow會(huì)根據(jù)文件大小默認(rèn)filename_queue的長(zhǎng)度。 注意這里面的capacity=4,min_after_dequeue=2這些個(gè)命令,capacity指的是example queue的最大長(zhǎng)度, 而min_after_dequeue是指在出隊(duì)列之后,example queue最少要保留的元素個(gè)數(shù),為什么需要這個(gè),其實(shí)是為了混合的更顯著。也正是有這兩個(gè)元素,讓shuffle變得可能。

到這里基本上大概的思路能明白,但是上面的實(shí)驗(yàn)都是對(duì)于單個(gè)的Reader,和上一節(jié)的圖不太一致,根據(jù)官網(wǎng)教程,為了使用多個(gè)Reader,我們可以這樣:

import tensorflow as tf
filenames=['test.txt','test2.txt']
filename_queue=tf.train.string_input_producer(filenames)
reader=tf.TextLineReader()
init=tf.initialize_all_variables()
key_list,value_list=[reader.read(filename_queue) for _ in range(2)]
bs2=tf.train.shuffle_batch_join([value_list],batch_size=1,capacity=4,min_after_dequeue=2)
sess=tf.Session()       
sess.run(init)    
tf.train.start_queue_runners(sess=sess)

運(yùn)行的結(jié)果如下:

>>> sess.run(bs2)
[array(['test2.txt:2'], dtype=object), array(['test2 line2'], dtype=object)]
>>> sess.run(bs2)
[array(['test2.txt:5'], dtype=object), array(['test2 line5'], dtype=object)]
>>> sess.run(bs2)
[array(['test2.txt:6'], dtype=object), array(['test2 line6'], dtype=object)]
>>> sess.run(bs2)
[array(['test2.txt:4'], dtype=object), array(['test2 line4'], dtype=object)]
>>> sess.run(bs2)
[array(['test2.txt:3'], dtype=object), array(['test2 line3'], dtype=object)]
>>> sess.run(bs2)
[array(['test2.txt:1'], dtype=object), array(['test2 line1'], dtype=object)]
>>> sess.run(bs2)
[array(['test.txt:4'], dtype=object), array(['test line4'], dtype=object)]
>>> sess.run(bs2)
[array(['test.txt:3'], dtype=object), array(['test line3'], dtype=object)]
>>> sess.run(bs2)
[array(['test.txt:2'], dtype=object), array(['test line2'], dtype=object)]

以上這篇對(duì)tensorflow中cifar-10文檔的Read操作詳解就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持億速云。

向AI問(wèn)一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI