要利用Torch構建自定義的損失函數(shù),你需要定義一個新的損失函數(shù)類,并重寫其中的forward方法來計算損失值。接下來是一個示例代碼:
import torch
import torch.nn as nn
class CustomLoss(nn.Module):
def __init__(self):
super(CustomLoss, self).__init__()
def forward(self, pred, target):
loss = torch.mean((pred - target) ** 2) # 例如,這里定義為均方誤差損失函數(shù)
return loss
# 使用自定義的損失函數(shù)
criterion = CustomLoss()
pred = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
loss = criterion(pred, target)
print(loss)
在上面的示例中,我們首先定義了一個名為CustomLoss的新的損失函數(shù)類,然后在其中定義了forward方法來計算損失值。在這個例子中,我們將損失函數(shù)定義為均方誤差損失函數(shù)。最后,我們實例化了這個自定義的損失函數(shù)類,并計算了預測值和目標值之間的損失值。
通過這種方式,你可以通過Torch構建自定義的損失函數(shù),并在模型訓練中使用它。