Procházet zdrojové kódy

Merge branch 'master' into master

myhloli před 1 rokem
rodič
revize
c8b06ad589

+ 14 - 4
.github/workflows/benchmark.yml

@@ -9,7 +9,7 @@ on:
     paths-ignore:
       - "cmds/**"
       - "**.md"
-
+  workflow_dispatch:
 jobs:
   pdf-test:
     runs-on: pdf
@@ -18,14 +18,16 @@ jobs:
       fail-fast: true
 
     steps:
+    - name: config-net
+      run: |
+        export http_proxy=http://bigdata_open_proxy:H89k5qwQRDYfz@10.140.90.20:10811
+        export https_proxy=http://bigdata_open_proxy:H89k5qwQRDYfz@10.140.90.20:10811
     - name: PDF benchmark
       uses: actions/checkout@v3
       with:
         fetch-depth: 2
     - name: check-requirements
       run: |
-        export http_proxy=http://bigdata_open_proxy:H89k5qwQRDYfz@10.140.90.20:10811
-        export https_proxy=http://bigdata_open_proxy:H89k5qwQRDYfz@10.140.90.20:10811
         changed_files=$(git diff --name-only -r HEAD~1 HEAD)
         echo $changed_files
         if [[ $changed_files =~ "requirements.txt" ]]; then
@@ -36,4 +38,12 @@ jobs:
     - name: benchmark
       run: |
         echo "start test"
-        cd tools && python ocr_badcase.py pdf_json_label_0306.json ocr_dataset.json json_files.zip output.json
+        cd tools && python ocr_badcase.py pdf_json_label_0306.json ocr_dataset.json json_files.zip badcase.json overall.json base_data.json
+  notify_to_feishu:
+    if: ${{ always() && !cancelled() && contains(needs.*.result, 'failure') && (github.ref_name == 'master') }}
+    needs: [pdf-test]
+    runs-on: [pdf]
+    steps:
+    - name: notify
+      run: |
+        curl -X POST -H "Content-Type: application/json" -d '{"msg_type":"post","content":{"post":{"zh_cn":{"title":"'${{ github.repository }}' GitHubAction Failed","content":[[{"tag":"text","text":""},{"tag":"a","text":"Please click here for details ","href":"https://github.com/'${{ github.repository }}'/actions/runs/'${GITHUB_RUN_ID}'"},{"tag":"at","user_id":"'${{ secrets.USER_ID }}'"}]]}}}}'  ${{ secrets.WEBHOOK_URL }}

+ 8 - 6
.github/workflows/update_base.yml

@@ -1,11 +1,12 @@
 # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
 # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
 
-name: PDF
+name: update-base
 on:
-release:
-  types: [published]
-
+  push:
+    tags:
+      - '*released'
+  workflow_dispatch:
 jobs:
   pdf-test:
     runs-on: pdf
@@ -15,6 +16,7 @@ jobs:
     steps:
     - name: update-base
       uses: actions/checkout@v3
+    - name: start-update
       run: |
-          python update_base.py
-  
+        echo "start test"
+  

+ 1 - 0
demo/ocr_demo.py

@@ -116,6 +116,7 @@ if __name__ == '__main__':
     pdf_path = r"/home/cxu/workspace/Magic-PDF/ocr_demo/j.1540-627x.2006.00176.x.pdf"
     json_file_path = r"/home/cxu/workspace/Magic-PDF/ocr_demo/j.1540-627x.2006.00176.x.json"
     # ocr_local_parse(pdf_path, json_file_path)
+    
     book_name = "数学新星网/edu_00001236"
     ocr_online_parse(book_name)
     

+ 24 - 10
magic_pdf/io/AbsReaderWriter.py

@@ -1,20 +1,34 @@
-
 from abc import ABC, abstractmethod
 
 
 class AbsReaderWriter(ABC):
     """
     同时支持二进制和文本读写的抽象类
-    TODO
     """
+    MODE_TXT = "text"
+    MODE_BIN = "binary"
+
+    def __init__(self, parent_path):
+        # 初始化代码可以在这里添加,如果需要的话
+        self.parent_path = parent_path # 对于本地目录是父目录,对于s3是会写到这个apth下。
+
+    @abstractmethod
+    def read(self, path: str, mode="text"):
+        """
+        无论对于本地还是s3的路径,检查如果path是绝对路径,那么就不再 拼接parent_path, 如果是相对路径就拼接parent_path
+        """
+        raise NotImplementedError
+
     @abstractmethod
-    def read(self, path: str):
-        pass
+    def write(self, content: str, path: str, mode=MODE_TXT):
+        """
+        无论对于本地还是s3的路径,检查如果path是绝对路径,那么就不再 拼接parent_path, 如果是相对路径就拼接parent_path
+        """
+        raise NotImplementedError
 
     @abstractmethod
-    def write(self, path: str, content: str):
-        pass
-    
-    
-    
-    
+    def read_jsonl(self, path: str, byte_start=0, byte_end=None, encoding='utf-8'):
+        """
+        无论对于本地还是s3的路径,检查如果path是绝对路径,那么就不再 拼接parent_path, 如果是相对路径就拼接parent_path
+        """
+        raise NotImplementedError

+ 49 - 0
magic_pdf/io/DiskReaderWriter.py

@@ -0,0 +1,49 @@
+import os
+from magic_pdf.io.AbsReaderWriter import AbsReaderWriter
+from loguru import logger
+class DiskReaderWriter(AbsReaderWriter):
+    def __init__(self, parent_path, encoding='utf-8'):
+        self.path = parent_path
+        self.encoding = encoding
+
+    def read(self, mode="text"):
+        if not os.path.exists(self.path):
+            logger.error(f"文件 {self.path} 不存在")
+            raise Exception(f"文件 {self.path} 不存在")
+        if mode == "text":
+            with open(self.path, 'r', encoding = self.encoding) as f:
+                return f.read()
+        elif mode == "binary":
+            with open(self.path, 'rb') as f:
+                return f.read()
+        else:
+            raise ValueError("Invalid mode. Use 'text' or 'binary'.")
+
+    def write(self, data, mode="text"):
+        if mode == "text":
+            with open(self.path, 'w', encoding=self.encoding) as f:
+                f.write(data)
+                logger.info(f"内容已成功写入 {self.path}")
+
+        elif mode == "binary":
+            with open(self.path, 'wb') as f:
+                f.write(data)
+                logger.info(f"内容已成功写入 {self.path}")
+        else:
+            raise ValueError("Invalid mode. Use 'text' or 'binary'.")
+
+
+# 使用示例
+if __name__ == "__main__":
+    file_path = "example.txt"
+    drw = DiskReaderWriter(file_path)
+
+    # 写入内容到文件
+    drw.write(b"Hello, World!", mode="binary")
+
+    # 从文件读取内容
+    content = drw.read()
+    if content:
+        logger.info(f"从 {file_path} 读取的内容: {content}")
+
+

+ 66 - 12
magic_pdf/io/S3ReaderWriter.py

@@ -1,18 +1,72 @@
 
 
-from magic_pdf.io import AbsReaderWriter
+from magic_pdf.io.AbsReaderWriter import AbsReaderWriter
+from magic_pdf.libs.commons import parse_aws_param, parse_bucket_key
+import boto3
+from loguru import logger
+from boto3.s3.transfer import TransferConfig
+from botocore.config import Config
 
 
-class DiskReaderWriter(AbsReaderWriter):
-    def __init__(self, parent_path, encoding='utf-8'):
-        self.path = parent_path
-        self.encoding = encoding
+class S3ReaderWriter(AbsReaderWriter):
+    def __init__(self, ak: str, sk: str, endpoint_url: str, addressing_style: str):
+        self.client = self._get_client(ak, sk, endpoint_url, addressing_style)
 
-    def read(self):
-        with open(self.path, 'rb') as f:
-            return f.read()
+    def _get_client(self, ak: str, sk: str, endpoint_url: str, addressing_style: str):
+        s3_client = boto3.client(
+            service_name="s3",
+            aws_access_key_id=ak,
+            aws_secret_access_key=sk,
+            endpoint_url=endpoint_url,
+            config=Config(s3={"addressing_style": addressing_style},
+                          retries={'max_attempts': 5, 'mode': 'standard'}),
+        )
+        return s3_client
+    def read(self, s3_path, mode="text", encoding="utf-8"):
+        bucket_name, bucket_key = parse_bucket_key(s3_path)
+        res = self.client.get_object(Bucket=bucket_name, Key=bucket_key)
+        body = res["Body"].read()
+        if mode == 'text':
+            data = body.decode(encoding)  # Decode bytes to text
+        elif mode == 'binary':
+            data = body
+        else:
+            raise ValueError("Invalid mode. Use 'text' or 'binary'.")
+        return data
 
-    def write(self, data):
-        with open(self.path, 'wb') as f:
-            f.write(data)
-            
+    def write(self, data, s3_path, mode="text", encoding="utf-8"):
+        if mode == 'text':
+            body = data.encode(encoding)  # Encode text data as bytes
+        elif mode == 'binary':
+            body = data
+        else:
+            raise ValueError("Invalid mode. Use 'text' or 'binary'.")
+        bucket_name, bucket_key = parse_bucket_key(s3_path)
+        self.client.put_object(Body=body, Bucket=bucket_name, Key=bucket_key)
+        logger.info(f"内容已写入 {s3_path} ")
+
+
+if __name__ == "__main__":
+    # Config the connection info
+    ak = ""
+    sk = ""
+    endpoint_url = ""
+    addressing_style = ""
+
+    # Create an S3ReaderWriter object
+    s3_reader_writer = S3ReaderWriter(ak, sk, endpoint_url, addressing_style)
+
+    # Write text data to S3
+    text_data = "This is some text data"
+    s3_reader_writer.write(data=text_data, s3_path = "s3://bucket_name/ebook/test/test.json", mode='text')
+
+    # Read text data from S3
+    text_data_read = s3_reader_writer.read(s3_path = "s3://bucket_name/ebook/test/test.json", mode='text')
+    logger.info(f"Read text data from S3: {text_data_read}")
+    # Write binary data to S3
+    binary_data = b"This is some binary data"
+    s3_reader_writer.write(data=text_data, s3_path = "s3://bucket_name/ebook/test/test2.json", mode='binary')
+
+    # Read binary data from S3
+    binary_data_read = s3_reader_writer.read(s3_path = "s3://bucket_name/ebook/test/test2.json", mode='binary')
+    logger.info(f"Read binary data from S3: {binary_data_read}")

+ 22 - 3
magic_pdf/para/para_split.py

@@ -183,11 +183,31 @@ def __valign_lines(blocks, layout_bboxes):
     return new_layout_bboxes
 
 
+def __align_text_in_layout(blocks, layout_bboxes):
+    """
+    由于ocr出来的line,有时候会在前后有一段空白,这个时候需要对文本进行对齐,超出的部分被layout左右侧截断。
+    """
+    for layout in layout_bboxes:
+        lb = layout['layout_bbox']
+        blocks_in_layoutbox = [b for b in blocks if is_in_layout(b['bbox'], lb)]
+        if len(blocks_in_layoutbox)==0:
+            continue
+        
+        for block in blocks_in_layoutbox:
+            for line in block['lines']:
+                x0, x1 = line['bbox'][0], line['bbox'][2]
+                if x0 < lb[0]:
+                    line['bbox'][0] = lb[0]
+                if x1 > lb[2]:
+                    line['bbox'][2] = lb[2]
+    
+ 
 def __common_pre_proc(blocks, layout_bboxes):
     """
     不分语言的,对文本进行预处理
     """
     #__add_line_period(blocks, layout_bboxes)
+    __align_text_in_layout(blocks, layout_bboxes)
     aligned_layout_bboxes = __valign_lines(blocks, layout_bboxes)
     
     return aligned_layout_bboxes
@@ -233,7 +253,6 @@ def __split_para_in_layoutbox(lines_group, new_layout_bbox, lang="en", char_avg_
     layout_paras = []
     right_tail_distance = 1.5 * char_avg_len
     
-    
     for lines in lines_group:
         paras = []
         total_lines = len(lines)
@@ -575,8 +594,8 @@ def __do_split_page(blocks, layout_bboxes, new_layout_bbox, page_num, lang):
     
     
     return connected_layout_paras, page_list_info
-   
-
+       
+    
 def para_split(pdf_info_dict, debug_mode, lang="en"):
     """
     根据line和layout情况进行分段

+ 87 - 0
tools/base_data.json

@@ -0,0 +1,87 @@
+{
+    "accuracy": 1.0,
+    "precision": 1.0,
+    "recall": 1.0,
+    "f1_score": 1.0,
+    "pdf间的平均编辑距离": 133.10256410256412,
+    "pdf间的平均bleu": 0.28838311595434046,
+    "分段准确率": 0.07220216606498195,
+    "行内公式准确率": {
+        "accuracy": 0.004835727492533068,
+        "precision": 0.008790072388831437,
+        "recall": 0.010634970284641852,
+        "f1_score": 0.009624911535739562
+    },
+    "行内公式编辑距离": 1.6176470588235294,
+    "行内公式bleu": 0.17154724654721457,
+    "行间公式准确率": {
+        "accuracy": 0.08490566037735849,
+        "precision": 0.1836734693877551,
+        "recall": 0.13636363636363635,
+        "f1_score": 0.1565217391304348
+    },
+    "行间公式编辑距离": 113.22222222222223,
+    "行间公式bleu": 0.2531053359913409,
+    "丢弃文本准确率": {
+        "accuracy": 0.00035398230088495576,
+        "precision": 0.0006389776357827476,
+        "recall": 0.0007930214115781126,
+        "f1_score": 0.0007077140835102619
+    },
+    "丢弃文本标签准确率": {
+        "color_background_header_txt_block": {
+            "precision": 0.0,
+            "recall": 0.0,
+            "f1-score": 0.0,
+            "support": 41.0
+        },
+        "header": {
+            "precision": 0.0,
+            "recall": 0.0,
+            "f1-score": 0.0,
+            "support": 4.0
+        },
+        "footnote": {
+            "precision": 1.0,
+            "recall": 0.009708737864077669,
+            "f1-score": 0.019230769230769232,
+            "support": 103.0
+        },
+        "on-table": {
+            "precision": 0.0,
+            "recall": 0.0,
+            "f1-score": 0.0,
+            "support": 665.0
+        },
+        "rotate": {
+            "precision": 0.0,
+            "recall": 0.0,
+            "f1-score": 0.0,
+            "support": 63.0
+        },
+        "on-image": {
+            "precision": 0.0,
+            "recall": 0.0,
+            "f1-score": 0.0,
+            "support": 380.0
+        },
+        "micro avg": {
+            "precision": 1.0,
+            "recall": 0.0007961783439490446,
+            "f1-score": 0.0015910898965791568,
+            "support": 1256.0
+        }
+    },
+    "丢弃图片准确率": {
+        "accuracy": 0.0,
+        "precision": 0.0,
+        "recall": 0.0,
+        "f1_score": 0.0
+    },
+    "丢弃表格准确率": {
+        "accuracy": 0.0,
+        "precision": 0.0,
+        "recall": 0.0,
+        "f1_score": 0.0
+    }
+}

+ 177 - 13
tools/ocr_badcase.py

@@ -413,7 +413,9 @@ def bbox_match_indicator_dropped_text_block(test_dropped_text_bboxs, standard_dr
 
     # 计算和返回标签匹配指标
     text_block_tag_report = classification_report(y_true=standard_tag, y_pred=test_tag, labels=list(set(standard_tag) - {'None'}), output_dict=True, zero_division=0)
-
+    del text_block_tag_report["macro avg"]
+    del text_block_tag_report["weighted avg"]
+    
     return text_block_report, text_block_tag_report
 
 def handle_multi_deletion(test_page, test_page_tag, test_page_bbox, standard_page_tag, standard_page_bbox):
@@ -500,6 +502,142 @@ def merge_json_data(json_test_df, json_standard_df):
     return inner_merge, standard_exist, test_exist
 
 
+def consolidate_data(test_data, standard_data, key_path):
+    """
+    Consolidates data from test and standard datasets based on the provided key path.
+    
+    :param test_data: Dictionary containing the test dataset.
+    :param standard_data: Dictionary containing the standard dataset.
+    :param key_path: List of keys leading to the desired data within the dictionaries.
+    :return: List containing all items from both test and standard data at the specified key path.
+    """
+    # Initialize an empty list to hold the consolidated data
+    overall_data_standard = []
+    overall_data_test = []
+    
+    # Helper function to recursively navigate through the dictionaries based on the key path
+    def extract_data(source_data, keys):
+        for key in keys[:-1]:
+            source_data = source_data.get(key, {})
+        return source_data.get(keys[-1], [])
+    
+    for data in extract_data(standard_data, key_path):
+    # 假设每个 single_table_tags 已经是一个列表,直接将它的元素添加到总列表中
+        overall_data_standard.extend(data)
+    
+    for data in extract_data(test_data, key_path):
+         overall_data_test.extend(data)
+    # Extract and extend the overall data list with items from both test and standard datasets
+
+    
+    return overall_data_standard, overall_data_test
+
+def overall_calculate_metrics(inner_merge, json_test, json_standard,standard_exist, test_exist):
+
+    process_data_standard = process_equations_and_blocks(json_standard, is_standard=True)
+    process_data_test = process_equations_and_blocks(json_test, is_standard=False)
+
+
+    overall_report = {}
+    overall_report['accuracy']=metrics.accuracy_score(standard_exist,test_exist)
+    overall_report['precision']=metrics.precision_score(standard_exist,test_exist)
+    overall_report['recall']=metrics.recall_score(standard_exist,test_exist)
+    overall_report['f1_score']=metrics.f1_score(standard_exist,test_exist)
+    overall_report
+
+    test_para_text = np.asarray(process_data_test['para_texts'], dtype=object)[inner_merge['pass_label'] == 'yes']
+    standard_para_text = np.asarray(process_data_standard['para_texts'], dtype=object)[inner_merge['pass_label'] == 'yes']
+    ids_yes = inner_merge['id'][inner_merge['pass_label'] == 'yes'].tolist()
+
+    pdf_dis = {}
+    pdf_bleu = {}
+
+    # 对pass_label为'yes'的数据计算编辑距离和BLEU得分
+    for idx,(a, b, id) in enumerate(zip(test_para_text, standard_para_text, ids_yes)):
+        a1 = ''.join(a)
+        b1 = ''.join(b)
+        pdf_dis[id] = Levenshtein_Distance(a, b)
+        pdf_bleu[id] = sentence_bleu([a1], b1)
+
+    overall_report['pdf间的平均编辑距离'] = np.mean(list(pdf_dis.values()))
+    overall_report['pdf间的平均bleu'] = np.mean(list(pdf_bleu.values()))
+
+    # Consolidate equations bboxs inline
+    overall_equations_bboxs_inline_standard,overall_equations_bboxs_inline_test = consolidate_data(process_data_test, process_data_standard, ["equations_bboxs", "inline"])
+
+    # # Consolidate equations texts inline
+    overall_equations_texts_inline_standard,overall_equations_texts_inline_test = consolidate_data(process_data_test, process_data_standard, ["equations_texts", "inline"])
+
+    # Consolidate equations bboxs interline
+    overall_equations_bboxs_interline_standard,overall_equations_bboxs_interline_test = consolidate_data(process_data_test, process_data_standard, ["equations_bboxs", "interline"])
+
+    # Consolidate equations texts interline
+    overall_equations_texts_interline_standard,overall_equations_texts_interline_test = consolidate_data(process_data_test, process_data_standard, ["equations_texts", "interline"])
+
+    overall_dropped_bboxs_text_standard,overall_dropped_bboxs_text_test = consolidate_data(process_data_test, process_data_standard, ["dropped_bboxs","text"])
+
+    overall_dropped_tags_text_standard,overall_dropped_tags_text_test = consolidate_data(process_data_test, process_data_standard, ["dropped_tags","text"])
+
+    overall_dropped_bboxs_image_standard,overall_dropped_bboxs_image_test = consolidate_data(process_data_test, process_data_standard, ["dropped_bboxs","image"])
+
+
+    overall_dropped_bboxs_table_standard,overall_dropped_bboxs_table_test=consolidate_data(process_data_test, process_data_standard,["dropped_bboxs","table"])
+
+
+    para_nums_test = process_data_test['para_nums']
+    para_nums_standard=process_data_standard['para_nums']
+    overall_para_nums_standard = [item for sublist in para_nums_standard for item in (sublist if isinstance(sublist, list) else [sublist])]
+    overall_para_nums_test = [item for sublist in para_nums_test for item in (sublist if isinstance(sublist, list) else [sublist])]
+
+
+    test_para_num=np.array(overall_para_nums_test)
+    standard_para_num=np.array(overall_para_nums_standard)
+    acc_para=np.mean(test_para_num==standard_para_num)
+
+
+    overall_report['分段准确率'] = acc_para
+
+    # 行内公式准确率和编辑距离、bleu
+    overall_report['行内公式准确率'] = bbox_match_indicator_general(
+        overall_equations_bboxs_inline_test,
+        overall_equations_bboxs_inline_standard)
+
+    overall_report['行内公式编辑距离'], overall_report['行内公式bleu'] = equations_indicator(
+        overall_equations_bboxs_inline_test,
+        overall_equations_bboxs_inline_standard,
+        overall_equations_texts_inline_test,
+        overall_equations_texts_inline_standard)
+
+    # 行间公式准确率和编辑距离、bleu
+    overall_report['行间公式准确率'] = bbox_match_indicator_general(
+        overall_equations_bboxs_interline_test,
+        overall_equations_bboxs_interline_standard)
+
+    overall_report['行间公式编辑距离'], overall_report['行间公式bleu'] = equations_indicator(
+        overall_equations_bboxs_interline_test,
+        overall_equations_bboxs_interline_standard,
+        overall_equations_texts_interline_test,
+        overall_equations_texts_interline_standard)
+
+    # 丢弃文本准确率,丢弃文本标签准确率
+    overall_report['丢弃文本准确率'], overall_report['丢弃文本标签准确率'] = bbox_match_indicator_dropped_text_block(
+        overall_dropped_bboxs_text_test,
+        overall_dropped_bboxs_text_standard,
+        overall_dropped_tags_text_standard,
+        overall_dropped_tags_text_test)
+
+    # 丢弃图片准确率
+    overall_report['丢弃图片准确率'] = bbox_match_indicator_general(
+        overall_dropped_bboxs_image_test,
+        overall_dropped_bboxs_image_standard)
+
+    # 丢弃表格准确率
+    overall_report['丢弃表格准确率'] = bbox_match_indicator_general(
+        overall_dropped_bboxs_table_test,
+        overall_dropped_bboxs_table_standard)
+
+    return overall_report
+
 
 
 def calculate_metrics(inner_merge, json_test, json_standard, json_standard_origin):
@@ -602,21 +740,27 @@ def calculate_metrics(inner_merge, json_test, json_standard, json_standard_origi
     return result_dict
 
 
-def save_results(result_dict, output_path):
+
+def save_results(result_dict,overall_report_dict,badcase_path,overall_path,):
     """
     将结果字典保存为JSON文件至指定路径。
 
     参数:
     - result_dict: 包含计算结果的字典。
-    - output_path: 结果文件的保存路径,包括文件名。
+    - overall_path: 结果文件的保存路径,包括文件名。
     """
     # 打开指定的文件以写入
-    with open(output_path, 'w', encoding='utf-8') as f:
+    with open(badcase_path, 'w', encoding='utf-8') as f:
         # 将结果字典转换为JSON格式并写入文件
         json.dump(result_dict, f, ensure_ascii=False, indent=4)
 
-    print(f"计算结果已经保存到文件:{output_path}")
+    print(f"计算结果已经保存到文件:{badcase_path}")
+
+    with open(overall_path, 'w', encoding='utf-8') as f:
+    # 将结果字典转换为JSON格式并写入文件
+        json.dump(overall_report_dict, f, ensure_ascii=False, indent=4)
 
+    print(f"计算结果已经保存到文件:{overall_path}")
 
 def upload_to_s3(file_path, bucket_name, s3_file_name,AWS_ACCESS_KEY,AWS_SECRET_KEY,END_POINT_URL):
     """
@@ -634,7 +778,7 @@ def upload_to_s3(file_path, bucket_name, s3_file_name,AWS_ACCESS_KEY,AWS_SECRET_
     except ClientError as e:
         print(f"上传文件时发生错误:{e}")
 
-def generate_output_filename(base_path):
+def generate_filename(badcase_path,overall_path):
     """
     生成带有当前时间戳的输出文件名。
 
@@ -647,13 +791,24 @@ def generate_output_filename(base_path):
     # 获取当前时间并格式化为字符串
     current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
     # 构建并返回完整的输出文件名
-    return f"{base_path}_{current_time}.json"
+    return f"{badcase_path}_{current_time}.json",f"{overall_path}_{current_time}.json"
 
 
 
+def compare_edit_distance(json_file, overall_report):
+    with open(json_file, 'r',encoding='utf-8') as f:
+        json_data = json.load(f)
+    
+    json_edit_distance = json_data['pdf间的平均编辑距离']
+    
+    if overall_report['pdf间的平均编辑距离'] >= json_edit_distance:
+        return 0
+    else:
+        return 1
+
 
 
-def main(standard_file, test_file, zip_file, base_output_path, s3_bucket_name=None, s3_file_name=None, AWS_ACCESS_KEY=None, AWS_SECRET_KEY=None, END_POINT_URL=None):
+def main(standard_file, test_file, zip_file, badcase_path, overall_path,base_data_path,s3_bucket_name=None, s3_file_name=None, AWS_ACCESS_KEY=None, AWS_SECRET_KEY=None, END_POINT_URL=None):
     """
     主函数,执行整个评估流程。
     
@@ -661,7 +816,8 @@ def main(standard_file, test_file, zip_file, base_output_path, s3_bucket_name=No
     - standard_file: 标准文件的路径。
     - test_file: 测试文件的路径。
     - zip_file: 压缩包的路径的路径。
-    - base_output_path: 结果文件的基础路径和文件名前缀。
+    - badcase_path: badcase文件的基础路径和文件名前缀。
+    - overall_path: overall文件的基础路径和文件名前缀。
     - s3_bucket_name: S3桶名称(可选)。
     - s3_file_name: S3上的文件名(可选)。
     - AWS_ACCESS_KEY, AWS_SECRET_KEY, END_POINT_URL: AWS访问凭证和端点URL(可选)。
@@ -675,21 +831,29 @@ def main(standard_file, test_file, zip_file, base_output_path, s3_bucket_name=No
     # 合并JSON数据
     inner_merge, standard_exist, test_exist = merge_json_data(json_test_origin, json_standard_origin)
 
+    #计算总体指标
+    overall_report_dict=overall_calculate_metrics(inner_merge, inner_merge['test_mid_json'], inner_merge['standard_mid_json'],standard_exist, test_exist)
     # 计算指标
     result_dict = calculate_metrics(inner_merge, inner_merge['test_mid_json'], inner_merge['standard_mid_json'], json_standard_origin)
 
     # 生成带时间戳的输出文件名
-    output_file = generate_output_filename(base_output_path)
+    badcase_file,overall_file = generate_filename(badcase_path,overall_path)
 
     # 保存结果到JSON文件
-    save_results(result_dict, output_file)
+    save_results(result_dict, overall_report_dict,badcase_file,overall_file)
+
+    result=compare_edit_distance(base_data_path, overall_report_dict)
+    print(result)
+    assert result == 1
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(description="主函数,执行整个评估流程。")
     parser.add_argument('standard_file', type=str, help='标准文件的路径。')
     parser.add_argument('test_file', type=str, help='测试文件的路径。')
     parser.add_argument('zip_file', type=str, help='压缩包的路径。')
-    parser.add_argument('base_output_path', type=str, help='结果文件的基础路径和文件名前缀。')
+    parser.add_argument('badcase_path', type=str, help='badcase文件的基础路径和文件名前缀。')
+    parser.add_argument('overall_path', type=str, help='overall文件的基础路径和文件名前缀。')
+    parser.add_argument('base_data_path', type=str, help='基准文件的基础路径和文件名前缀。')
     parser.add_argument('--s3_bucket_name', type=str, help='S3桶名称。', default=None)
     parser.add_argument('--s3_file_name', type=str, help='S3上的文件名。', default=None)
     parser.add_argument('--AWS_ACCESS_KEY', type=str, help='AWS访问密钥。', default=None)
@@ -698,5 +862,5 @@ if __name__ == "__main__":
 
     args = parser.parse_args()
 
-    main(args.standard_file, args.test_file, args.zip_file, args.base_output_path, args.s3_bucket_name, args.s3_file_name, args.AWS_ACCESS_KEY, args.AWS_SECRET_KEY, args.END_POINT_URL)
+    main(args.standard_file, args.test_file, args.zip_file, args.badcase_path,args.overall_path,args.base_data_path,args.s3_bucket_name, args.s3_file_name, args.AWS_ACCESS_KEY, args.AWS_SECRET_KEY, args.END_POINT_URL)