溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點(diǎn)擊 登錄注冊 即表示同意《億速云用戶服務(wù)條款》

如何在Pytorch 中使用retain_graph

發(fā)布時(shí)間:2021-01-11 15:47:30 來源:億速云 閱讀:344 作者:Leah 欄目:開發(fā)技術(shù)

這期內(nèi)容當(dāng)中小編將會(huì)給大家?guī)碛嘘P(guān)如何在Pytorch 中使用retain_graph,文章內(nèi)容豐富且以專業(yè)的角度為大家分析和敘述,閱讀完這篇文章希望大家可以有所收獲。

用法分析

在查看SRGAN源碼時(shí)有如下?lián)p失函數(shù),其中設(shè)置了retain_graph=True,其作用是什么?

		############################
    # (1) Update D network: maximize D(x)-1-D(G(z))
    ###########################
    real_img = Variable(target)
    if torch.cuda.is_available():
      real_img = real_img.cuda()
    z = Variable(data)
    if torch.cuda.is_available():
      z = z.cuda()
    fake_img = netG(z)

    netD.zero_grad()
    real_out = netD(real_img).mean()
    fake_out = netD(fake_img).mean()
    d_loss = 1 - real_out + fake_out
    d_loss.backward(retain_graph=True) #####
    optimizerD.step()

    ############################
    # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
    ###########################
    netG.zero_grad()
    g_loss = generator_criterion(fake_out, fake_img, real_img)
    g_loss.backward()
    optimizerG.step()
    fake_img = netG(z)
    fake_out = netD(fake_img).mean()

    g_loss = generator_criterion(fake_out, fake_img, real_img)
    running_results['g_loss'] += g_loss.data[0] * batch_size
    d_loss = 1 - real_out + fake_out
    running_results['d_loss'] += d_loss.data[0] * batch_size
    running_results['d_score'] += real_out.data[0] * batch_size
    running_results['g_score'] += fake_out.data[0] * batch_size

在更新D網(wǎng)絡(luò)時(shí)的loss反向傳播過程中使用了retain_graph=True,目的為是為保留該過程中計(jì)算的梯度,后續(xù)G網(wǎng)絡(luò)更新時(shí)使用;

其實(shí)retain_graph這個(gè)參數(shù)在平常中我們是用不到的,但是在特殊的情況下我們會(huì)用到它,

如下代碼:

import torch
y=x**2
z=y*4
output1=z.mean()
output2=z.sum()
output1.backward()
output2.backward()

輸出如下錯(cuò)誤信息:

---------------------------------------------------------------------------
RuntimeError               Traceback (most recent call last)
<ipython-input-19-8ad6b0658906> in <module>()
----> 1 output1.backward()
   2 output2.backward()

