怎么使用pytorch預(yù)測(cè)模型

小億
143
2024-04-11 12:23:11

要使用PyTorch來預(yù)測(cè)模型,首先需要加載已經(jīng)訓(xùn)練好的模型,并準(zhǔn)備輸入數(shù)據(jù)。然后使用模型對(duì)輸入數(shù)據(jù)進(jìn)行預(yù)測(cè),得到輸出結(jié)果。

以下是一個(gè)使用PyTorch預(yù)測(cè)模型的簡(jiǎn)單示例代碼:

import torch
import torch.nn as nn
import torch.optim as optim

# 定義一個(gè)簡(jiǎn)單的神經(jīng)網(wǎng)絡(luò)模型
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 加載已經(jīng)訓(xùn)練好的模型
model = SimpleNN()
model.load_state_dict(torch.load('model.pth'))
model.eval()

# 準(zhǔn)備輸入數(shù)據(jù)
input_data = torch.randn(1, 10)

# 使用模型進(jìn)行預(yù)測(cè)
output = model(input_data)

print(output)

在上面的示例中,首先定義了一個(gè)簡(jiǎn)單的神經(jīng)網(wǎng)絡(luò)模型SimpleNN,然后加載了已經(jīng)訓(xùn)練好的模型參數(shù)model.pth。接著準(zhǔn)備輸入數(shù)據(jù)input_data,最后使用模型對(duì)輸入數(shù)據(jù)進(jìn)行預(yù)測(cè),得到輸出結(jié)果output。

需要注意的是,在預(yù)測(cè)時(shí)需要將模型設(shè)置為評(píng)估模式(model.eval()),這可以確保在預(yù)測(cè)時(shí)不會(huì)影響模型的參數(shù)。

0