溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務(wù)條款》

怎么理解BiLSTM和CRF算法

發(fā)布時間:2021-11-02 17:11:57 來源:億速云 閱讀:533 作者:iii 欄目:web開發(fā)

本篇內(nèi)容介紹了“怎么理解BiLSTM和CRF算法”的有關(guān)知識,在實際案例的操作過程中,不少人都會遇到這樣的困境,接下來就讓小編帶領(lǐng)大家學(xué)習(xí)一下如何處理這些情況吧!希望大家仔細(xì)閱讀,能夠?qū)W有所成!

1.前言

給定一個句子 "什么是地攤經(jīng)濟",其正確的分詞方式是 "什么 / 是 / 地攤 / 經(jīng)濟",每個字對應(yīng)的分詞標(biāo)簽是 "be / s / be /  be"。從下面的圖片可以看出 LSTM 在做序列標(biāo)注時的問題。

怎么理解BiLSTM和CRF算法

BiLSTM 分詞

BiLSTM 可以預(yù)測出每一個字屬于不同標(biāo)簽的概率,然后使用 Softmax  得到概率最大的標(biāo)簽,作為該位置的預(yù)測值。這樣在預(yù)測的時候會忽略了標(biāo)簽之間的關(guān)聯(lián)性,如上圖中 BiLSTM 把第一個詞預(yù)測成 s,把第二個詞預(yù)測成  e。但是實際上在分詞時 s 后面是不會出現(xiàn) e 的,因此 BiLSTM 沒有考慮標(biāo)簽間聯(lián)系。

因此 BiLSTM+CRF 在 BiLSTM 的輸出層加上一個 CRF,使得模型可以考慮類標(biāo)之間的相關(guān)性,標(biāo)簽之間的相關(guān)性就是 CRF  中的轉(zhuǎn)移矩陣,表示從一個狀態(tài)轉(zhuǎn)移到另一個狀態(tài)的概率。假設(shè) CRF 的轉(zhuǎn)移矩陣如下圖所示。

怎么理解BiLSTM和CRF算法

CRF 狀態(tài)轉(zhuǎn)移矩陣

則對于前兩個字 "什么",其標(biāo)簽為 "se" 的概率 =0.8×0×0.7=0,而標(biāo)簽為 "be" 的概率=0.6×0.5×0.7=0.21。

因此,BiLSTM+CRF 考慮的是整個類標(biāo)路徑的概率而不僅僅是單個類標(biāo)的概率,在 BiLSTM 輸出層加上 CRF 后,如下所示。

怎么理解BiLSTM和CRF算法

BiLSTM+CRF 分詞

最終算得所有路徑中,besbebe 的概率最大,因此預(yù)測結(jié)果為 besbebe。

2.BiLSTM+CRF 模型

CRF 包括兩種特征函數(shù),不熟悉的童鞋可以看下之前的文章。第一種特征函數(shù)是狀態(tài)特征函數(shù),也稱為發(fā)射概率,表示字 x 對應(yīng)標(biāo)簽 y 的概率。

怎么理解BiLSTM和CRF算法

CRF 狀態(tài)特征函數(shù)

在 BiLSTM+CRF 中,這一個特征函數(shù) (發(fā)射概率) 直接使用 LSTM 的輸出計算得到,如第一小節(jié)中的圖所示,LSTM  可以計算出每一時刻位置對應(yīng)不同標(biāo)簽的概率。

CRF 的第二個特征函數(shù)是狀態(tài)轉(zhuǎn)移特征函數(shù),表示從一個狀態(tài) y1 轉(zhuǎn)移到另一個狀態(tài) y2 的概率。

怎么理解BiLSTM和CRF算法

CRF 狀態(tài)轉(zhuǎn)移特征函數(shù)

CRF 的狀態(tài)轉(zhuǎn)移特征函數(shù)可以用一個狀態(tài)轉(zhuǎn)移矩陣表示,在訓(xùn)練時需要調(diào)整狀態(tài)轉(zhuǎn)移矩陣的元素值。因此 BiLSTM+CRF 需要在 BiLSTM  的模型內(nèi)增加一個狀態(tài)轉(zhuǎn)移矩陣。在代碼中如下。

class BiLSTM_CRF(nn.Module):     def __init__(self, vocab_size, tag2idx, embedding_dim, hidden_dim):         self.word_embeds = nn.Embedding(vocab_size, embedding_dim)         self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2,                             num_layers=1, bidirectional=True)          # 對應(yīng) CRF 的發(fā)射概率,即每一個位置對應(yīng)不同類標(biāo)的概率         self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)                  # 轉(zhuǎn)移矩陣,維度等于標(biāo)簽數(shù)量,表示從一個標(biāo)簽轉(zhuǎn)移到另一標(biāo)簽的概率         self.transitions = nn.Parameter(             torch.randn(len(tag2idx), len(tag2idx))

給定句子 x,其標(biāo)簽序列為 y 的概率用下面的公式計算。

