溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶服務(wù)條款》

用BERT進(jìn)行中文短文本分類

發(fā)布時(shí)間:2020-07-15 08:40:21 來源:網(wǎng)絡(luò) 閱讀:2271 作者:nineteens 欄目:編程語言

  1. 環(huán)境配置

  本實(shí)驗(yàn)使用操作系統(tǒng):Ubuntu 18.04.3 LTS 4.15.0-29-generic GNU/Linux操作系統(tǒng)。

  1.1 查看CUDA版本

  cat /usr/local/cuda/version.txt

  輸出:

  CUDA Version 10.0.130*

  1.2 查看 cudnn版本

  cat /usr/local/cuda/include/cudnn.h | grep CUDNN_MAJOR -A 2

  輸出:

  #define CUDNN_MINOR 6

  #define CUDNN_PATCHLEVEL 3

  --

  #define CUDNN_VERSION (CUDNN_MAJOR * 1000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL)

  如果沒有安裝 cuda 和 cudnn,到官網(wǎng)根自己的 GPU 型號(hào)版本安裝即可

  1.3 安裝tensorflow-gpu

  通過Anaconda創(chuàng)建虛擬環(huán)境來安裝tensorflow-gpu(Anaconda安裝步驟就不說了)

  創(chuàng)建虛擬環(huán)境

  虛擬環(huán)境名為:tensorflow

  conda create -n tensorflow python=3.7.1

  進(jìn)入虛擬環(huán)境

  下次使用也可以通過此命令進(jìn)入虛擬環(huán)境

  source activate tensorflow

  安裝tensorflow-gpu

  不推薦直接pip install tensorflow-gpu 因?yàn)樗俣缺容^慢。可以從豆瓣的鏡像中下載,速度還是很快的。https://pypi.doubanio.com/simple/tensorflow-gpu/

  找到自己適用的版本(cp37表示python版本為3.7)

  然后通過pip install 安裝

  pip install https://pypi.doubanio.com/packages/15/21/17f941058556b67ce6d1e3f0e0932c9c2deaf457e3d45eecd93f2c20827d/tensorflow_gpu-1.14.0rc1-cp37-cp37m-manylinux1_x86_64.whl

  我選擇了1.14.0的tensorflow-gpu linux版本,python版本為3.7。使用BERT的話,tensorflow-gpu版本必須大于1.11.0。同時(shí),不建議選擇2.0版本,2.0版本好像修改了一些方法,還需要自己手動(dòng)修改代碼

  環(huán)境測試

  在tensorflow虛擬環(huán)境中,python命令進(jìn)入Python環(huán)境中,輸入import tensorflow,看是否能成功導(dǎo)入

  

  2. 準(zhǔn)備工作

  

  2.1 預(yù)訓(xùn)練模型下載

  Bert-base Chinese

  BERT-wwm :由哈工大和訊飛聯(lián)合實(shí)驗(yàn)室發(fā)布的,效果比Bert-base Chinese要好一些(鏈接地址為訊飛云,密碼:mva8。無奈當(dāng)時(shí)用wwm訓(xùn)練完提交結(jié)果時(shí),提交通道已經(jīng)關(guān)閉了,嗚嗚)

  bert_model.ckpt:負(fù)責(zé)模型變量載入

  vocab.txt:訓(xùn)練時(shí)中文文本采用的字典

  bert_config.json:BERT在訓(xùn)練時(shí),可選調(diào)整的一些參數(shù)

  2.2 數(shù)據(jù)準(zhǔn)備

  1)將自己的數(shù)據(jù)集格式改成如下格式:第一列是標(biāo)簽,第二列是文本數(shù)據(jù),中間用tab隔開(若測試集沒有標(biāo)簽,只保留一列樣本數(shù)據(jù))。 分別將訓(xùn)練集、驗(yàn)證集、測試集文件名改為train.tsv、val.tsv、test.tsv。文件格式為UTF-8(無BOM)

  2)新建data文件夾,存放這三個(gè)文件。

  3)預(yù)訓(xùn)練模型解壓,存放到新建文件夾chinese中

  2.3 代碼修改

  我們需要對(duì)bert源碼中run_classifier.py進(jìn)行兩處修改

  1)在run_classifier.py中添加我們的任務(wù)類

  可以參照其他Processor類,添加自己的任務(wù)類

  # 自定義Processor類

  class MyProcessor(DataProcessor):

  def __init__(self):

  self.labels = ['Addictive Behavior',

  'Address',

  'Age',

  'Alcohol Consumer',

  'Allergy Intolerance',

  'Bedtime',

  'Blood Donation',

  'Capacity',

  'Compliance with Protocol',

  'Consent',

  'Data Accessible',

  'Device',

  'Diagnostic',

  'Diet',

  'Disabilities',

  'Disease',

  'Education',

  'Encounter',

  'Enrollment in other studies',

  'Ethical Audit',

  'Ethnicity',

  'Exercise',

  'Gender',

  'Healthy',

  'Laboratory Examinations',

  'Life Expectancy',

  'Literacy',

  'Multiple',

  'Neoplasm Status',

  'Non-Neoplasm Disease Stage',

  'Nursing',

  'Oral related',

  'Organ or Tissue Status',

  'Pharmaceutical Substance or Drug',

  'Pregnancy-related Activity',

  'Receptor Status',

  'Researcher Decision',

  'Risk Assessment',

  'Sexual related',

  'Sign',

  'Smoking Status',

  'Special Patient Characteristic',

  'Symptom',

  'Therapy or Surgery']

  def get_train_examples(self, data_dir):

  return self._create_examples(

  self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

  def get_dev_examples(self, data_dir):

  return self._create_examples(

  self._read_tsv(os.path.join(data_dir, "val.tsv")), "val")

  def get_test_examples(self, data_dir):

  return self._create_examples(

  self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

  def get_labels(self):

  return self.labels

  def _create_examples(self, lines, set_type):

  examples = []

  for (i, line) in enumerate(lines):

  guid = "%s-%s" % (set_type, i)

  if set_type == "test":

  """

  因?yàn)槲业臏y試集中沒有標(biāo)簽,所以對(duì)test進(jìn)行單獨(dú)處理,

  test的label值設(shè)為任意一標(biāo)簽(一定是存在的類標(biāo)簽,

  不然predict時(shí)會(huì)keyError),如果測試集中有標(biāo)簽,就

  不需要if了,統(tǒng)一處理即可。

  """

  text_a = tokenization.convert_to_unicode(line[0])

  label = "Address"

  else:

  text_a = tokenization.convert_to_unicode(line[1])

  label = tokenization.convert_to_unicode(line[0])

  examples.append(

  InputExample(guid=guid, text_a=text_a, text_b=None, label=label))

  return examples

  2)修改processor字典

  def main(_):

  tf.logging.set_verbosity(tf.logging.INFO)

  processors = {

  "cola": ColaProcessor,

  "mnli": MnliProcessor,

  "mrpc": MrpcProcessor,

  "xnli": XnliProcessor,

  "mytask": MyProcessor, # 將自己的Processor添加到字典

  }

  3 開工

  3.1 配置訓(xùn)練腳本

  創(chuàng)建并運(yùn)行run.sh這個(gè)文件

  python run_classifier.py \

  --data_dir=data \

  --task_name=mytask \

  --do_train=true \

  --do_eval=true \

  --vocab_file=chinese/vocab.txt \

  --bert_config_file=chinese/bert_config.json \

  --init_checkpoint=chinese/bert_model.ckpt \

  --max_seq_length=128 \

  --train_batch_size=8 \

  --learning_rate=2e-5 \

  --num_train_epochs=3.0

  --output_dir=out \

  fine-tune需要一定的時(shí)間,我的訓(xùn)練集有兩萬條,驗(yàn)證集有八千條,GPU為2080Ti,需要20分鐘左右。如果顯存不夠大,記得適當(dāng)調(diào)整max_seq_length 和 train_batch_size

  3.2 預(yù)測

  創(chuàng)建并運(yùn)行test.sh(注:init_checkpoint為自己之前輸出模型地址)

  python run_classifier.py \

  --task_name=mytask \

  --do_predict=true \

  --data_dir=data \

  --vocab_file=chinese/vocab.txt \

  --bert_config_file=chinese/bert_config.json \

  --init_checkpoint=out \

  --max_seq_length=128 \

  --output_dir=out

  預(yù)測完會(huì)在out目錄下生成test_results.tsv。生成文件中,每一行對(duì)應(yīng)你訓(xùn)練集中的每一個(gè)樣本,每一列對(duì)應(yīng)的是每一類的概率(對(duì)應(yīng)之前自定義的label列表)。如第5行第8列表示第5個(gè)樣本是第8類的概率。

  3.3 預(yù)測結(jié)果處理鄭州婦科醫(yī)院 http://www.zykdfkyy.com/

  因?yàn)轭A(yù)測結(jié)果是概率,我們需要對(duì)其處理,選取每一行中的最大值最為預(yù)測值,并轉(zhuǎn)換成對(duì)應(yīng)的真實(shí)標(biāo)簽。

  data_dir = "C:\\test_results.tsv"

  lable = ['Addictive Behavior',

  'Address',

  'Age',

  'Alcohol Consumer',

  'Allergy Intolerance',

  'Bedtime',

  'Blood Donation',

  'Capacity',

  'Compliance with Protocol',

  'Consent',

  'Data Accessible',

  'Device',

  'Diagnostic',

  'Diet',

  'Disabilities',

  'Disease',

  'Education',

  'Encounter',

  'Enrollment in other studies',

  'Ethical Audit',

  'Ethnicity',

  'Exercise',

  'Gender',

  'Healthy',

  'Laboratory Examinations',

  'Life Expectancy',

  'Literacy',

  'Multiple',

  'Neoplasm Status',

  'Non-Neoplasm Disease Stage',

  'Nursing',

  'Oral related',

  'Organ or Tissue Status',

  'Pharmaceutical Substance or Drug',

  'Pregnancy-related Activity',

  'Receptor Status',

  'Researcher Decision',

  'Risk Assessment',

  'Sexual related',

  'Sign',

  'Smoking Status',

  'Special Patient Characteristic',

  'Symptom',

  'Therapy or Surgery']

  # 用pandas讀取test_result.tsv,將標(biāo)簽設(shè)置為列名

  data_df = pd.read_table(data_dir, sep="\t", names=lable, encoding="utf-8")

  label_test = []

  for i in range(data_df.shape[0]):

  # 獲取一行中最大值對(duì)應(yīng)的列名,追加到列表

  label_test.append(data_df.loc[i, :].idxmax())


向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場,如果涉及侵權(quán)請(qǐng)聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI