溫馨提示×

PyTorch中怎么創(chuàng)建自定義自動求導(dǎo)函數(shù)

小億
86
2024-05-10 15:31:54

要創(chuàng)建自定義自動求導(dǎo)函數(shù),需要繼承torch.autograd.Function類,并實現(xiàn)forward和backward方法。以下是一個簡單的示例:

import torch

class CustomFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)  # 保存輸入用于反向傳播
        output = input * 2
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output * 2  # 計算輸入的梯度
        return grad_input

# 使用自定義自動求導(dǎo)函數(shù)
x = torch.tensor([3.0], requires_grad=True)
custom_func = CustomFunction.apply
y = custom_func(x)

# 計算梯度
y.backward()
print(x.grad)  # 輸出tensor([2.])

在上面的示例中,我們定義了一個叫做CustomFunction的自定義自動求導(dǎo)函數(shù),實現(xiàn)了forward和backward方法。我們可以像使用其他PyTorch函數(shù)一樣使用這個自定義函數(shù),并計算梯度。

0