怎么理解BiLSTM和CRF算法

p(y|x)

公式中的 score 用下面的式子計算,其中 Emit 對應(yīng)發(fā)射概率 (即 LSTM 輸出的概率),而 Trans 對應(yīng)了轉(zhuǎn)移概率 (即 CRF  轉(zhuǎn)移矩陣對應(yīng)的數(shù)值)

怎么理解BiLSTM和CRF算法

score 的計算公式

BiLSTM+CRF 采用最大似然法訓(xùn)練,對應(yīng)的損失函數(shù)如下:

怎么理解BiLSTM和CRF算法

損失函數(shù)

其中 score(x,y) 比較容易計算,而 Z(x) 是所有標(biāo)簽序列 (y) 打分的指數(shù)之和,如果序列的長度是 l,標(biāo)簽個數(shù)是 k,則序列的數(shù)量為  (k^l)。無法直接計算,因此要用前向算法進行計算。

用目前主流的深度學(xué)習(xí)框架,對 loss 進行求導(dǎo)和梯度下降,即可優(yōu)化 BiLSTM+CRF。訓(xùn)練好模型之后可以采用 viterbi 算法 (動態(tài)規(guī)劃)  找出最優(yōu)的路徑。

3.損失函數(shù)計算

計算 BiLSTM+CRF 損失函數(shù)的難點在于計算 log Z(x),用 F 表示 log Z(x),如下公式所示。

怎么理解BiLSTM和CRF算法

我們將 score 拆分,變成發(fā)射概率 p 和轉(zhuǎn)移概率 T 的和。為了簡化問題,我們假設(shè)序列的長度為3,則可以分別計算寫出長度為 1、2、3 時候的  log Z 值,如下所示。

怎么理解BiLSTM和CRF算法

上式中 p 表示發(fā)射概率,T 表示轉(zhuǎn)移概率,Start 表示開始,End 表示句子結(jié)束。F(3) 即是最終得到的 log Z(x)  值。通過對上式進行變換,可以將 F(3) 轉(zhuǎn)成遞歸的形式,如下。

怎么理解BiLSTM和CRF算法

可以看到上式中每一步的操作都是一樣的,操作包括 log_sum_exp,例如 F(1):

  • 首先需要計算 exp,對于所有 y1,計算 exp(p(y1)+T(Start,y1))

  • 求和,對上一步得到的 exp 值進行求和

  • 求 log,對求和的結(jié)果計算 log

因此可以寫出前向算法計算 log Z 的代碼,如下所示:

def forward_algorithm(self, probs):     def forward_algorithm(probs):     """     probs: LSTM 輸出的概率值,尺寸為 [seq_len, num_tags],num_tags 是標(biāo)簽的個數(shù)     """      # forward_var (可以理解為文章中的 F) 保存前一時刻的值,是一個向量,維度等于 num_tags     # 初始時只有 Start 為 0,其他的都取一個很小的值 (-10000.)     forward_var = torch.full((1, num_tags), -10000.0)  # [1, num_tags]     forward_var[0][Start] = 0.0      for p in probs:  # probs [seq_len, num_tags],遍歷序列         alphas_t = []  # alphas_t 保存下一時刻取不同標(biāo)簽的累積概率值         for next_tag in range(num_tags): # 遍歷標(biāo)簽              # 下一時刻發(fā)射 next_tag 的概率             emit_score = p[next_tag].view(1, -1).expand(1, num_tags)              # 從所有標(biāo)簽轉(zhuǎn)移到 next_tag 的概率, transitions 是一個矩陣,長寬都是 num_tags             trans_score = transitions[next_tag].view(1, -1)              # next_tag_ver = F(i-1) + p + T             next_tag_var = forward_var + trans_score + emit_score              alphas_t.append(log_sum_exp(next_tag_var).view(1))          forward_var = torch.cat(alphas_t).view(1, -1)      terminal_var = forward_var + self.transitions[Stop] # 最后轉(zhuǎn)移到 Stop 表示句子結(jié)束     alpha = log_sum_exp(terminal_var)     return alpha

4.viterbi 算法解碼

訓(xùn)練好模型后,預(yù)測過程需要用 viterbi 算法對序列進行解碼,感興趣的童鞋可以參看《統(tǒng)計學(xué)習(xí)方法》。下面介紹一下 viterbi  的公式,首先是一些符號的意義,如下:

怎么理解BiLSTM和CRF算法

然后可以得到 viterbi 算法的遞推公式

怎么理解BiLSTM和CRF算法

最終可以根據(jù) viterbi 計算得到的值,往前查找最合適的序列

怎么理解BiLSTM和CRF算法

“怎么理解BiLSTM和CRF算法”的內(nèi)容就介紹到這里了,感謝大家的閱讀。如果想了解更多行業(yè)相關(guān)的知識可以關(guān)注億速云網(wǎng)站,小編將為大家輸出更多高質(zhì)量的實用文章!

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進行舉報,并提供相關(guān)證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI