您好,登錄后才能下訂單哦!
這篇文章主要介紹“pytorch中retain_graph==True的作用是什么”的相關(guān)知識(shí),小編通過(guò)實(shí)際案例向大家展示操作過(guò)程,操作方法簡(jiǎn)單快捷,實(shí)用性強(qiáng),希望這篇“pytorch中retain_graph==True的作用是什么”文章能幫助大家解決問(wèn)題。
總的來(lái)說(shuō)進(jìn)行一次backward之后,各個(gè)節(jié)點(diǎn)的值會(huì)清除,這樣進(jìn)行第二次backward會(huì)報(bào)錯(cuò),如果加上retain_graph==True后,可以再來(lái)一次backward。
官方定義:
retain_graph (bool, optional) – If False, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. Defaults to the value of create_graph.
大意是如果設(shè)置為False,計(jì)算圖中的中間變量在計(jì)算完后就會(huì)被釋放。
但是在平時(shí)的使用中這個(gè)參數(shù)默認(rèn)都為False從而提高效率,和creat_graph的值一樣。
假設(shè)一個(gè)我們有一個(gè)輸入x,y = x **2, z = y*4,然后我們有兩個(gè)輸出,一個(gè)output_1 = z.mean(),另一個(gè)output_2 = z.sum()。
然后我們對(duì)兩個(gè)output執(zhí)行backward。
import torch x = torch.randn((1,4),dtype=torch.float32,requires_grad=True) y = x ** 2 z = y * 4 print(x) print(y) print(z) loss1 = z.mean() loss2 = z.sum() print(loss1,loss2) loss1.backward() # 這個(gè)代碼執(zhí)行正常,但是執(zhí)行完中間變量都free了,所以下一個(gè)出現(xiàn)了問(wèn)題 print(loss1,loss2) loss2.backward() # 這時(shí)會(huì)引發(fā)錯(cuò)誤
程序正常執(zhí)行到第12行,所有的變量正常保存。
但是在第13行報(bào)錯(cuò):
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
分析:計(jì)算節(jié)點(diǎn)數(shù)值保存了,但是計(jì)算圖x-y-z結(jié)構(gòu)被釋放了,而計(jì)算loss2的backward仍然試圖利用x-y-z的結(jié)構(gòu),因此會(huì)報(bào)錯(cuò)。
因此需要retain_graph參數(shù)為T(mén)rue去保留中間參數(shù)從而兩個(gè)loss的backward()不會(huì)相互影響。
正確的代碼應(yīng)當(dāng)把第11行以及之后改成
1 # 假如你需要執(zhí)行兩次backward,先執(zhí)行第一個(gè)的backward,再執(zhí)行第二個(gè)backward
2 loss1.backward(retain_graph=True)# 這里參數(shù)表明保留backward后的中間參數(shù)。
3 loss2.backward() # 執(zhí)行完這個(gè)后,所有中間變量都會(huì)被釋放,以便下一次的循環(huán)
4 #如果是在訓(xùn)練網(wǎng)絡(luò)optimizer.step() # 更新參數(shù)
create_graph參數(shù)比較簡(jiǎn)單,參考官方定義:
create_graph (bool, optional) – If True, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults to False.
(Pytorch:RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time)
retain_graph設(shè)置True,一般多用于兩次backward
# 假如有兩個(gè)Loss,先執(zhí)行第一個(gè)的backward,再執(zhí)行第二個(gè)backward loss1.backward(retain_graph=True) # 這樣計(jì)算圖就不會(huì)立即釋放 loss2.backward() # 執(zhí)行完這個(gè)后,所有中間變量都會(huì)被釋放,以便下一次的循環(huán) optimizer.step() # 更新參數(shù)
retain_graph設(shè)置True后一定要知道釋放,否則顯卡會(huì)占用越來(lái)越多,代碼速度也會(huì)跑的越來(lái)越慢。
第一種是輸入的原因。
// Example x = torch.randn((100,1), requires_grad = True) y = 1 + 2 * x + 0.3 * torch.randn(100,1) x_train, y_train = x[:70], y[:70] x_val, y_val = x[70:], y[70:] for epoch in range(n_epochs): ... prediction = model(x_train) loss.backward() ...
在多次循環(huán)的過(guò)程中,input的梯度沒(méi)有清除,而且我們也不需要計(jì)算輸入的梯度,因此將x的require_grad設(shè)置為False就可以解決問(wèn)題。
第二種是我在訓(xùn)練LSTM時(shí)候發(fā)現(xiàn)的。
class LSTMpred(nn.Module): def __init__(self, input_size, hidden_dim): self.hidden = self.init_hidden() ... def init_hidden(self): #這里我們是需要個(gè)隱層參數(shù)的 return (torch.zeros(1, 1, self.hidden_dim, requires_grad=True), torch.zeros(1, 1, self.hidden_dim, requires_grad=True)) def forward(self, seq): ...
這里面的self.hidden我們?cè)诿恳淮斡?xùn)練的時(shí)候都要重新初始化隱層參數(shù):
for epoch in range(Epoch): ... model.hidden = model.init_hidden() modout = model(seq) ...
其實(shí),想想這幾種情況都是一回事,都是網(wǎng)絡(luò)在反向傳播中不允許多個(gè)backward(),也就是梯度下降反饋的時(shí)候,有多個(gè)循環(huán)過(guò)程中共用了同一個(gè)需要計(jì)算梯度的變量,在前一個(gè)循環(huán)清除梯度后,后面一個(gè)循環(huán)過(guò)程就會(huì)在這個(gè)變量上栽跟頭(個(gè)人想法)。
關(guān)于“pytorch中retain_graph==True的作用是什么”的內(nèi)容就介紹到這里了,感謝大家的閱讀。如果想了解更多行業(yè)相關(guān)的知識(shí),可以關(guān)注億速云行業(yè)資訊頻道,小編每天都會(huì)為大家更新不同的知識(shí)點(diǎn)。
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。