溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務(wù)條款》

python的tf.train.batch函數(shù)怎么用

發(fā)布時間:2022-05-05 09:29:49 來源:億速云 閱讀:160 作者:iii 欄目:開發(fā)技術(shù)

這篇文章主要介紹“python的tf.train.batch函數(shù)怎么用”的相關(guān)知識,小編通過實際案例向大家展示操作過程,操作方法簡單快捷,實用性強,希望這篇“python的tf.train.batch函數(shù)怎么用”文章能幫助大家解決問題。

tf.train.batch函數(shù)

tf.train.batch(
    tensors,
    batch_size,
    num_threads=1,
    capacity=32,
    enqueue_many=False,
    shapes=None,
    dynamic_pad=False,
    allow_smaller_final_batch=False,
    shared_name=None,
    name=None
)

其中:

1、tensors:利用slice_input_producer獲得的數(shù)據(jù)組合。

2、batch_size:設(shè)置每次從隊列中獲取出隊數(shù)據(jù)的數(shù)量。

3、num_threads:用來控制線程的數(shù)量,如果其值不唯一,由于線程執(zhí)行的特性,數(shù)據(jù)獲取可能變成亂序。

4、capacity:一個整數(shù),用來設(shè)置隊列中元素的最大數(shù)量

5、allow_samller_final_batch:當(dāng)其為True時,如果隊列中的樣本數(shù)量小于batch_size,出隊的數(shù)量會以最終遺留下來的樣本進行出隊;當(dāng)其為False時,小于batch_size的樣本不會做出隊處理。

6、name:名字

測試代碼

1、allow_samller_final_batch=True

import pandas as pd
import numpy as np
import tensorflow as tf
# 生成數(shù)據(jù)
def generate_data():
    num = 18
    label = np.arange(num)
    return label
# 獲取數(shù)據(jù)
def get_batch_data():
    label = generate_data()
    input_queue = tf.train.slice_input_producer([label], shuffle=False,num_epochs=2)
    label_batch = tf.train.batch(input_queue, batch_size=5, num_threads=1, capacity=64,allow_smaller_final_batch=True)
    return label_batch
# 數(shù)據(jù)組
label = get_batch_data()
sess = tf.Session()
# 初始化變量
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
# 初始化batch訓(xùn)練的參數(shù)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess,coord)
try:
    while not coord.should_stop():
        # 自動獲取下一組數(shù)據(jù)
        l = sess.run(label)
        print(l)
except tf.errors.OutOfRangeError:
    print('Done training')
finally:
    coord.request_stop()
coord.join(threads)
sess.close()

運行結(jié)果為:

[0 1 2 3 4]
[5 6 7 8 9]
[10 11 12 13 14]
[15 16 17  0  1]
[2 3 4 5 6]
[ 7  8  9 10 11]
[12 13 14 15 16]
[17]
Done training

2、allow_samller_final_batch=False

相比allow_samller_final_batch=True,輸出結(jié)果少了[17]

import pandas as pd
import numpy as np
import tensorflow as tf
# 生成數(shù)據(jù)
def generate_data():
    num = 18
    label = np.arange(num)
    return label
# 獲取數(shù)據(jù)
def get_batch_data():
    label = generate_data()
    input_queue = tf.train.slice_input_producer([label], shuffle=False,num_epochs=2)
    label_batch = tf.train.batch(input_queue, batch_size=5, num_threads=1, capacity=64,allow_smaller_final_batch=False)
    return label_batch
# 數(shù)據(jù)組
label = get_batch_data()
sess = tf.Session()
# 初始化變量
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
# 初始化batch訓(xùn)練的參數(shù)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess,coord)
try:
    while not coord.should_stop():
        # 自動獲取下一組數(shù)據(jù)
        l = sess.run(label)
        print(l)
except tf.errors.OutOfRangeError:
    print('Done training')
finally:
    coord.request_stop()
coord.join(threads)
sess.close()

運行結(jié)果為:

[0 1 2 3 4]
[5 6 7 8 9]
[10 11 12 13 14]
[15 16 17  0  1]
[2 3 4 5 6]
[ 7  8  9 10 11]
[12 13 14 15 16]
Done training

關(guān)于“python的tf.train.batch函數(shù)怎么用”的內(nèi)容就介紹到這里了,感謝大家的閱讀。如果想了解更多行業(yè)相關(guān)的知識,可以關(guān)注億速云行業(yè)資訊頻道,小編每天都會為大家更新不同的知識點。

向AI問一下細節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進行舉報,并提供相關(guān)證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI