c++如何調(diào)用pytorch模型

小億
190
2024-03-28 15:05:23

要在C++中調(diào)用PyTorch模型,可以使用LibTorch庫(kù)。以下是一個(gè)簡(jiǎn)單的示例代碼,演示了如何加載一個(gè)PyTorch模型并使用輸入數(shù)據(jù)進(jìn)行推理:

#include <torch/torch.h>
#include <iostream>

int main() {
    // 加載模型
    torch::jit::script::Module module;
    try {
        module = torch::jit::load("path/to/model.pt");
    } catch (const c10::Error& e) {
        std::cerr << "Error loading the model\n";
        return -1;
    }

    // 準(zhǔn)備輸入數(shù)據(jù)
    torch::Tensor input = torch::ones({1, 3, 224, 224});  // 示例輸入數(shù)據(jù)

    // 運(yùn)行推理
    at::Tensor output = module.forward({input}).toTensor();

    // 輸出結(jié)果
    std::cout << "Output tensor: " << output << std::endl;

    return 0;
}

在這個(gè)示例中,首先加載了一個(gè)PyTorch模型(假設(shè)模型保存在model.pt文件中)。然后創(chuàng)建了一個(gè)示例輸入張量input,并將其傳遞給模型進(jìn)行推理。最后,輸出了模型的輸出張量。

請(qǐng)注意,為了能夠編譯這段代碼,需要在項(xiàng)目中鏈接LibTorch庫(kù)并設(shè)置正確的包含路徑。更多關(guān)于LibTorch的用法和配置信息,請(qǐng)參考PyTorch官方文檔。

0