您好,登錄后才能下訂單哦!
這篇文章主要為大家展示了“在pytorch中如何對(duì)非葉節(jié)點(diǎn)的變量進(jìn)行梯度計(jì)算”,內(nèi)容簡(jiǎn)而易懂,條理清晰,希望能夠幫助大家解決疑惑,下面讓小編帶領(lǐng)大家一起研究并學(xué)習(xí)一下“在pytorch中如何對(duì)非葉節(jié)點(diǎn)的變量進(jìn)行梯度計(jì)算”這篇文章吧。
在pytorch中一般只對(duì)葉節(jié)點(diǎn)進(jìn)行梯度計(jì)算,也就是下圖中的d,e節(jié)點(diǎn),而對(duì)非葉節(jié)點(diǎn),也即是c,b節(jié)點(diǎn)則沒有顯式地去保留其中間計(jì)算過程中的梯度(因?yàn)橐话銇碚f只有葉節(jié)點(diǎn)才需要去更新),這樣可以節(jié)省很大部分的顯存,但是在調(diào)試過程中,有時(shí)候我們需要對(duì)中間變量梯度進(jìn)行監(jiān)控,以確保網(wǎng)絡(luò)的有效性,這個(gè)時(shí)候我們需要打印出非葉節(jié)點(diǎn)的梯度,為了實(shí)現(xiàn)這個(gè)目的,我們可以通過兩種手段進(jìn)行。
注冊(cè)hook函數(shù)
Tensor.register_hook[2] 可以注冊(cè)一個(gè)反向梯度傳導(dǎo)時(shí)的hook函數(shù),這個(gè)hook函數(shù)將會(huì)在每次計(jì)算 關(guān)于該張量 的時(shí)候 被調(diào)用,經(jīng)常用于調(diào)試的時(shí)候打印出非葉節(jié)點(diǎn)梯度。當(dāng)然,通過這個(gè)手段,你也可以自定義某一層的梯度更新方法。[3] 具體到這里的打印非葉節(jié)點(diǎn)的梯度,代碼如:
def hook_y(grad): print(grad) x = Variable(torch.ones(2, 2), requires_grad=True) y = x + 2 z = y * y * 3 y.register_hook(hook_y) out = z.mean() out.backward()
輸出如:
tensor([[4.5000, 4.5000], [4.5000, 4.5000]])
retain_grad()
Tensor.retain_grad()顯式地保存非葉節(jié)點(diǎn)的梯度,當(dāng)然代價(jià)就是會(huì)增加顯存的消耗,而用hook函數(shù)的方法則是在反向計(jì)算時(shí)直接打印,因此不會(huì)增加顯存消耗,但是使用起來retain_grad()要比hook函數(shù)方便一些。代碼如:
x = Variable(torch.ones(2, 2), requires_grad=True) y = x + 2 y.retain_grad() z = y * y * 3 out = z.mean() out.backward() print(y.grad)
輸出如:
tensor([[4.5000, 4.5000], [4.5000, 4.5000]])
以上是“在pytorch中如何對(duì)非葉節(jié)點(diǎn)的變量進(jìn)行梯度計(jì)算”這篇文章的所有內(nèi)容,感謝各位的閱讀!相信大家都有了一定的了解,希望分享的內(nèi)容對(duì)大家有所幫助,如果還想學(xué)習(xí)更多知識(shí),歡迎關(guān)注億速云行業(yè)資訊頻道!
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。