您好,登錄后才能下訂單哦!
今天就跟大家聊聊有關(guān)怎么在C++中加載TorchScript模型,可能很多人都不太了解,為了讓大家更加了解,小編給大家總結(jié)了以下內(nèi)容,希望大家根據(jù)這篇文章可以有所收獲。
步驟1:將PyTorch模型轉(zhuǎn)換為Torch腳本
PyTorch模型從Python到C ++的旅程由Torch Script啟動,Torch Script是PyTorch模型的一種表示形式,可以由Torch Script編譯器理解,編譯和序列化。如果您是從使用vanilla“eager” API編寫的現(xiàn)有PyTorch模型開始的,則必須首先將模型轉(zhuǎn)換為Torch腳本。在最常見的情況下(如下所述),這只需要花費(fèi)很少的功夫。如果您已經(jīng)有了Torch腳本模塊,則可以跳到本教程的下一部分。
有兩種將PyTorch模型轉(zhuǎn)換為Torch腳本的方法。第一種稱為跟蹤,一種機(jī)制,其中通過使用示例輸入對模型的結(jié)構(gòu)進(jìn)行一次評估,并記錄這些輸入在模型中的流量,從而捕獲模型的結(jié)構(gòu)。這適用于有限使用控制流的模型。第二種方法是在模型中添加顯式批注,以告知Torch Script編譯器可以根據(jù)Torch Script語言施加的約束直接解析和編譯模型代碼。
提示:您可以在官方 Torch腳本參考 中找到有關(guān)這兩種方法的完整文檔,以及使用方法的進(jìn)一步指導(dǎo)。
方法1:通過跟蹤轉(zhuǎn)換為Torch腳本
要將PyTorch模型通過跟蹤轉(zhuǎn)換為Torch腳本,必須將模型的實(shí)例以及示例輸入傳遞給 torch.jit.trace 函數(shù)。這將產(chǎn)生一個 torch.jit.ScriptModule 對象,該對象的模型評估痕跡將嵌入模塊的 forward 方法中:
import torch import torchvision # 你模型的一個實(shí)例. model = torchvision.models.resnet18() # 您通常會提供給模型的forward()方法的示例輸入。 example = torch.rand(1, 3, 224, 224) # 使用`torch.jit.trace `來通過跟蹤生成`torch.jit.ScriptModule` traced_script_module = torch.jit.trace(model, example)
現(xiàn)在可以對跟蹤的 ScriptModule 進(jìn)行評估,使其與常規(guī)PyTorch模塊相同:
In[1]: output = traced_script_module(torch.ones(1, 3, 224, 224)) In[2]: output[0, :5] Out[2]: tensor([-0.2698, -0.0381, 0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)
方法2:通過注釋轉(zhuǎn)換為Torch腳本
在某些情況下,例如,如果模型采用特定形式的控制流,則可能需要直接在Torch腳本中編寫模型并相應(yīng)地注釋模型。例如,假設(shè)您具有以下vanilla Pytorch模型:
import torch class MyModule(torch.nn.Module): def __init__(self, N, M): super(MyModule, self).__init__() self.weight = torch.nn.Parameter(torch.rand(N, M)) def forward(self, input): if input.sum() > 0: output = self.weight.mv(input) else: output = self.weight + input return output
因?yàn)榇四K的前向方法使用取決于輸入的控制流,所以它不適合跟蹤。相反,我們可以將其轉(zhuǎn)換為 ScriptModule 。為了將模塊轉(zhuǎn)換為 ScriptModule ,需要使用 torch.jit.script 編譯模塊,如下所示:
class MyModule(torch.nn.Module): def __init__(self, N, M): super(MyModule, self).__init__() self.weight = torch.nn.Parameter(torch.rand(N, M)) def forward(self, input): if input.sum() > 0: output = self.weight.mv(input) else: output = self.weight + input return output my_module = MyModule(10,20) sm = torch.jit.script(my_module)
如果您需要在 nn.Module 中排除某些方法,因?yàn)樗鼈兪褂昧?nbsp;TorchScript 尚不支持的Python功能,則可以使用 @torch.jit.ignore 對其進(jìn)行注釋
my_module 是 ScriptModule 的實(shí)例,可以序列化。
步驟2:將腳本模塊序列化為文件
一旦有了ScriptModule(通過跟蹤或注釋PyTorch模型),您就可以將其序列化為文件了。稍后,您將可以使用C ++從此文件加載模塊并執(zhí)行它,而無需依賴Python。假設(shè)我們要序列化先前在跟蹤示例中顯示的 ResNet18 模型。要執(zhí)行此序列化,只需在模塊上調(diào)用 save 并傳遞一個文件名即可:
traced_script_module.save("traced_resnet_model.pt")
這將在您的工作目錄中生成 traced_resnet_model.pt 文件。如果您還想序列化 my_module ,請調(diào)用 my_module.save("my_module_model.pt") 我們現(xiàn)在已經(jīng)正式離開Python領(lǐng)域,并準(zhǔn)備跨入C ++領(lǐng)域。
步驟3:在C ++中加載腳本模塊
要在C ++中加載序列化的PyTorch模型,您的應(yīng)用程序必須依賴于PyTorch C ++ API(也稱為LibTorch)。LibTorch發(fā)行版包含共享庫,頭文件和CMake構(gòu)建配置文件的集合。雖然CMake不是依賴LibTorch的要求,但它是推薦的方法,并且將來會得到很好的支持。 對于本教程,我們將使用CMake和LibTorch構(gòu)建一個最小的C ++應(yīng)用程序,該應(yīng)用程序簡單地加載并執(zhí)行序列化的PyTorch模型。
最小的C ++應(yīng)用程序
讓我們從討論加載模塊的代碼開始。以下將已經(jīng)做:
include <torch/script.h> // One-stop header. #include <iostream> #include <memory> int main(int argc, const char* argv[]) { if (argc != 2) { std::cerr << "usage: example-app <path-to-exported-script-module>\n"; return -1; } torch::jit::script::Module module; try { // 使用以下命令從文件中反序列化腳本模塊: torch::jit::load(). module = torch::jit::load(argv[1]); } catch (const c10::Error& e) { std::cerr << "error loading the model\n"; return -1; } std::cout << "ok\n"; }
標(biāo)頭包含運(yùn)行示例所需的LibTorch庫中的所有相關(guān)包含。我們的應(yīng)用程序接受序列化的PyTorch ScriptModule的文件路徑作為其唯一的命令行參數(shù),然后使用 torch::jit::load() 函數(shù)繼續(xù)對該模塊進(jìn)行反序列化,該函數(shù)將此文件路徑作為輸入。作為返回,我們收到一個 Torch::jit::script::Module 對象。我們將稍后討論如何執(zhí)行它。
取決于LibTorch和構(gòu)建應(yīng)用程序
假設(shè)我們將以上代碼存儲在名為 example-app.cpp 的文件中。最小的 CMakeLists.txt 可能看起來很簡單:
cmake_minimum_required(VERSION 3.0 FATAL_ERROR) project(custom_ops) find_package(Torch REQUIRED) add_executable(example-app example-app.cpp) target_link_libraries(example-app "${TORCH_LIBRARIES}") set_property(TARGET example-app PROPERTY CXX_STANDARD 11)
建立示例應(yīng)用程序的最后一件事是LibTorch發(fā)行版。您可以隨時從PyTorch網(wǎng)站的 下載頁面 上獲取最新的穩(wěn)定版本。如果下載并解壓縮最新的歸檔文件,則應(yīng)收到具有以下目錄結(jié)構(gòu)的文件夾:
libtorch/ bin/ include/ lib/ share/
find_package(Torch)
提示;在Windows上,調(diào)試和發(fā)行版本不兼容ABI。 如果您打算以調(diào)試模式構(gòu)建項目,請嘗試使用LibTorch的調(diào)試版本。
最后一步是構(gòu)建應(yīng)用程序。為此,假定示例目錄的布局如下:
example-app/ CMakeLists.txt example-app.cpp
現(xiàn)在,我們可以運(yùn)行以下命令從 example-app/ 文件夾中構(gòu)建應(yīng)用程序:
mkdir build cd build cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch .. make
/path/to/libtorch 應(yīng)該是解壓縮的LibTorch發(fā)行版的完整路徑。如果一切順利,它將看起來像這樣:
root@4b5a67132e81:/example-app# mkdir build root@4b5a67132e81:/example-app# cd build root@4b5a67132e81:/example-app/build# cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch .. -- The C compiler identification is GNU 5.4.0 -- The CXX compiler identification is GNU 5.4.0 -- Check for working C compiler: /usr/bin/cc -- Check for working C compiler: /usr/bin/cc -- works -- Detecting C compiler ABI info -- Detecting C compiler ABI info - done -- Detecting C compile features -- Detecting C compile features - done -- Check for working CXX compiler: /usr/bin/c++ -- Check for working CXX compiler: /usr/bin/c++ -- works -- Detecting CXX compiler ABI info -- Detecting CXX compiler ABI info - done -- Detecting CXX compile features -- Detecting CXX compile features - done -- Looking for pthread.h -- Looking for pthread.h - found -- Looking for pthread_create -- Looking for pthread_create - not found -- Looking for pthread_create in pthreads -- Looking for pthread_create in pthreads - not found -- Looking for pthread_create in pthread -- Looking for pthread_create in pthread - found -- Found Threads: TRUE -- Configuring done -- Generating done -- Build files have been written to: /example-app/build root@4b5a67132e81:/example-app/build# make Scanning dependencies of target example-app [ 50%] Building CXX object CMakeFiles/example-app.dir/example-app.cpp.o [100%] Linking CXX executable example-app [100%] Built target example-app
如果我們提供了我們之前創(chuàng)建的到示例應(yīng)用程序二進(jìn)制文件的跟蹤ResNet18模型 traced_resnet_model.pt 的路徑,則應(yīng)該以友好的“ ok”作為獎勵。 請注意,如果嘗試使用 my_module_model.pt 運(yùn)行此示例,則會收到一條錯誤消息,提示您輸入的形狀不兼容。 my_module_model.pt 需要1D而不是4D。
root@4b5a67132e81:/example-app/build# ./example-app <path_to_model>/traced_resnet_model.pt ok
步驟4:在C ++中執(zhí)行腳本模塊
成功用C ++加載了序列化的ResNet18之后,我們現(xiàn)在只需執(zhí)行幾行代碼即可!讓我們將這些行添加到C ++應(yīng)用程序的 main() 函數(shù)中:
// 創(chuàng)建輸入向量 std::vector<torch::jit::IValue> inputs; inputs.push_back(torch::ones({1, 3, 224, 224})); // 執(zhí)行模型并將輸出轉(zhuǎn)化為張量 at::Tensor output = module.forward(inputs).toTensor(); std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';
前兩行設(shè)置了我們模型的輸入。我們創(chuàng)建一個 torch::jit::IValue 的向量(類型為type-erased的值 Script::Module 方法接受并返回),并添加單個輸入。要創(chuàng)建輸入張量,我們使用 torch::ones() ,等效于C ++ API中的 torch.ones 。然后,我們運(yùn)行 script::Module 的 forward 方法,并向其傳遞我們創(chuàng)建的輸入向量。作為回報,我們得到一個新的IValue,通過調(diào)用 toTensor() 將其轉(zhuǎn)換為張量。
提示:要總體上了解有關(guān)torch::ones和PyTorch C ++ API之類的功能的更多信息,請參閱其文檔,網(wǎng)址為https://pytorch.org/cppdocs。
PyTorch C ++ API提供了與Python API幾乎相同的功能奇偶校驗(yàn),使您可以像在Python中一樣進(jìn)一步操縱和處理張量。
在最后一行中,我們打印輸出的前五個條目。由于在本教程前面的部分中,我們向Python中的模型提供了相同的輸入,因此理想情況下,我們應(yīng)該看到相同的輸出。讓我們通過重新編譯我們的應(yīng)用程序并以相同的序列化模型運(yùn)行它來進(jìn)行嘗試:
root@4b5a67132e81:/example-app/build# make Scanning dependencies of target example-app [ 50%] Building CXX object CMakeFiles/example-app.dir/example-app.cpp.o [100%] Linking CXX executable example-app [100%] Built target example-app root@4b5a67132e81:/example-app/build# ./example-app traced_resnet_model.pt -0.2698 -0.0381 0.4023 -0.3010 -0.0448 [ Variable[CPUFloatType]{1,5} ]
作為參考,Python以前的輸出為:
tensor([-0.2698, -0.0381, 0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)
看來匹配得很好!
提示:要將模型移至GPU內(nèi)存,可以編寫model.to(at::kCUDA);。通過調(diào)用tensor.to(at::kCUDA),確保模型的輸入也位于CUDA內(nèi)存中,
這將在CUDA內(nèi)存中返回新的張量。
看完上述內(nèi)容,你們對怎么在C++中加載TorchScript模型有進(jìn)一步的了解嗎?如果還想了解更多知識或者相關(guān)內(nèi)容,請關(guān)注億速云行業(yè)資訊頻道,感謝大家的支持。
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報,并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。