您好,登錄后才能下訂單哦!
這篇文章主要介紹“PyTorch策略梯度算法怎么使用”,在日常操作中,相信很多人在PyTorch策略梯度算法怎么使用問題上存在疑惑,小編查閱了各式資料,整理出簡單好用的操作方法,希望對大家解答”PyTorch策略梯度算法怎么使用”的疑惑有所幫助!接下來,請跟著小編一起來學(xué)習(xí)吧!
策略梯度算法通過記錄回合中的所有時間步并基于回合結(jié)束時與這些時間步相關(guān)聯(lián)的獎勵來更新權(quán)重訓(xùn)練智能體。使智能體遍歷整個回合然后基于獲得的獎勵更新策略的技術(shù)稱為蒙特卡洛策略梯度。
在策略梯度算法中,模型權(quán)重在每個回合結(jié)束時沿梯度方向移動。關(guān)于梯度的計算,我們將在下一節(jié)中詳細(xì)解釋。此外,在每一時間步中,基于當(dāng)前狀態(tài)和權(quán)重計算的概率得到策略,并從中采樣一個動作。與隨機(jī)搜索和爬山算法(通過采取確定性動作以獲得更高的得分)相反,它不再確定地采取動作。因此,策略從確定性轉(zhuǎn)變?yōu)殡S機(jī)性。例如,如果向左的動作和向右的動作的概率為 [0.8,0.2]
,則表示有 80%
的概率選擇向左的動作,但這并不意味著一定會選擇向左的動作。
在本節(jié)中,我們將學(xué)習(xí)使用 PyTorch
實現(xiàn)策略梯度算法了。 導(dǎo)入所需的庫,創(chuàng)建 CartPole
環(huán)境實例,并計算狀態(tài)空間和動作空間的尺寸:
import gym import torch import matplotlib.pyplot as plt env = gym.make('CartPole-v0') n_state = env.observation_space.shape[0] print(n_state) n_action = env.action_space.n print(n_action)
定義 run_episode
函數(shù),在此函數(shù)中,根據(jù)給定輸入權(quán)重的情況下模擬一回合 CartPole
游戲,并返回獎勵和計算出的梯度。在每個時間步中執(zhí)行以下操作:
根據(jù)當(dāng)前狀態(tài)和輸入權(quán)重計算兩個動作的概率 probs
根據(jù)結(jié)果概率采樣一個動作 action
以概率作為輸入計算 softmax
函數(shù)的導(dǎo)數(shù) d_softmax
,由于只需要計算與選定動作相關(guān)的導(dǎo)數(shù),因此:
\frac {\partial p_i} {\partial z_j} = p_i(1-p_j), i=j∂zj∂pi=pi(1−pj),i=j
將所得的導(dǎo)數(shù) d_softmax
除以概率 probs
,以得與策略相關(guān)的對數(shù)導(dǎo)數(shù) d_log
根據(jù)鏈?zhǔn)椒▌t計算權(quán)重的梯度 grad
:
\frac {dy}{dx}=\frac{dy}{du}\cdot\frac{du}{dx}dxdy=dudy⋅dxdu
記錄得到的梯度 grad
執(zhí)行動作,累積獎勵并更新狀態(tài)
def run_episode(env, weight): state = env.reset() grads = [] total_reward = 0 is_done = False while not is_done: state = torch.from_numpy(state).float() # 根據(jù)當(dāng)前狀態(tài)和輸入權(quán)重計算兩個動作的概率 probs z = torch.matmul(state, weight) probs = torch.nn.Softmax(dim=0)(z) # 根據(jù)結(jié)果概率采樣一個動作 action action = int(torch.bernoulli(probs[1]).item()) # 以概率作為輸入計算 softmax 函數(shù)的導(dǎo)數(shù) d_softmax d_softmax = torch.diag(probs) - probs.view(-1, 1) * probs # 計算與策略相關(guān)的對數(shù)導(dǎo)數(shù)d_log d_log = d_softmax[action] / probs[action] # 計算權(quán)重的梯度grad grad = state.view(-1, 1) * d_log grads.append(grad) state, reward, is_done, _ = env.step(action) total_reward += reward if is_done: break return total_reward, grads
回合完成后,返回在此回合中獲得的總獎勵以及在各個時間步中計算的梯度信息,用于之后更新權(quán)重。
接下來,定義要運(yùn)行的回合數(shù),在每個回合中調(diào)用 run_episode
函數(shù),并初始化權(quán)重以及用于記錄每個回合總獎勵的變量:
n_episode = 1000 weight = torch.rand(n_state, n_action) total_rewards = []
在每個回合結(jié)束后,使用計算出的梯度來更新權(quán)重。對于回合中的每個時間步,權(quán)重都根據(jù)學(xué)習(xí)率、計算出的梯度和智能體在剩余時間步中的獲得的總獎勵進(jìn)行更新。
我們知道在回合終止之前,每一時間步的獎勵都是 1
。因此,我們用于計算每個時間步策略梯度的未來獎勵是剩余的時間步數(shù)。在每個回合之后,我們使用隨機(jī)梯度上升方法將梯度乘以未來獎勵來更新權(quán)重。這樣,一個回合中經(jīng)歷的時間步越長,權(quán)重的更新幅度就越大,這將增加獲得更大總獎勵的機(jī)會。我們設(shè)定學(xué)習(xí)率為 0.001
:
learning_rate = 0.001 for e in range(n_episode): total_reward, gradients = run_episode(env, weight) print('Episode {}: {}'.format(e + 1, total_reward)) for i, gradient in enumerate(gradients): weight += learning_rate * gradient * (total_reward - i) total_rewards.append(total_reward)
然后,我們計算通過策略梯度算法獲得的平均總獎勵:
print('Average total reward over {} episode: {}'.format(n_episode, sum(total_rewards)/n_episode))
我們可以繪制每個回合的總獎勵變化情況,如下所示:
plt.plot(total_rewards) plt.xlabel('Episode') plt.ylabel('Reward') plt.show()
在上圖中,我們可以看到獎勵會隨著訓(xùn)練回合的增加呈現(xiàn)出上升趨勢,然后能夠在最大值處穩(wěn)定。我們還可以看到,即使在收斂之后,獎勵也會振蕩,這是由于策略梯度算法是一種隨機(jī)策略算法。
最后,我們查看學(xué)習(xí)到策略在 1000
個新回合中的性能表現(xiàn),并計算平均獎勵:
n_episode_eval = 1000 total_rewards_eval = [] for e in range(n_episode_eval): total_reward, _ = run_episode(env, weight) print('Episode {}: {}'.format(e+1, total_reward)) total_rewards_eval.append(total_reward) print('Average total reward over {} episode: {}'.format(n_episode_eval, sum(total_rewards_eval)/n_episode_eval)) # Average total reward over 1000 episode: 200
進(jìn)行測試后,可以看到回合的平均獎勵接近最大值 200
??梢远啻螠y試訓(xùn)練后的模型,得到的平均獎勵較為穩(wěn)定。正如我們一開始所說的那樣,對于諸如 CartPole
之類的簡單環(huán)境,策略梯度算法可能大材小用,但它為我們解決更加復(fù)雜的問題奠定了基礎(chǔ)。
到此,關(guān)于“PyTorch策略梯度算法怎么使用”的學(xué)習(xí)就結(jié)束了,希望能夠解決大家的疑惑。理論與實踐的搭配能更好的幫助大家學(xué)習(xí),快去試試吧!若想繼續(xù)學(xué)習(xí)更多相關(guān)知識,請繼續(xù)關(guān)注億速云網(wǎng)站,小編會繼續(xù)努力為大家?guī)砀鄬嵱玫奈恼拢?/p>
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報,并提供相關(guān)證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權(quán)內(nèi)容。