Эх сурвалжийг харах

update the text_badcase script and add auto upload s3 function

Shuimo 1 жил өмнө
parent
commit
90216330c1

+ 19 - 6
.github/workflows/benchmark.yml

@@ -5,7 +5,13 @@ name: PDF
 on:
   push:
     branches:
-      - master
+      - "master"
+    paths-ignore:
+      - "cmds/**"
+      - "**.md"
+  pull_request:
+    branches:
+      - "master"
     paths-ignore:
       - "cmds/**"
       - "**.md"
@@ -18,14 +24,16 @@ jobs:
       fail-fast: true
 
     steps:
+    - name: config-net
+      run: |
+        export http_proxy=http://bigdata_open_proxy:H89k5qwQRDYfz@10.140.90.20:10811
+        export https_proxy=http://bigdata_open_proxy:H89k5qwQRDYfz@10.140.90.20:10811
     - name: PDF benchmark
       uses: actions/checkout@v3
       with:
         fetch-depth: 2
     - name: check-requirements
       run: |
-        export http_proxy=http://bigdata_open_proxy:H89k5qwQRDYfz@10.140.90.20:10811
-        export https_proxy=http://bigdata_open_proxy:H89k5qwQRDYfz@10.140.90.20:10811
         changed_files=$(git diff --name-only -r HEAD~1 HEAD)
         echo $changed_files
         if [[ $changed_files =~ "requirements.txt" ]]; then
@@ -36,12 +44,17 @@ jobs:
     - name: benchmark
       run: |
         echo "start test"
-        cd tools && python ocr_badcase.py pdf_json_label_0306.json ocr_dataset.json json_files.zip output.json
+        cd tools && python ocr_badcase.py pdf_json_label_0306.json ocr_dataset.json json_files.zip badcase.json overall.json base_data.json
   notify_to_feishu:
-    if: ${{ (github.ref_name == 'master') }}
+    if: ${{ always() && !cancelled() && contains(needs.*.result, 'failure') && (github.ref_name == 'master') }}
     needs: [pdf-test]
     runs-on: [pdf]
     steps:
     - name: notify
       run: |
-        curl -X POST -H "Content-Type: application/json" -d '{"msg_type":"post","content":{"post":{"zh_cn":{"title":"'${{ github.repository }}' GitHubAction Failed","content":[[{"tag":"text","text":""},{"tag":"a","text":"Please click here for details ","href":"https://github.com/'${{ github.repository }}'/actions/runs/'${GITHUB_RUN_ID}'"},{"tag":"at","user_id":"'${{ secrets.USER_ID }}'"}]]}}}}'  ${{ secrets.WEBHOOK_URL }}
+        curl  ${{ secrets.WEBHOOK_URL }} -H 'Content-Type: application/json'  -d '{
+        "msgtype": "text",
+        "text": {
+            "content": "'${{ github.repository }}' GitHubAction Failed!\n 细节请查看:https://github.com/'${{ github.repository }}'/actions/runs/'${GITHUB_RUN_ID}'"
+        } 
+        }'                                                                                                                            

+ 4 - 3
demo/ocr_demo.py

