您好,登錄后才能下訂單哦!
小編給大家分享一下pytorch如何實現(xiàn)用CNN和LSTM對文本進(jìn)行分類方式,相信大部分人都還不怎么了解,因此分享這篇文章給大家參考一下,希望大家閱讀完這篇文章后大有收獲,下面讓我們一起去了解一下吧!
model.py:
#!/usr/bin/python # -*- coding: utf-8 -*- import torch from torch import nn import numpy as np from torch.autograd import Variable import torch.nn.functional as F class TextRNN(nn.Module): """文本分類,RNN模型""" def __init__(self): super(TextRNN, self).__init__() # 三個待輸入的數(shù)據(jù) self.embedding = nn.Embedding(5000, 64) # 進(jìn)行詞嵌入 # self.rnn = nn.LSTM(input_size=64, hidden_size=128, num_layers=2, bidirectional=True) self.rnn = nn.GRU(input_size=64, hidden_size=128, num_layers=2, bidirectional=True) self.f1 = nn.Sequential(nn.Linear(256,128), nn.Dropout(0.8), nn.ReLU()) self.f2 = nn.Sequential(nn.Linear(128,10), nn.Softmax()) def forward(self, x): x = self.embedding(x) x,_ = self.rnn(x) x = F.dropout(x,p=0.8) x = self.f1(x[:,-1,:]) return self.f2(x) class TextCNN(nn.Module): def __init__(self): super(TextCNN, self).__init__() self.embedding = nn.Embedding(5000,64) self.conv = nn.Conv1d(64,256,5) self.f1 = nn.Sequential(nn.Linear(256*596, 128), nn.ReLU()) self.f2 = nn.Sequential(nn.Linear(128, 10), nn.Softmax()) def forward(self, x): x = self.embedding(x) x = x.detach().numpy() x = np.transpose(x,[0,2,1]) x = torch.Tensor(x) x = Variable(x) x = self.conv(x) x = x.view(-1,256*596) x = self.f1(x) return self.f2(x)
train.py:
# coding: utf-8 from __future__ import print_function import torch from torch import nn from torch import optim from torch.autograd import Variable import os import numpy as np from model import TextRNN,TextCNN from cnews_loader import read_vocab, read_category, batch_iter, process_file, build_vocab base_dir = 'cnews' train_dir = os.path.join(base_dir, 'cnews.train.txt') test_dir = os.path.join(base_dir, 'cnews.test.txt') val_dir = os.path.join(base_dir, 'cnews.val.txt') vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt') def train(): x_train, y_train = process_file(train_dir, word_to_id, cat_to_id,600)#獲取訓(xùn)練數(shù)據(jù)每個字的id和對應(yīng)標(biāo)簽的oe-hot形式 x_val, y_val = process_file(val_dir, word_to_id, cat_to_id,600) #使用LSTM或者CNN model = TextRNN() # model = TextCNN() #選擇損失函數(shù) Loss = nn.MultiLabelSoftMarginLoss() # Loss = nn.BCELoss() # Loss = nn.MSELoss() optimizer = optim.Adam(model.parameters(),lr=0.001) best_val_acc = 0 for epoch in range(1000): batch_train = batch_iter(x_train, y_train,100) for x_batch, y_batch in batch_train: x = np.array(x_batch) y = np.array(y_batch) x = torch.LongTensor(x) y = torch.Tensor(y) # y = torch.LongTensor(y) x = Variable(x) y = Variable(y) out = model(x) loss = Loss(out,y) optimizer.zero_grad() loss.backward() optimizer.step() accracy = np.mean((torch.argmax(out,1)==torch.argmax(y,1)).numpy()) #對模型進(jìn)行驗證 if (epoch+1)%20 == 0: batch_val = batch_iter(x_val, y_val, 100) for x_batch, y_batch in batch_train: x = np.array(x_batch) y = np.array(y_batch) x = torch.LongTensor(x) y = torch.Tensor(y) # y = torch.LongTensor(y) x = Variable(x) y = Variable(y) out = model(x) loss = Loss(out, y) optimizer.zero_grad() loss.backward() optimizer.step() accracy = np.mean((torch.argmax(out, 1) == torch.argmax(y, 1)).numpy()) if accracy > best_val_acc: torch.save(model.state_dict(),'model_params.pkl') best_val_acc = accracy print(accracy) if __name__ == '__main__': #獲取文本的類別及其對應(yīng)id的字典 categories, cat_to_id = read_category() #獲取訓(xùn)練文本中所有出現(xiàn)過的字及其所對應(yīng)的id words, word_to_id = read_vocab(vocab_dir) #獲取字?jǐn)?shù) vocab_size = len(words) train()
test.py:
# coding: utf-8 from __future__ import print_function import os import tensorflow.contrib.keras as kr import torch from torch import nn from cnews_loader import read_category, read_vocab from model import TextRNN from torch.autograd import Variable import numpy as np try: bool(type(unicode)) except NameError: unicode = str base_dir = 'cnews' vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt') class TextCNN(nn.Module): def __init__(self): super(TextCNN, self).__init__() self.embedding = nn.Embedding(5000,64) self.conv = nn.Conv1d(64,256,5) self.f1 = nn.Sequential(nn.Linear(152576, 128), nn.ReLU()) self.f2 = nn.Sequential(nn.Linear(128, 10), nn.Softmax()) def forward(self, x): x = self.embedding(x) x = x.detach().numpy() x = np.transpose(x,[0,2,1]) x = torch.Tensor(x) x = Variable(x) x = self.conv(x) x = x.view(-1,152576) x = self.f1(x) return self.f2(x) class CnnModel: def __init__(self): self.categories, self.cat_to_id = read_category() self.words, self.word_to_id = read_vocab(vocab_dir) self.model = TextCNN() self.model.load_state_dict(torch.load('model_params.pkl')) def predict(self, message): # 支持不論在python2還是python3下訓(xùn)練的模型都可以在2或者3的環(huán)境下運行 content = unicode(message) data = [self.word_to_id[x] for x in content if x in self.word_to_id] data = kr.preprocessing.sequence.pad_sequences([data],600) data = torch.LongTensor(data) y_pred_cls = self.model(data) class_index = torch.argmax(y_pred_cls[0]).item() return self.categories[class_index] class RnnModel: def __init__(self): self.categories, self.cat_to_id = read_category() self.words, self.word_to_id = read_vocab(vocab_dir) self.model = TextRNN() self.model.load_state_dict(torch.load('model_rnn_params.pkl')) def predict(self, message): # 支持不論在python2還是python3下訓(xùn)練的模型都可以在2或者3的環(huán)境下運行 content = unicode(message) data = [self.word_to_id[x] for x in content if x in self.word_to_id] data = kr.preprocessing.sequence.pad_sequences([data], 600) data = torch.LongTensor(data) y_pred_cls = self.model(data) class_index = torch.argmax(y_pred_cls[0]).item() return self.categories[class_index] if __name__ == '__main__': model = CnnModel() # model = RnnModel() test_demo = ['湖人助教力助科比恢復(fù)手感 他也是阿泰的精神導(dǎo)師新浪體育訊記者戴高樂報道 上賽季,科比的右手食指遭遇重創(chuàng),他的投籃手感也因此大受影響。不過很快科比就調(diào)整了自己的投籃手型,并通過這一方式讓自己的投籃命中率回升。而在這科比背后,有一位特別助教對科比幫助很大,他就是查克·珀森。珀森上賽季擔(dān)任湖人的特別助教,除了幫助科比調(diào)整投籃手型之外,他的另一個重要任務(wù)就是擔(dān)任阿泰的精神導(dǎo)師。來到湖人隊之后,阿泰收斂起了暴躁的脾氣,成為湖人奪冠路上不可或缺的一員,珀森的“心靈按摩”功不可沒。經(jīng)歷了上賽季的成功之后,珀森本賽季被“升職”成為湖人隊的全職助教,每場比賽,他都會坐在球場邊,幫助禪師杰克遜一起指揮湖人球員在場上拼殺。對于珀森的工作,禪師非常欣賞,“查克非常善于分析問題,”菲爾·杰克遜說,“他總是在尋找問題的答案,同時也在找造成這一問題的原因,這是我們都非常樂于看到的。我會在平時把防守中出現(xiàn)的一些問題交給他,然后他會通過組織球員練習(xí)找到解決的辦法。他在球員時代曾是一名很好的外線投手,不過現(xiàn)在他與內(nèi)線球員的配合也相當(dāng)不錯。', '弗老大被裁美國媒體看熱鬧“特權(quán)”在中國像蠢蛋弗老大要走了。雖然他只在首鋼男籃效力了13天,而且表現(xiàn)毫無亮點,大大地讓球迷和俱樂部失望了,但就像中國人常說的“好聚好散”,隊友還是友好地與他告別,俱樂部與他和平分手,球迷還請他留下了在北京的最后一次簽名。相比之下,弗老大的同胞美國人卻沒那么“寬容”。他們嘲諷這位NBA前巨星的英雄遲暮,批評他在CBA的業(yè)余表現(xiàn),還驚訝于中國人的“大方”。今天,北京首鋼俱樂部將與弗朗西斯繼續(xù)商討解約一事。從昨日的進(jìn)展來看,雙方可以做到“買賣不成人意在”,但回到美國后,恐怕等待弗朗西斯的就沒有這么輕松的環(huán)境了。進(jìn)展@北京昨日與隊友告別 最后一次為球迷簽名弗朗西斯在13天里為首鋼隊打了4場比賽,3場的得分為0,只有一場得了2分。昨天是他來到北京的第14天,雖然他與首鋼還未正式解約,但雙方都明白“緣分已盡”。下午,弗朗西斯來到首鋼俱樂部與隊友們告別。弗朗西斯走到隊友身邊,依次與他們握手擁抱?!澳銈兌紝ξ液芎茫才诺臈l件也很好,我很喜歡這支球隊,想融入你們,但我現(xiàn)在真的很不適應(yīng)。希望你們'] for i in test_demo: print(i,":",model.predict(i))
cnews_loader.py:
# coding: utf-8 import sys from collections import Counter import numpy as np import tensorflow.contrib.keras as kr if sys.version_info[0] > 2: is_py3 = True else: reload(sys) sys.setdefaultencoding("utf-8") is_py3 = False def native_word(word, encoding='utf-8'): """如果在python2下面使用python3訓(xùn)練的模型,可考慮調(diào)用此函數(shù)轉(zhuǎn)化一下字符編碼""" if not is_py3: return word.encode(encoding) else: return word def native_content(content): if not is_py3: return content.decode('utf-8') else: return content def open_file(filename, mode='r'): """ 常用文件操作,可在python2和python3間切換. mode: 'r' or 'w' for read or write """ if is_py3: return open(filename, mode, encoding='utf-8', errors='ignore') else: return open(filename, mode) def read_file(filename): """讀取文件數(shù)據(jù)""" contents, labels = [], [] with open_file(filename) as f: for line in f: try: label, content = line.strip().split('\t') if content: contents.append(list(native_content(content))) labels.append(native_content(label)) except: pass return contents, labels def build_vocab(train_dir, vocab_dir, vocab_size=5000): """根據(jù)訓(xùn)練集構(gòu)建詞匯表,存儲""" data_train, _ = read_file(train_dir) all_data = [] for content in data_train: all_data.extend(content) counter = Counter(all_data) count_pairs = counter.most_common(vocab_size - 1) words, _ = list(zip(*count_pairs)) # 添加一個 <PAD> 來將所有文本pad為同一長度 words = ['<PAD>'] + list(words) open_file(vocab_dir, mode='w').write('\n'.join(words) + '\n') def read_vocab(vocab_dir): """讀取詞匯表""" # words = open_file(vocab_dir).read().strip().split('\n') with open_file(vocab_dir) as fp: # 如果是py2 則每個值都轉(zhuǎn)化為unicode words = [native_content(_.strip()) for _ in fp.readlines()] word_to_id = dict(zip(words, range(len(words)))) return words, word_to_id def read_category(): """讀取分類目錄,固定""" categories = ['體育', '財經(jīng)', '房產(chǎn)', '家居', '教育', '科技', '時尚', '時政', '游戲', '娛樂'] categories = [native_content(x) for x in categories] cat_to_id = dict(zip(categories, range(len(categories)))) return categories, cat_to_id def to_words(content, words): """將id表示的內(nèi)容轉(zhuǎn)換為文字""" return ''.join(words[x] for x in content) def process_file(filename, word_to_id, cat_to_id, max_length=600): """將文件轉(zhuǎn)換為id表示""" contents, labels = read_file(filename)#讀取訓(xùn)練數(shù)據(jù)的每一句話及其所對應(yīng)的類別 data_id, label_id = [], [] for i in range(len(contents)): data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id])#將每句話id化 label_id.append(cat_to_id[labels[i]])#每句話對應(yīng)的類別的id # # # 使用keras提供的pad_sequences來將文本pad為固定長度 x_pad = kr.preprocessing.sequence.pad_sequences(data_id, max_length) y_pad = kr.utils.to_categorical(label_id, num_classes=len(cat_to_id)) # 將標(biāo)簽轉(zhuǎn)換為one-hot表示 # return x_pad, y_pad def batch_iter(x, y, batch_size=64): """生成批次數(shù)據(jù)""" data_len = len(x) num_batch = int((data_len - 1) / batch_size) + 1 indices = np.random.permutation(np.arange(data_len)) x_shuffle = x[indices] y_shuffle = y[indices] for i in range(num_batch): start_id = i * batch_size end_id = min((i + 1) * batch_size, data_len) yield x_shuffle[start_id:end_id], y_shuffle[start_id:end_id]
1.PyTorch是相當(dāng)簡潔且高效快速的框架;2.設(shè)計追求最少的封裝;3.設(shè)計符合人類思維,它讓用戶盡可能地專注于實現(xiàn)自己的想法;4.與google的Tensorflow類似,F(xiàn)AIR的支持足以確保PyTorch獲得持續(xù)的開發(fā)更新;5.PyTorch作者親自維護(hù)的論壇 供用戶交流和求教問題6.入門簡單
以上是“pytorch如何實現(xiàn)用CNN和LSTM對文本進(jìn)行分類方式”這篇文章的所有內(nèi)容,感謝各位的閱讀!相信大家都有了一定的了解,希望分享的內(nèi)容對大家有所幫助,如果還想學(xué)習(xí)更多知識,歡迎關(guān)注億速云行業(yè)資訊頻道!
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報,并提供相關(guān)證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權(quán)內(nèi)容。