D:\ProgramData\Anaconda3\lib\site-packages\torch\tensor.py in backward(self, gradient, retain_graph, create_graph)
   91         products. Defaults to ``False``.
   92     """
---> 93     torch.autograd.backward(self, gradient, retain_graph, create_graph)
   94 
   95   def register_hook(self, hook):

D:\ProgramData\Anaconda3\lib\site-packages\torch\autograd\__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
   88   Variable._execution_engine.run_backward(
   89     tensors, grad_tensors, retain_graph, create_graph,
---> 90     allow_unreachable=True) # allow_unreachable flag
   91 
   92 

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

修改成如下正確:

import torch
y=x**2
z=y*4
output1=z.mean()
output2=z.sum()
output1.backward(retain_graph=True)
output2.backward()
# 假如你有兩個(gè)Loss,先執(zhí)行第一個(gè)的backward,再執(zhí)行第二個(gè)backward
loss1.backward(retain_graph=True)
loss2.backward() # 執(zhí)行完這個(gè)后,所有中間變量都會(huì)被釋放,以便下一次的循環(huán)
optimizer.step() # 更新參數(shù)

Variable 類源代碼

class Variable(_C._VariableBase):
 
  """
  Attributes:
    data: 任意類型的封裝好的張量。
    grad: 保存與data類型和位置相匹配的梯度,此屬性難以分配并且不能重新分配。
    requires_grad: 標(biāo)記變量是否已經(jīng)由一個(gè)需要調(diào)用到此變量的子圖創(chuàng)建的bool值。只能在葉子變量上進(jìn)行修改。
    volatile: 標(biāo)記變量是否能在推理模式下應(yīng)用(如不保存歷史記錄)的bool值。只能在葉變量上更改。
    is_leaf: 標(biāo)記變量是否是圖葉子(如由用戶創(chuàng)建的變量)的bool值.
    grad_fn: Gradient function graph trace.
 
  Parameters:
    data (any tensor class): 要包裝的張量.
    requires_grad (bool): bool型的標(biāo)記值. **Keyword only.**
    volatile (bool): bool型的標(biāo)記值. **Keyword only.**
  """
 
  def backward(self, gradient=None, retain_graph=None, create_graph=None, retain_variables=None):
    """計(jì)算關(guān)于當(dāng)前圖葉子變量的梯度,圖使用鏈?zhǔn)椒▌t導(dǎo)致分化
    如果Variable是一個(gè)標(biāo)量(例如它包含一個(gè)單元素?cái)?shù)據(jù)),你無需對backward()指定任何參數(shù)
    如果變量不是標(biāo)量(包含多個(gè)元素?cái)?shù)據(jù)的矢量)且需要梯度,函數(shù)需要額外的梯度;
    需要指定一個(gè)和tensor的形狀匹配的grad_output參數(shù)(y在指定方向投影對x的導(dǎo)數(shù));
    可以是一個(gè)類型和位置相匹配且包含與自身相關(guān)的不同函數(shù)梯度的張量。
    函數(shù)在葉子上累積梯度,調(diào)用前需要對該葉子進(jìn)行清零。
 
    Arguments:
      grad_variables (Tensor, Variable or None):
              變量的梯度,如果是一個(gè)張量,除非“create_graph”是True,否則會(huì)自動(dòng)轉(zhuǎn)換成volatile型的變量。
              可以為標(biāo)量變量或不需要grad的值指定None值。如果None值可接受,則此參數(shù)可選。
      retain_graph (bool, optional): 如果為False,用來計(jì)算梯度的圖將被釋放。
                      在幾乎所有情況下,將此選項(xiàng)設(shè)置為True不是必需的,通??梢砸愿行У姆绞浇鉀Q。
                      默認(rèn)值為create_graph的值。
      create_graph (bool, optional): 為True時(shí),會(huì)構(gòu)造一個(gè)導(dǎo)數(shù)的圖,用來計(jì)算出更高階導(dǎo)數(shù)結(jié)果。
                      默認(rèn)為False,除非``gradient``是一個(gè)volatile變量。
    """
    torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
 
 
  def register_hook(self, hook):
    """Registers a backward hook.
 
    每當(dāng)與variable相關(guān)的梯度被計(jì)算時(shí)調(diào)用hook,hook的申明:hook(grad)->Variable or None
    不能對hook的參數(shù)進(jìn)行修改,但可以選擇性地返回一個(gè)新的梯度以用在`grad`的相應(yīng)位置。
 
    函數(shù)返回一個(gè)handle,其``handle.remove()``方法用于將hook從模塊中移除。
 
    Example:
      >>> v = Variable(torch.Tensor([0, 0, 0]), requires_grad=True)
      >>> h = v.register_hook(lambda grad: grad * 2) # double the gradient
      >>> v.backward(torch.Tensor([1, 1, 1]))
      >>> v.grad.data
       2
       2
       2
      [torch.FloatTensor of size 3]
      >>> h.remove() # removes the hook
    """
    if self.volatile:
      raise RuntimeError("cannot register a hook on a volatile variable")
    if not self.requires_grad:
      raise RuntimeError("cannot register a hook on a variable that "
                "doesn't require gradient")
    if self._backward_hooks is None:
      self._backward_hooks = OrderedDict()
      if self.grad_fn is not None:
        self.grad_fn._register_hook_dict(self)
    handle = hooks.RemovableHandle(self._backward_hooks)
    self._backward_hooks[handle.id] = hook
    return handle
 
  def reinforce(self, reward):
    """Registers a reward obtained as a result of a stochastic process.
    區(qū)分隨機(jī)節(jié)點(diǎn)需要為他們提供reward值。如果圖表中包含任何的隨機(jī)操作,都應(yīng)該在其輸出上調(diào)用此函數(shù),否則會(huì)出現(xiàn)錯(cuò)誤。
    Parameters:
      reward(Tensor): 帶有每個(gè)元素獎(jiǎng)賞的張量,必須與Variable數(shù)據(jù)的設(shè)備位置和形狀相匹配。
    """
    if not isinstance(self.grad_fn, StochasticFunction):
      raise RuntimeError("reinforce() can be only called on outputs "
                "of stochastic functions")
    self.grad_fn._reinforce(reward)
 
  def detach(self):
    """返回一個(gè)從當(dāng)前圖分離出來的心變量。
    結(jié)果不需要梯度,如果輸入是volatile,則輸出也是volatile。
 
    .. 注意::
     返回變量使用與原始變量相同的數(shù)據(jù)張量,并且可以看到其中任何一個(gè)的就地修改,并且可能會(huì)觸發(fā)正確性檢查中的錯(cuò)誤。
    """
    result = NoGrad()(self) # this is needed, because it merges version counters
    result._grad_fn = None
    return result
 
  def detach_(self):
    """從創(chuàng)建它的圖中分離出變量并作為該圖的一個(gè)葉子"""
    self._grad_fn = None
    self.requires_grad = False
 
  def retain_grad(self):
    """Enables .grad attribute for non-leaf Variables."""
    if self.grad_fn is None: # no-op for leaves
      return
    if not self.requires_grad:
      raise RuntimeError("can't retain_grad on Variable that has requires_grad=False")
    if hasattr(self, 'retains_grad'):
      return
    weak_self = weakref.ref(self)
 
    def retain_grad_hook(grad):
      var = weak_self()
      if var is None:
        return
      if var._grad is None:
        var._grad = grad.clone()
      else:
        var._grad = var._grad + grad
 
    self.register_hook(retain_grad_hook)
    self.retains_grad = True

上述就是小編為大家分享的如何在Pytorch 中使用retain_graph了,如果剛好有類似的疑惑,不妨參照上述分析進(jìn)行理解。如果想知道更多相關(guān)知識(shí),歡迎關(guān)注億速云行業(yè)資訊頻道。

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI