溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶服務(wù)條款》

Pytorch之保存讀取模型實(shí)例

發(fā)布時(shí)間:2020-08-22 18:48:31 來(lái)源:腳本之家 閱讀:330 作者:嘖嘖嘖biubiu 欄目:開(kāi)發(fā)技術(shù)

pytorch保存數(shù)據(jù)

pytorch保存數(shù)據(jù)的格式為.t7文件或者.pth文件,t7文件是沿用torch7中讀取模型權(quán)重的方式。而pth文件是python中存儲(chǔ)文件的常用格式。而在keras中則是使用.h6文件。

# 保存模型示例代碼
print('===> Saving models...')
state = {
  'state': model.state_dict(),
  'epoch': epoch          # 將epoch一并保存
}
if not os.path.isdir('checkpoint'):
  os.mkdir('checkpoint')
torch.save(state, './checkpoint/autoencoder.t7')

保存用到torch.save函數(shù),注意該函數(shù)第一個(gè)參數(shù)可以是單個(gè)值也可以是字典,字典可以存更多你要保存的參數(shù)(不僅僅是權(quán)重?cái)?shù)據(jù))。

pytorch讀取數(shù)據(jù)

pytorch讀取數(shù)據(jù)使用的方法和我們平時(shí)使用預(yù)訓(xùn)練參數(shù)所用的方法是一樣的,都是使用load_state_dict這個(gè)函數(shù)。

下方的代碼和上方的保存代碼可以搭配使用。

print('===> Try resume from checkpoint')
if os.path.isdir('checkpoint'):
  try:
    checkpoint = torch.load('./checkpoint/autoencoder.t7')
    model.load_state_dict(checkpoint['state'])    # 從字典中依次讀取
    start_epoch = checkpoint['epoch']
    print('===> Load last checkpoint data')
  except FileNotFoundError:
    print('Can\'t found autoencoder.t7')
else:
  start_epoch = 0
  print('===> Start from scratch')

以上是pytorch讀取的方法匯總,但是要注意,在使用官方的預(yù)處理模型進(jìn)行讀取時(shí),一般使用的格式是pth,使用官方的模型讀取命令會(huì)檢查你模型的格式是否正確,如果不是使用官方提供模型通過(guò)下面的函數(shù)強(qiáng)行讀取模型(將其他模型例如caffe模型轉(zhuǎn)過(guò)來(lái)的模型放到指定目錄下)會(huì)發(fā)生錯(cuò)誤。

def vgg19(pretrained=False, **kwargs):
  """VGG 19-layer model (configuration "E")
 
  Args:
    pretrained (bool): If True, returns a model pre-trained on ImageNet
  """
  model = VGG(make_layers(cfg['E']), **kwargs)
  if pretrained:
    model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
  return model

假如我們有從caffe模型轉(zhuǎn)過(guò)來(lái)的pytorch模型([0-255,BGR]),我們可以使用:

model_dir = '自己的模型地址'
model = VGG()
model.load_state_dict(torch.load(model_dir + 'vgg_conv.pth'))

也就是pytorch的讀取函數(shù)進(jìn)行讀取即可。

以上這篇Pytorch之保存讀取模型實(shí)例就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持億速云。

向AI問(wèn)一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI