要使用PyTorch來預(yù)測(cè)模型,首先需要加載已經(jīng)訓(xùn)練好的模型,并準(zhǔn)備輸入數(shù)據(jù)。然后使用模型對(duì)輸入數(shù)據(jù)進(jìn)行預(yù)測(cè),得到輸出結(jié)果。
以下是一個(gè)使用PyTorch預(yù)測(cè)模型的簡(jiǎn)單示例代碼:
import torch
import torch.nn as nn
import torch.optim as optim
# 定義一個(gè)簡(jiǎn)單的神經(jīng)網(wǎng)絡(luò)模型
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# 加載已經(jīng)訓(xùn)練好的模型
model = SimpleNN()
model.load_state_dict(torch.load('model.pth'))
model.eval()
# 準(zhǔn)備輸入數(shù)據(jù)
input_data = torch.randn(1, 10)
# 使用模型進(jìn)行預(yù)測(cè)
output = model(input_data)
print(output)
在上面的示例中,首先定義了一個(gè)簡(jiǎn)單的神經(jīng)網(wǎng)絡(luò)模型SimpleNN
,然后加載了已經(jīng)訓(xùn)練好的模型參數(shù)model.pth
。接著準(zhǔn)備輸入數(shù)據(jù)input_data
,最后使用模型對(duì)輸入數(shù)據(jù)進(jìn)行預(yù)測(cè),得到輸出結(jié)果output
。
需要注意的是,在預(yù)測(cè)時(shí)需要將模型設(shè)置為評(píng)估模式(model.eval()
),這可以確保在預(yù)測(cè)時(shí)不會(huì)影響模型的參數(shù)。