您好,登錄后才能下訂單哦!
本篇內(nèi)容介紹了“DialoGPT是什么”的有關(guān)知識(shí),在實(shí)際案例的操作過程中,不少人都會(huì)遇到這樣的困境,接下來就讓小編帶領(lǐng)大家學(xué)習(xí)一下如何處理這些情況吧!希望大家仔細(xì)閱讀,能夠?qū)W有所成!
引言
Large-scale pretraining for dialogue
DialoGPT是基于GPT-2的對(duì)話生成預(yù)訓(xùn)練模型,在reddit數(shù)據(jù)集上訓(xùn)練
假定已經(jīng)設(shè)置好環(huán)境,
在eval_util.py中增加 inference函數(shù)
def inference_model_results(model, tokenizer, inference_dataloader, args):
# use the same signature with eval_model_generation
logger.info('compute eval model loss, using eval mode, '
'please change it back to train after calling this function')
model.eval()
tot_sample = []
with torch.no_grad():
for step, batch in enumerate(inference_dataloader):
batch = tuple(t.to(args.device) for t in batch)
input_ids, position_ids, token_ids, label_ids, src_len, _ = batch
if args.no_token_id:
token_ids = None
n_sample = input_ids.shape[0]
logits = model.inference(input_ids, position_ids, token_ids)
def decode(batch_data, tokenizer, input_flag):
results = []
batch_data = batch_data.cpu().data.numpy()
for one_logits in batch_data: # [sentence_len, vocabulary_size]
if not input_flag:
word_ids = np.argmax(one_logits, axis=1)
else:
word_ids = one_logits
words = []
for id in word_ids:
if tokenizer.decoder[id] != "<|endoftext|>":
words.append(tokenizer.decoder[id])
else:
break
output_words = []
for word in words:
output_words.append(word[1:]) if word.startswith("?") else output_words.append(word)
results.append(" ".join(output_words))
return results
posts = decode(input_ids, tokenizer, True)
inferences = decode(logits, tokenizer, False)
tot_sample.append(n_sample)
logger.info("model inference results")
for index in range(len(posts)):
print("post: ", posts[index])
print("inference: ", inferences[index])
# print(inferences)
break
# todo
return None
在modeling_gpt2.py中class GPT2LMHeadModel(GPT2PreTrainedModel)中增加inference函數(shù)
def inference(self, input_ids, position_ids=None, token_type_ids=None, past=None):
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
lm_logits = self.lm_head(hidden_states)
return lm_logits
自定義inference_LSP.py 文件
文件內(nèi)容
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
'''
* @Desc: train GPT2 from scratch/ fine tuning.
Modified based on Huggingface GPT-2 implementation
'''
import json
import os
import sys
import argparse
import logging
import time
import tqdm
import datetime
import torch
import numpy as np
from os.path import join
from torch.distributed import get_rank, get_world_size
from lsp_model import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, Adam
from gpt2_training.train_utils import load_model, boolean_string, set_lr, get_eval_list_same_length
from gpt2_training.eval_utils import eval_model_loss, inference_model_results
from data_loader import BucketingDataLoader, DynamicBatchingLoader, DistributedBucketingDataLoader
from gpt2_training.distributed import all_reduce_and_rescale_tensors, all_gather_list
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)
logger = logging.getLogger(__name__)
INF = 100000000
CACHE_EMPTY_STEP = 10000
EVAL_STEP = 10000
#########################################################################
# Prepare Parser
##########################################################################
parser = argparse.ArgumentParser()
parser.add_argument('--model_name_or_path', type=str, required=True,
help='pretrained model name or path to local checkpoint')
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--max_seq_length", type=int, default=128)
parser.add_argument("--init_checkpoint", type=str, required=True)
parser.add_argument("--inference_input_file", type=str, required=True)
parser.add_argument("--inference_batch_size", type=int, default=8)
parser.add_argument("--num_optim_steps", type=int, default=1000000,
help="new API specifies num update steps")
parser.add_argument("--fp16", type=boolean_string, default=True)
parser.add_argument("--normalize_data", type=boolean_string, default=True)
parser.add_argument("--loss_scale", type=float, default=0)
parser.add_argument("--no_token_id", type=boolean_string, default=True)
parser.add_argument("--log_dir", type=str, required=True)
# distributed
parser.add_argument('--local_rank', type=int, default=-1,
help='for torch.distributed')
parser.add_argument('--config', help='JSON config file')
# do normal parsing
args = parser.parse_args()
if args.config is not None:
# override argparse defaults by config JSON
opts = json.load(open(args.config))
for k, v in opts.items():
if isinstance(v, str):
# PHILLY ENV special cases
if 'PHILLY_JOB_DIRECTORY' in v:
v = v.replace('PHILLY_JOB_DIRECTORY',
os.environ['PHILLY_JOB_DIRECTORY'])
elif 'PHILLY_LOG_DIRECTORY' in v:
v = v.replace('PHILLY_LOG_DIRECTORY',
os.environ['PHILLY_LOG_DIRECTORY'])
setattr(args, k, v)
# command line should override config JSON
argv = sys.argv[1:]
overrides, _ = parser.parse_known_args(argv)
for k, v in vars(overrides).items():
if f'--{k}' in argv:
setattr(args, k, v)
setattr(args, 'local_rank', overrides.local_rank)
if args.local_rank == -1:
logger.info('CUDA available? {}'.format(str(torch.cuda.is_available())))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
args.device, args.n_gpu = device, n_gpu
else:鄭州婦科醫(yī)院哪家好 http://www.120zzzy.com/
# distributed training
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
# Initializes the distributed backend which will take care of
# sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl')
n_gpu = torch.distributed.get_world_size()
args.device, args.n_gpu = device, 1
logger.info("device: {} n_gpu: {}, distributed training: {}, "
"16-bits training: {}".format(
device, n_gpu, bool(args.local_rank != -1), args.fp16))
timestamp = datetime.datetime.now().strftime('%Y-%m-%d%H%M%S')
log_dir = args.log_dir
logger.info('Input Argument Information')
args_dict = vars(args)
for a in args_dict:
logger.info('%-28s %s' % (a, args_dict[a]))
#########################################################################
# Prepare Data Set
##########################################################################
print("Prepare Data")
enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
config = GPT2Config.from_json_file(
join(args.model_name_or_path, 'config.json'))
inference_dataloader_loss = DynamicBatchingLoader(
args.inference_input_file, enc, args.normalize_data,
args.inference_batch_size, args.max_seq_length)
inference_dataloader_gen = get_eval_list_same_length(
args.inference_input_file, enc, args.inference_batch_size, True)
# eval_dataloader_loss = DynamicBatchingLoader(
# args.eval_input_file, enc, args.normalize_data,
# args.eval_batch_size, args.max_seq_length)
#
# eval_dataloader_gen = get_eval_list_same_length(
# args.eval_input_file, enc, args.eval_batch_size, True)
#########################################################################
# Prepare Model
##########################################################################
print("Prepare Model")
logger.info("Prepare Model")
model = load_model(GPT2LMHeadModel(config), args.init_checkpoint,
args, verbose=True)
if args.local_rank != -1:
# when from scratch make sure initial models are the same
params = [p.data for p in model.parameters()]
all_reduce_and_rescale_tensors(params, float(torch.distributed.get_world_size()))
no_decay = ['bias', 'ln'] # no decay for bias and LayerNorm (ln)
#########################################################################
# Inference !
##########################################################################
print("Model inference")
logger.info("Model inference")
inference_logger = open(join(log_dir, 'inference_log.txt'), 'a+', buffering=1)
epoch = 0
if args.local_rank != -1:
n_gpu = 1
# todo modify loss out.
results = inference_model_results(model, enc, inference_dataloader_loss, args)
# todo output format
# print('{},{},{},{},{}'.format(epoch + 1, global_step + 1, step + 1, eval_loss, eval_ppl), file=inference_logger)
logger.info("inference_final_results:")
if results is None:
logger.info("current results are None")
else:
logger.info(results)
inference_logger.close()
python inference_LSP.py --model_name_or_path ./models/medium/ --init_checkpoint ./12_5_self_output/GPT2.1e-05.8.3gpu.2019-12-04225327/GP2-pretrain-step-50000.pkl --inference_input_file ./selfdata/attack_chatbot.tsv --log_dir inference_logs_dir/
Inference
python inference_LSP.py --model_name_or_path ./models/medium/ --init_checkpoint ./12_5_self_output/GPT2.1e-05.8.3gpu.2019-12-04225327/GP2-pretrain-step-50000.pkl --inference_input_file ./selfdata/attack_chatbot.tsv --log_dir inference_logs_dir/
validset.tsv:
–model_name_or_path ./models/medium/ --init_checkpoint ./12_5_self_output/GPT2.1e-05.8.3gpu.2019-12-04225327/GP2-pretrain-step-50000.pkl --inference_input_file ./selfdata/validset.tsv --log_dir inference_logs_dir/
./models/medium/medium_ft.pkl
“DialoGPT是什么”的內(nèi)容就介紹到這里了,感謝大家的閱讀。如果想了解更多行業(yè)相關(guān)的知識(shí)可以關(guān)注億速云網(wǎng)站,小編將為大家輸出更多高質(zhì)量的實(shí)用文章!
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。