@@ -115,8 +115,9 @@ def ocr_parse_pdf_core(pdf_bytes, model_output_json_list, book_name, start_page_
 if __name__ == '__main__':
     pdf_path = r"/home/cxu/workspace/Magic-PDF/ocr_demo/j.1540-627x.2006.00176.x.pdf"
     json_file_path = r"/home/cxu/workspace/Magic-PDF/ocr_demo/j.1540-627x.2006.00176.x.json"
-    ocr_local_parse(pdf_path, json_file_path)
-    # book_name = "科数网/edu_00011318"
-    # ocr_online_parse(book_name)
+    # ocr_local_parse(pdf_path, json_file_path)
+    
+    book_name = "数学新星网/edu_00001236"
+    ocr_online_parse(book_name)
     
     pass

+ 4 - 5
magic_pdf/dict2md/ocr_mkcontent.py

@@ -1,4 +1,3 @@
-from magic_pdf.libs.commons import s3_image_save_path, join_path
 from magic_pdf.libs.language import detect_lang
 from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char
 from magic_pdf.libs.ocr_content_type import ContentType
@@ -56,7 +55,7 @@ def ocr_mk_mm_markdown(pdf_info_dict: dict):
                         if not span.get('image_path'):
                             continue
                         else:
-                            content = f"![]({join_path(s3_image_save_path, span['image_path'])})"
+                            content = f"![]({span['image_path']})"
                     else:
                         content = ocr_escape_special_markdown_char(span['content'])  # 转义特殊符号
                         if span['type'] == ContentType.InlineEquation:
@@ -123,7 +122,7 @@ def ocr_mk_markdown_with_para_core(paras_of_layout, mode):
                         content = f"\n$$\n{span['content']}\n$$\n"
                     elif span_type in [ContentType.Image, ContentType.Table]:
                         if mode == 'mm':
-                            content = f"\n![]({join_path(s3_image_save_path, span['image_path'])})\n"
+                            content = f"\n![]({span['image_path']})\n"
                         elif mode == 'nlp':
                             pass
                     if content != '':
@@ -195,13 +194,13 @@ def line_to_standard_format(line):
                 if span['type'] == ContentType.Image:
                     content = {
                         'type': 'image',
-                        'img_path': join_path(s3_image_save_path, span['image_path'])
+                        'img_path': span['image_path']
                     }
                     return content
                 elif span['type'] == ContentType.Table:
                     content = {
                         'type': 'table',
-                        'img_path': join_path(s3_image_save_path, span['image_path'])
+                        'img_path': span['image_path']
                     }
                     return content
         else:

+ 2 - 2
magic_pdf/io/AbsReaderWriter.py

@@ -10,10 +10,10 @@ class AbsReaderWriter(ABC):
 
     def __init__(self, parent_path):
         # 初始化代码可以在这里添加,如果需要的话
-        self.parent_path = parent_path # 对于本地目录是父目录,对于s3是会写到这个apth下。
+        self.parent_path = parent_path # 对于本地目录是父目录,对于s3是会写到这个path下。
 
     @abstractmethod
-    def read(self, path: str, mode="text"):
+    def read(self, path: str, mode=MODE_TXT):
         """
         无论对于本地还是s3的路径,检查如果path是绝对路径,那么就不再 拼接parent_path, 如果是相对路径就拼接parent_path
         """

+ 36 - 21
magic_pdf/io/DiskReaderWriter.py

@@ -1,48 +1,63 @@
 import os
 from magic_pdf.io.AbsReaderWriter import AbsReaderWriter
 from loguru import logger
+
+
+MODE_TXT = "text"
+MODE_BIN = "binary"
 class DiskReaderWriter(AbsReaderWriter):
+
     def __init__(self, parent_path, encoding='utf-8'):
         self.path = parent_path
         self.encoding = encoding
 
-    def read(self, mode="text"):
-        if not os.path.exists(self.path):
-            logger.error(f"文件 {self.path} 不存在")
-            raise Exception(f"文件 {self.path} 不存在")
-        if mode == "text":
-            with open(self.path, 'r', encoding = self.encoding) as f:
+    def read(self, path, mode=MODE_TXT):
+        if os.path.isabs(path):
+            abspath = path
+        else:
+            abspath = os.path.join(self.path, path)
+        if not os.path.exists(abspath):
+            logger.error(f"文件 {abspath} 不存在")
+            raise Exception(f"文件 {abspath} 不存在")
+        if mode == MODE_TXT:
+            with open(abspath, 'r', encoding = self.encoding) as f:
                 return f.read()
-        elif mode == "binary":
-            with open(self.path, 'rb') as f:
+        elif mode == MODE_BIN:
+            with open(abspath, 'rb') as f:
                 return f.read()
         else:
             raise ValueError("Invalid mode. Use 'text' or 'binary'.")
 
-    def write(self, data, mode="text"):
-        if mode == "text":
-            with open(self.path, 'w', encoding=self.encoding) as f:
-                f.write(data)
-                logger.info(f"内容已成功写入 {self.path}")
+    def write(self, content, path, mode=MODE_TXT):
+        if os.path.isabs(path):
+            abspath = path
+        else:
+            abspath = os.path.join(self.path, path)
+        if mode == MODE_TXT:
+            with open(abspath, 'w', encoding=self.encoding) as f:
+                f.write(content)
+                logger.info(f"内容已成功写入 {abspath}")
 
-        elif mode == "binary":
-            with open(self.path, 'wb') as f:
-                f.write(data)
-                logger.info(f"内容已成功写入 {self.path}")
+        elif mode == MODE_BIN:
+            with open(abspath, 'wb') as f:
+                f.write(content)
+                logger.info(f"内容已成功写入 {abspath}")
         else:
             raise ValueError("Invalid mode. Use 'text' or 'binary'.")
 
+    def read_jsonl(self, path: str, byte_start=0, byte_end=None, encoding='utf-8'):
+        return self.read(path)
 
 # 使用示例
 if __name__ == "__main__":
-    file_path = "example.txt"
-    drw = DiskReaderWriter(file_path)
+    file_path = "io/example.txt"
+    drw = DiskReaderWriter("D:\projects\papayfork\Magic-PDF\magic_pdf")
 
     # 写入内容到文件
-    drw.write(b"Hello, World!", mode="binary")
+    drw.write(b"Hello, World!", path="io/example.txt", mode="binary")
 
     # 从文件读取内容
-    content = drw.read()
+    content = drw.read(path=file_path)
     if content:
         logger.info(f"从 {file_path} 读取的内容: {content}")
 

+ 58 - 23
magic_pdf/io/S3ReaderWriter.py

@@ -1,16 +1,19 @@
-
-
 from magic_pdf.io.AbsReaderWriter import AbsReaderWriter
 from magic_pdf.libs.commons import parse_aws_param, parse_bucket_key
 import boto3
 from loguru import logger
 from boto3.s3.transfer import TransferConfig
 from botocore.config import Config
+import os
+
+MODE_TXT = "text"
+MODE_BIN = "binary"
 
 
 class S3ReaderWriter(AbsReaderWriter):
-    def __init__(self, ak: str, sk: str, endpoint_url: str, addressing_style: str):
+    def __init__(self, ak: str, sk: str, endpoint_url: str, addressing_style: str, parent_path: str):
         self.client = self._get_client(ak, sk, endpoint_url, addressing_style)
+        self.path = parent_path
 
     def _get_client(self, ak: str, sk: str, endpoint_url: str, addressing_style: str):
         s3_client = boto3.client(
@@ -22,51 +25,83 @@ class S3ReaderWriter(AbsReaderWriter):
                           retries={'max_attempts': 5, 'mode': 'standard'}),
         )
         return s3_client
-    def read(self, s3_path, mode="text", encoding="utf-8"):
-        bucket_name, bucket_key = parse_bucket_key(s3_path)
-        res = self.client.get_object(Bucket=bucket_name, Key=bucket_key)
+
+    def read(self, s3_relative_path, mode=MODE_TXT, encoding="utf-8"):
+        if s3_relative_path.startswith("s3://"):
+            s3_path = s3_relative_path
+        else:
+            s3_path = os.path.join(self.path, s3_relative_path)
+        bucket_name, key = parse_bucket_key(s3_path)
+        res = self.client.get_object(Bucket=bucket_name, Key=key)
         body = res["Body"].read()
-        if mode == 'text':
+        if mode == MODE_TXT:
             data = body.decode(encoding)  # Decode bytes to text
-        elif mode == 'binary':
+        elif mode == MODE_BIN:
             data = body
         else:
             raise ValueError("Invalid mode. Use 'text' or 'binary'.")
         return data
 
-    def write(self, data, s3_path, mode="text", encoding="utf-8"):
-        if mode == 'text':
-            body = data.encode(encoding)  # Encode text data as bytes
-        elif mode == 'binary':
-            body = data
+    def write(self, content, s3_relative_path, mode=MODE_TXT, encoding="utf-8"):
+        if s3_relative_path.startswith("s3://"):
+            s3_path = s3_relative_path
+        else:
+            s3_path = os.path.join(self.path, s3_relative_path)
+        if mode == MODE_TXT:
+            body = content.encode(encoding)  # Encode text data as bytes
+        elif mode == MODE_BIN:
+            body = content
         else:
             raise ValueError("Invalid mode. Use 'text' or 'binary'.")
-        bucket_name, bucket_key = parse_bucket_key(s3_path)
-        self.client.put_object(Body=body, Bucket=bucket_name, Key=bucket_key)
+        bucket_name, key = parse_bucket_key(s3_path)
+        self.client.put_object(Body=body, Bucket=bucket_name, Key=key)
         logger.info(f"内容已写入 {s3_path} ")
 
+    def read_jsonl(self, path: str, byte_start=0, byte_end=None, mode=MODE_TXT, encoding='utf-8'):
+        if path.startswith("s3://"):
+            s3_path = path
+        else:
+            s3_path = os.path.join(self.path, path)
+        bucket_name, key = parse_bucket_key(s3_path)
+
+        range_header = f'bytes={byte_start}-{byte_end}' if byte_end else f'bytes={byte_start}-'
+        res = self.client.get_object(Bucket=bucket_name, Key=key, Range=range_header)
+        body = res["Body"].read()
+        if mode == MODE_TXT:
+            data = body.decode(encoding)  # Decode bytes to text
+        elif mode == MODE_BIN:
+            data = body
+        else:
+            raise ValueError("Invalid mode. Use 'text' or 'binary'.")
+        return data
+
 
 if __name__ == "__main__":
     # Config the connection info
     ak = ""
     sk = ""
     endpoint_url = ""
