溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶(hù)服務(wù)條款》

pytorch中nn.RNN()怎么使用

發(fā)布時(shí)間:2022-12-03 09:18:10 來(lái)源:億速云 閱讀:156 作者:iii 欄目:開(kāi)發(fā)技術(shù)

這篇文章主要介紹“pytorch中nn.RNN()怎么使用”,在日常操作中,相信很多人在pytorch中nn.RNN()怎么使用問(wèn)題上存在疑惑,小編查閱了各式資料,整理出簡(jiǎn)單好用的操作方法,希望對(duì)大家解答”pytorch中nn.RNN()怎么使用”的疑惑有所幫助!接下來(lái),請(qǐng)跟著小編一起來(lái)學(xué)習(xí)吧!

參數(shù)說(shuō)明

  • input_size輸入特征的維度, 一般rnn中輸入的是詞向量,那么 input_size 就等于一個(gè)詞向量的維度

  • hidden_size隱藏層神經(jīng)元個(gè)數(shù),或者也叫輸出的維度(因?yàn)閞nn輸出為各個(gè)時(shí)間步上的隱藏狀態(tài))

  • num_layers網(wǎng)絡(luò)的層數(shù)

  • nonlinearity激活函數(shù)

  • bias是否使用偏置

  • batch_first輸入數(shù)據(jù)的形式,默認(rèn)是 False,就是這樣形式,(seq(num_step), batch, input_dim),也就是將序列長(zhǎng)度放在第一位,batch 放在第二位

  • dropout是否應(yīng)用dropout, 默認(rèn)不使用,如若使用將其設(shè)置成一個(gè)0-1的數(shù)字即可

  • birdirectional是否使用雙向的 rnn,默認(rèn)是 False

  • 注意某些參數(shù)的默認(rèn)值在標(biāo)題中已注明

輸入輸出shape

  • input_shape = [時(shí)間步數(shù), 批量大小, 特征維度] = [num_steps(seq_length), batch_size, input_dim]

  • 在前向計(jì)算后會(huì)分別返回輸出和隱藏狀態(tài)h,其中輸出指的是隱藏層在各個(gè)時(shí)間步上計(jì)算并輸出的隱藏狀態(tài),它們通常作為后續(xù)輸出層的輸?。需要強(qiáng)調(diào)的是,該“輸出”本身并不涉及輸出層計(jì)算,形狀為(時(shí)間步數(shù), 批量大小, 隱藏單元個(gè)數(shù));隱藏狀態(tài)指的是隱藏層在最后時(shí)間步的隱藏狀態(tài):當(dāng)隱藏層有多層時(shí),每?層的隱藏狀態(tài)都會(huì)記錄在該變量中;對(duì)于像?短期記憶(LSTM),隱藏狀態(tài)是?個(gè)元組(h, c),即hidden state和cell state(此處普通rnn只有一個(gè)值)隱藏狀態(tài)h的形狀為(層數(shù), 批量大小,隱藏單元個(gè)數(shù))

代碼

rnn_layer = nn.RNN(input_size=vocab_size, hidden_size=num_hiddens, )
# 定義模型, 其中vocab_size = 1027, hidden_size = 256
num_steps = 35
batch_size = 2
state = None    # 初始隱藏層狀態(tài)可以不定義
X = torch.rand(num_steps, batch_size, vocab_size)
Y, state_new = rnn_layer(X, state)
print(Y.shape, len(state_new), state_new.shape)

輸出

torch.Size([35, 2, 256])     1       torch.Size([1, 2, 256])

具體計(jì)算過(guò)程
H t = i n p u t ∗ W x h + H t − 1 ∗ W h h + b i a s H_t = input * W_{xh} + H_{t-1} * W_{hh} + bias Ht=input∗Wxh+Ht−1∗Whh+bias
[batch_size, input_dim] * [input_dim, num_hiddens] + [batch_size, num_hiddens] *[num_hiddens, num_hiddens] +bias
可以發(fā)現(xiàn)每個(gè)隱藏狀態(tài)形狀都是[batch_size, num_hiddens], 起始輸出也是一樣的
注意:上面為了方便假設(shè)num_step=1

GRU/LSTM等參數(shù)同上面RNN

到此,關(guān)于“pytorch中nn.RNN()怎么使用”的學(xué)習(xí)就結(jié)束了,希望能夠解決大家的疑惑。理論與實(shí)踐的搭配能更好的幫助大家學(xué)習(xí),快去試試吧!若想繼續(xù)學(xué)習(xí)更多相關(guān)知識(shí),請(qǐng)繼續(xù)關(guān)注億速云網(wǎng)站,小編會(huì)繼續(xù)努力為大家?guī)?lái)更多實(shí)用的文章!

向AI問(wèn)一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI