PyTorch中怎么使用損失函數(shù)

小億
86
2024-05-10 15:36:59

在PyTorch中,可以使用torch.nn模塊中提供的損失函數(shù)來(lái)計(jì)算模型的損失。以下是一個(gè)使用損失函數(shù)計(jì)算模型損失的示例代碼:

import torch
import torch.nn as nn

# 定義模型
model = nn.Linear(10, 1)

# 定義損失函數(shù)
criterion = nn.MSELoss()

# 生成輸入數(shù)據(jù)和目標(biāo)數(shù)據(jù)
input_data = torch.randn(1, 10)
target_data = torch.randn(1, 1)

# 前向傳播
output = model(input_data)

# 計(jì)算損失
loss = criterion(output, target_data)

print(loss)

在上面的示例中,我們首先定義了一個(gè)簡(jiǎn)單的線性模型和一個(gè)均方誤差損失函數(shù)。然后生成輸入數(shù)據(jù)和目標(biāo)數(shù)據(jù),通過(guò)模型的前向傳播得到輸出,最后使用損失函數(shù)計(jì)算模型的損失。通過(guò)調(diào)用loss.backward()方法,可以計(jì)算損失函數(shù)相對(duì)于模型參數(shù)的梯度,進(jìn)而進(jìn)行模型的參數(shù)更新。

0