溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務(wù)條款》

怎么用Java訓(xùn)練出一只不死鳥

發(fā)布時間:2021-11-24 14:43:28 來源:億速云 閱讀:167 作者:iii 欄目:大數(shù)據(jù)

本篇內(nèi)容介紹了“怎么用Java訓(xùn)練出一只不死鳥”的有關(guān)知識,在實際案例的操作過程中,不少人都會遇到這樣的困境,接下來就讓小編帶領(lǐng)大家學(xué)習(xí)一下如何處理這些情況吧!希望大家仔細(xì)閱讀,能夠?qū)W有所成!

增強學(xué)習(xí)(RL)的架構(gòu)

在這一節(jié)會介紹主要用到的算法以及神經(jīng)網(wǎng)絡(luò),幫助你更好的了解如何進行訓(xùn)練。本項目與 DeepLearningFlappyBird 使用了類似的方法進行訓(xùn)練。算法整體的架構(gòu)是 Q-Learning + 卷積神經(jīng)網(wǎng)絡(luò)(CNN),把游戲每一幀的狀態(tài)存儲起來,即小鳥采用的動作和采用動作之后的效果,這些將作為卷積神經(jīng)網(wǎng)絡(luò)的訓(xùn)練數(shù)據(jù)。

CNN 訓(xùn)練簡述

CNN 的輸入數(shù)據(jù)為連續(xù)的 4 幀圖像,我們將這圖像 stack 起來作為小鳥當(dāng)前的“observation”,圖像會轉(zhuǎn)換成灰度圖以減少所需的訓(xùn)練資源。圖像存儲的矩陣形式是 (batch size, 4 (frames), 80 (width), 80 (height)) 數(shù)組里的元素就是當(dāng)前幀的像素值,這些數(shù)據(jù)將輸入到 CNN 后將輸出 (batch size, 2) 的矩陣,矩陣的第二個維度就是小鳥 (振翅不采取動作) 對應(yīng)的收益。

訓(xùn)練數(shù)據(jù)

在小鳥采取動作后,我們會得到 preObservation and currentObservation 即是兩組 4 幀的連續(xù)的圖像表示小鳥動作前和動作后的狀態(tài)。然后我們將 preObservation, currentObservation, action, reward, terminal 組成的五元組作為一個 step 存進 replayBuffer 中。它是一個有限大小的訓(xùn)練數(shù)據(jù)集,他會隨著最新的操作動態(tài)更新內(nèi)容。

public void step(NDList action, boolean training) {
    if (action.singletonOrThrow().getInt(1) == 1) {
        bird.birdFlap();
    }
    stepFrame();
    NDList preObservation = currentObservation;
    currentObservation = createObservation(currentImg);
    FlappyBirdStep step = new FlappyBirdStep(manager.newSubManager(),
            preObservation, currentObservation, action, currentReward, currentTerminal);
    if (training) {
        replayBuffer.addStep(step);
    }
    if (gameState == GAME_OVER) {
        restartGame();
    }
}

訓(xùn)練的三個周期

訓(xùn)練分為 3 個不同的周期以更好地生成訓(xùn)練數(shù)據(jù):

  • Observe(觀察) 周期:隨機產(chǎn)生訓(xùn)練數(shù)據(jù)

  • Explore (探索) 周期:隨機與推理動作結(jié)合更新訓(xùn)練數(shù)據(jù)

  • Training (訓(xùn)練) 周期:推理動作主導(dǎo)產(chǎn)生新數(shù)據(jù)

通過這種訓(xùn)練模式,我們可以更好的達到預(yù)期效果。

處于 Explore 周期時,我們會根據(jù)權(quán)重選取隨機的動作或使用模型推理出的動作來作為小鳥的動作。訓(xùn)練前期,隨機動作的權(quán)重會非常大,因為模型的決策十分不準(zhǔn)確 (甚至不如隨機)。在訓(xùn)練后期時,隨著模型學(xué)習(xí)的動作逐步增加,我們會不斷增加模型推理動作的權(quán)重并最終使它成為主導(dǎo)動作。調(diào)節(jié)隨機動作的參數(shù)叫做 epsilon 它會隨著訓(xùn)練的過程不斷變化。

public NDList chooseAction(RlEnv env, boolean training) {
    if (training && RandomUtils.random() < exploreRate.getNewValue(counter++)) {
        return env.getActionSpace().randomAction();
    } else return baseAgent.chooseAction(env, training);
}

訓(xùn)練邏輯

首先,我們會從 replayBuffer 中隨機抽取一批數(shù)據(jù)作為作為訓(xùn)練集。然后將 preObservation 輸入到神經(jīng)網(wǎng)絡(luò)得到所有行為的 reward(Q)作為預(yù)測值:

NDList QReward = trainer.forward(preInput);
NDList Q = new NDList(QReward.singletonOrThrow()
        .mul(actionInput.singletonOrThrow())
        .sum(new int[]{1}));

postObservation 同樣會輸入到神經(jīng)網(wǎng)絡(luò),根據(jù)馬爾科夫決策過程以及貝爾曼價值函數(shù)計算出所有行為的 reward(targetQ)作為真實值:

// 將 postInput 輸入到神經(jīng)網(wǎng)絡(luò)中得到 targetQReward 是 (batchsize,2) 的矩陣。根據(jù) Q-learning 的算法,每一次的 targetQ 需要根據(jù)當(dāng)前環(huán)境是否結(jié)束算出不同的值,因此需要將每一個 step 的 targetQ 單獨算出后再將 targetQ 堆積成 NDList。
NDList targetQReward = trainer.forward(postInput);
NDArray[] targetQValue = new NDArray[batchSteps.length]; 
for (int i = 0; i < batchSteps.length; i++) {
    if (batchSteps[i].isTerminal()) {
        targetQValue[i] = batchSteps[i].getReward();
    } else {
        targetQValue[i] = targetQReward.singletonOrThrow().get(i)
                .max()
                .mul(rewardDiscount)
                .add(rewardInput.singletonOrThrow().get(i));
    }
}
NDList targetQBatch = new NDList();
Arrays.stream(targetQValue).forEach(value -> targetQBatch.addAll(new NDList(value)));
NDList targetQ = new NDList(NDArrays.stack(targetQBatch, 0));

