溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點(diǎn)擊 登錄注冊 即表示同意《億速云用戶服務(wù)條款》

怎么在pytorch中處理可變長度序列

發(fā)布時(shí)間:2021-05-22 15:59:19 來源:億速云 閱讀:320 作者:Leah 欄目:開發(fā)技術(shù)

這篇文章給大家介紹怎么在pytorch中處理可變長度序列,內(nèi)容非常詳細(xì),感興趣的小伙伴們可以參考借鑒,希望對大家能有所幫助。

1、torch.nn.utils.rnn.PackedSequence()

NOTE: 這個(gè)類的實(shí)例不能手動(dòng)創(chuàng)建。它們只能被 pack_padded_sequence() 實(shí)例化。

PackedSequence對象包括:

一個(gè)data對象:一個(gè)torch.Variable(令牌的總數(shù),每個(gè)令牌的維度),在這個(gè)簡單的例子中有五個(gè)令牌序列(用整數(shù)表示):(18,1)

一個(gè)batch_sizes對象:每個(gè)時(shí)間步長的令牌數(shù)列表,在這個(gè)例子中為:[6,5,2,4,1]

用pack_padded_sequence函數(shù)來構(gòu)造這個(gè)對象非常的簡單:

怎么在pytorch中處理可變長度序列

如何構(gòu)造一個(gè)PackedSequence對象(batch_first = True)

PackedSequence對象有一個(gè)很不錯(cuò)的特性,就是我們無需對序列解包(這一步操作非常慢)即可直接在PackedSequence數(shù)據(jù)變量上執(zhí)行許多操作。特別是我們可以對令牌執(zhí)行任何操作(即對令牌的順序/上下文不敏感)。當(dāng)然,我們也可以使用接受PackedSequence作為輸入的任何一個(gè)pyTorch模塊(pyTorch 0.2)。

2、torch.nn.utils.rnn.pack_padded_sequence()

這里的pack,理解成壓緊比較好。 將一個(gè) 填充過的變長序列 壓緊。(填充時(shí)候,會(huì)有冗余,所以壓緊一下)

輸入的形狀可以是(T×B×* )。T是最長序列長度,B是batch size,*代表任意維度(可以是0)。如果batch_first=True的話,那么相應(yīng)的 input size 就是 (B×T×*)。

Variable中保存的序列,應(yīng)該按序列長度的長短排序,長的在前,短的在后。即input[:,0]代表的是最長的序列,input[:, B-1]保存的是最短的序列。

NOTE: 只要是維度大于等于2的input都可以作為這個(gè)函數(shù)的參數(shù)。你可以用它來打包labels,然后用RNN的輸出和打包后的labels來計(jì)算loss。通過PackedSequence對象的.data屬性可以獲取 Variable。

參數(shù)說明:

input (Variable) – 變長序列 被填充后的 batch

lengths (list[int]) – Variable 中 每個(gè)序列的長度。

batch_first (bool, optional) – 如果是True,input的形狀應(yīng)該是B*T*size。

返回值:

一個(gè)PackedSequence 對象。

3、torch.nn.utils.rnn.pad_packed_sequence()

填充packed_sequence。

上面提到的函數(shù)的功能是將一個(gè)填充后的變長序列壓緊。 這個(gè)操作和pack_padded_sequence()是相反的。把壓緊的序列再填充回來。

返回的Varaible的值的size是 T×B×*, T 是最長序列的長度,B 是 batch_size,如果 batch_first=True,那么返回值是B×T×*。

Batch中的元素將會(huì)以它們長度的逆序排列。

參數(shù)說明:

sequence (PackedSequence) – 將要被填充的 batch

batch_first (bool, optional) – 如果為True,返回的數(shù)據(jù)的格式為 B×T×*。

返回值: 一個(gè)tuple,包含被填充后的序列,和batch中序列的長度列表。

例子:

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import utils as nn_utils
batch_size = 2
max_length = 3
hidden_size = 2
n_layers =1
 
tensor_in = torch.FloatTensor([[1, 2, 3], [1, 0, 0]]).resize_(2,3,1)
tensor_in = Variable( tensor_in ) #[batch, seq, feature], [2, 3, 1]
seq_lengths = [3,1] # list of integers holding information about the batch size at each sequence step
 
# pack it
pack = nn_utils.rnn.pack_padded_sequence(tensor_in, seq_lengths, batch_first=True)
 
# initialize
rnn = nn.RNN(1, hidden_size, n_layers, batch_first=True)
h0 = Variable(torch.randn(n_layers, batch_size, hidden_size))
 
#forward
out, _ = rnn(pack, h0)
 
# unpack
unpacked = nn_utils.rnn.pad_packed_sequence(out)
print('111',unpacked)

輸出:

111 (Variable containing:
(0 ,.,.) =
 0.5406 0.3584
 -0.1403 0.0308
 
(1 ,.,.) =
 -0.6855 -0.9307
 0.0000 0.0000
[torch.FloatTensor of size 2x2x2]
, [2, 1])

pytorch的優(yōu)點(diǎn)

1.PyTorch是相當(dāng)簡潔且高效快速的框架;2.設(shè)計(jì)追求最少的封裝;3.設(shè)計(jì)符合人類思維,它讓用戶盡可能地專注于實(shí)現(xiàn)自己的想法;4.與google的Tensorflow類似,F(xiàn)AIR的支持足以確保PyTorch獲得持續(xù)的開發(fā)更新;5.PyTorch作者親自維護(hù)的論壇 供用戶交流和求教問題6.入門簡單

關(guān)于怎么在pytorch中處理可變長度序列就分享到這里了,希望以上內(nèi)容可以對大家有一定的幫助,可以學(xué)到更多知識。如果覺得文章不錯(cuò),可以把它分享出去讓更多的人看到。

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI