在PyTorch中,torch.load()
函數(shù)用于加載保存的模型或張量。其基本語法如下:
torch.load(filepath, map_location=None, pickle_module=<module 'pickle' from '...'>)
filepath
是保存模型或張量的文件路徑。map_location
是一個(gè)可選參數(shù),用于指定設(shè)備將模型/張量加載到哪個(gè)位置。可以是一個(gè)字符串,表示設(shè)備名稱(如’cpu’、'cuda:0’等),也可以是一個(gè)torch.device對(duì)象。默認(rèn)值為None,表示加載到與保存時(shí)設(shè)備相同的位置。pickle_module
是一個(gè)可選參數(shù),用于覆蓋默認(rèn)的pickle模塊。默認(rèn)值為Python內(nèi)置的pickle模塊。以下是torch.load()
函數(shù)的使用示例:
import torch
# 加載保存的模型
model = torch.load('model.pth')
# 加載保存的張量
tensor = torch.load('tensor.pt')
# 加載保存的模型,并將其加載到指定設(shè)備上
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = torch.load('model.pth', map_location=device)
# 加載保存的模型,使用自定義的pickle模塊
import pickle5 as pickle
model = torch.load('model.pth', pickle_module=pickle)
注意,torch.load()
函數(shù)只能加載在相同版本的PyTorch中保存的模型或張量。如果模型或張量是在不同版本的PyTorch中保存的,則需要使用其他方法進(jìn)行轉(zhuǎn)換或加載。