-    addressing_style = ""
-
+    addressing_style = "auto"
+    bucket_name = ""
     # Create an S3ReaderWriter object
-    s3_reader_writer = S3ReaderWriter(ak, sk, endpoint_url, addressing_style)
+    s3_reader_writer = S3ReaderWriter(ak, sk, endpoint_url, addressing_style, "s3://bucket_name/")
 
     # Write text data to S3
     text_data = "This is some text data"
-    s3_reader_writer.write(data=text_data, s3_path = "s3://bucket_name/ebook/test/test.json", mode='text')
+    s3_reader_writer.write(data=text_data, s3_relative_path=f"s3://{bucket_name}/ebook/test/test.json", mode=MODE_TXT)
 
     # Read text data from S3
-    text_data_read = s3_reader_writer.read(s3_path = "s3://bucket_name/ebook/test/test.json", mode='text')
+    text_data_read = s3_reader_writer.read(s3_relative_path=f"s3://{bucket_name}/ebook/test/test.json", mode=MODE_TXT)
     logger.info(f"Read text data from S3: {text_data_read}")
     # Write binary data to S3
     binary_data = b"This is some binary data"
-    s3_reader_writer.write(data=text_data, s3_path = "s3://bucket_name/ebook/test/test2.json", mode='binary')
+    s3_reader_writer.write(data=text_data, s3_relative_path=f"s3://{bucket_name}/ebook/test/test.json", mode=MODE_BIN)
 
     # Read binary data from S3
-    binary_data_read = s3_reader_writer.read(s3_path = "s3://bucket_name/ebook/test/test2.json", mode='binary')
-    logger.info(f"Read binary data from S3: {binary_data_read}")
+    binary_data_read = s3_reader_writer.read(s3_relative_path=f"s3://{bucket_name}/ebook/test/test.json", mode=MODE_BIN)
+    logger.info(f"Read binary data from S3: {binary_data_read}")
+
+    # Range Read text data from S3
+    binary_data_read = s3_reader_writer.read_jsonl(path=f"s3://{bucket_name}/ebook/test/test.json",
+                                                   byte_start=0, byte_end=10, mode=MODE_BIN)
+    logger.info(f"Read binary data from S3: {binary_data_read}")

+ 4 - 24
magic_pdf/libs/commons.py

@@ -24,7 +24,7 @@ error_log_path = "s3://llm-pdf-text/err_logs/"
 # json_dump_path = "s3://pdf_books_temp/json_dump/" # 这条路径仅用于临时本地测试,不能提交到main
 json_dump_path = "s3://llm-pdf-text/json_dump/"
 
-s3_image_save_path = "s3://mllm-raw-media/pdf2md_img/" # TODO 基础库不应该有这些存在的路径,应该在业务代码中定义
+# s3_image_save_path = "s3://mllm-raw-media/pdf2md_img/" # 基础库不应该有这些存在的路径,应该在业务代码中定义
 
 
 def get_top_percent_list(num_list, percent):
@@ -120,29 +120,9 @@ def read_file(pdf_path: str, s3_profile):
             return f.read()
 
 
-def get_docx_model_output(pdf_model_output, pdf_model_s3_profile, page_id):
-    if isinstance(pdf_model_output, str):
-        model_output_json_path = join_path(pdf_model_output, f"page_{page_id + 1}.json")  # 模型输出的页面编号从1开始的
-        if os.path.exists(model_output_json_path):
-            json_from_docx = read_file(model_output_json_path, pdf_model_s3_profile)
-            model_output_json = json.loads(json_from_docx)
-        else:
-            try:
-                model_output_json_path = join_path(pdf_model_output, "model.json")
-                with open(model_output_json_path, "r", encoding="utf-8") as f:
-                    model_output_json = json.load(f)
-                    model_output_json = model_output_json["doc_layout_result"][page_id]
-            except:
-                s3_model_output_json_path = join_path(pdf_model_output, f"page_{page_id + 1}.json")
-                s3_model_output_json_path = join_path(pdf_model_output, f"{page_id}.json")
-                #s3_model_output_json_path = join_path(pdf_model_output, f"page_{page_id }.json")
-                # logger.warning(f"model_output_json_path: {model_output_json_path} not found. try to load from s3: {s3_model_output_json_path}")
-
-                s = read_file(s3_model_output_json_path, pdf_model_s3_profile)
-                return json.loads(s)
-
-    elif isinstance(pdf_model_output, list):
-        model_output_json = pdf_model_output[page_id]
+def get_docx_model_output(pdf_model_output, page_id):
+
+    model_output_json = pdf_model_output[page_id]
 
     return model_output_json
 

+ 15 - 0
magic_pdf/libs/hash_utils.py

@@ -0,0 +1,15 @@
+import hashlib
+
+
+def compute_md5(file_bytes):
+    hasher = hashlib.md5()
+    hasher.update(file_bytes)
+    return hasher.hexdigest().upper()
+
+
+def compute_sha256(input_string):
+    hasher = hashlib.sha256()
+    # 在Python3中,需要将字符串转化为字节对象才能被哈希函数处理
+    input_bytes = input_string.encode('utf-8')
+    hasher.update(input_bytes)
+    return hasher.hexdigest()

+ 31 - 92
magic_pdf/libs/pdf_image_tools.py

@@ -1,39 +1,23 @@
-import os
-from pathlib import Path
-from typing import Tuple
-import io
 
-# from app.common.s3 import get_s3_client
 from magic_pdf.libs.commons import fitz
 from loguru import logger
-from magic_pdf.libs.commons import parse_bucket_key, join_path
+from magic_pdf.libs.commons import join_path
+from magic_pdf.libs.hash_utils import compute_sha256
 
 
-def cut_image(bbox: Tuple, page_num: int, page: fitz.Page, save_parent_path: str, s3_return_path=None, img_s3_client=None, upload_switch=True):
+def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWriter):
     """
     从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径
     save_path:需要同时支持s3和本地, 图片存放在save_path下,文件名是: {page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。
     """
     # 拼接文件名
-    filename = f"{page_num}_{int(bbox[0])}_{int(bbox[1])}_{int(bbox[2])}_{int(bbox[3])}.jpg"
-    # 拼接路径
-    image_save_path = join_path(save_parent_path, filename)
-    s3_img_path = join_path(s3_return_path, filename) if s3_return_path is not None else None
-    # 打印图片文件名
-    # print(f"Saved {image_save_path}")
-
-    #检查坐标
-    # x_check = int(bbox[2]) - int(bbox[0])
-    # y_check = int(bbox[3]) - int(bbox[1])
-    # if x_check <= 0 or y_check <= 0:
-    #
-    #     if image_save_path.startswith("s3://"):
-    #         logger.exception(f"传入图片坐标有误,x1<x0或y1<y0,{s3_img_path}")
-    #         return s3_img_path
-    #     else:
-    #         logger.exception(f"传入图片坐标有误,x1<x0或y1<y0,{image_save_path}")
-    #         return image_save_path
+    filename = f"{page_num}_{int(bbox[0])}_{int(bbox[1])}_{int(bbox[2])}_{int(bbox[3])}"
 
