PyTorch中怎么使用自動(dòng)求導(dǎo)

小億
88
2024-05-10 15:32:54

PyTorch中使用自動(dòng)求導(dǎo)可以通過(guò)定義一個(gè)torch.Tensor對(duì)象,并設(shè)置requires_grad=True來(lái)告訴PyTorch需要對(duì)該對(duì)象進(jìn)行求導(dǎo)。然后可以使用backward()方法對(duì)目標(biāo)函數(shù)進(jìn)行求導(dǎo)。下面是一個(gè)簡(jiǎn)單的示例:

import torch

# 創(chuàng)建一個(gè)需要求導(dǎo)的張量
x = torch.tensor([2.0], requires_grad=True)

# 定義一個(gè)函數(shù) f = x^2
def f(x):
    return x**2

# 計(jì)算 f 在 x=2 處的值
output = f(x)
print(output)

# 對(duì) f 進(jìn)行反向傳播,計(jì)算梯度
output.backward()

# 查看梯度值
print(x.grad)

在這個(gè)示例中,我們創(chuàng)建了一個(gè)張量x,并定義了一個(gè)函數(shù)f(x) = x^2,然后計(jì)算了函數(shù)在x=2處的值,并對(duì)其進(jìn)行反向傳播,計(jì)算出梯度值。最后可以通過(guò)x.grad查看梯度值。

0