您好,登錄后才能下訂單哦!
如何對(duì)loss進(jìn)行mask
pytorch官方教程中有一個(gè)Chatbot教程,就是利用seq2seq和注意力機(jī)制實(shí)現(xiàn)的,感覺(jué)和機(jī)器翻譯沒(méi)什么不同啊,如果對(duì)話(huà)中一句話(huà)有下一句,那么就把這一對(duì)句子加入模型進(jìn)行訓(xùn)練。其中在訓(xùn)練階段,損失函數(shù)通常需要進(jìn)行mask操作,因?yàn)橐粋€(gè)batch中句子的長(zhǎng)度通常是不一樣的,一個(gè)batch中不足長(zhǎng)度的位置需要進(jìn)行填充(pad)補(bǔ)0,最后生成句子計(jì)算loss時(shí)需要忽略那些原本是pad的位置的值,即只保留mask中值為1位置的值,忽略值為0位置的值,具體演示如下:
import torch import torch.nn as nn import torch.nn.functional as F import itertools DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") PAD_token = 0
首先是pad函數(shù)和建立mask矩陣,矩陣的維度應(yīng)該和目標(biāo)一致。
def zeroPadding(l, fillvalue=PAD_token): # 輸入:[[1, 1, 1], [2, 2], [3]] # 返回:[(1, 2, 3), (1, 2, 0), (1, 0, 0)] 返回已經(jīng)是轉(zhuǎn)置后的 [L, B] return list(itertools.zip_longest(*l, fillvalue=fillvalue)) def binaryMatrix(l): # 將targets里非pad部分標(biāo)記為1,pad部分標(biāo)記為0 m = [] for i, seq in enumerate(l): m.append([]) for token in seq: if token == PAD_token: m[i].append(0) else: m[i].append(1) return m
假設(shè)現(xiàn)在輸入一個(gè)batch中有三個(gè)句子,我們按照長(zhǎng)度從大到小排好序,LSTM或是GRU的輸入和輸出我們需要利用pack_padded_sequence和pad_packed_sequence進(jìn)行打包和解包,感覺(jué)也是在進(jìn)行mask操作。
inputs = [[1, 2, 3], [4, 5], [6]] # 輸入句,一個(gè)batch,需要按照長(zhǎng)度從大到小排好序 inputs_lengths = [3, 2, 1] targets = [[1, 2], [1, 2, 3], [1]] # 目標(biāo)句,這里的長(zhǎng)度是不確定的,mask是針對(duì)targets的 inputs_batch = torch.LongTensor(zeroPadding(inputs)) inputs_lengths = torch.LongTensor(inputs_lengths) targets_batch = torch.LongTensor(zeroPadding(targets)) targets_mask = torch.ByteTensor(binaryMatrix(zeroPadding(targets))) # 注意這里是ByteTensor print(inputs_batch) print(targets_batch) print(targets_mask)
打印后結(jié)果如下,可見(jiàn)維度統(tǒng)一變成了[L, B],并且mask和target長(zhǎng)得一樣。另外,seq2seq模型處理時(shí)for循環(huán)每次讀取一行,預(yù)測(cè)下一行的值(即[B, L]時(shí)的一列預(yù)測(cè)下一列)。
tensor([[ 1, 4, 6], [ 2, 5, 0], [ 3, 0, 0]]) tensor([[ 1, 1, 1], [ 2, 2, 0], [ 0, 3, 0]]) tensor([[ 1, 1, 1], [ 1, 1, 0], [ 0, 1, 0]], dtype=torch.uint8)
現(xiàn)在假設(shè)我們將inputs輸入模型后,模型讀入sos后預(yù)測(cè)的第一行為outputs1, 維度為[B, vocab_size],即每個(gè)詞在詞匯表中的概率,模型輸出之前需要softmax。
outputs1 = torch.FloatTensor([[0.2, 0.1, 0.7], [0.3, 0.6, 0.1], [0.4, 0.5, 0.1]]) print(outputs1)
tensor([[ 0.2000, 0.1000, 0.7000], [ 0.3000, 0.6000, 0.1000], [ 0.4000, 0.5000, 0.1000]])
先看看兩個(gè)函數(shù)
torch.gather(input, dim, index, out=None)->Tensor
沿著某個(gè)軸,按照指定維度采集數(shù)據(jù),對(duì)于3維數(shù)據(jù),相當(dāng)于進(jìn)行如下操作:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
比如在這里,在第1維,選第二個(gè)元素。
# 收集每行的第2個(gè)元素 temp = torch.gather(outputs1, 1, torch.LongTensor([[1], [1], [1]])) print(temp)
tensor([[ 0.1000], [ 0.6000], [ 0.5000]])
torch.masked_select(input, mask, out=None)->Tensor
根據(jù)mask(ByteTensor)選取對(duì)應(yīng)位置的值,返回一維張量。
例如在這里我們選取temp大于等于0.5的值。
mask = temp.ge(0.5) # 大于等于0.5 print(mask) print(torch.masked_select(temp, temp.ge(0.5)))
tensor([[ 0], [ 1], [ 1]], dtype=torch.uint8) tensor([ 0.6000, 0.5000])
然后我們就可以計(jì)算loss了,這里是負(fù)對(duì)數(shù)損失函數(shù),之前模型的輸出要進(jìn)行softmax。
# 計(jì)算一個(gè)batch內(nèi)的平均負(fù)對(duì)數(shù)似然損失,即只考慮mask為1的元素 def maskNLLLoss(inp, target, mask): nTotal = mask.sum() # 收集目標(biāo)詞的概率,并取負(fù)對(duì)數(shù) crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1))) # 只保留mask中值為1的部分,并求均值 loss = crossEntropy.masked_select(mask).mean() loss = loss.to(DEVICE) return loss, nTotal.item()
這里我們計(jì)算第一行的平均損失。
# 計(jì)算預(yù)測(cè)的第一行和targets的第一行的loss maskNLLLoss(outputs1, targets_batch[0], targets_mask[0]) (tensor(1.1689, device='cuda:0'), 3)
最后進(jìn)行最后把所有行的loss累加起來(lái)變?yōu)閠otal_loss.backward()進(jìn)行反向傳播就可以了。
以上這篇pytorch實(shí)現(xiàn)seq2seq時(shí)對(duì)loss進(jìn)行mask的方式就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持億速云。
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀(guān)點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。