溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶(hù)服務(wù)條款》

tensorflow tf.train.batch之?dāng)?shù)據(jù)批量讀取的示例分析

發(fā)布時(shí)間:2021-07-26 10:26:10 來(lái)源:億速云 閱讀:118 作者:小新 欄目:開(kāi)發(fā)技術(shù)

這篇文章給大家分享的是有關(guān)tensorflow tf.train.batch之?dāng)?shù)據(jù)批量讀取的示例分析的內(nèi)容。小編覺(jué)得挺實(shí)用的,因此分享給大家做個(gè)參考,一起跟隨小編過(guò)來(lái)看看吧。

在進(jìn)行大量數(shù)據(jù)訓(xùn)練神經(jīng)網(wǎng)絡(luò)的時(shí)候,可能需要批量讀取數(shù)據(jù)。于是參考了這篇文章的代碼,結(jié)果發(fā)現(xiàn)數(shù)據(jù)一直批量循環(huán)輸出,不會(huì)在數(shù)據(jù)的末尾自動(dòng)停止。

然后發(fā)現(xiàn)這篇博文說(shuō)slice_input_producer()這個(gè)函數(shù)有一個(gè)形參num_epochs,通過(guò)設(shè)置它的值就可以控制全部數(shù)據(jù)循環(huán)輸出幾次。

于是我設(shè)置之后出現(xiàn)以下的報(bào)錯(cuò):

tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value input_producer/input_producer/limit_epochs/epochs

     [[Node: input_producer/input_producer/limit_epochs/CountUpTo = CountUpTo[T=DT_INT64, _class=["loc:@input_producer/input_producer/limit_epochs/epochs"], limit=2, _device="/job:localhost/replica:0/task:0/cpu:0"](input_producer/input_producer/limit_epochs/epochs)]]

找了好久,都不知道為什么會(huì)錯(cuò),于是只好去看看slice_input_producer()函數(shù)的源碼,結(jié)果在源碼中發(fā)現(xiàn)作者說(shuō)這個(gè)num_epochs如果不是空的話,就是一個(gè)局部變量,需要先調(diào)用global_variables_initializer()函數(shù)初始化。

于是我調(diào)用了之后,一切就正常了,特此記錄下來(lái),希望其他人遇到的時(shí)候能夠及時(shí)找到原因。

哈哈,這是筆者第一次通過(guò)閱讀源碼解決了問(wèn)題,心情還是有點(diǎn)小激動(dòng)。啊啊,扯遠(yuǎn)了,上最終成功的代碼:

import pandas as pd
import numpy as np
import tensorflow as tf


def generate_data():
  num = 25
  label = np.asarray(range(0, num))
  images = np.random.random([num, 5])
  print('label size :{}, image size {}'.format(label.shape, images.shape))
  return images,label

def get_batch_data():
  label, images = generate_data()
  input_queue = tf.train.slice_input_producer([images, label], shuffle=False,num_epochs=2)
  image_batch, label_batch = tf.train.batch(input_queue, batch_size=5, num_threads=1, capacity=64,allow_smaller_final_batch=False)
  return image_batch,label_batch


images,label = get_batch_data()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())#就是這一行
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess,coord)
try:
  while not coord.should_stop():
    i,l = sess.run([images,label])
    print(i)
    print(l)
except tf.errors.OutOfRangeError:
  print('Done training')
finally:
  coord.request_stop()
coord.join(threads)
sess.close()

感謝各位的閱讀!關(guān)于“tensorflow tf.train.batch之?dāng)?shù)據(jù)批量讀取的示例分析”這篇文章就分享到這里了,希望以上內(nèi)容可以對(duì)大家有一定的幫助,讓大家可以學(xué)到更多知識(shí),如果覺(jué)得文章不錯(cuò),可以把它分享出去讓更多的人看到吧!

向AI問(wèn)一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI