溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點(diǎn)擊 登錄注冊 即表示同意《億速云用戶服務(wù)條款》

pytorch如何使用加載訓(xùn)練好的模型做inference

發(fā)布時間:2021-06-24 09:23:36 來源:億速云 閱讀:239 作者:小新 欄目:開發(fā)技術(shù)

這篇文章主要介紹pytorch如何使用加載訓(xùn)練好的模型做inference,文中介紹的非常詳細(xì),具有一定的參考價值,感興趣的小伙伴們一定要看完!

1、 構(gòu)建模型(# load model graph)

model = MODEL()

2、加載模型參數(shù)(# load model state_dict)

 model.load_state_dict
 (
 {

 k.replace('module.',''):v for k,v in

 torch.load(config.model_path, map_location=config.device).items()

 }
 )
 
model = self.model.to(config.device)

* config.device 指定使用哪塊GPU或者CPU  

*k.replace('module.',''):v 防止torch.DataParallel訓(xùn)練的模型出現(xiàn)加載錯誤

(解決RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cuda:1問題)

3、設(shè)置當(dāng)前階段為inference(# predict)

model.eval()

以上是“pytorch如何使用加載訓(xùn)練好的模型做inference”這篇文章的所有內(nèi)容,感謝各位的閱讀!希望分享的內(nèi)容對大家有幫助,更多相關(guān)知識,歡迎關(guān)注億速云行業(yè)資訊頻道!

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報,并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI