溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶服務(wù)條款》

Pytorch如何實(shí)現(xiàn)常用乘法算子TensorRT

發(fā)布時(shí)間:2022-06-02 10:06:55 來(lái)源:億速云 閱讀:278 作者:iii 欄目:開(kāi)發(fā)技術(shù)

這篇文章主要介紹了Pytorch如何實(shí)現(xiàn)常用乘法算子TensorRT的相關(guān)知識(shí),內(nèi)容詳細(xì)易懂,操作簡(jiǎn)單快捷,具有一定借鑒價(jià)值,相信大家閱讀完這篇Pytorch如何實(shí)現(xiàn)常用乘法算子TensorRT文章都會(huì)有所收獲,下面我們一起來(lái)看看吧。

1.乘法運(yùn)算總覽

先把 pytorch 中的一些常用的乘法運(yùn)算進(jìn)行一個(gè)總覽:

  • torch.mm:用于兩個(gè)矩陣 (不包括向量) 的乘法,如維度 (m, n) 的矩陣乘以維度 (n, p) 的矩陣;

  • torch.bmm:用于帶 batch 的三維向量的乘法,如維度 (b, m, n) 的矩陣乘以維度 (b, n, p) 的矩陣;

  • torch.mul:用于同維度矩陣的逐像素點(diǎn)相乘,也即點(diǎn)乘,如維度 (m, n) 的矩陣點(diǎn)乘維度 (m, n) 的矩陣。該方法支持廣播,也即支持矩陣和元素點(diǎn)乘;

  • torch.mv:用于矩陣和向量的乘法,矩陣在前,向量在后,如維度 (m, n) 的矩陣乘以維度為 (n) 的向量,輸出維度為 (m);

  • torch.matmul:用于兩個(gè)張量相乘,或矩陣與向量乘法,作用包含 torch.mm、torch.bmm、torch.mv;

  • @:作用相當(dāng)于 torch.matmul;

  • *:作用相當(dāng)于 torch.mul;

如上進(jìn)行了一些具體羅列,可以歸納出,常用的乘法無(wú)非兩種:矩陣乘 和 點(diǎn)乘,所以下面分這兩類進(jìn)行介紹。

2.乘法算子實(shí)現(xiàn)

2.1矩陣乘算子實(shí)現(xiàn)

先來(lái)看看矩陣乘法的 pytorch 的實(shí)現(xiàn) (以下實(shí)現(xiàn)在終端):

>>> import torch
>>> # torch.mm
>>> a = torch.randn(66, 99)
>>> b = torch.randn(99, 88)
>>> c = torch.mm(a, b)
>>> c.shape
torch.size([66, 88])
>>>
>>> # torch.bmm
>>> a = torch.randn(3, 66, 99)
>>> b = torch.randn(3, 99, 77)
>>> c = torch.bmm(a, b)
>>> c.shape
torch.size([3, 66, 77])
>>>
>>> # torch.mv
>>> a = torch.randn(66, 99)
>>> b = torch.randn(99)
>>> c = torch.mv(a, b)
>>> c.shape
torch.size([66])
>>>
>>> # torch.matmul
>>> a = torch.randn(32, 3, 66, 99)
>>> b = torch.randn(32, 3, 99, 55)
>>> c = torch.matmul(a, b)
>>> c.shape
torch.size([32, 3, 66, 55])
>>>
>>> # @
>>> d = a @ b
>>> d.shape
torch.size([32, 3, 66, 55])

來(lái)看 TensorRT 的實(shí)現(xiàn),以上乘法都可使用 addMatrixMultiply 方法覆蓋,對(duì)應(yīng) torch.matmul,先來(lái)看該方法的定義:

//!
//! \brief Add a MatrixMultiply layer to the network.
//!
//! \param input0 The first input tensor (commonly A).
//! \param op0 The operation to apply to input0.
//! \param input1 The second input tensor (commonly B).
//! \param op1 The operation to apply to input1.
//!
//! \see IMatrixMultiplyLayer
//!
//! \warning Int32 tensors are not valid input tensors.
//!
//! \return The new matrix multiply layer, or nullptr if it could not be created.
//!
IMatrixMultiplyLayer* addMatrixMultiply(
  ITensor& input0, MatrixOperation op0, ITensor& input1, MatrixOperation op1) noexcept
{
  return mImpl->addMatrixMultiply(input0, op0, input1, op1);
}

可以看到這個(gè)方法有四個(gè)傳參,對(duì)應(yīng)兩個(gè)張量和其 operation。來(lái)看這個(gè)算子在 TensorRT 中怎么添加:

// 構(gòu)造張量 Tensor0
nvinfer1::IConstantLayer *Constant_layer0 = m_network->addConstant(tensorShape0, value0);
// 構(gòu)造張量 Tensor1
nvinfer1::IConstantLayer *Constant_layer1 = m_network->addConstant(tensorShape1, value1);

// 添加矩陣乘法
nvinfer1::IMatrixMultiplyLayer *Matmul_layer = m_network->addMatrixMultiply(Constant_layer0->getOutput(0), matrix0Type, Constant_layer1->getOutput(0), matrix2Type);

// 獲取輸出
matmulOutput = Matmul_layer->getOputput(0);

2.2點(diǎn)乘算子實(shí)現(xiàn)

再來(lái)看看點(diǎn)乘的 pytorch 的實(shí)現(xiàn) (以下實(shí)現(xiàn)在終端):

>>> import torch
>>> # torch.mul
>>> a = torch.randn(66, 99)
>>> b = torch.randn(66, 99)
>>> c = torch.mul(a, b)
>>> c.shape
torch.size([66, 99])
>>> d = 0.125
>>> e = torch.mul(a, d)
>>> e.shape
torch.size([66, 99])
>>> # *
>>> f = a * b
>>> f.shape
torch.size([66, 99])

來(lái)看 TensorRT 的實(shí)現(xiàn),以上乘法都可使用 addScale 方法覆蓋,這在圖像預(yù)處理中十分常用,先來(lái)看該方法的定義:

//!
//! \brief Add a Scale layer to the network.
//!
//! \param input The input tensor to the layer.
//!              This tensor is required to have a minimum of 3 dimensions in implicit batch mode
//!              and a minimum of 4 dimensions in explicit batch mode.
//! \param mode The scaling mode.
//! \param shift The shift value.
//! \param scale The scale value.
//! \param power The power value.
//!
//! If the weights are available, then the size of weights are dependent on the ScaleMode.
//! For ::kUNIFORM, the number of weights equals 1.
//! For ::kCHANNEL, the number of weights equals the channel dimension.
//! For ::kELEMENTWISE, the number of weights equals the product of the last three dimensions of the input.
//!
//! \see addScaleNd
//! \see IScaleLayer
//! \warning Int32 tensors are not valid input tensors.
//!
//! \return The new Scale layer, or nullptr if it could not be created.
//!
IScaleLayer* addScale(ITensor& input, ScaleMode mode, Weights shift, Weights scale, Weights power) noexcept
{
  return mImpl->addScale(input, mode, shift, scale, power);
}

 可以看到有三個(gè)模式:

  • kUNIFORM:weights 為一個(gè)值,對(duì)應(yīng)張量乘一個(gè)元素;

  • kCHANNEL:weights 維度和輸入張量通道的 c 維度對(duì)應(yīng),可以做一些以通道為基準(zhǔn)的預(yù)處理;

  • kELEMENTWISE:weights 維度和輸入張量的 c、h、w 對(duì)應(yīng),不考慮 batch,所以是輸入的后三維;

再來(lái)看這個(gè)算子在 TensorRT 中怎么添加:

// 構(gòu)造張量 input
nvinfer1::IConstantLayer *Constant_layer = m_network->addConstant(tensorShape, value);

// scalemode選擇,kUNIFORM、kCHANNEL、kELEMENTWISE
scalemode = kUNIFORM;

// 構(gòu)建 Weights 類型的 shift、scale、power,其中 volume 為元素?cái)?shù)量
nvinfer1::Weights scaleShift{nvinfer1::DataType::kFLOAT, nullptr, volume };
nvinfer1::Weights scaleScale{nvinfer1::DataType::kFLOAT, nullptr, volume };
nvinfer1::Weights scalePower{nvinfer1::DataType::kFLOAT, nullptr, volume };

// !! 注意這里還需要對(duì) shift、scale、power 的 values 進(jìn)行賦值,若只是乘法只需要對(duì) scale 進(jìn)行賦值就行

// 添加張量乘法
nvinfer1::IScaleLayer *Scale_layer = m_network->addScale(Constant_layer->getOutput(0), scalemode, scaleShift, scaleScale, scalePower);

// 獲取輸出
scaleOutput = Scale_layer->getOputput(0);

有一點(diǎn)你可能會(huì)比較疑惑,既然是點(diǎn)乘,那么輸入只需要兩個(gè)張量就可以了,為啥這里有 input、shift、scale、power 四個(gè)張量這么多呢。解釋一下,input 不用說(shuō),就是輸入張量,而 shift 表示加法參數(shù)、scale 表示乘法參數(shù)、power 表示指數(shù)參數(shù),說(shuō)到這里,你應(yīng)該能發(fā)現(xiàn),這個(gè)函數(shù)除了我們上面講的點(diǎn)乘外還有其他更加豐富的運(yùn)算功能。

關(guān)于“Pytorch如何實(shí)現(xiàn)常用乘法算子TensorRT”這篇文章的內(nèi)容就介紹到這里,感謝各位的閱讀!相信大家對(duì)“Pytorch如何實(shí)現(xiàn)常用乘法算子TensorRT”知識(shí)都有一定的了解,大家如果還想學(xué)習(xí)更多知識(shí),歡迎關(guān)注億速云行業(yè)資訊頻道。

向AI問(wèn)一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI