|
|
@@ -413,7 +413,9 @@ def bbox_match_indicator_dropped_text_block(test_dropped_text_bboxs, standard_dr
|
|
|
|
|
|
# 计算和返回标签匹配指标
|
|
|
text_block_tag_report = classification_report(y_true=standard_tag, y_pred=test_tag, labels=list(set(standard_tag) - {'None'}), output_dict=True, zero_division=0)
|
|
|
-
|
|
|
+ del text_block_tag_report["macro avg"]
|
|
|
+ del text_block_tag_report["weighted avg"]
|
|
|
+
|
|
|
return text_block_report, text_block_tag_report
|
|
|
|
|
|
def handle_multi_deletion(test_page, test_page_tag, test_page_bbox, standard_page_tag, standard_page_bbox):
|
|
|
@@ -500,6 +502,142 @@ def merge_json_data(json_test_df, json_standard_df):
|
|
|
return inner_merge, standard_exist, test_exist
|
|
|
|
|
|
|
|
|
+def consolidate_data(test_data, standard_data, key_path):
|
|
|
+ """
|
|
|
+ Consolidates data from test and standard datasets based on the provided key path.
|
|
|
+
|
|
|
+ :param test_data: Dictionary containing the test dataset.
|
|
|
+ :param standard_data: Dictionary containing the standard dataset.
|
|
|
+ :param key_path: List of keys leading to the desired data within the dictionaries.
|
|
|
+ :return: List containing all items from both test and standard data at the specified key path.
|
|
|
+ """
|
|
|
+ # Initialize an empty list to hold the consolidated data
|
|
|
+ overall_data_standard = []
|
|
|
+ overall_data_test = []
|
|
|
+
|
|
|
+ # Helper function to recursively navigate through the dictionaries based on the key path
|
|
|
+ def extract_data(source_data, keys):
|
|
|
+ for key in keys[:-1]:
|
|
|
+ source_data = source_data.get(key, {})
|
|
|
+ return source_data.get(keys[-1], [])
|
|
|
+
|
|
|
+ for data in extract_data(standard_data, key_path):
|
|
|
+ # 假设每个 single_table_tags 已经是一个列表,直接将它的元素添加到总列表中
|
|
|
+ overall_data_standard.extend(data)
|
|
|
+
|
|
|
+ for data in extract_data(test_data, key_path):
|
|
|
+ overall_data_test.extend(data)
|
|
|
+ # Extract and extend the overall data list with items from both test and standard datasets
|
|
|
+
|
|
|
+
|
|
|
+ return overall_data_standard, overall_data_test
|
|
|
+
|
|
|
+def overall_calculate_metrics(inner_merge, json_test, json_standard,standard_exist, test_exist):
|
|
|
+
|
|
|
+ process_data_standard = process_equations_and_blocks(json_standard, is_standard=True)
|
|
|
+ process_data_test = process_equations_and_blocks(json_test, is_standard=False)
|
|
|
+
|
|
|
+
|
|
|
+ overall_report = {}
|
|
|
+ overall_report['accuracy']=metrics.accuracy_score(standard_exist,test_exist)
|
|
|
+ overall_report['precision']=metrics.precision_score(standard_exist,test_exist)
|
|
|
+ overall_report['recall']=metrics.recall_score(standard_exist,test_exist)
|
|
|
+ overall_report['f1_score']=metrics.f1_score(standard_exist,test_exist)
|
|
|
+ overall_report
|
|
|
+
|
|
|
+ test_para_text = np.asarray(process_data_test['para_texts'], dtype=object)[inner_merge['pass_label'] == 'yes']
|
|
|
+ standard_para_text = np.asarray(process_data_standard['para_texts'], dtype=object)[inner_merge['pass_label'] == 'yes']
|
|
|
+ ids_yes = inner_merge['id'][inner_merge['pass_label'] == 'yes'].tolist()
|
|
|
+
|
|
|
+ pdf_dis = {}
|
|
|
+ pdf_bleu = {}
|
|
|
+
|
|
|
+ # 对pass_label为'yes'的数据计算编辑距离和BLEU得分
|
|
|
+ for idx,(a, b, id) in enumerate(zip(test_para_text, standard_para_text, ids_yes)):
|
|
|
+ a1 = ''.join(a)
|
|
|
+ b1 = ''.join(b)
|
|
|
+ pdf_dis[id] = Levenshtein_Distance(a, b)
|
|
|
+ pdf_bleu[id] = sentence_bleu([a1], b1)
|
|
|
+
|
|
|
+ overall_report['pdf间的平均编辑距离'] = np.mean(list(pdf_dis.values()))
|
|
|
+ overall_report['pdf间的平均bleu'] = np.mean(list(pdf_bleu.values()))
|
|
|
+
|
|
|
+ # Consolidate equations bboxs inline
|
|
|
+ overall_equations_bboxs_inline_standard,overall_equations_bboxs_inline_test = consolidate_data(process_data_test, process_data_standard, ["equations_bboxs", "inline"])
|
|
|
+
|
|
|
+ # # Consolidate equations texts inline
|
|
|
+ overall_equations_texts_inline_standard,overall_equations_texts_inline_test = consolidate_data(process_data_test, process_data_standard, ["equations_texts", "inline"])
|
|
|
+
|
|
|
+ # Consolidate equations bboxs interline
|
|
|
+ overall_equations_bboxs_interline_standard,overall_equations_bboxs_interline_test = consolidate_data(process_data_test, process_data_standard, ["equations_bboxs", "interline"])
|
|
|
+
|
|
|
+ # Consolidate equations texts interline
|
|
|
+ overall_equations_texts_interline_standard,overall_equations_texts_interline_test = consolidate_data(process_data_test, process_data_standard, ["equations_texts", "interline"])
|
|
|
+
|
|
|
+ overall_dropped_bboxs_text_standard,overall_dropped_bboxs_text_test = consolidate_data(process_data_test, process_data_standard, ["dropped_bboxs","text"])
|
|
|
+
|
|
|
+ overall_dropped_tags_text_standard,overall_dropped_tags_text_test = consolidate_data(process_data_test, process_data_standard, ["dropped_tags","text"])
|
|
|
+
|
|
|
+ overall_dropped_bboxs_image_standard,overall_dropped_bboxs_image_test = consolidate_data(process_data_test, process_data_standard, ["dropped_bboxs","image"])
|
|
|
+
|
|
|
+
|
|
|
+ overall_dropped_bboxs_table_standard,overall_dropped_bboxs_table_test=consolidate_data(process_data_test, process_data_standard,["dropped_bboxs","table"])
|
|
|
+
|
|
|
+
|
|
|
+ para_nums_test = process_data_test['para_nums']
|
|
|
+ para_nums_standard=process_data_standard['para_nums']
|
|
|
+ overall_para_nums_standard = [item for sublist in para_nums_standard for item in (sublist if isinstance(sublist, list) else [sublist])]
|
|
|
+ overall_para_nums_test = [item for sublist in para_nums_test for item in (sublist if isinstance(sublist, list) else [sublist])]
|
|
|
+
|
|
|
+
|
|
|
+ test_para_num=np.array(overall_para_nums_test)
|
|
|
+ standard_para_num=np.array(overall_para_nums_standard)
|
|
|
+ acc_para=np.mean(test_para_num==standard_para_num)
|
|
|
+
|
|
|
+
|
|
|
+ overall_report['分段准确率'] = acc_para
|
|
|
+
|
|
|
+ # 行内公式准确率和编辑距离、bleu
|
|
|
+ overall_report['行内公式准确率'] = bbox_match_indicator_general(
|
|
|
+ overall_equations_bboxs_inline_test,
|
|
|
+ overall_equations_bboxs_inline_standard)
|
|
|
+
|
|
|
+ overall_report['行内公式编辑距离'], overall_report['行内公式bleu'] = equations_indicator(
|
|
|
+ overall_equations_bboxs_inline_test,
|
|
|
+ overall_equations_bboxs_inline_standard,
|
|
|
+ overall_equations_texts_inline_test,
|
|
|
+ overall_equations_texts_inline_standard)
|
|
|
+
|
|
|
+ # 行间公式准确率和编辑距离、bleu
|
|
|
+ overall_report['行间公式准确率'] = bbox_match_indicator_general(
|
|
|
+ overall_equations_bboxs_interline_test,
|
|
|
+ overall_equations_bboxs_interline_standard)
|
|
|
+
|
|
|
+ overall_report['行间公式编辑距离'], overall_report['行间公式bleu'] = equations_indicator(
|
|
|
+ overall_equations_bboxs_interline_test,
|
|
|
+ overall_equations_bboxs_interline_standard,
|
|
|
+ overall_equations_texts_interline_test,
|
|
|
+ overall_equations_texts_interline_standard)
|
|
|
+
|
|
|
+ # 丢弃文本准确率,丢弃文本标签准确率
|
|
|
+ overall_report['丢弃文本准确率'], overall_report['丢弃文本标签准确率'] = bbox_match_indicator_dropped_text_block(
|
|
|
+ overall_dropped_bboxs_text_test,
|
|
|
+ overall_dropped_bboxs_text_standard,
|
|
|
+ overall_dropped_tags_text_standard,
|
|
|
+ overall_dropped_tags_text_test)
|
|
|
+
|
|
|
+ # 丢弃图片准确率
|
|
|
+ overall_report['丢弃图片准确率'] = bbox_match_indicator_general(
|
|
|
+ overall_dropped_bboxs_image_test,
|
|
|
+ overall_dropped_bboxs_image_standard)
|
|
|
+
|
|
|
+ # 丢弃表格准确率
|
|
|
+ overall_report['丢弃表格准确率'] = bbox_match_indicator_general(
|
|
|
+ overall_dropped_bboxs_table_test,
|
|
|
+ overall_dropped_bboxs_table_standard)
|
|
|
+
|
|
|
+ return overall_report
|
|
|
+
|
|
|
|
|
|
|
|
|
def calculate_metrics(inner_merge, json_test, json_standard, json_standard_origin):
|
|
|
@@ -602,21 +740,27 @@ def calculate_metrics(inner_merge, json_test, json_standard, json_standard_origi
|
|
|
return result_dict
|
|
|
|
|
|
|
|
|
-def save_results(result_dict, output_path):
|
|
|
+
|
|
|
+def save_results(result_dict,overall_report_dict,badcase_path,overall_path,):
|
|
|
"""
|
|
|
将结果字典保存为JSON文件至指定路径。
|
|
|
|
|
|
参数:
|
|
|
- result_dict: 包含计算结果的字典。
|
|
|
- - output_path: 结果文件的保存路径,包括文件名。
|
|
|
+ - overall_path: 结果文件的保存路径,包括文件名。
|
|
|
"""
|
|
|
# 打开指定的文件以写入
|
|
|
- with open(output_path, 'w', encoding='utf-8') as f:
|
|
|
+ with open(badcase_path, 'w', encoding='utf-8') as f:
|
|
|
# 将结果字典转换为JSON格式并写入文件
|
|
|
json.dump(result_dict, f, ensure_ascii=False, indent=4)
|
|
|
|
|
|
- print(f"计算结果已经保存到文件:{output_path}")
|
|
|
+ print(f"计算结果已经保存到文件:{badcase_path}")
|
|
|
+
|
|
|
+ with open(overall_path, 'w', encoding='utf-8') as f:
|
|
|
+ # 将结果字典转换为JSON格式并写入文件
|
|
|
+ json.dump(overall_report_dict, f, ensure_ascii=False, indent=4)
|
|
|
|
|
|
+ print(f"计算结果已经保存到文件:{overall_path}")
|
|
|
|
|
|
def upload_to_s3(file_path, bucket_name, s3_file_name,AWS_ACCESS_KEY,AWS_SECRET_KEY,END_POINT_URL):
|
|
|
"""
|
|
|
@@ -634,7 +778,7 @@ def upload_to_s3(file_path, bucket_name, s3_file_name,AWS_ACCESS_KEY,AWS_SECRET_
|
|
|
except ClientError as e:
|
|
|
print(f"上传文件时发生错误:{e}")
|
|
|
|
|
|
-def generate_output_filename(base_path):
|
|
|
+def generate_filename(badcase_path,overall_path):
|
|
|
"""
|
|
|
生成带有当前时间戳的输出文件名。
|
|
|
|
|
|
@@ -647,13 +791,24 @@ def generate_output_filename(base_path):
|
|
|
# 获取当前时间并格式化为字符串
|
|
|
current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
|
|
|
# 构建并返回完整的输出文件名
|
|
|
- return f"{base_path}_{current_time}.json"
|
|
|
+ return f"{badcase_path}_{current_time}.json",f"{overall_path}_{current_time}.json"
|
|
|
|
|
|
|
|
|
|
|
|
+def compare_edit_distance(json_file, overall_report):
|
|
|
+ with open(json_file, 'r',encoding='utf-8') as f:
|
|
|
+ json_data = json.load(f)
|
|
|
+
|
|
|
+ json_edit_distance = json_data['pdf间的平均编辑距离']
|
|
|
+
|
|
|
+ if overall_report['pdf间的平均编辑距离'] >= json_edit_distance:
|
|
|
+ return 0
|
|
|
+ else:
|
|
|
+ return 1
|
|
|
+
|
|
|
|
|
|
|
|
|
-def main(standard_file, test_file, zip_file, base_output_path, s3_bucket_name=None, s3_file_name=None, AWS_ACCESS_KEY=None, AWS_SECRET_KEY=None, END_POINT_URL=None):
|
|
|
+def main(standard_file, test_file, zip_file, badcase_path, overall_path,base_data_path,s3_bucket_name=None, s3_file_name=None, AWS_ACCESS_KEY=None, AWS_SECRET_KEY=None, END_POINT_URL=None):
|
|
|
"""
|
|
|
主函数,执行整个评估流程。
|
|
|
|
|
|
@@ -661,7 +816,8 @@ def main(standard_file, test_file, zip_file, base_output_path, s3_bucket_name=No
|
|
|
- standard_file: 标准文件的路径。
|
|
|
- test_file: 测试文件的路径。
|
|
|
- zip_file: 压缩包的路径的路径。
|
|
|
- - base_output_path: 结果文件的基础路径和文件名前缀。
|
|
|
+ - badcase_path: badcase文件的基础路径和文件名前缀。
|
|
|
+ - overall_path: overall文件的基础路径和文件名前缀。
|
|
|
- s3_bucket_name: S3桶名称(可选)。
|
|
|
- s3_file_name: S3上的文件名(可选)。
|
|
|
- AWS_ACCESS_KEY, AWS_SECRET_KEY, END_POINT_URL: AWS访问凭证和端点URL(可选)。
|
|
|
@@ -675,21 +831,29 @@ def main(standard_file, test_file, zip_file, base_output_path, s3_bucket_name=No
|
|
|
# 合并JSON数据
|
|
|
inner_merge, standard_exist, test_exist = merge_json_data(json_test_origin, json_standard_origin)
|
|
|
|
|
|
+ #计算总体指标
|
|
|
+ overall_report_dict=overall_calculate_metrics(inner_merge, inner_merge['test_mid_json'], inner_merge['standard_mid_json'],standard_exist, test_exist)
|
|
|
# 计算指标
|
|
|
result_dict = calculate_metrics(inner_merge, inner_merge['test_mid_json'], inner_merge['standard_mid_json'], json_standard_origin)
|
|
|
|
|
|
# 生成带时间戳的输出文件名
|
|
|
- output_file = generate_output_filename(base_output_path)
|
|
|
+ badcase_file,overall_file = generate_filename(badcase_path,overall_path)
|
|
|
|
|
|
# 保存结果到JSON文件
|
|
|
- save_results(result_dict, output_file)
|
|
|
+ save_results(result_dict, overall_report_dict,badcase_file,overall_file)
|
|
|
+
|
|
|
+ result=compare_edit_distance(base_data_path, overall_report_dict)
|
|
|
+ print(result)
|
|
|
+ assert result == 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
parser = argparse.ArgumentParser(description="主函数,执行整个评估流程。")
|
|
|
parser.add_argument('standard_file', type=str, help='标准文件的路径。')
|
|
|
parser.add_argument('test_file', type=str, help='测试文件的路径。')
|
|
|
parser.add_argument('zip_file', type=str, help='压缩包的路径。')
|
|
|
- parser.add_argument('base_output_path', type=str, help='结果文件的基础路径和文件名前缀。')
|
|
|
+ parser.add_argument('badcase_path', type=str, help='badcase文件的基础路径和文件名前缀。')
|
|
|
+ parser.add_argument('overall_path', type=str, help='overall文件的基础路径和文件名前缀。')
|
|
|
+ parser.add_argument('base_data_path', type=str, help='基准文件的基础路径和文件名前缀。')
|
|
|
parser.add_argument('--s3_bucket_name', type=str, help='S3桶名称。', default=None)
|
|
|
parser.add_argument('--s3_file_name', type=str, help='S3上的文件名。', default=None)
|
|
|
parser.add_argument('--AWS_ACCESS_KEY', type=str, help='AWS访问密钥。', default=None)
|
|
|
@@ -698,5 +862,5 @@ if __name__ == "__main__":
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
- main(args.standard_file, args.test_file, args.zip_file, args.base_output_path, args.s3_bucket_name, args.s3_file_name, args.AWS_ACCESS_KEY, args.AWS_SECRET_KEY, args.END_POINT_URL)
|
|
|
+ main(args.standard_file, args.test_file, args.zip_file, args.badcase_path,args.overall_path,args.base_data_path,args.s3_bucket_name, args.s3_file_name, args.AWS_ACCESS_KEY, args.AWS_SECRET_KEY, args.END_POINT_URL)
|
|
|
|