0 Star 4 Fork 2

科大讯飞/Chinese-ELECTRA

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
cmrc2018_drcd_evaluate.py 4.16 KB
一键复制 编辑 原始数据 按行查看 历史
Christian Clauss 提交于 2020-10-27 19:49 +08:00 . ur"strings" are syntax errors in Python 3
# -*- coding: utf-8 -*-
'''
Evaluation script for CMRC 2018
version: v5
Note:
v5 formatted output, add usage description
v4 fixed segmentation issues
'''
from __future__ import print_function
from collections import Counter, OrderedDict
import string
import re
import argparse
import json
import sys
reload(sys)
sys.setdefaultencoding('utf8')
import nltk
import pdb
# split Chinese with English
def mixed_segmentation(in_str, rm_punc=False):
in_str = str(in_str).decode('utf-8').lower().strip()
segs_out = []
temp_str = ""
sp_char = ['-',':','_','*','^','/','\\','~','`','+','=',
',','。',':','?','!','“','”',';','’','《','》','……','·','、',
'「','」','(',')','-','~','『','』']
for char in in_str:
if rm_punc and char in sp_char:
continue
if re.search(u'[\u4e00-\u9fa5]', char) or char in sp_char:
if temp_str != "":
ss = nltk.word_tokenize(temp_str)
segs_out.extend(ss)
temp_str = ""
segs_out.append(char)
else:
temp_str += char
#handling last part
if temp_str != "":
ss = nltk.word_tokenize(temp_str)
segs_out.extend(ss)
return segs_out
# remove punctuation
def remove_punctuation(in_str):
in_str = str(in_str).decode('utf-8').lower().strip()
sp_char = ['-',':','_','*','^','/','\\','~','`','+','=',
',','。',':','?','!','“','”',';','’','《','》','……','·','、',
'「','」','(',')','-','~','『','』']
out_segs = []
for char in in_str:
if char in sp_char:
continue
else:
out_segs.append(char)
return ''.join(out_segs)
# find longest common string
def find_lcs(s1, s2):
m = [[0 for i in range(len(s2)+1)] for j in range(len(s1)+1)]
mmax = 0
p = 0
for i in range(len(s1)):
for j in range(len(s2)):
if s1[i] == s2[j]:
m[i+1][j+1] = m[i][j]+1
if m[i+1][j+1] > mmax:
mmax=m[i+1][j+1]
p=i+1
return s1[p-mmax:p], mmax
#
def evaluate(ground_truth_file, prediction_file):
f1 = 0
em = 0
total_count = 0
skip_count = 0
for instance in ground_truth_file["data"]:
#context_id = instance['context_id'].strip()
#context_text = instance['context_text'].strip()
for para in instance["paragraphs"]:
for qas in para['qas']:
total_count += 1
query_id = qas['id'].strip()
query_text = qas['question'].strip()
answers = [x["text"] for x in qas['answers']]
if query_id not in prediction_file:
sys.stderr.write('Unanswered question: {}\n'.format(query_id))
skip_count += 1
continue
prediction = str(prediction_file[query_id]).decode('utf-8')
f1 += calc_f1_score(answers, prediction)
em += calc_em_score(answers, prediction)
f1_score = 100.0 * f1 / total_count
em_score = 100.0 * em / total_count
return f1_score, em_score, total_count, skip_count
def calc_f1_score(answers, prediction):
f1_scores = []
for ans in answers:
ans_segs = mixed_segmentation(ans, rm_punc=True)
prediction_segs = mixed_segmentation(prediction, rm_punc=True)
lcs, lcs_len = find_lcs(ans_segs, prediction_segs)
if lcs_len == 0:
f1_scores.append(0)
continue
precision = 1.0*lcs_len/len(prediction_segs)
recall = 1.0*lcs_len/len(ans_segs)
f1 = (2*precision*recall)/(precision+recall)
f1_scores.append(f1)
return max(f1_scores)
def calc_em_score(answers, prediction):
em = 0
for ans in answers:
ans_ = remove_punctuation(ans)
prediction_ = remove_punctuation(prediction)
if ans_ == prediction_:
em = 1
break
return em
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Evaluation Script for CMRC 2018')
parser.add_argument('dataset_file', help='Official dataset file')
parser.add_argument('prediction_file', help='Your prediction File')
args = parser.parse_args()
ground_truth_file = json.load(open(args.dataset_file, 'rb'))
prediction_file = json.load(open(args.prediction_file, 'rb'))
F1, EM, TOTAL, SKIP = evaluate(ground_truth_file, prediction_file)
AVG = (EM+F1)*0.5
output_result = OrderedDict()
output_result['AVERAGE'] = '%.3f' % AVG
output_result['F1'] = '%.3f' % F1
output_result['EM'] = '%.3f' % EM
output_result['TOTAL'] = TOTAL
output_result['SKIP'] = SKIP
output_result['FILE'] = args.prediction_file
print(json.dumps(output_result))
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/iflytek/Chinese-ELECTRA.git
[email protected]:iflytek/Chinese-ELECTRA.git
iflytek
Chinese-ELECTRA
Chinese-ELECTRA
master

搜索帮助