您好,登錄后才能下訂單哦!
這篇文章主要為大家展示了“PyTorch梯度裁剪如何避免訓(xùn)練loss nan”,內(nèi)容簡(jiǎn)而易懂,條理清晰,希望能夠幫助大家解決疑惑,下面讓小編帶領(lǐng)大家一起研究并學(xué)習(xí)一下“PyTorch梯度裁剪如何避免訓(xùn)練loss nan”這篇文章吧。
from torch.nn.utils import clip_grad_norm_
outputs = model(data)
loss= loss_fn(outputs, target)
optimizer.zero_grad()
loss.backward()
# clip the grad
clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)
optimizer.step()
其中,max_norm為梯度的最大范數(shù),也是梯度裁剪時(shí)主要設(shè)置的參數(shù)。
備注:網(wǎng)上有同學(xué)提醒在(強(qiáng)化學(xué)習(xí))使用了梯度裁剪之后訓(xùn)練時(shí)間會(huì)大大增加。目前在我的檢測(cè)網(wǎng)絡(luò)訓(xùn)練中暫時(shí)還沒(méi)有碰到這個(gè)問(wèn)題,以后遇到再來(lái)更新。
補(bǔ)充:pytorch訓(xùn)練過(guò)程中出現(xiàn)nan的排查思路
看看代碼中在這種操作的時(shí)候有沒(méi)有加一個(gè)很小的數(shù),但是這個(gè)數(shù)數(shù)量級(jí)要和運(yùn)算的數(shù)的數(shù)量級(jí)要差很多。一般是1e-8。
optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
optim.step()
max_norm一般是1,3,5。
就按照下面的流程來(lái)判斷。
...
loss = model(input)
# 1. 先看loss是不是nan,如果loss是nan,那么說(shuō)明可能是在forward的過(guò)程中出現(xiàn)了第一條列舉的除0或者log0的操作
assert torch.isnan(loss).sum() == 0, print(loss)
optim.zero_grad()
loss.backward()
# 2. 如果loss不是nan,那么說(shuō)明forward過(guò)程沒(méi)問(wèn)題,可能是梯度爆炸,所以用梯度裁剪試試
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
# 3.1 在step之前,判斷參數(shù)是不是nan, 如果不是判斷step之后是不是nan
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
optim.step()
# 3.2 在step之后判斷,參數(shù)和其梯度是不是nan,如果3.1不是nan,而3.2是nan,
# 特別是梯度出現(xiàn)了Nan,考慮學(xué)習(xí)速率是否太大,調(diào)小學(xué)習(xí)速率或者換個(gè)優(yōu)化器試試。
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
assert torch.isnan(model.mu.grad).sum() == 0, print(model.mu.grad)
以上是“PyTorch梯度裁剪如何避免訓(xùn)練loss nan”這篇文章的所有內(nèi)容,感謝各位的閱讀!相信大家都有了一定的了解,希望分享的內(nèi)容對(duì)大家有所幫助,如果還想學(xué)習(xí)更多知識(shí),歡迎關(guān)注億速云行業(yè)資訊頻道!
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。