溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶(hù)服務(wù)條款》

pytorch 6中batch_train批訓(xùn)練操作的示例分析

發(fā)布時(shí)間:2021-05-31 10:06:19 來(lái)源:億速云 閱讀:419 作者:小新 欄目:開(kāi)發(fā)技術(shù)

這篇文章主要介紹pytorch 6中batch_train批訓(xùn)練操作的示例分析,文中介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們一定要看完!

看代碼吧~

import torch
import torch.utils.data as Data
torch.manual_seed(1)    # reproducible
# BATCH_SIZE = 5  
BATCH_SIZE = 8      # 每次使用8個(gè)數(shù)據(jù)同時(shí)傳入網(wǎng)路
x = torch.linspace(1, 10, 10)       # this is x data (torch tensor)
y = torch.linspace(10, 1, 10)       # this is y data (torch tensor)
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    dataset=torch_dataset,      # torch TensorDataset format
    batch_size=BATCH_SIZE,      # mini batch size
    shuffle=False,              # 設(shè)置不隨機(jī)打亂數(shù)據(jù) random shuffle for training
    num_workers=2,              # 使用兩個(gè)進(jìn)程提取數(shù)據(jù),subprocesses for loading data
)
def show_batch():
    for epoch in range(3):   # 全部的數(shù)據(jù)使用3遍,train entire dataset 3 times
        for step, (batch_x, batch_y) in enumerate(loader):  # for each training step
            # train your data...
            print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
                  batch_x.numpy(), '| batch y: ', batch_y.numpy())
if __name__ == '__main__':
    show_batch()

BATCH_SIZE = 8 , 所有數(shù)據(jù)利用三次

Epoch:  0 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  0 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
Epoch:  1 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  1 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
Epoch:  2 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  2 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]

補(bǔ)充:pytorch批訓(xùn)練bug

問(wèn)題描述:

在進(jìn)行pytorch神經(jīng)網(wǎng)絡(luò)批訓(xùn)練的時(shí)候,有時(shí)會(huì)出現(xiàn)報(bào)錯(cuò) 

TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'torch.autograd.variable.Variable'>

解決辦法:

第一步:

檢查(重點(diǎn)?。。。。?:

train_dataset = Data.TensorDataset(train_x, train_y)

train_x,和train_y格式,要求是tensor類(lèi),我第一次出錯(cuò)就是因?yàn)閭魅氲氖莢ariable

可以這樣將數(shù)據(jù)變?yōu)閠ensor類(lèi):

train_x = torch.FloatTensor(train_x)

第二步:

train_loader = Data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

實(shí)例化一個(gè)DataLoader對(duì)象

第三步:

for epoch in range(epochs):
        for step, (batch_x, batch_y) in enumerate(train_loader):
            batch_x, batch_y = Variable(batch_x), Variable(batch_y)

這樣就可以批訓(xùn)練了

需要注意的是:train_loader輸出的是tensor,在訓(xùn)練網(wǎng)絡(luò)時(shí),需要變成Variable

以上是“pytorch 6中batch_train批訓(xùn)練操作的示例分析”這篇文章的所有內(nèi)容,感謝各位的閱讀!希望分享的內(nèi)容對(duì)大家有幫助,更多相關(guān)知識(shí),歡迎關(guān)注億速云行業(yè)資訊頻道!

向AI問(wèn)一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀(guān)點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI