فهرست منبع

add an option to freely output 'badcase.json

Shuimo 1 سال پیش
والد
کامیت
d14579373a
2فایلهای تغییر یافته به همراه78 افزوده شده و 88 حذف شده
  1. 40 45
      tools/ocr_badcase.py
  2. 38 43
      tools/text_badcase.py

+ 40 - 45
tools/ocr_badcase.py

@@ -756,26 +756,28 @@ def merge_json_data(json_test_df, json_standard_df):
 
     return inner_merge, standard_exist, test_exist
 
-def save_results(result_dict,overall_report_dict,badcase_path,overall_path,):
+def generate_filename(base_path):
     """
-    将结果字典保存为JSON文件至指定路径。
-
+    生成带有当前时间戳的输出文件名。
     参数:
-    - result_dict: 包含计算结果的字典。
-    - overall_path: 结果文件的保存路径,包括文件名。
+    - base_path: 基础路径和文件名前缀。
+    返回:
+    - 带有当前时间戳的完整输出文件名。
     """
-    # 打开指定的文件以写入
-    with open(badcase_path, 'w', encoding='utf-8') as f:
-        # 将结果字典转换为JSON格式并写入文件
-        json.dump(result_dict, f, ensure_ascii=False, indent=4)
-
-    print(f"计算结果已经保存到文件:{badcase_path}")
+    current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
+    return f"{base_path}_{current_time}.json"
 
-    with open(overall_path, 'w', encoding='utf-8') as f:
-    # 将结果字典转换为JSON格式并写入文件
-        json.dump(overall_report_dict, f, ensure_ascii=False, indent=4)
+def save_results(data_dict, file_path):
+    """
+    将数据字典保存为JSON文件至指定路径。
+    参数:
+    - data_dict: 包含数据的字典。
+    - file_path: 结果文件的保存路径,包括文件名。
+    """
+    with open(file_path, 'w', encoding='utf-8') as f:
+        json.dump(data_dict, f, ensure_ascii=False, indent=4)
+    print(f"结果已经保存到文件:{file_path}")
 
-    print(f"计算结果已经保存到文件:{overall_path}")
 
 def upload_to_s3(file_path, bucket_name, s3_directory, AWS_ACCESS_KEY, AWS_SECRET_KEY, END_POINT_URL):
     """
@@ -801,20 +803,7 @@ def upload_to_s3(file_path, bucket_name, s3_directory, AWS_ACCESS_KEY, AWS_SECRE
     except ClientError as e:
         print(f"上传文件时发生错误:{e}")
 
-def generate_filename(badcase_path,overall_path):
-    """
-    生成带有当前时间戳的输出文件名。
-
-    参数:
-    - base_path: 基础路径和文件名前缀。
 
-    返回:
-    - 带有当前时间戳的完整输出文件名。
-    """
-    # 获取当前时间并格式化为字符串
-    current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
-    # 构建并返回完整的输出文件名
-    return f"{badcase_path}_{current_time}.json",f"{overall_path}_{current_time}.json"
 
 
 
@@ -831,7 +820,7 @@ def compare_edit_distance(json_file, overall_report):
 
 
 
-def main(standard_file, test_file, zip_file, badcase_path, overall_path,base_data_path, s3_bucket_name=None, s3_file_directory=None, 
+def main(standard_file, test_file, zip_file, overall_path, base_data_path, badcase_path=None, s3_bucket_name=None, s3_file_directory=None, 
          aws_access_key=None, aws_secret_key=None, end_point_url=None):
     """
     主函数,执行整个评估流程。
@@ -840,8 +829,9 @@ def main(standard_file, test_file, zip_file, badcase_path, overall_path,base_dat
     - standard_file: 标准文件的路径。
     - test_file: 测试文件的路径。
     - zip_file: 压缩包的路径的路径。
-    - badcase_path: badcase文件的基础路径和文件名前缀。
+    - badcase_path: badcase文件的基础路径和文件名前缀(可选)
     - overall_path: overall文件的基础路径和文件名前缀。
+    - base_data_path: 基础数据路径。
     - s3_bucket_name: S3桶名称(可选)。
     - s3_file_directory: S3上的文件保存目录(可选)。
     - AWS_ACCESS_KEY, AWS_SECRET_KEY, END_POINT_URL: AWS访问凭证和端点URL(可选)。
@@ -855,46 +845,51 @@ def main(standard_file, test_file, zip_file, badcase_path, overall_path,base_dat
     # 合并JSON数据
     inner_merge, standard_exist, test_exist = merge_json_data(json_test_origin, json_standard_origin)
 
-    #计算总体指标
-    overall_report_dict=overall_calculate_metrics(inner_merge, inner_merge['test_mid_json'], inner_merge['standard_mid_json'],standard_exist, test_exist)
-    # 计算指标
-    result_dict = calculate_metrics(inner_merge, inner_merge['test_mid_json'], inner_merge['standard_mid_json'], json_standard_origin)
+    # 计算总体指标
+    overall_report_dict = overall_calculate_metrics(inner_merge, inner_merge['test_mid_json'], inner_merge['standard_mid_json'], standard_exist, test_exist)
 
     # 生成带时间戳的输出文件名
-    badcase_file,overall_file = generate_filename(badcase_path,overall_path)
+    if badcase_path:
+        badcase_file = generate_filename(badcase_path)
+        result_dict =  result_dict = calculate_metrics(inner_merge, inner_merge['test_mid_json'], inner_merge['standard_mid_json'], json_standard_origin)
+        save_results(result_dict, badcase_file)
 
-    # 保存结果到JSON文件
-    save_results(result_dict, overall_report_dict,badcase_file,overall_file)
+    overall_file = generate_filename(overall_path)
+    save_results(overall_report_dict, overall_file)
 
-    result=compare_edit_distance(base_data_path, overall_report_dict)
-<<<<<<< HEAD
+    result = compare_edit_distance(base_data_path, overall_report_dict)
 
     if all([s3_bucket_name, s3_file_directory, aws_access_key, aws_secret_key, end_point_url]):
         try:
-            upload_to_s3(badcase_file, s3_bucket_name, s3_file_directory, aws_access_key, aws_secret_key, end_point_url)
+            if badcase_path:
+                upload_to_s3(badcase_file, s3_bucket_name, s3_file_directory, aws_access_key, aws_secret_key, end_point_url)
             upload_to_s3(overall_file, s3_bucket_name, s3_file_directory, aws_access_key, aws_secret_key, end_point_url)
         except Exception as e:
             print(f"上传到S3时发生错误: {e}")
-=======
->>>>>>> ff8f62aa3c28facc192104387f131d87978064fc
+
     print(result)
     assert result == 1
 
+
+
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(description="主函数,执行整个评估流程。")
     parser.add_argument('standard_file', type=str, help='标准文件的路径。')
     parser.add_argument('test_file', type=str, help='测试文件的路径。')
     parser.add_argument('zip_file', type=str, help='压缩包的路径。')
-    parser.add_argument('badcase_path', type=str, help='badcase文件的基础路径和文件名前缀。')
     parser.add_argument('overall_path', type=str, help='overall文件的基础路径和文件名前缀。')
     parser.add_argument('base_data_path', type=str, help='基准文件的基础路径和文件名前缀。')
+    parser.add_argument('--badcase_path', type=str, default=None, help='badcase文件的基础路径和文件名前缀(可选)。')
     parser.add_argument('--s3_bucket_name', type=str, help='S3桶名称。', default=None)
-    parser.add_argument('--s3_file_directory', type=str, help='S3上的文件。', default=None)
+    parser.add_argument('--s3_file_directory', type=str, help='S3上的文件保存目录。', default=None)
     parser.add_argument('--AWS_ACCESS_KEY', type=str, help='AWS访问密钥。', default=None)
     parser.add_argument('--AWS_SECRET_KEY', type=str, help='AWS秘密密钥。', default=None)
     parser.add_argument('--END_POINT_URL', type=str, help='AWS端点URL。', default=None)
 
     args = parser.parse_args()
 
-    main(args.standard_file, args.test_file, args.zip_file, args.badcase_path,args.overall_path,args.base_data_path,args.s3_bucket_name, args.s3_file_directory, args.AWS_ACCESS_KEY, args.AWS_SECRET_KEY, args.END_POINT_URL)
+    main(args.standard_file, args.test_file, args.zip_file, args.overall_path, args.base_data_path,
+         badcase_path=args.badcase_path, s3_bucket_name=args.s3_bucket_name, 
+         s3_file_directory=args.s3_file_directory, aws_access_key=args.AWS_ACCESS_KEY, 
+         aws_secret_key=args.AWS_SECRET_KEY, end_point_url=args.END_POINT_URL)
 

+ 38 - 43
tools/text_badcase.py

@@ -768,26 +768,28 @@ def merge_json_data(json_test_df, json_standard_df):
 
     return inner_merge, standard_exist, test_exist
 
-def save_results(result_dict,overall_report_dict,badcase_path,overall_path,):
+def generate_filename(base_path):
     """
-    将结果字典保存为JSON文件至指定路径。
-
+    生成带有当前时间戳的输出文件名。
     参数:
-    - result_dict: 包含计算结果的字典。
-    - overall_path: 结果文件的保存路径,包括文件名。
+    - base_path: 基础路径和文件名前缀。
+    返回:
+    - 带有当前时间戳的完整输出文件名。
     """
-    # 打开指定的文件以写入
-    with open(badcase_path, 'w', encoding='utf-8') as f:
-        # 将结果字典转换为JSON格式并写入文件
-        json.dump(result_dict, f, ensure_ascii=False, indent=4)
-
-    print(f"计算结果已经保存到文件:{badcase_path}")
+    current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
+    return f"{base_path}_{current_time}.json"
 
-    with open(overall_path, 'w', encoding='utf-8') as f:
-    # 将结果字典转换为JSON格式并写入文件
-        json.dump(overall_report_dict, f, ensure_ascii=False, indent=4)
+def save_results(data_dict, file_path):
+    """
+    将数据字典保存为JSON文件至指定路径。
+    参数:
+    - data_dict: 包含数据的字典。
+    - file_path: 结果文件的保存路径,包括文件名。
+    """
+    with open(file_path, 'w', encoding='utf-8') as f:
+        json.dump(data_dict, f, ensure_ascii=False, indent=4)
+    print(f"结果已经保存到文件:{file_path}")
 
-    print(f"计算结果已经保存到文件:{overall_path}")
 
     
 def upload_to_s3(file_path, bucket_name, s3_directory, AWS_ACCESS_KEY, AWS_SECRET_KEY, END_POINT_URL):
@@ -815,21 +817,6 @@ def upload_to_s3(file_path, bucket_name, s3_directory, AWS_ACCESS_KEY, AWS_SECRE
         print(f"上传文件时发生错误:{e}")
 
 
-def generate_filename(badcase_path,overall_path):
-    """
-    生成带有当前时间戳的输出文件名。
-
-    参数:
-    - base_path: 基础路径和文件名前缀。
-
-    返回:
-    - 带有当前时间戳的完整输出文件名。
-    """
-    # 获取当前时间并格式化为字符串
-    current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
-    # 构建并返回完整的输出文件名
-    return f"{badcase_path}_{current_time}.json",f"{overall_path}_{current_time}.json"
-
 
 def compare_edit_distance(json_file, overall_report):
     with open(json_file, 'r',encoding='utf-8') as f:
@@ -842,7 +829,7 @@ def compare_edit_distance(json_file, overall_report):
     else:
         return 1
     
-def main(standard_file, test_file, zip_file, badcase_path, overall_path,base_data_path, s3_bucket_name=None, s3_file_directory=None, 
+def main(standard_file, test_file, zip_file, overall_path, base_data_path, badcase_path=None, s3_bucket_name=None, s3_file_directory=None, 
          aws_access_key=None, aws_secret_key=None, end_point_url=None):
     """
     主函数,执行整个评估流程。
@@ -851,12 +838,14 @@ def main(standard_file, test_file, zip_file, badcase_path, overall_path,base_dat
     - standard_file: 标准文件的路径。
     - test_file: 测试文件的路径。
     - zip_file: 压缩包的路径的路径。
-    - badcase_path: badcase文件的基础路径和文件名前缀。
+    - badcase_path: badcase文件的基础路径和文件名前缀(可选)
     - overall_path: overall文件的基础路径和文件名前缀。
+    - base_data_path: 基础数据路径。
     - s3_bucket_name: S3桶名称(可选)。
     - s3_file_directory: S3上的文件保存目录(可选)。
     - AWS_ACCESS_KEY, AWS_SECRET_KEY, END_POINT_URL: AWS访问凭证和端点URL(可选)。
     """
+
     # 检查文件是否存在
     check_json_files_in_zip_exist(zip_file, standard_file, test_file)
 
@@ -868,39 +857,45 @@ def main(standard_file, test_file, zip_file, badcase_path, overall_path,base_dat
 
     #计算总体指标
     overall_report_dict=overall_calculate_metrics(inner_merge, inner_merge['test_mid_json'], inner_merge['standard_mid_json'],standard_exist, test_exist)
-    # 计算指标
-    result_dict = calculate_metrics(inner_merge, inner_merge['test_mid_json'], inner_merge['standard_mid_json'], json_standard_origin)
-
     # 生成带时间戳的输出文件名
-    badcase_file,overall_file = generate_filename(badcase_path,overall_path)
+    if badcase_path:
+        badcase_file = generate_filename(badcase_path)
+        result_dict =  result_dict = calculate_metrics(inner_merge, inner_merge['test_mid_json'], inner_merge['standard_mid_json'], json_standard_origin)
+        save_results(result_dict, badcase_file)
 
-    # 保存结果到JSON文件
-    save_results(result_dict, overall_report_dict,badcase_file,overall_file)
+    overall_file = generate_filename(overall_path)
+    save_results(overall_report_dict, overall_file)
 
-    result=compare_edit_distance(base_data_path, overall_report_dict)
+    result = compare_edit_distance(base_data_path, overall_report_dict)
 
     if all([s3_bucket_name, s3_file_directory, aws_access_key, aws_secret_key, end_point_url]):
         try:
-            upload_to_s3(badcase_file, s3_bucket_name, s3_file_directory, aws_access_key, aws_secret_key, end_point_url)
+            if badcase_path:
+                upload_to_s3(badcase_file, s3_bucket_name, s3_file_directory, aws_access_key, aws_secret_key, end_point_url)
             upload_to_s3(overall_file, s3_bucket_name, s3_file_directory, aws_access_key, aws_secret_key, end_point_url)
         except Exception as e:
             print(f"上传到S3时发生错误: {e}")
+
     print(result)
+    assert result == 1
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(description="主函数,执行整个评估流程。")
     parser.add_argument('standard_file', type=str, help='标准文件的路径。')
     parser.add_argument('test_file', type=str, help='测试文件的路径。')
     parser.add_argument('zip_file', type=str, help='压缩包的路径。')
-    parser.add_argument('badcase_path', type=str, help='badcase文件的基础路径和文件名前缀。')
     parser.add_argument('overall_path', type=str, help='overall文件的基础路径和文件名前缀。')
     parser.add_argument('base_data_path', type=str, help='基准文件的基础路径和文件名前缀。')
+    parser.add_argument('--badcase_path', type=str, default=None, help='badcase文件的基础路径和文件名前缀(可选)。')
     parser.add_argument('--s3_bucket_name', type=str, help='S3桶名称。', default=None)
-    parser.add_argument('--s3_file_directory', type=str, help='S3上的文件。', default=None)
+    parser.add_argument('--s3_file_directory', type=str, help='S3上的文件保存目录。', default=None)
     parser.add_argument('--AWS_ACCESS_KEY', type=str, help='AWS访问密钥。', default=None)
     parser.add_argument('--AWS_SECRET_KEY', type=str, help='AWS秘密密钥。', default=None)
     parser.add_argument('--END_POINT_URL', type=str, help='AWS端点URL。', default=None)
 
     args = parser.parse_args()
 
-    main(args.standard_file, args.test_file, args.zip_file, args.badcase_path,args.overall_path,args.base_data_path,args.s3_bucket_name, args.s3_file_directory, args.AWS_ACCESS_KEY, args.AWS_SECRET_KEY, args.END_POINT_URL)
+    main(args.standard_file, args.test_file, args.zip_file, args.overall_path, args.base_data_path,
+         badcase_path=args.badcase_path, s3_bucket_name=args.s3_bucket_name, 
+         s3_file_directory=args.s3_file_directory, aws_access_key=args.AWS_ACCESS_KEY, 
+         aws_secret_key=args.AWS_SECRET_KEY, end_point_url=args.END_POINT_URL)