代码拉取完成,页面将自动刷新
import os
import logging
import argparse
from tqdm import tqdm, trange
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
# 设置绝对路径
import sys
sys.path.append("./")
from bert_finetune_cls.utils import init_logger, load_tokenizer, get_intent_labels, MODEL_CLASSES
# 日志对象初始化
logger = logging.getLogger(__name__)
def get_device(pred_config):
"""
获得device参数
:param pred_config:
:return:
"""
return "cuda" if torch.cuda.is_available() and not pred_config.no_cuda else "cpu"
def get_args(pred_config):
"""
得到训练好后保存的模型参数
:param pred_config:
:return:
"""
return torch.load(os.path.join(pred_config.model_dir, 'training_args.bin'))
def load_model(pred_config, args, device):
"""
加载模型
:param pred_config:
:param args: 参数
:param device: 配置
:return:
"""
# Check whether model exists
if not os.path.exists(pred_config.model_dir):
raise Exception("Model doesn't exists! Train first!")
try:
# 加载模型
model = MODEL_CLASSES[args.model_type][1].from_pretrained(args.model_dir,
args=args,
intent_label_lst=get_intent_labels(args),
)
model.to(device)
# 将模型固定不在训练
model.eval()
logger.info("***** Model Loaded *****")
except:
raise Exception("Some model files might be missing...")
return model
def read_input_file(pred_config):
"""
逐行读取输入文件
:param pred_config:
:return:
"""
lines = []
with open(pred_config.input_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
words = line.split()
lines.append(words)
return lines
def convert_input_file_to_tensor_dataset(lines,
pred_config,
args,
tokenizer,
cls_token_segment_id=0,
pad_token_segment_id=0,
sequence_a_segment_id=0,
mask_padding_with_zero=True):
"""
将原始输入数据转换成BERT模型需要的数据
:param lines: 输入文件
:param pred_config: 训练好的模型参数
:param args: 参数
:param tokenizer: 分词模型
:param cls_token_segment_id: -100
:param pad_token_segment_id: 0
:param sequence_a_segment_id: 0
:param mask_padding_with_zero: 0
:return:
"""
# 基于当前模型进行设置
# [CLS]
cls_token = tokenizer.cls_token
# [SEP]
sep_token = tokenizer.sep_token
# [UNK]
unk_token = tokenizer.unk_token
# [PAD]
pad_token_id = tokenizer.pad_token_id
all_input_ids = []
all_attention_mask = []
all_token_type_ids = []
# 循环读取每句话
for words in lines:
tokens = []
# 循环读取每句话中的每个单词
for word in words:
# 对每个单词进行分词
word_tokens = tokenizer.tokenize(word)
# 处理错误编码的单词
if not word_tokens:
word_tokens = [unk_token] # For handling the bad-encoded word
tokens.extend(word_tokens)
# Account for [CLS] and [SEP]
special_tokens_count = 2
# 如果句子长了就截断
if len(tokens) > args.max_seq_len - special_tokens_count:
tokens = tokens[: (args.max_seq_len - special_tokens_count)]
# Add [SEP] token
tokens += [sep_token]
token_type_ids = [sequence_a_segment_id] * len(tokens)
# Add [CLS] token
tokens = [cls_token] + tokens
token_type_ids = [cls_token_segment_id] + token_type_ids
# 把tokens转化为bert词表中的id
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
# 长度补齐,保证长度满足最大序列长度
# 需要填充序列的长度
padding_length = args.max_seq_len - len(input_ids)
# 输入样本序列在bert词表里的索引
input_ids = input_ids + ([pad_token_id] * padding_length)
# 注意力mask,padding的部分为0,其他为1
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
# token_type_ids表示每个token属于句子1还是句子2
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
all_input_ids.append(input_ids)
all_attention_mask.append(attention_mask)
all_token_type_ids.append(token_type_ids)
# # 将数据转换成张量
all_input_ids = torch.tensor(all_input_ids, dtype=torch.long)
all_attention_mask = torch.tensor(all_attention_mask, dtype=torch.long)
all_token_type_ids = torch.tensor(all_token_type_ids, dtype=torch.long)
dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids)
return dataset
def predict(pred_config):
# # 加载参数
args = get_args(pred_config)
device = get_device(pred_config)
# 加载模型
model = load_model(pred_config, args, device)
logger.info(args)
# 获取意图标签id
intent_label_lst = get_intent_labels(args)
# 计算损失时,忽略的label序号
pad_token_label_id = args.ignore_index
# 加载分词模型
tokenizer = load_tokenizer(args)
# 读取输入文件
lines = read_input_file(pred_config)
# 将输入文件转化为TensorDataset
dataset = convert_input_file_to_tensor_dataset(
lines,
pred_config,
args,
tokenizer,
)
# SequentialSampler:按顺序进行采样
sampler = SequentialSampler(dataset)
# 读取数据
data_loader = DataLoader(dataset, sampler=sampler, batch_size=pred_config.batch_size)
intent_preds = None
# 循环预测每个batch
for batch in tqdm(data_loader, desc="Predicting"):
batch = tuple(t.to(device) for t in batch)
# torch.no_grad():它包裹的不需要进行梯度计算
with torch.no_grad():
inputs = {"input_ids": batch[0],
"attention_mask": batch[1],
"intent_label_ids": None,}
if args.model_type != "distilbert":
inputs["token_type_ids"] = batch[2]
# 通过前向传播得到outputs
outputs = model(**inputs)
# 意图标签预测值
intent_logits = outputs[0]
# 如果意图标签存在
if intent_preds is None:
# detach()阻断反向传播,不再有梯度
# numpy不能读取CUDA tensor 需要将它转化为 CPU tensor
intent_preds = intent_logits.detach().cpu().numpy()
# 如果意图标签不存在
else:
intent_preds = np.append(intent_preds, intent_logits.detach().cpu().numpy(), axis=0)
# 获取意图标签预测的索引
intent_preds = np.argmax(intent_preds, axis=1)
# 写入到文件中
with open(pred_config.output_file, "w", encoding="utf-8") as f:
for words, intent_pred in zip(lines, intent_preds):
line = ""
f.write("{}\n".format(intent_label_lst[intent_pred]))
logger.info("Prediction Done!")
if __name__ == "__main__":
# 初始化日志
init_logger()
# 建立解析对象
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", default="sample_pred_in.txt", type=str, help="Input file for prediction")
parser.add_argument("--output_file", default="sample_pred_out.txt", type=str, help="Output file for prediction")
parser.add_argument("--model_dir", default="./atis_model", type=str, help="Path to save, load model")
parser.add_argument("--batch_size", default=32, type=int, help="Batch size for prediction")
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
# 属性给与args实例:把parser中设置的所有"add_argument"给返回到args子类实例当中,那么parser中增加的属性内容都会在args实例中,使用即可
pred_config = parser.parse_args()
# 预测
predict(pred_config)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。