您好,登錄后才能下訂單哦!
這篇文章主要介紹“PyTorch中tensor.detach()和tensor.data的區(qū)別有哪些”的相關(guān)知識,小編通過實(shí)際案例向大家展示操作過程,操作方法簡單快捷,實(shí)用性強(qiáng),希望這篇“PyTorch中tensor.detach()和tensor.data的區(qū)別有哪些”文章能幫助大家解決問題。
以 a.data, a.detach() 為例:
兩種方法均會返回和a相同的tensor,且與原tensor a 共享數(shù)據(jù),一方改變,則另一方也改變。
所起的作用均是將變量tensor從原有的計(jì)算圖中分離出來,分離所得tensor的requires_grad = False。
data是一個屬性,.detach()是一個方法;data是不安全的,.detach()是安全的;
>>> a = torch.tensor([1,2,3.], requires_grad =True) >>> out = a.sigmoid() >>> c = out.data >>> c.zero_() tensor([ 0., 0., 0.]) >>> out # out的數(shù)值被c.zero_()修改 tensor([ 0., 0., 0.]) >>> out.sum().backward() # 反向傳播 >>> a.grad # 這個結(jié)果很嚴(yán)重的錯誤,因?yàn)閛ut已經(jīng)改變了 tensor([ 0., 0., 0.])
這是因?yàn)椋?dāng)我們修改分離后的tensor,從而導(dǎo)致原tensora發(fā)生改變。PyTorch的自動求導(dǎo)Autograd是無法捕捉到這種變化的,會依然按照求導(dǎo)規(guī)則進(jìn)行求導(dǎo),導(dǎo)致計(jì)算出錯誤的導(dǎo)數(shù)值。
其風(fēng)險性在于,如果我在某一處修改了某一個變量,求導(dǎo)的時候也無法得知這一修改,可能會在不知情的情況下計(jì)算出錯誤的導(dǎo)數(shù)值。
>>> a = torch.tensor([1,2,3.], requires_grad =True) >>> out = a.sigmoid() >>> c = out.detach() >>> c.zero_() tensor([ 0., 0., 0.]) >>> out # out的值被c.zero_()修改 !! tensor([ 0., 0., 0.]) >>> out.sum().backward() # 需要原來out得值,但是已經(jīng)被c.zero_()覆蓋了,結(jié)果報(bào)錯 RuntimeError: one of the variables needed for gradient computation has been modified by an
使用.detach()的好處在于,若是出現(xiàn)上述情況,Autograd可以檢測出某一處變量已經(jīng)發(fā)生了改變,進(jìn)而以如下形式報(bào)錯,從而避免了錯誤的求導(dǎo)。
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
從以上可以看出,是在前向傳播的過程中使用就地操作(In-place operation)導(dǎo)致了這一問題,那么就地操作是什么呢?
官方文檔中,對這個方法是這么介紹的。
返回一個新的從當(dāng)前圖中分離的 Variable。
返回的 Variable 永遠(yuǎn)不會需要梯度 如果 被 detach
的Variable volatile=True, 那么 detach 出來的 volatile 也為 True
還有一個注意事項(xiàng),即:返回的 Variable 和 被 detach 的Variable 指向同一個 tensor
import torch from torch.nn import init from torch.autograd import Variable t1 = torch.FloatTensor([1., 2.]) v1 = Variable(t1) t2 = torch.FloatTensor([2., 3.]) v2 = Variable(t2) v3 = v1 + v2 v3_detached = v3.detach() v3_detached.data.add_(t1) # 修改了 v3_detached Variable中 tensor 的值 print(v3, v3_detached) # v3 中tensor 的值也會改變
可以對部分網(wǎng)絡(luò)求梯度。
如果我們有兩個網(wǎng)絡(luò) , 兩個關(guān)系是這樣的 現(xiàn)在我們想用 來為B網(wǎng)絡(luò)的參數(shù)來求梯度,但是又不想求A網(wǎng)絡(luò)參數(shù)的梯度。我們可以這樣:
# y=A(x), z=B(y) 求B中參數(shù)的梯度,不求A中參數(shù)的梯度 y = A(x) z = B(y.detach()) z.backward()
關(guān)于“PyTorch中tensor.detach()和tensor.data的區(qū)別有哪些”的內(nèi)容就介紹到這里了,感謝大家的閱讀。如果想了解更多行業(yè)相關(guān)的知識,可以關(guān)注億速云行業(yè)資訊頻道,小編每天都會為大家更新不同的知識點(diǎn)。
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。