+    # 老版本返回不带bucket的路径
+    img_path = join_path(return_path, filename) if return_path is not None else None
+
+    # 新版本生成平铺路径
+    img_hash256_path = f"{compute_sha256(img_path)}.jpg"
 
     # 将坐标转换为fitz.Rect对象
     rect = fitz.Rect(*bbox)
@@ -42,39 +26,17 @@ def cut_image(bbox: Tuple, page_num: int, page: fitz.Page, save_parent_path: str
     # 截取图片
     pix = page.get_pixmap(clip=rect, matrix=zoom)
 
-    if image_save_path.startswith("s3://"):
-        if not upload_switch:
-            pass
-        else:
-            # 图片保存到s3
-            bucket_name, bucket_key = parse_bucket_key(image_save_path)
-            # 将字节流上传到s3
-            byte_data = pix.tobytes(output='jpeg', jpg_quality=95)
-            file_obj = io.BytesIO(byte_data)
-            if img_s3_client is not None:
-                img_s3_client.upload_fileobj(file_obj, bucket_name, bucket_key)
-                # 每个图片上传任务都创建一个新的client
-                # img_s3_client_once = get_s3_client(image_save_path)
-                # img_s3_client_once.upload_fileobj(file_obj, bucket_name, bucket_key)
-            else:
-                logger.exception("must input img_s3_client")
-        return s3_img_path
-    else:
-        # 保存图片到本地
-        # 先检查一下image_save_path的父目录是否存在,如果不存在,就创建
-        parent_dir = os.path.dirname(image_save_path)
-        if not os.path.exists(parent_dir):
-            os.makedirs(parent_dir)
-        pix.save(image_save_path, jpg_quality=95)
-        # 为了直接能在markdown里看,这里把地址改为相对于mardown的地址
-        pth = Path(image_save_path)
-        image_save_path = f"{pth.parent.name}/{pth.name}"
-        return image_save_path
-
-
-def save_images_by_bboxes(book_name: str, page_num: int, page: fitz.Page, save_path: str,
-                            image_bboxes: list, images_overlap_backup:list, table_bboxes: list, equation_inline_bboxes: list,
-                            equation_interline_bboxes: list, img_s3_client) -> dict:
+    byte_data = pix.tobytes(output='jpeg', jpg_quality=95)
+
+    imageWriter.write(data=byte_data, path=img_hash256_path, mode="binary")
+
+    return img_hash256_path
+
+
+def save_images_by_bboxes(page_num: int, page: fitz.Page, pdf_bytes_md5: str,
+                          image_bboxes: list, images_overlap_backup: list, table_bboxes: list,
+                          equation_inline_bboxes: list,
+                          equation_interline_bboxes: list, imageWriter) -> dict:
     """
     返回一个dict, key为bbox, 值是图片地址
     """
@@ -85,53 +47,30 @@ def save_images_by_bboxes(book_name: str, page_num: int, page: fitz.Page, save_p
     interline_eq_info = []
 
     # 图片的保存路径组成是这样的: {s3_or_local_path}/{book_name}/{images|tables|equations}/{page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg
-    s3_return_image_path = join_path(book_name, "images")
-    image_save_path = join_path(save_path, s3_return_image_path)
-
-    s3_return_table_path = join_path(book_name, "tables")
-    table_save_path = join_path(save_path, s3_return_table_path)
-
-    s3_return_equations_inline_path = join_path(book_name, "equations_inline")
-    equation_inline_save_path = join_path(save_path, s3_return_equations_inline_path)
-
-    s3_return_equation_interline_path = join_path(book_name, "equation_interline")
-    equation_interline_save_path = join_path(save_path, s3_return_equation_interline_path)
 
+    def return_path(type):
+        return join_path(pdf_bytes_md5, type)
 
     for bbox in image_bboxes:
-        if any([bbox[0]>=bbox[2], bbox[1]>=bbox[3]]):
+        if any([bbox[0] >= bbox[2], bbox[1] >= bbox[3]]):
             logger.warning(f"image_bboxes: 错误的box, {bbox}")
             continue
-        
-        image_path = cut_image(bbox, page_num, page, image_save_path, s3_return_image_path, img_s3_client)
+
+        image_path = cut_image(bbox, page_num, page, return_path("images"), imageWriter)
         image_info.append({"bbox": bbox, "image_path": image_path})
-        
+
     for bbox in images_overlap_backup:
-        if any([bbox[0]>=bbox[2], bbox[1]>=bbox[3]]):
+        if any([bbox[0] >= bbox[2], bbox[1] >= bbox[3]]):
             logger.warning(f"images_overlap_backup: 错误的box, {bbox}")
             continue
-        image_path = cut_image(bbox, page_num, page, image_save_path, s3_return_image_path, img_s3_client)
+        image_path = cut_image(bbox, page_num, page, return_path("images"), imageWriter)
         image_backup_info.append({"bbox": bbox, "image_path": image_path})
 
     for bbox in table_bboxes:
-        if any([bbox[0]>=bbox[2], bbox[1]>=bbox[3]]):
+        if any([bbox[0] >= bbox[2], bbox[1] >= bbox[3]]):
             logger.warning(f"table_bboxes: 错误的box, {bbox}")
             continue
-        image_path = cut_image(bbox, page_num, page, table_save_path, s3_return_table_path, img_s3_client)
+        image_path = cut_image(bbox, page_num, page, return_path("tables"), imageWriter)
         table_info.append({"bbox": bbox, "image_path": image_path})
 
-    for bbox in equation_inline_bboxes:
-        if any([bbox[0]>=bbox[2], bbox[1]>=bbox[3]]):
-            logger.warning(f"equation_inline_bboxes: 错误的box, {bbox}")
-            continue
-        image_path = cut_image(bbox[:4], page_num, page, equation_inline_save_path, s3_return_equations_inline_path, img_s3_client, upload_switch=False)
-        inline_eq_info.append({'bbox':bbox[:4], "image_path":image_path, "latex_text":bbox[4]})
-
-    for bbox in equation_interline_bboxes:
-        if any([bbox[0]>=bbox[2], bbox[1]>=bbox[3]]):
-            logger.warning(f"equation_interline_bboxes: 错误的box, {bbox}")
-            continue
-        image_path = cut_image(bbox[:4], page_num, page, equation_interline_save_path, s3_return_equation_interline_path, img_s3_client, upload_switch=False)
-        interline_eq_info.append({"bbox":bbox[:4], "image_path":image_path, "latex_text":bbox[4]})
-
-    return image_info, image_backup_info,  table_info, inline_eq_info, interline_eq_info
+    return image_info, image_backup_info, table_info, inline_eq_info, interline_eq_info

+ 6 - 50
magic_pdf/pdf_parse_by_ocr.py

@@ -1,22 +1,14 @@
-import json
-import os
 import time
-
 from loguru import logger
-
-from magic_pdf.libs.draw_bbox import draw_layout_bbox, draw_text_bbox
 from magic_pdf.libs.commons import (
-    read_file,
-    join_path,
     fitz,
-    get_img_s3_client,
     get_delta_time,
     get_docx_model_output,
 )
 from magic_pdf.libs.coordinate_transform import get_scale_ratio
 from magic_pdf.libs.drop_tag import DropTag
+from magic_pdf.libs.hash_utils import compute_md5
 from magic_pdf.libs.ocr_content_type import ContentType
-from magic_pdf.libs.safe_filename import sanitize_filename
 from magic_pdf.para.para_split import para_split
 from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component
 from magic_pdf.pre_proc.detect_footer_by_model import parse_footers
@@ -38,38 +30,16 @@ from magic_pdf.pre_proc.remove_bbox_overlap import remove_overlap_between_bbox
 def parse_pdf_by_ocr(
         pdf_bytes,
         pdf_model_output,
-        save_path,
-        book_name,
-        pdf_model_profile=None,
-        image_s3_config=None,
+        imageWriter,
         start_page_id=0,
         end_page_id=None,
         debug_mode=False,
 ):
-
-    save_tmp_path = os.path.join(os.path.dirname(__file__), "../..", "tmp", "unittest")
-    book_name = sanitize_filename(book_name)
-    md_bookname_save_path = ""
-    if debug_mode:
-        save_path = join_path(save_tmp_path, "md")
-        pdf_local_path = join_path(save_tmp_path, "download-pdfs", book_name)
-
-        if not os.path.exists(os.path.dirname(pdf_local_path)):
-            # 如果目录不存在,创建它
-            os.makedirs(os.path.dirname(pdf_local_path))
-
-        md_bookname_save_path = join_path(save_tmp_path, "md", book_name)
-        if not os.path.exists(md_bookname_save_path):
-            # 如果目录不存在,创建它
-            os.makedirs(md_bookname_save_path)
-
-        with open(pdf_local_path + ".pdf", "wb") as pdf_file:
-            pdf_file.write(pdf_bytes)
+    pdf_bytes_md5 = compute_md5(pdf_bytes)
 
     pdf_docs = fitz.open("pdf", pdf_bytes)
     # 初始化空的pdf_info_dict
     pdf_info_dict = {}
-    img_s3_client = get_img_s3_client(save_path, image_s3_config)
 
     start_time = time.time()
 
@@ -91,16 +61,14 @@ def parse_pdf_by_ocr(
 
         # 获取当前页的模型数据
         ocr_page_info = get_docx_model_output(
-            pdf_model_output, pdf_model_profile, page_id
+            pdf_model_output, page_id
         )
 
         """从json中获取每页的页码、页眉、页脚的bbox"""
         page_no_bboxes = parse_pageNos(page_id, page, ocr_page_info)
         header_bboxes = parse_headers(page_id, page, ocr_page_info)
         footer_bboxes = parse_footers(page_id, page, ocr_page_info)
-        footnote_bboxes = parse_footnotes_by_model(
-            page_id, page, ocr_page_info, md_bookname_save_path, debug_mode=debug_mode
-        )
+        footnote_bboxes = parse_footnotes_by_model(page_id, page, ocr_page_info, debug_mode=debug_mode)
 
         # 构建需要remove的bbox字典
         need_remove_spans_bboxes_dict = {
@@ -179,7 +147,7 @@ def parse_pdf_by_ocr(
         spans, dropped_spans_by_removed_bboxes = remove_spans_by_bboxes_dict(spans, need_remove_spans_bboxes_dict)
 
         '''对image和table截图'''
-        spans = cut_image_and_table(spans, page, page_id, book_name, save_path, img_s3_client)
+        spans = cut_image_and_table(spans, page, page_id, pdf_bytes_md5, imageWriter)
 
         '''行内公式调整, 高度调整至与同行文字高度一致(优先左侧, 其次右侧)'''
         displayed_list = []
@@ -242,16 +210,4 @@ def parse_pdf_by_ocr(
     """分段"""
     para_split(pdf_info_dict, debug_mode=debug_mode)
 
-    '''在测试时,保存调试信息'''
-    if debug_mode:
-        params_file_save_path = join_path(
-            save_tmp_path, "md", book_name, "preproc_out.json"
-        )
-        with open(params_file_save_path, "w", encoding="utf-8") as f:
-            json.dump(pdf_info_dict, f, ensure_ascii=False, indent=4)
-
-        # drow_bbox
-        draw_layout_bbox(pdf_info_dict, pdf_bytes, md_bookname_save_path)
-        draw_text_bbox(pdf_info_dict, pdf_bytes, md_bookname_save_path)
-
     return pdf_info_dict

+ 27 - 110
magic_pdf/pdf_parse_by_txt.py

@@ -12,6 +12,7 @@ from magic_pdf.layout.bbox_sort import (
 )
 from magic_pdf.layout.layout_sort import LAYOUT_UNPROC, get_bboxes_layout, get_columns_cnt_of_layout, sort_text_block
 from magic_pdf.libs.drop_reason import DropReason
+from magic_pdf.libs.hash_utils import compute_md5
 from magic_pdf.libs.markdown_utils import escape_special_markdown_char
 from magic_pdf.libs.safe_filename import sanitize_filename
 from magic_pdf.libs.vis_utils import draw_bbox_on_page, draw_layout_bbox_on_page
@@ -73,46 +74,20 @@ paraMergeException_msg = ParaMergeException().message
 def parse_pdf_by_txt(
     pdf_bytes,
     pdf_model_output,
-    save_path,
-    book_name,
-    pdf_model_profile=None,
-    image_s3_config=None,
+    imageWriter,
     start_page_id=0,
     end_page_id=None,
-    junk_img_bojids=[],
     debug_mode=False,
 ):
-
-    save_tmp_path = os.path.join(os.path.dirname(__file__), "../..", "tmp", "unittest")
-    md_bookname_save_path = ""
-    book_name = sanitize_filename(book_name)
-    if debug_mode:
-        save_path = join_path(save_tmp_path, "md")
-        pdf_local_path = join_path(save_tmp_path, "download-pdfs", book_name)
-
-        if not os.path.exists(os.path.dirname(pdf_local_path)):
-            # 如果目录不存在,创建它
-            os.makedirs(os.path.dirname(pdf_local_path))
-
-        md_bookname_save_path = join_path(save_tmp_path, "md", book_name)
-        if not os.path.exists(md_bookname_save_path):
-            # 如果目录不存在,创建它
-            os.makedirs(md_bookname_save_path)
-
-        with open(pdf_local_path + ".pdf", "wb") as pdf_file:
-            pdf_file.write(pdf_bytes)
+    pdf_bytes_md5 = compute_md5(pdf_bytes)
 
     pdf_docs = fitz.open("pdf", pdf_bytes)
     pdf_info_dict = {}
-    img_s3_client = get_img_s3_client(save_path, image_s3_config)  # 更改函数名和参数,避免歧义
-    # img_s3_client = "img_s3_client"  #不创建这个对象,直接用字符串占位
-
     start_time = time.time()
 
     """通过统计pdf全篇文字,识别正文字体"""
     main_text_font = get_main_text_font(pdf_docs)
 
-
     end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
     for page_id in range(start_page_id, end_page_id + 1):
         page = pdf_docs[page_id]
@@ -128,19 +103,10 @@ def parse_pdf_by_txt(
         # 对单页面非重复id的img数量做统计,如果当前页超过1500则直接return need_drop
         """
         page_imgs = page.get_images()
-        img_counts = 0
-        for img in page_imgs:
-            img_bojid = img[0]
-            if img_bojid in junk_img_bojids:  # 判断这个图片在不在junklist中
-                continue  # 如果在junklist就不用管了,跳过
-            else:
-                recs = page.get_image_rects(img, transform=True)
-                if recs:  # 如果这张图在当前页面有展示
-                    img_counts += 1
-        if img_counts >= 1500:  # 如果去除了junkimg的影响,单页img仍然超过1500的话,就排除当前pdf
-            logger.warning(
-                f"page_id: {page_id}, img_counts: {img_counts}, drop this pdf: {book_name}, drop_reason: {DropReason.HIGH_COMPUTATIONAL_lOAD_BY_IMGS}"
-            )
+
+        # 去除对junkimg的依赖,简化逻辑
+        if len(page_imgs) > 1500:  # 如果当前页超过1500张图片,直接跳过
+            logger.warning(f"page_id: {page_id}, img_counts: {len(page_imgs)}, drop this pdf")
             result = {"need_drop": True, "drop_reason": DropReason.HIGH_COMPUTATIONAL_lOAD_BY_IMGS}
             if not debug_mode:
                 return result
@@ -154,10 +120,10 @@ def parse_pdf_by_txt(
             "dict",
             flags=fitz.TEXTFLAGS_TEXT,
         )["blocks"]
-        model_output_json = get_docx_model_output(pdf_model_output, pdf_model_profile, page_id)
+        model_output_json = get_docx_model_output(pdf_model_output, page_id)
 
         # 解析图片
-        image_bboxes = parse_images(page_id, page, model_output_json, junk_img_bojids)
+        image_bboxes = parse_images(page_id, page, model_output_json)
         image_bboxes = fix_image_vertical(image_bboxes, text_raw_blocks)  # 修正图片的位置
         image_bboxes = fix_seperated_image(image_bboxes)  # 合并有边重合的图片
         image_bboxes = include_img_title(text_raw_blocks, image_bboxes)  # 向图片上方和下方寻找title,使用规则进行匹配,暂时只支持英文规则
@@ -225,22 +191,18 @@ def parse_pdf_by_txt(
         """
         ==================================================================================================================================
         """
-        if debug_mode:  # debugmode截图到本地
-            save_path = join_path(save_tmp_path, "md")
 
         # 把图、表、公式都进行截图,保存到存储上,返回图片路径作为内容
         image_info, image_backup_info, table_info, inline_eq_info, interline_eq_info = save_images_by_bboxes(
-            book_name,
             page_id,
             page,
-            save_path,
+            pdf_bytes_md5,
             image_bboxes,
             images_overlap_backup,
             table_bboxes,
             equations_inline_bboxes,
             equations_interline_bboxes,
-            # 传入img_s3_client
-            img_s3_client,
+            imageWriter
         )  # 只要表格和图片的截图
         
         """"以下进入到公式替换环节 """
@@ -253,13 +215,13 @@ def parse_pdf_by_txt(
 
         """去掉footnote, 从文字和图片中(先去角标再去footnote试试)"""
         # 通过模型识别到的footnote
-        footnote_bboxes_by_model = parse_footnotes_by_model(page_id, page, model_output_json, md_bookname_save_path, debug_mode=debug_mode)
+        footnote_bboxes_by_model = parse_footnotes_by_model(page_id, page, model_output_json, debug_mode=debug_mode)
         # 通过规则识别到的footnote
         footnote_bboxes_by_rule = parse_footnotes_by_rule(remain_text_blocks, page_height, page_id, main_text_font)
         """进入pdf过滤器,去掉一些不合理的pdf"""
         is_good_pdf, err = pdf_filter(page, remain_text_blocks, table_bboxes, image_bboxes)
         if not is_good_pdf:
-            logger.warning(f"page_id: {page_id}, drop this pdf: {book_name}, reason: {err}")
+            logger.warning(f"page_id: {page_id}, drop this pdf: {pdf_bytes_md5}, reason: {err}")
             if not debug_mode:
                 return err
 
@@ -273,7 +235,7 @@ def parse_pdf_by_txt(
 
         if is_text_block_horz_overlap:
             # debug_show_bbox(pdf_docs, page_id, [b['bbox'] for b in remain_text_blocks], [], [], join_path(save_path, book_name, f"{book_name}_debug.pdf"), 0)
-            logger.warning(f"page_id: {page_id}, drop this pdf: {book_name}, reason: {DropReason.TEXT_BLCOK_HOR_OVERLAP}")
+            logger.warning(f"page_id: {page_id}, drop this pdf: {pdf_bytes_md5}, reason: {DropReason.TEXT_BLCOK_HOR_OVERLAP}")
             result = {"need_drop": True, "drop_reason": DropReason.TEXT_BLCOK_HOR_OVERLAP}
             if not debug_mode:
                 return result
@@ -292,21 +254,21 @@ def parse_pdf_by_txt(
         layout_bboxes, layout_tree = get_bboxes_layout(all_bboxes, page_boundry, page_id)
         
         if len(remain_text_blocks)>0 and len(all_bboxes)>0 and len(layout_bboxes)==0:
-            logger.warning(f"page_id: {page_id}, drop this pdf: {book_name}, reason: {DropReason.CAN_NOT_DETECT_PAGE_LAYOUT}")
+            logger.warning(f"page_id: {page_id}, drop this pdf: {pdf_bytes_md5}, reason: {DropReason.CAN_NOT_DETECT_PAGE_LAYOUT}")
             result = {"need_drop": True, "drop_reason": DropReason.CAN_NOT_DETECT_PAGE_LAYOUT}
             if not debug_mode:
                 return result
 
         """以下去掉复杂的布局和超过2列的布局"""
         if any([lay["layout_label"] == LAYOUT_UNPROC for lay in layout_bboxes]):  # 复杂的布局
-            logger.warning(f"page_id: {page_id}, drop this pdf: {book_name}, reason: {DropReason.COMPLICATED_LAYOUT}")
+            logger.warning(f"page_id: {page_id}, drop this pdf: {pdf_bytes_md5}, reason: {DropReason.COMPLICATED_LAYOUT}")
             result = {"need_drop": True, "drop_reason": DropReason.COMPLICATED_LAYOUT}
             if not debug_mode:
                 return result
 
         layout_column_width = get_columns_cnt_of_layout(layout_tree)
         if layout_column_width > 2:  # 去掉超过2列的布局pdf
-            logger.warning(f"page_id: {page_id}, drop this pdf: {book_name}, reason: {DropReason.TOO_MANY_LAYOUT_COLUMNS}")
+            logger.warning(f"page_id: {page_id}, drop this pdf: {pdf_bytes_md5}, reason: {DropReason.TOO_MANY_LAYOUT_COLUMNS}")
             result = {
                 "need_drop": True,
                 "drop_reason": DropReason.TOO_MANY_LAYOUT_COLUMNS,
@@ -390,29 +352,12 @@ def parse_pdf_by_txt(
     for page_info in pdf_info_dict.values():
         is_good_pdf, err = pdf_post_filter(page_info)
         if not is_good_pdf:
-            logger.warning(f"page_id: {i}, drop this pdf: {book_name}, reason: {err}")
+            logger.warning(f"page_id: {i}, drop this pdf: {pdf_bytes_md5}, reason: {err}")
             if not debug_mode:
                 return err
         i += 1
 
     if debug_mode:
-        params_file_save_path = join_path(save_tmp_path, "md", book_name, "preproc_out.json")
-        page_draw_rect_save_path = join_path(save_tmp_path, "md", book_name, "layout.pdf")
-        # dir_path = os.path.dirname(page_draw_rect_save_path)
-        # if not os.path.exists(dir_path):
-        #     # 如果目录不存在,创建它
-        #     os.makedirs(dir_path)
-
-        with open(params_file_save_path, "w", encoding="utf-8") as f:
-            json.dump(pdf_info_dict, f, ensure_ascii=False, indent=4)
-        # 先检测本地 page_draw_rect_save_path 是否存在,如果存在则删除
-        if os.path.exists(page_draw_rect_save_path):
-            os.remove(page_draw_rect_save_path)
-        # 绘制bbox和layout到pdf
-        draw_bbox_on_page(pdf_docs, pdf_info_dict, page_draw_rect_save_path)
-        draw_layout_bbox_on_page(pdf_docs, pdf_info_dict, header, footer, page_draw_rect_save_path)
-
-    if debug_mode:
         # 打印后处理阶段耗时
         logger.info(f"post_processing_time: {get_delta_time(start_time)}")
 
@@ -429,58 +374,30 @@ def parse_pdf_by_txt(
     para_process_pipeline = ParaProcessPipeline()
 
     def _deal_with_text_exception(error_info):
-        logger.warning(f"page_id: {page_id}, drop this pdf: {book_name}, reason: {error_info}")
+        logger.warning(f"page_id: {page_id}, drop this pdf: {pdf_bytes_md5}, reason: {error_info}")
         if error_info == denseSingleLineBlockException_msg:
-            logger.warning(f"Drop this pdf: {book_name}, reason: {DropReason.DENSE_SINGLE_LINE_BLOCK}")
+            logger.warning(f"Drop this pdf: {pdf_bytes_md5}, reason: {DropReason.DENSE_SINGLE_LINE_BLOCK}")
             result = {"need_drop": True, "drop_reason": DropReason.DENSE_SINGLE_LINE_BLOCK}
             return result
         if error_info == titleDetectionException_msg:
-            logger.warning(f"Drop this pdf: {book_name}, reason: {DropReason.TITLE_DETECTION_FAILED}")
+            logger.warning(f"Drop this pdf: {pdf_bytes_md5}, reason: {DropReason.TITLE_DETECTION_FAILED}")
             result = {"need_drop": True, "drop_reason": DropReason.TITLE_DETECTION_FAILED}
             return result
         elif error_info == titleLevelException_msg:
-            logger.warning(f"Drop this pdf: {book_name}, reason: {DropReason.TITLE_LEVEL_FAILED}")
+            logger.warning(f"Drop this pdf: {pdf_bytes_md5}, reason: {DropReason.TITLE_LEVEL_FAILED}")
             result = {"need_drop": True, "drop_reason": DropReason.TITLE_LEVEL_FAILED}
             return result
         elif error_info == paraSplitException_msg:
-            logger.warning(f"Drop this pdf: {book_name}, reason: {DropReason.PARA_SPLIT_FAILED}")
+            logger.warning(f"Drop this pdf: {pdf_bytes_md5}, reason: {DropReason.PARA_SPLIT_FAILED}")
             result = {"need_drop": True, "drop_reason": DropReason.PARA_SPLIT_FAILED}
             return result
         elif error_info == paraMergeException_msg:
-            logger.warning(f"Drop this pdf: {book_name}, reason: {DropReason.PARA_MERGE_FAILED}")
+            logger.warning(f"Drop this pdf: {pdf_bytes_md5}, reason: {DropReason.PARA_MERGE_FAILED}")
             result = {"need_drop": True, "drop_reason": DropReason.PARA_MERGE_FAILED}
             return result
 
-    if debug_mode:
-        input_pdf_file = f"{pdf_local_path}.pdf"
-        output_dir = f"{save_path}/{book_name}"
-        output_pdf_file = f"{output_dir}/pdf_annos.pdf"
-
-        """
-        Call the para_process_pipeline function to process the pdf_info_dict.
-        
-        Parameters:
-        para_debug_mode: str or None
-            If para_debug_mode is None, the para_process_pipeline will not keep any intermediate results.
-            If para_debug_mode is "simple", the para_process_pipeline will only keep the annos on the pdf and the final results as a json file.
-            If para_debug_mode is "full", the para_process_pipeline will keep all the intermediate results generated during each step.
-        """
-        pdf_info_dict, error_info = para_process_pipeline.para_process_pipeline(
-            pdf_info_dict,
-            para_debug_mode="simple",
-            input_pdf_path=input_pdf_file,
-            output_pdf_path=output_pdf_file,
-        )
-        # 打印段落处理阶段耗时
-        logger.info(f"para_process_time: {get_delta_time(start_time)}")
-
-        # debug的时候不return drop信息
-        if error_info is not None:
-            _deal_with_text_exception(error_info)
-        return pdf_info_dict
-    else:
-        pdf_info_dict, error_info = para_process_pipeline.para_process_pipeline(pdf_info_dict)
-        if error_info is not None:
-            return _deal_with_text_exception(error_info)
+    pdf_info_dict, error_info = para_process_pipeline.para_process_pipeline(pdf_info_dict)
+    if error_info is not None:
+        return _deal_with_text_exception(error_info)
 
     return pdf_info_dict

+ 1 - 2
magic_pdf/pdf_parse_for_train.py

@@ -112,7 +112,6 @@ def parse_pdf_for_train(
     pdf_model_output,
     save_path,
     book_name,
-    pdf_model_profile=None,
     image_s3_config=None,
     start_page_id=0,
     end_page_id=None,
@@ -200,7 +199,7 @@ def parse_pdf_for_train(
             flags=fitz.TEXTFLAGS_TEXT,
         )["blocks"]
         model_output_json = get_docx_model_output(
-            pdf_model_output, pdf_model_profile, page_id
+            pdf_model_output, page_id
         )
 
         # 解析图片

+ 1 - 1
magic_pdf/pre_proc/detect_footnote.py

@@ -3,7 +3,7 @@ from magic_pdf.libs.commons import fitz             # pyMuPDF库
 from magic_pdf.libs.coordinate_transform import get_scale_ratio
 
 
-def parse_footnotes_by_model(page_ID: int, page: fitz.Page, json_from_DocXchain_obj: dict, md_bookname_save_path, debug_mode=False):
+def parse_footnotes_by_model(page_ID: int, page: fitz.Page, json_from_DocXchain_obj: dict, md_bookname_save_path=None, debug_mode=False):
     """
     :param page_ID: int类型,当前page在当前pdf文档中是第page_D页。
     :param page :fitz读取的当前页的内容

+ 5 - 7
magic_pdf/pre_proc/ocr_cut_image.py

@@ -3,18 +3,16 @@ from magic_pdf.libs.ocr_content_type import ContentType
 from magic_pdf.libs.pdf_image_tools import cut_image
 
 
-def cut_image_and_table(spans, page, page_id, book_name, save_path, img_s3_client):
-    def s3_return_path(type):
-        return join_path(book_name, type)
+def cut_image_and_table(spans, page, page_id, pdf_bytes_md5, imageWriter):
 
-    def img_save_path(type):
-        return join_path(save_path, s3_return_path(type))
+    def return_path(type):
+        return join_path(pdf_bytes_md5, type)
 
     for span in spans:
         span_type = span['type']
         if span_type == ContentType.Image:
-            span['image_path'] = cut_image(span['bbox'], page_id, page, img_save_path('images'), s3_return_path=s3_return_path('images'), img_s3_client=img_s3_client)
+            span['image_path'] = cut_image(span['bbox'], page_id, page, return_path=return_path('images'), imageWriter=imageWriter)
         elif span_type == ContentType.Table:
-            span['image_path'] = cut_image(span['bbox'], page_id, page, img_save_path('tables'), s3_return_path=s3_return_path('tables'), img_s3_client=img_s3_client)
+            span['image_path'] = cut_image(span['bbox'], page_id, page, return_path=return_path('tables'), imageWriter=imageWriter)
 
     return spans

+ 55 - 4
magic_pdf/spark/spark_api.py

@@ -12,27 +12,78 @@
 其余部分至于构造s3cli, 获取ak,sk都在code-clean里写代码完成。不要反向依赖!!!
 
 """
-
+from loguru import logger
 
 from magic_pdf.io import AbsReaderWriter
+from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr
+from magic_pdf.pdf_parse_by_txt import parse_pdf_by_txt
 
 
 def parse_txt_pdf(pdf_bytes:bytes, pdf_models:list, imageWriter: AbsReaderWriter, is_debug=False, start_page=0, *args, **kwargs):
     """
     解析文本类pdf
     """
-    pass
+    pdf_info_dict = parse_pdf_by_txt(
+        pdf_bytes,
+        pdf_models,
+        imageWriter,
+        start_page_id=start_page,
+        debug_mode=is_debug,
+    )
+
+    pdf_info_dict["parse_type"] = "txt"
+
+    return pdf_info_dict
 
 
 def parse_ocr_pdf(pdf_bytes:bytes,  pdf_models:list, imageWriter: AbsReaderWriter, is_debug=False, start_page=0, *args, **kwargs):
     """
     解析ocr类pdf
     """
-    pass
+    pdf_info_dict = parse_pdf_by_ocr(
+        pdf_bytes,
+        pdf_models,
+        imageWriter,
+        start_page_id=start_page,
+        debug_mode=is_debug,
+    )
+
+    pdf_info_dict["parse_type"] = "ocr"
+
+    return pdf_info_dict
 
 
 def parse_union_pdf(pdf_bytes:bytes,  pdf_models:list, imageWriter: AbsReaderWriter, is_debug=False, start_page=0,  *args, **kwargs):
     """
     ocr和文本混合的pdf,全部解析出来
     """
-    pass
+    def parse_pdf(method):
+        try:
+            return method(
+                pdf_bytes,
+                pdf_models,
+                imageWriter,
+                start_page_id=start_page,
+                debug_mode=is_debug,
+            )
+        except Exception as e:
+            logger.error(f"{method.__name__} error: {e}")
+            return None
+
+    pdf_info_dict = parse_pdf(parse_pdf_by_txt)
+
+    if pdf_info_dict is None or pdf_info_dict.get("need_drop", False):
+        logger.warning(f"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr")
+        pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
+        if pdf_info_dict is None:
+            raise Exception("Both parse_pdf_by_txt and parse_pdf_by_ocr failed.")
+        else:
+            pdf_info_dict["parse_type"] = "ocr"
+    else:
+        pdf_info_dict["parse_type"] = "txt"
+
+    return pdf_info_dict
+
+
+def spark_json_extractor(jso:dict):
+    pass

+ 4 - 0
tools/ocr_badcase.py

@@ -867,6 +867,7 @@ def main(standard_file, test_file, zip_file, badcase_path, overall_path,base_dat
     save_results(result_dict, overall_report_dict,badcase_file,overall_file)
 
     result=compare_edit_distance(base_data_path, overall_report_dict)
+<<<<<<< HEAD
 
     if all([s3_bucket_name, s3_file_directory, aws_access_key, aws_secret_key, end_point_url]):
         try:
@@ -874,7 +875,10 @@ def main(standard_file, test_file, zip_file, badcase_path, overall_path,base_dat
             upload_to_s3(overall_file, s3_bucket_name, s3_file_directory, aws_access_key, aws_secret_key, end_point_url)
         except Exception as e:
             print(f"上传到S3时发生错误: {e}")
+=======
+>>>>>>> ff8f62aa3c28facc192104387f131d87978064fc
     print(result)
+    assert result == 1
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(description="主函数,执行整个评估流程。")