要導(dǎo)出PyTorch模型,可以使用torch.save()
函數(shù)將模型參數(shù)保存到文件中。以下是一個簡單的示例:
import torch
import torch.nn as nn
# 定義一個簡單的模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
model = SimpleModel()
# 保存模型
torch.save(model.state_dict(), 'model.pth')
在上面的示例中,我們定義了一個簡單的模型SimpleModel
,然后使用torch.save()
函數(shù)將模型的參數(shù)保存到文件model.pth
中。要加載已保存的模型,可以使用torch.load()
函數(shù):
# 加載模型
model = SimpleModel()
model.load_state_dict(torch.load('model.pth'))
這樣就可以將模型導(dǎo)出和加載回來,繼續(xù)進(jìn)行訓(xùn)練或推斷。