溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶服務(wù)條款》

pytorch模型怎么轉(zhuǎn)onnx模型

發(fā)布時(shí)間:2022-08-30 14:11:48 來源:億速云 閱讀:168 作者:iii 欄目:開發(fā)技術(shù)

這篇文章主要介紹“pytorch模型怎么轉(zhuǎn)onnx模型”,在日常操作中,相信很多人在pytorch模型怎么轉(zhuǎn)onnx模型問題上存在疑惑,小編查閱了各式資料,整理出簡(jiǎn)單好用的操作方法,希望對(duì)大家解答”pytorch模型怎么轉(zhuǎn)onnx模型”的疑惑有所幫助!接下來,請(qǐng)跟著小編一起來學(xué)習(xí)吧!

    學(xué)習(xí)內(nèi)容

    前提條件:需要安裝onnx 和 onnxruntime,可以通過 pip install onnx 和 pip install onnxruntime 進(jìn)行安裝

    1 . pytorch 轉(zhuǎn) onnx

    pytorch 轉(zhuǎn) onnx 只需要一個(gè)函數(shù) torch.onnx.export

    torch.onnx.export(model, args, path, export_params, verbose, input_names, output_names, do_constant_folding, dynamic_axes, opset_version)

    參數(shù)說明:

    • model——需要導(dǎo)出的pytorch模型

    • args——模型的輸入?yún)?shù),滿足輸入層的shape正確即可。

    • path——輸出的onnx模型的位置。例如‘yolov5.onnx’。

    • export_params——輸出模型是否可訓(xùn)練。default=True,表示導(dǎo)出trained model,否則untrained。

    • verbose——是否打印模型轉(zhuǎn)換信息。default=False。

    • input_names——輸入節(jié)點(diǎn)名稱。default=None。

    • output_names——輸出節(jié)點(diǎn)名稱。default=None。

    • do_constant_folding——是否使用常量折疊(不了解),默認(rèn)即可。default=True。

    • dynamic_axes——模型的輸入輸出有時(shí)是可變的,如Rnn,或者輸出圖像的batch可變,可通過該參數(shù)設(shè)置。如輸入層的shape為(b,3,h,w),batch,height,width是可變的,但是chancel是固定三通道。
      格式如下 :
      1)僅list(int) dynamic_axes={‘input’:[0,2,3],‘output’:[0,1]}
      2)僅dict<int, string> dynamic_axes={&lsquo;input&rsquo;:{0:&lsquo;batch&rsquo;,2:&lsquo;height&rsquo;,3:&lsquo;width&rsquo;},&lsquo;output&rsquo;:{0:&lsquo;batch&rsquo;,1:&lsquo;c&rsquo;}}
      3)mixed dynamic_axes={&lsquo;input&rsquo;:{0:&lsquo;batch&rsquo;,2:&lsquo;height&rsquo;,3:&lsquo;width&rsquo;},&lsquo;output&rsquo;:[0,1]}

    • opset_version&mdash;&mdash;opset的版本,低版本不支持upsample等操作。

    import torch
    import torch.nn
    import onnx
    
    model = torch.load('best.pt')
    model.eval()
    
    input_names = ['input']
    output_names = ['output']
    
    x = torch.randn(1,3,32,32,requires_grad=True)
    
    torch.onnx.export(model, x, 'best.onnx', input_names=input_names, output_names=output_names, verbose='True')

    2 . 運(yùn)行onnx模型

    檢查onnx模型,并使用onnxruntime運(yùn)行。

    import onnx
    import onnxruntime as ort
    
    model = onnx.load('best.onnx')
    onnx.checker.check_model(model)
    
    session = ort.InferenceSession('best.onnx')
    x=np.random.randn(1,3,32,32).astype(np.float32)  # 注意輸入type一定要np.float32!!!!!
    # x= torch.randn(batch_size,chancel,h,w)
    
    
    outputs = session.run(None,input = { 'input' : x })

    參數(shù)說明:

    • output_names: default=None
      用來指定輸出哪些,以及順序
      若為None,則按序輸出所有的output,即返回[output_0,output_1]
      若為[&lsquo;output_1&rsquo;,&lsquo;output_0&rsquo;],則返回[output_1,output_0]
      若為[&lsquo;output_0&rsquo;],則僅返回[output_0:tensor]

    • input:dict
      可以通過session.get_inputs().name獲得名稱
      其中key值要求與torch.onnx.export中設(shè)定的一致

    3.onnx模型輸出與pytorch模型比對(duì)

    import numpy as np
    np.testing.assert_allclose(torch_result[0].detach().numpu(),onnx_result,rtol=0.0001)

    如前所述,經(jīng)驗(yàn)表明,ONNX 模型的運(yùn)行效率明顯優(yōu)于原 PyTorch 模型,這似乎是源于 ONNX 模型生成過程中的優(yōu)化,這也導(dǎo)致了模型的生成過程比較耗時(shí),但整體效率依舊可觀。

    此外,根據(jù)對(duì) ONNX 模型和 PyTorch 模型運(yùn)行結(jié)果的統(tǒng)計(jì)分析(誤差的均值和標(biāo)準(zhǔn)差),可以看出 ONNX 模型的運(yùn)行結(jié)果誤差很小、基本可靠。

    到此,關(guān)于“pytorch模型怎么轉(zhuǎn)onnx模型”的學(xué)習(xí)就結(jié)束了,希望能夠解決大家的疑惑。理論與實(shí)踐的搭配能更好的幫助大家學(xué)習(xí),快去試試吧!若想繼續(xù)學(xué)習(xí)更多相關(guān)知識(shí),請(qǐng)繼續(xù)關(guān)注億速云網(wǎng)站,小編會(huì)繼續(xù)努力為大家?guī)砀鄬?shí)用的文章!

    向AI問一下細(xì)節(jié)

    免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

    AI