在訓(xùn)練結(jié)束時,計算 Q 和 targetQ 的損失值,并在 CNN 中更新權(quán)重。

卷積神經(jīng)網(wǎng)絡(luò)模型(CNN)

我們采用了采用了 3 個卷積層,4 個 relu 激活函數(shù)以及 2 個全連接層的神經(jīng)網(wǎng)絡(luò)架構(gòu)。

layerinput shapeoutput shape
conv2d(batchSize, 4, 80, 80)(batchSize,4,20,20)
conv2d(batchSize, 4, 20 ,20)(batchSize, 32, 9, 9)
conv2d(batchSize, 32, 9, 9)(batchSize, 64, 7, 7)
linear(batchSize, 3136)(batchSize, 512)
linear(batchSize, 512)(batchSize, 2)

訓(xùn)練過程

DJL 的 RL 庫中提供了非常方便的用于實現(xiàn)強化學(xué)習(xí)的接口:(RlEnv, RlAgent, ReplayBuffer)。

  • 實現(xiàn) RlAgent 接口即可構(gòu)建一個可以進行訓(xùn)練的智能體。

  • 在現(xiàn)有的游戲環(huán)境中實現(xiàn) RlEnv 接口即可生成訓(xùn)練所需的數(shù)據(jù)。

  • 創(chuàng)建 ReplayBuffer 可以存儲并動態(tài)更新訓(xùn)練數(shù)據(jù)。

在實現(xiàn)這些接口后,只需要調(diào)用 step 方法:

RlEnv.step(action, training);

這個方法會將 RlAgent 決策出的動作輸入到游戲環(huán)境中獲得反饋。我們可以在 RlEnv 中提供的 runEnviroment 方法中調(diào)用 step 方法,然后只需要重復(fù)執(zhí)行 runEnvironment 方法,即可不斷地生成用于訓(xùn)練的數(shù)據(jù)。

public Step[] runEnvironment(RlAgent agent, boolean training) {
    // run the game
    NDList action = agent.chooseAction(this, training);
    step(action, training);
    if (training) {
        batchSteps = this.getBatch();
    }
    return batchSteps;
}

我們將 ReplayBuffer 可存儲的 step 數(shù)量設(shè)置為 50000,在 observe 周期我們會先向 replayBuffer 中存儲 1000 個使用隨機動作生成的 step,這樣可以使智能體更快地從隨機動作中學(xué)習(xí)。

在 explore 和 training 周期,神經(jīng)網(wǎng)絡(luò)會隨機從 replayBuffer 中生成訓(xùn)練集并將它們輸入到模型中訓(xùn)練。我們使用 Adam 優(yōu)化器和 MSE 損失函數(shù)迭代神經(jīng)網(wǎng)絡(luò)。

神經(jīng)網(wǎng)絡(luò)輸入預(yù)處理

首先將圖像大小 resize 成 80x80 并轉(zhuǎn)為灰度圖,這有助于在不丟失信息的情況下提高訓(xùn)練速度。

public static NDArray imgPreprocess(BufferedImage observation) {
    return NDImageUtils.toTensor(
            NDImageUtils.resize(
                    ImageFactory.getInstance().fromImage(observation)
                    .toNDArray(NDManager.newBaseManager(),
                     Image.Flag.GRAYSCALE) ,80,80));
}

然后我們把連續(xù)的四幀圖像作為一個輸入,為了獲得連續(xù)四幀的連續(xù)圖像,我們維護了一個全局的圖像隊列保存游戲線程中的圖像,每一次動作后替換掉最舊的一幀,然后把隊列里的圖像 stack 成一個單獨的 NDArray。

public NDList createObservation(BufferedImage currentImg) {
    NDArray observation = GameUtil.imgPreprocess(currentImg);
    if (imgQueue.isEmpty()) {
        for (int i = 0; i < 4; i++) {
            imgQueue.offer(observation);
        }
        return new NDList(NDArrays.stack(new NDList(observation, observation, observation, observation), 1));
    } else {
        imgQueue.remove();
        imgQueue.offer(observation);
        NDArray[] buf = new NDArray[4];
        int i = 0;
        for (NDArray nd : imgQueue) {
            buf[i++] = nd;
        }
        return new NDList(NDArrays.stack(new NDList(buf[0], buf[1], buf[2], buf[3]), 1));
    }
}

一旦以上部分完成,我們就可以開始訓(xùn)練了。訓(xùn)練優(yōu)化為了獲得最佳的訓(xùn)練性能,我們關(guān)閉了 GUI 以加快樣本生成速度。并使用 Java 多線程將訓(xùn)練循環(huán)和樣本生成循環(huán)分別在不同的線程中運行。

List<Callable<Object>> callables = new ArrayList<>(numOfThreads);
callables.add(new GeneratorCallable(game, agent, training));
if(training) {
    callables.add(new TrainerCallable(model, agent));
}

怎么用Java訓(xùn)練出一只不死鳥

“怎么用Java訓(xùn)練出一只不死鳥”的內(nèi)容就介紹到這里了,感謝大家的閱讀。如果想了解更多行業(yè)相關(guān)的知識可以關(guān)注億速云網(wǎng)站,小編將為大家輸出更多高質(zhì)量的實用文章!

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進行舉報,并提供相關(guān)證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI