您好,登錄后才能下訂單哦!
本篇內(nèi)容介紹了“怎么用Java訓(xùn)練出一只不死鳥”的有關(guān)知識,在實際案例的操作過程中,不少人都會遇到這樣的困境,接下來就讓小編帶領(lǐng)大家學(xué)習(xí)一下如何處理這些情況吧!希望大家仔細(xì)閱讀,能夠?qū)W有所成!
在這一節(jié)會介紹主要用到的算法以及神經(jīng)網(wǎng)絡(luò),幫助你更好的了解如何進行訓(xùn)練。本項目與 DeepLearningFlappyBird 使用了類似的方法進行訓(xùn)練。算法整體的架構(gòu)是 Q-Learning + 卷積神經(jīng)網(wǎng)絡(luò)(CNN),把游戲每一幀的狀態(tài)存儲起來,即小鳥采用的動作和采用動作之后的效果,這些將作為卷積神經(jīng)網(wǎng)絡(luò)的訓(xùn)練數(shù)據(jù)。
CNN 的輸入數(shù)據(jù)為連續(xù)的 4 幀圖像,我們將這圖像 stack 起來作為小鳥當(dāng)前的“observation”,圖像會轉(zhuǎn)換成灰度圖以減少所需的訓(xùn)練資源。圖像存儲的矩陣形式是 (batch size, 4 (frames), 80 (width), 80 (height))
數(shù)組里的元素就是當(dāng)前幀的像素值,這些數(shù)據(jù)將輸入到 CNN 后將輸出 (batch size, 2)
的矩陣,矩陣的第二個維度就是小鳥 (振翅不采取動作) 對應(yīng)的收益。
在小鳥采取動作后,我們會得到 preObservation and currentObservation
即是兩組 4 幀的連續(xù)的圖像表示小鳥動作前和動作后的狀態(tài)。然后我們將 preObservation, currentObservation, action, reward, terminal
組成的五元組作為一個 step 存進 replayBuffer 中。它是一個有限大小的訓(xùn)練數(shù)據(jù)集,他會隨著最新的操作動態(tài)更新內(nèi)容。
public void step(NDList action, boolean training) { if (action.singletonOrThrow().getInt(1) == 1) { bird.birdFlap(); } stepFrame(); NDList preObservation = currentObservation; currentObservation = createObservation(currentImg); FlappyBirdStep step = new FlappyBirdStep(manager.newSubManager(), preObservation, currentObservation, action, currentReward, currentTerminal); if (training) { replayBuffer.addStep(step); } if (gameState == GAME_OVER) { restartGame(); } }
訓(xùn)練分為 3 個不同的周期以更好地生成訓(xùn)練數(shù)據(jù):
Observe(觀察) 周期:隨機產(chǎn)生訓(xùn)練數(shù)據(jù)
Explore (探索) 周期:隨機與推理動作結(jié)合更新訓(xùn)練數(shù)據(jù)
Training (訓(xùn)練) 周期:推理動作主導(dǎo)產(chǎn)生新數(shù)據(jù)
通過這種訓(xùn)練模式,我們可以更好的達到預(yù)期效果。
處于 Explore 周期時,我們會根據(jù)權(quán)重選取隨機的動作或使用模型推理出的動作來作為小鳥的動作。訓(xùn)練前期,隨機動作的權(quán)重會非常大,因為模型的決策十分不準(zhǔn)確 (甚至不如隨機)。在訓(xùn)練后期時,隨著模型學(xué)習(xí)的動作逐步增加,我們會不斷增加模型推理動作的權(quán)重并最終使它成為主導(dǎo)動作。調(diào)節(jié)隨機動作的參數(shù)叫做 epsilon 它會隨著訓(xùn)練的過程不斷變化。
public NDList chooseAction(RlEnv env, boolean training) { if (training && RandomUtils.random() < exploreRate.getNewValue(counter++)) { return env.getActionSpace().randomAction(); } else return baseAgent.chooseAction(env, training); }
首先,我們會從 replayBuffer 中隨機抽取一批數(shù)據(jù)作為作為訓(xùn)練集。然后將 preObservation 輸入到神經(jīng)網(wǎng)絡(luò)得到所有行為的 reward(Q)作為預(yù)測值:
NDList QReward = trainer.forward(preInput); NDList Q = new NDList(QReward.singletonOrThrow() .mul(actionInput.singletonOrThrow()) .sum(new int[]{1}));
postObservation 同樣會輸入到神經(jīng)網(wǎng)絡(luò),根據(jù)馬爾科夫決策過程以及貝爾曼價值函數(shù)計算出所有行為的 reward(targetQ)作為真實值:
// 將 postInput 輸入到神經(jīng)網(wǎng)絡(luò)中得到 targetQReward 是 (batchsize,2) 的矩陣。根據(jù) Q-learning 的算法,每一次的 targetQ 需要根據(jù)當(dāng)前環(huán)境是否結(jié)束算出不同的值,因此需要將每一個 step 的 targetQ 單獨算出后再將 targetQ 堆積成 NDList。 NDList targetQReward = trainer.forward(postInput); NDArray[] targetQValue = new NDArray[batchSteps.length]; for (int i = 0; i < batchSteps.length; i++) { if (batchSteps[i].isTerminal()) { targetQValue[i] = batchSteps[i].getReward(); } else { targetQValue[i] = targetQReward.singletonOrThrow().get(i) .max() .mul(rewardDiscount) .add(rewardInput.singletonOrThrow().get(i)); } } NDList targetQBatch = new NDList(); Arrays.stream(targetQValue).forEach(value -> targetQBatch.addAll(new NDList(value))); NDList targetQ = new NDList(NDArrays.stack(targetQBatch, 0));
在訓(xùn)練結(jié)束時,計算 Q 和 targetQ 的損失值,并在 CNN 中更新權(quán)重。
我們采用了采用了 3 個卷積層,4 個 relu 激活函數(shù)以及 2 個全連接層的神經(jīng)網(wǎng)絡(luò)架構(gòu)。
layer | input shape | output shape |
---|---|---|
conv2d | (batchSize, 4, 80, 80) | (batchSize,4,20,20) |
conv2d | (batchSize, 4, 20 ,20) | (batchSize, 32, 9, 9) |
conv2d | (batchSize, 32, 9, 9) | (batchSize, 64, 7, 7) |
linear | (batchSize, 3136) | (batchSize, 512) |
linear | (batchSize, 512) | (batchSize, 2) |
DJL 的 RL 庫中提供了非常方便的用于實現(xiàn)強化學(xué)習(xí)的接口:(RlEnv, RlAgent, ReplayBuffer)。
實現(xiàn) RlAgent 接口即可構(gòu)建一個可以進行訓(xùn)練的智能體。
在現(xiàn)有的游戲環(huán)境中實現(xiàn) RlEnv 接口即可生成訓(xùn)練所需的數(shù)據(jù)。
創(chuàng)建 ReplayBuffer 可以存儲并動態(tài)更新訓(xùn)練數(shù)據(jù)。
在實現(xiàn)這些接口后,只需要調(diào)用 step 方法:
RlEnv.step(action, training);
這個方法會將 RlAgent 決策出的動作輸入到游戲環(huán)境中獲得反饋。我們可以在 RlEnv 中提供的 runEnviroment
方法中調(diào)用 step 方法,然后只需要重復(fù)執(zhí)行 runEnvironment
方法,即可不斷地生成用于訓(xùn)練的數(shù)據(jù)。
public Step[] runEnvironment(RlAgent agent, boolean training) { // run the game NDList action = agent.chooseAction(this, training); step(action, training); if (training) { batchSteps = this.getBatch(); } return batchSteps; }
我們將 ReplayBuffer 可存儲的 step 數(shù)量設(shè)置為 50000,在 observe 周期我們會先向 replayBuffer 中存儲 1000 個使用隨機動作生成的 step,這樣可以使智能體更快地從隨機動作中學(xué)習(xí)。
在 explore 和 training 周期,神經(jīng)網(wǎng)絡(luò)會隨機從 replayBuffer 中生成訓(xùn)練集并將它們輸入到模型中訓(xùn)練。我們使用 Adam 優(yōu)化器和 MSE 損失函數(shù)迭代神經(jīng)網(wǎng)絡(luò)。
首先將圖像大小 resize 成 80x80
并轉(zhuǎn)為灰度圖,這有助于在不丟失信息的情況下提高訓(xùn)練速度。
public static NDArray imgPreprocess(BufferedImage observation) { return NDImageUtils.toTensor( NDImageUtils.resize( ImageFactory.getInstance().fromImage(observation) .toNDArray(NDManager.newBaseManager(), Image.Flag.GRAYSCALE) ,80,80)); }
然后我們把連續(xù)的四幀圖像作為一個輸入,為了獲得連續(xù)四幀的連續(xù)圖像,我們維護了一個全局的圖像隊列保存游戲線程中的圖像,每一次動作后替換掉最舊的一幀,然后把隊列里的圖像 stack 成一個單獨的 NDArray。
public NDList createObservation(BufferedImage currentImg) { NDArray observation = GameUtil.imgPreprocess(currentImg); if (imgQueue.isEmpty()) { for (int i = 0; i < 4; i++) { imgQueue.offer(observation); } return new NDList(NDArrays.stack(new NDList(observation, observation, observation, observation), 1)); } else { imgQueue.remove(); imgQueue.offer(observation); NDArray[] buf = new NDArray[4]; int i = 0; for (NDArray nd : imgQueue) { buf[i++] = nd; } return new NDList(NDArrays.stack(new NDList(buf[0], buf[1], buf[2], buf[3]), 1)); } }
一旦以上部分完成,我們就可以開始訓(xùn)練了。訓(xùn)練優(yōu)化為了獲得最佳的訓(xùn)練性能,我們關(guān)閉了 GUI 以加快樣本生成速度。并使用 Java 多線程將訓(xùn)練循環(huán)和樣本生成循環(huán)分別在不同的線程中運行。
List<Callable<Object>> callables = new ArrayList<>(numOfThreads); callables.add(new GeneratorCallable(game, agent, training)); if(training) { callables.add(new TrainerCallable(model, agent)); }
“怎么用Java訓(xùn)練出一只不死鳥”的內(nèi)容就介紹到這里了,感謝大家的閱讀。如果想了解更多行業(yè)相關(guān)的知識可以關(guān)注億速云網(wǎng)站,小編將為大家輸出更多高質(zhì)量的實用文章!
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進行舉報,并提供相關(guān)證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權(quán)內(nèi)容。