溫馨提示×

PyTorch中怎么導(dǎo)出模型

小億
127
2024-05-10 19:20:57

要導(dǎo)出PyTorch模型,可以使用torch.save()函數(shù)將模型參數(shù)保存到文件中。以下是一個簡單的示例:

import torch
import torch.nn as nn

# 定義一個簡單的模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

model = SimpleModel()

# 保存模型
torch.save(model.state_dict(), 'model.pth')

在上面的示例中,我們定義了一個簡單的模型SimpleModel,然后使用torch.save()函數(shù)將模型的參數(shù)保存到文件model.pth中。要加載已保存的模型,可以使用torch.load()函數(shù):

# 加載模型
model = SimpleModel()
model.load_state_dict(torch.load('model.pth'))

這樣就可以將模型導(dǎo)出和加載回來,繼續(xù)進(jìn)行訓(xùn)練或推斷。

0