Преглед на файлове

Merge pull request #808 from opendatalab/dev

Dev->0.9 release
Xiaomeng Zhao преди 1 година
родител
ревизия
6575adeafe
променени са 67 файла, в които са добавени 729 реда и са изтрити 192 реда
  1. 33 16
      README.md
  2. 32 14
      README_zh-CN.md
  3. 2 15
      demo/demo.py
  4. 14 9
      demo/magic_pdf_parse_main.py
  5. 0 0
      docs/FAQ_en_us.md
  6. 0 0
      docs/FAQ_zh_cn.md
  7. 0 2
      docs/README_Ubuntu_CUDA_Acceleration_en_US.md
  8. 0 2
      docs/README_Ubuntu_CUDA_Acceleration_zh_CN.md
  9. 0 2
      docs/README_Windows_CUDA_Acceleration_en_US.md
  10. 0 2
      docs/README_Windows_CUDA_Acceleration_zh_CN.md
  11. 0 0
      docs/chemical_knowledge_introduction/introduction.pdf
  12. 0 0
      docs/chemical_knowledge_introduction/introduction.xmind
  13. 0 0
      docs/download_models.py
  14. 0 0
      docs/download_models_hf.py
  15. 3 1
      docs/how_to_download_models_en.md
  16. 4 8
      docs/how_to_download_models_zh_cn.md
  17. 0 0
      docs/images/MinerU-logo-hq.png
  18. 0 0
      docs/images/MinerU-logo.png
  19. 0 0
      docs/images/datalab_logo.png
  20. 0 0
      docs/images/flowchart_en.png
  21. 0 0
      docs/images/flowchart_zh_cn.png
  22. 0 0
      docs/images/layout_example.png
  23. 0 0
      docs/images/poly.png
  24. 0 0
      docs/images/project_panorama_en.png
  25. 0 0
      docs/images/project_panorama_zh_cn.png
  26. 0 0
      docs/images/spans_example.png
  27. 0 0
      docs/images/web_demo_1.png
  28. 0 0
      docs/output_file_en_us.md
  29. 0 0
      docs/output_file_zh_cn.md
  30. 0 0
      magic_pdf/config/__init__.py
  31. 11 11
      magic_pdf/dict2md/ocr_mkcontent.py
  32. 8 2
      magic_pdf/libs/Constants.py
  33. 39 13
      magic_pdf/libs/draw_bbox.py
  34. 228 27
      magic_pdf/model/magic_model.py
  35. 25 6
      magic_pdf/model/pdf_extract_kit.py
  36. 8 1
      magic_pdf/model/pek_sub_modules/structeqtable/StructTableModel.py
  37. 102 32
      magic_pdf/pdf_parse_union_core_v2.py
  38. 31 24
      magic_pdf/pre_proc/ocr_detect_all_bboxes.py
  39. 27 1
      magic_pdf/pre_proc/ocr_dict_merge.py
  40. 1 1
      magic_pdf/tools/cli.py
  41. 0 0
      magic_pdf/utils/__init__.py
  42. 0 0
      next_docs/en/.readthedocs.yaml
  43. 0 0
      next_docs/en/Makefile
  44. 0 0
      next_docs/en/_static/image/logo.png
  45. 0 0
      next_docs/en/api.rst
  46. 0 0
      next_docs/en/api/data_reader_writer.rst
  47. 0 0
      next_docs/en/api/dataset.rst
  48. 0 0
      next_docs/en/api/io.rst
  49. 0 0
      next_docs/en/api/read_api.rst
  50. 0 0
      next_docs/en/api/schemas.rst
  51. 0 0
      next_docs/en/api/utils.rst
  52. 0 0
      next_docs/en/conf.py
  53. 0 0
      next_docs/en/index.rst
  54. 0 0
      next_docs/en/make.bat
  55. 0 0
      next_docs/requirements.txt
  56. 0 0
      next_docs/zh_cn/.readthedocs.yaml
  57. 0 0
      next_docs/zh_cn/Makefile
  58. 0 0
      next_docs/zh_cn/_static/image/logo.png
  59. 0 0
      next_docs/zh_cn/conf.py
  60. 0 0
      next_docs/zh_cn/index.rst
  61. 0 0
      next_docs/zh_cn/make.bat
  62. 1 2
      projects/README.md
  63. 1 1
      projects/README_zh-CN.md
  64. 46 0
      projects/multi_gpu/README.md
  65. 39 0
      projects/multi_gpu/client.py
  66. 74 0
      projects/multi_gpu/server.py
  67. BIN
      projects/multi_gpu/small_ocr.pdf

Файловите разлики са ограничени, защото са твърде много
+ 33 - 16
README.md


Файловите разлики са ограничени, защото са твърде много
+ 32 - 14
README_zh-CN.md


+ 2 - 15
demo/demo.py

@@ -1,35 +1,22 @@
 import os
-import json
 
 from loguru import logger
-
 from magic_pdf.pipe.UNIPipe import UNIPipe
 from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
 
-import magic_pdf.model as model_config 
-model_config.__use_inside_model__ = True
 
 try:
     current_script_dir = os.path.dirname(os.path.abspath(__file__))
     demo_name = "demo1"
     pdf_path = os.path.join(current_script_dir, f"{demo_name}.pdf")
-    model_path = os.path.join(current_script_dir, f"{demo_name}.json")
     pdf_bytes = open(pdf_path, "rb").read()
-    # model_json = json.loads(open(model_path, "r", encoding="utf-8").read())
-    model_json = []  # model_json传空list使用内置模型解析
-    jso_useful_key = {"_pdf_type": "", "model_list": model_json}
+    jso_useful_key = {"_pdf_type": "", "model_list": []}
     local_image_dir = os.path.join(current_script_dir, 'images')
     image_dir = str(os.path.basename(local_image_dir))
     image_writer = DiskReaderWriter(local_image_dir)
     pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer)
     pipe.pipe_classify()
-    """如果没有传入有效的模型数据,则使用内置model解析"""
-    if len(model_json) == 0:
-        if model_config.__use_inside_model__:
-            pipe.pipe_analyze()
-        else:
-            logger.error("need model list input")
-            exit(1)
+    pipe.pipe_analyze()
     pipe.pipe_parse()
     md_content = pipe.pipe_mk_markdown(image_dir, drop_mode="none")
     with open(f"{demo_name}.md", "w", encoding="utf-8") as f:

+ 14 - 9
demo/magic_pdf_parse_main.py

@@ -4,13 +4,12 @@ import copy
 
 from loguru import logger
 
+from magic_pdf.libs.draw_bbox import draw_layout_bbox, draw_span_bbox
 from magic_pdf.pipe.UNIPipe import UNIPipe
 from magic_pdf.pipe.OCRPipe import OCRPipe
 from magic_pdf.pipe.TXTPipe import TXTPipe
 from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
-import magic_pdf.model as model_config
 
-model_config.__use_inside_model__ = True
 
 # todo: 设备类型选择 (?)
 
@@ -47,11 +46,20 @@ def json_md_dump(
     )
 
 
+# 可视化
+def draw_visualization_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name):
+    # 画布局框,附带排序结果
+    draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
+    # 画 span 框
+    draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
+
+
 def pdf_parse_main(
         pdf_path: str,
         parse_method: str = 'auto',
         model_json_path: str = None,
         is_json_md_dump: bool = True,
+        is_draw_visualization_bbox: bool = True,
         output_dir: str = None
 ):
     """
@@ -108,11 +116,7 @@ def pdf_parse_main(
 
         # 如果没有传入模型数据,则使用内置模型解析
         if not model_json:
-            if model_config.__use_inside_model__:
-                pipe.pipe_analyze()  # 解析
-            else:
-                logger.error("need model list input")
-                exit(1)
+            pipe.pipe_analyze()  # 解析
 
         # 执行解析
         pipe.pipe_parse()
@@ -121,10 +125,11 @@ def pdf_parse_main(
         content_list = pipe.pipe_mk_uni_format(image_path_parent, drop_mode="none")
         md_content = pipe.pipe_mk_markdown(image_path_parent, drop_mode="none")
 
-
         if is_json_md_dump:
             json_md_dump(pipe, md_writer, pdf_name, content_list, md_content)
 
+        if is_draw_visualization_bbox:
+            draw_visualization_bbox(pipe.pdf_mid_data['pdf_info'], pdf_bytes, output_path, pdf_name)
 
     except Exception as e:
         logger.exception(e)
@@ -132,5 +137,5 @@ def pdf_parse_main(
 
 # 测试
 if __name__ == '__main__':
-    pdf_path = r"C:\Users\XYTK2\Desktop\2024-2016-gb-cd-300.pdf"
+    pdf_path = r"D:\project\20240617magicpdf\Magic-PDF\demo\demo1.pdf"
     pdf_parse_main(pdf_path)

+ 0 - 0
old_docs/FAQ_en_us.md → docs/FAQ_en_us.md


+ 0 - 0
old_docs/FAQ_zh_cn.md → docs/FAQ_zh_cn.md


+ 0 - 2
old_docs/README_Ubuntu_CUDA_Acceleration_en_US.md → docs/README_Ubuntu_CUDA_Acceleration_en_US.md

@@ -97,8 +97,6 @@ magic-pdf -p small_ocr.pdf
 
 If your graphics card has at least **8GB** of VRAM, follow these steps to test CUDA acceleration:
 
-> ❗ Due to the extremely limited nature of 8GB VRAM for running this application, you need to close all other programs using VRAM to ensure that 8GB of VRAM is available when running this application.
-
 1. Modify the value of `"device-mode"` in the `magic-pdf.json` configuration file located in your home directory.
    ```json
    {

+ 0 - 2
old_docs/README_Ubuntu_CUDA_Acceleration_zh_CN.md → docs/README_Ubuntu_CUDA_Acceleration_zh_CN.md

@@ -98,8 +98,6 @@ magic-pdf -p small_ocr.pdf
 
 如果您的显卡显存大于等于 **8GB** ,可以进行以下流程,测试CUDA解析加速效果
 
-> ❗️因8GB显存运行本应用非常极限,需要关闭所有其他正在使用显存的程序以确保本应用运行时有足额8GB显存可用。
-
 **1.修改【用户目录】中配置文件magic-pdf.json中"device-mode"的值**
 
 ```json

+ 0 - 2
old_docs/README_Windows_CUDA_Acceleration_en_US.md → docs/README_Windows_CUDA_Acceleration_en_US.md

@@ -60,8 +60,6 @@ Download a sample file from the repository and test it.
 
 If your graphics card has at least 8GB of VRAM, follow these steps to test CUDA-accelerated parsing performance.
 
-> ❗ Due to the extremely limited nature of 8GB VRAM for running this application, you need to close all other programs using VRAM to ensure that 8GB of VRAM is available when running this application.
-
 1. **Overwrite the installation of torch and torchvision** supporting CUDA.
 
    ```

+ 0 - 2
old_docs/README_Windows_CUDA_Acceleration_zh_CN.md → docs/README_Windows_CUDA_Acceleration_zh_CN.md

@@ -61,8 +61,6 @@ pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com -i h
 
 如果您的显卡显存大于等于 **8GB** ,可以进行以下流程,测试CUDA解析加速效果
 
-> ❗️因8GB显存运行本应用非常极限,需要关闭所有其他正在使用显存的程序以确保本应用运行时有足额8GB显存可用。
-
 **1.覆盖安装支持cuda的torch和torchvision**
 
 ```bash

+ 0 - 0
old_docs/chemical_knowledge_introduction/introduction.pdf → docs/chemical_knowledge_introduction/introduction.pdf


+ 0 - 0
old_docs/chemical_knowledge_introduction/introduction.xmind → docs/chemical_knowledge_introduction/introduction.xmind


+ 0 - 0
old_docs/download_models.py → docs/download_models.py


+ 0 - 0
old_docs/download_models_hf.py → docs/download_models_hf.py


+ 3 - 1
old_docs/how_to_download_models_en.md → docs/how_to_download_models_en.md

@@ -22,7 +22,9 @@ The configuration file can be found in the user directory, with the filename `ma
 
 > Due to feedback from some users that downloading model files using git lfs was incomplete or resulted in corrupted model files, this method is no longer recommended.
 
-If you previously downloaded model files via git lfs, you can navigate to the previous download directory and use the `git pull` command to update the model.
+When magic-pdf <= 0.8.1, if you have previously downloaded the model files via git lfs, you can navigate to the previous download directory and update the models using the `git pull` command.
+
+> For versions 0.9.x and later, due to the repository change and the addition of the layout sorting model in PDF-Extract-Kit 1.0, the models cannot be updated using the `git pull` command. Instead, a Python script must be used for one-click updates.
 
 ## 2. Models downloaded via Hugging Face or Model Scope
 

+ 4 - 8
old_docs/how_to_download_models_zh_cn.md → docs/how_to_download_models_zh_cn.md

@@ -34,14 +34,10 @@ python脚本会自动下载模型文件并配置好配置文件中的模型目
 
 > 由于部分用户反馈通过git lfs下载模型文件遇到下载不全和模型文件损坏情况,现已不推荐使用该方式下载。
 
-如此前通过 git lfs 下载过模型文件,可以进入到之前的下载目录中,通过`git pull`命令更新模型。
-
-> 0.9.x及以后版本由于新增layout排序模型,且该模型和此前的模型不在同一仓库,不能通过`git pull`命令更新,需要单独下载。
->
-> ```
-> from modelscope import snapshot_download
-> snapshot_download('ppaanngggg/layoutreader')
-> ```
+当magic-pdf <= 0.8.1时,如此前通过 git lfs 下载过模型文件,可以进入到之前的下载目录中,通过`git pull`命令更新模型。
+
+> 0.9.x及以后版本由于PDF-Extract-Kit 1.0更换仓库和新增layout排序模型,不能通过`git pull`命令更新,需要使用python脚本一键更新。
+
 
 ## 2. 通过 Hugging Face 或 Model Scope 下载过模型
 

+ 0 - 0
old_docs/images/MinerU-logo-hq.png → docs/images/MinerU-logo-hq.png


+ 0 - 0
old_docs/images/MinerU-logo.png → docs/images/MinerU-logo.png


+ 0 - 0
old_docs/images/datalab_logo.png → docs/images/datalab_logo.png


+ 0 - 0
old_docs/images/flowchart_en.png → docs/images/flowchart_en.png


+ 0 - 0
old_docs/images/flowchart_zh_cn.png → docs/images/flowchart_zh_cn.png


+ 0 - 0
old_docs/images/layout_example.png → docs/images/layout_example.png


+ 0 - 0
old_docs/images/poly.png → docs/images/poly.png


+ 0 - 0
old_docs/images/project_panorama_en.png → docs/images/project_panorama_en.png


+ 0 - 0
old_docs/images/project_panorama_zh_cn.png → docs/images/project_panorama_zh_cn.png


+ 0 - 0
old_docs/images/spans_example.png → docs/images/spans_example.png


+ 0 - 0
old_docs/images/web_demo_1.png → docs/images/web_demo_1.png


+ 0 - 0
old_docs/output_file_en_us.md → docs/output_file_en_us.md


+ 0 - 0
old_docs/output_file_zh_cn.md → docs/output_file_zh_cn.md


+ 0 - 0
docs/en/api/io.rst → magic_pdf/config/__init__.py


+ 11 - 11
magic_pdf/dict2md/ocr_mkcontent.py

@@ -70,17 +70,17 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
                                     para_text += f"\n![]({join_path(img_buket_path, span['image_path'])})  \n"
                 for block in para_block['blocks']:  # 2nd.拼image_caption
                     if block['type'] == BlockType.ImageCaption:
-                        para_text += merge_para_with_text(block)
-                for block in para_block['blocks']:  # 2nd.拼image_caption
+                        para_text += merge_para_with_text(block) + '  \n'
+                for block in para_block['blocks']:  # 3rd.拼image_footnote
                     if block['type'] == BlockType.ImageFootnote:
-                        para_text += merge_para_with_text(block)
+                        para_text += merge_para_with_text(block) + '  \n'
         elif para_type == BlockType.Table:
             if mode == 'nlp':
                 continue
             elif mode == 'mm':
                 for block in para_block['blocks']:  # 1st.拼table_caption
                     if block['type'] == BlockType.TableCaption:
-                        para_text += merge_para_with_text(block)
+                        para_text += merge_para_with_text(block) + '  \n'
                 for block in para_block['blocks']:  # 2nd.拼table_body
                     if block['type'] == BlockType.TableBody:
                         for line in block['lines']:
@@ -95,7 +95,7 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
                                         para_text += f"\n![]({join_path(img_buket_path, span['image_path'])})  \n"
                 for block in para_block['blocks']:  # 3rd.拼table_footnote
                     if block['type'] == BlockType.TableFootnote:
-                        para_text += merge_para_with_text(block)
+                        para_text += merge_para_with_text(block) + '  \n'
 
         if para_text.strip() == '':
             continue
@@ -180,18 +180,18 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx, drop_reason
             'text_format': 'latex',
         }
     elif para_type == BlockType.Image:
-        para_content = {'type': 'image'}
+        para_content = {'type': 'image', 'img_caption': [], 'img_footnote': []}
         for block in para_block['blocks']:
             if block['type'] == BlockType.ImageBody:
                 para_content['img_path'] = join_path(
                     img_buket_path,
                     block['lines'][0]['spans'][0]['image_path'])
             if block['type'] == BlockType.ImageCaption:
-                para_content['img_caption'] = merge_para_with_text(block)
+                para_content['img_caption'].append(merge_para_with_text(block))
             if block['type'] == BlockType.ImageFootnote:
-                para_content['img_footnote'] = merge_para_with_text(block)
+                para_content['img_footnote'].append(merge_para_with_text(block))
     elif para_type == BlockType.Table:
-        para_content = {'type': 'table'}
+        para_content = {'type': 'table', 'table_caption': [], 'table_footnote': []}
         for block in para_block['blocks']:
             if block['type'] == BlockType.TableBody:
                 if block["lines"][0]["spans"][0].get('latex', ''):
@@ -200,9 +200,9 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx, drop_reason
                     para_content['table_body'] = f"\n\n{block['lines'][0]['spans'][0]['html']}\n\n"
                 para_content['img_path'] = join_path(img_buket_path, block["lines"][0]["spans"][0]['image_path'])
             if block['type'] == BlockType.TableCaption:
-                para_content['table_caption'] = merge_para_with_text(block)
+                para_content['table_caption'].append(merge_para_with_text(block))
             if block['type'] == BlockType.TableFootnote:
-                para_content['table_footnote'] = merge_para_with_text(block)
+                para_content['table_footnote'].append(merge_para_with_text(block))
 
     para_content['page_idx'] = page_idx
 

+ 8 - 2
magic_pdf/libs/Constants.py

@@ -23,14 +23,20 @@ TABLE_MASTER_DICT = "table_master_structure_dict.txt"
 TABLE_MASTER_DIR = "table_structure_tablemaster_infer/"
 
 # pp detect model dir
-DETECT_MODEL_DIR = "ch_PP-OCRv3_det_infer"
+DETECT_MODEL_DIR = "ch_PP-OCRv4_det_infer"
 
 # pp rec model dir
-REC_MODEL_DIR = "ch_PP-OCRv3_rec_infer"
+REC_MODEL_DIR = "ch_PP-OCRv4_rec_infer"
 
 # pp rec char dict path
 REC_CHAR_DICT = "ppocr_keys_v1.txt"
 
+# pp rec copy rec directory
+PP_REC_DIRECTORY = ".paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer"
+
+# pp rec copy det directory
+PP_DET_DIRECTORY = ".paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer"
+
 
 class MODEL_NAME:
     # pp table structure algorithm

+ 39 - 13
magic_pdf/libs/draw_bbox.py

@@ -141,11 +141,33 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
 
     layout_bbox_list = []
 
+    table_type_order = {
+        'table_caption': 1,
+        'table_body': 2,
+        'table_footnote': 3
+    }
     for page in pdf_info:
         page_block_list = []
         for block in page['para_blocks']:
-            bbox = block['bbox']
-            page_block_list.append(bbox)
+            if block['type'] in [
+                BlockType.Text,
+                BlockType.Title,
+                BlockType.InterlineEquation,
+                BlockType.List,
+                BlockType.Index,
+            ]:
+                bbox = block['bbox']
+                page_block_list.append(bbox)
+            elif block['type'] in [BlockType.Image]:
+                for sub_block in block['blocks']:
+                    bbox = sub_block['bbox']
+                    page_block_list.append(bbox)
+            elif block['type'] in [BlockType.Table]:
+                sorted_blocks = sorted(block['blocks'], key=lambda x: table_type_order[x['type']])
+                for sub_block in sorted_blocks:
+                    bbox = sub_block['bbox']
+                    page_block_list.append(bbox)
+
         layout_bbox_list.append(page_block_list)
 
     pdf_docs = fitz.open('pdf', pdf_bytes)
@@ -153,11 +175,11 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
     for i, page in enumerate(pdf_docs):
 
         draw_bbox_without_number(i, dropped_bbox_list, page, [158, 158, 158], True)
-        draw_bbox_without_number(i, tables_list, page, [153, 153, 0], True)  # color !
+        # draw_bbox_without_number(i, tables_list, page, [153, 153, 0], True)  # color !
         draw_bbox_without_number(i, tables_body_list, page, [204, 204, 0], True)
         draw_bbox_without_number(i, tables_caption_list, page, [255, 255, 102], True)
         draw_bbox_without_number(i, tables_footnote_list, page, [229, 255, 204], True)
-        draw_bbox_without_number(i, imgs_list, page, [51, 102, 0], True)
+        # draw_bbox_without_number(i, imgs_list, page, [51, 102, 0], True)
         draw_bbox_without_number(i, imgs_body_list, page, [153, 255, 51], True)
         draw_bbox_without_number(i, imgs_caption_list, page, [102, 178, 255], True)
         draw_bbox_without_number(i, imgs_footnote_list, page, [255, 178, 102], True),
@@ -338,19 +360,23 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
     for page in pdf_info:
         page_line_list = []
         for block in page['preproc_blocks']:
-            if block['type'] in ['text', 'title', 'interline_equation']:
+            if block['type'] in [BlockType.Text, BlockType.Title, BlockType.InterlineEquation]:
                 for line in block['lines']:
                     bbox = line['bbox']
                     index = line['index']
                     page_line_list.append({'index': index, 'bbox': bbox})
-            if block['type'] in ['table', 'image']:
-                bbox = block['bbox']
-                index = block['index']
-                page_line_list.append({'index': index, 'bbox': bbox})
-            # for line in block['lines']:
-            #     bbox = line['bbox']
-            #     index = line['index']
-            #     page_line_list.append({'index': index, 'bbox': bbox})
+            if block['type'] in [BlockType.Image, BlockType.Table]:
+                for sub_block in block['blocks']:
+                    if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
+                        for line in sub_block['virtual_lines']:
+                            bbox = line['bbox']
+                            index = line['index']
+                            page_line_list.append({'index': index, 'bbox': bbox})
+                    elif sub_block['type'] in [BlockType.ImageCaption, BlockType.TableCaption, BlockType.ImageFootnote, BlockType.TableFootnote]:
+                        for line in sub_block['lines']:
+                            bbox = line['bbox']
+                            index = line['index']
+                            page_line_list.append({'index': index, 'bbox': bbox})
         sorted_bboxes = sorted(page_line_list, key=lambda x: x['index'])
         layout_bbox_list.append(sorted_bbox['bbox'] for sorted_bbox in sorted_bboxes)
     pdf_docs = fitz.open('pdf', pdf_bytes)

+ 228 - 27
magic_pdf/model/magic_model.py

@@ -1,3 +1,4 @@
+import enum
 import json
 
 from magic_pdf.data.dataset import Dataset
@@ -10,6 +11,7 @@ from magic_pdf.libs.coordinate_transform import get_scale_ratio
 from magic_pdf.libs.local_math import float_gt
 from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
 from magic_pdf.libs.ocr_content_type import CategoryId, ContentType
+from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox
 from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
 from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
 
@@ -17,6 +19,14 @@ CAPATION_OVERLAP_AREA_RATIO = 0.6
 MERGE_BOX_OVERLAP_AREA_RATIO = 1.1
 
 
+class PosRelationEnum(enum.Enum):
+    LEFT = 'left'
+    RIGHT = 'right'
+    UP = 'up'
+    BOTTOM = 'bottom'
+    ALL = 'all'
+
+
 class MagicModel:
     """每个函数没有得到元素的时候返回空list."""
 
@@ -124,8 +134,7 @@ class MagicModel:
             l1 = bbox1[2] - bbox1[0]
             l2 = bbox2[2] - bbox2[0]
 
-        min_l, max_l = min(l1, l2), max(l1, l2)
-        if (max_l - min_l) * 1.0 / max_l > 0.4:
+        if l2 > l1 and (l2 - l1) / l1 > 0.3:
             return float('inf')
 
         return bbox_distance(bbox1, bbox2)
@@ -591,9 +600,24 @@ class MagicModel:
         return ret, total_subject_object_dis
 
     def __tie_up_category_by_distance_v2(
-        self, page_no, subject_category_id, object_category_id
+        self,
+        page_no: int,
+        subject_category_id: int,
+        object_category_id: int,
+        priority_pos: PosRelationEnum,
     ):
+        """_summary_
+
+        Args:
+            page_no (int): _description_
+            subject_category_id (int): _description_
+            object_category_id (int): _description_
+            priority_pos (PosRelationEnum): _description_
 
+        Returns:
+            _type_: _description_
+        """
+        AXIS_MULPLICITY = 0.5
         subjects = self.__reduct_overlap(
             list(
                 map(
@@ -617,67 +641,244 @@ class MagicModel:
                 )
             )
         )
+        M = len(objects)
 
         subjects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
         objects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
-        dis = [[float('inf')] * len(subjects) for _ in range(len(objects))]
+
+        sub_obj_map_h = {i: [] for i in range(len(subjects))}
+
+        dis_by_directions = {
+            'top': [[-1, float('inf')]] * M,
+            'bottom': [[-1, float('inf')]] * M,
+            'left': [[-1, float('inf')]] * M,
+            'right': [[-1, float('inf')]] * M,
+        }
+
         for i, obj in enumerate(objects):
+            l_x_axis, l_y_axis = (
+                obj['bbox'][2] - obj['bbox'][0],
+                obj['bbox'][3] - obj['bbox'][1],
+            )
+            axis_unit = min(l_x_axis, l_y_axis)
             for j, sub in enumerate(subjects):
-                dis[i][j] = self._bbox_distance(sub['bbox'], obj['bbox'])
 
-        sub_obj_map_h = {i: [] for i in range(len(subjects))}
-        for i in range(len(objects)):
-            min_l_idx = 0
-            for j in range(1, len(subjects)):
-                if dis[i][j] == float('inf'):
+                bbox1, bbox2, _ = _remove_overlap_between_bbox(
+                    objects[i]['bbox'], subjects[j]['bbox']
+                )
+                left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
+                flags = [left, right, bottom, top]
+                if sum([1 if v else 0 for v in flags]) > 1:
+                    continue
+
+                if left:
+                    if dis_by_directions['left'][i][1] > bbox_distance(
+                        obj['bbox'], sub['bbox']
+                    ):
+                        dis_by_directions['left'][i] = [
+                            j,
+                            bbox_distance(obj['bbox'], sub['bbox']),
+                        ]
+                if right:
+                    if dis_by_directions['right'][i][1] > bbox_distance(
+                        obj['bbox'], sub['bbox']
+                    ):
+                        dis_by_directions['right'][i] = [
+                            j,
+                            bbox_distance(obj['bbox'], sub['bbox']),
+                        ]
+                if bottom:
+                    if dis_by_directions['bottom'][i][1] > bbox_distance(
+                        obj['bbox'], sub['bbox']
+                    ):
+                        dis_by_directions['bottom'][i] = [
+                            j,
+                            bbox_distance(obj['bbox'], sub['bbox']),
+                        ]
+                if top:
+                    if dis_by_directions['top'][i][1] > bbox_distance(
+                        obj['bbox'], sub['bbox']
+                    ):
+                        dis_by_directions['top'][i] = [
+                            j,
+                            bbox_distance(obj['bbox'], sub['bbox']),
+                        ]
+
+            if (
+                dis_by_directions['top'][i][1] != float('inf')
+                and dis_by_directions['bottom'][i][1] != float('inf')
+                and priority_pos in (PosRelationEnum.BOTTOM, PosRelationEnum.UP)
+            ):
+                RATIO = 3
+                if (
+                    abs(
+                        dis_by_directions['top'][i][1]
+                        - dis_by_directions['bottom'][i][1]
+                    )
+                    < RATIO * axis_unit
+                ):
+
+                    if priority_pos == PosRelationEnum.BOTTOM:
+                        sub_obj_map_h[dis_by_directions['bottom'][i][0]].append(i)
+                    else:
+                        sub_obj_map_h[dis_by_directions['top'][i][0]].append(i)
                     continue
-                if dis[i][j] < dis[i][min_l_idx]:
-                    min_l_idx = j
 
-            if dis[i][min_l_idx] < float('inf'):
-                sub_obj_map_h[min_l_idx].append(i)
+            if dis_by_directions['left'][i][1] != float('inf') or dis_by_directions[
+                'right'
+            ][i][1] != float('inf'):
+                if dis_by_directions['left'][i][1] != float(
+                    'inf'
+                ) and dis_by_directions['right'][i][1] != float('inf'):
+                    if AXIS_MULPLICITY * axis_unit >= abs(
+                        dis_by_directions['left'][i][1]
+                        - dis_by_directions['right'][i][1]
+                    ):
+                        left_sub_bbox = subjects[dis_by_directions['left'][i][0]][
+                            'bbox'
+                        ]
+                        right_sub_bbox = subjects[dis_by_directions['right'][i][0]][
+                            'bbox'
+                        ]
+
+                        left_sub_bbox_y_axis = left_sub_bbox[3] - left_sub_bbox[1]
+                        right_sub_bbox_y_axis = right_sub_bbox[3] - right_sub_bbox[1]
+
+                        if (
+                            abs(left_sub_bbox_y_axis - l_y_axis)
+                            + dis_by_directions['left'][i][0]
+                            > abs(right_sub_bbox_y_axis - l_y_axis)
+                            + dis_by_directions['right'][i][0]
+                        ):
+                            left_or_right = dis_by_directions['right'][i]
+                        else:
+                            left_or_right = dis_by_directions['left'][i]
+                    else:
+                        left_or_right = dis_by_directions['left'][i]
+                        if left_or_right[1] > dis_by_directions['right'][i][1]:
+                            left_or_right = dis_by_directions['right'][i]
+                else:
+                    left_or_right = dis_by_directions['left'][i]
+                    if left_or_right[1] == float('inf'):
+                        left_or_right = dis_by_directions['right'][i]
+            else:
+                left_or_right = [-1, float('inf')]
+
+            if dis_by_directions['top'][i][1] != float('inf') or dis_by_directions[
+                'bottom'
+            ][i][1] != float('inf'):
+                if dis_by_directions['top'][i][1] != float('inf') and dis_by_directions[
+                    'bottom'
+                ][i][1] != float('inf'):
+                    if AXIS_MULPLICITY * axis_unit >= abs(
+                        dis_by_directions['top'][i][1]
+                        - dis_by_directions['bottom'][i][1]
+                    ):
+                        top_bottom = subjects[dis_by_directions['bottom'][i][0]]['bbox']
+                        bottom_top = subjects[dis_by_directions['top'][i][0]]['bbox']
+
+                        top_bottom_x_axis = top_bottom[2] - top_bottom[0]
+                        bottom_top_x_axis = bottom_top[2] - bottom_top[0]
+                        if (
+                            abs(top_bottom_x_axis - l_x_axis)
+                            + dis_by_directions['bottom'][i][1]
+                            > abs(bottom_top_x_axis - l_x_axis)
+                            + dis_by_directions['top'][i][1]
+                        ):
+                            top_or_bottom = dis_by_directions['top'][i]
+                        else:
+                            top_or_bottom = dis_by_directions['bottom'][i]
+                    else:
+                        top_or_bottom = dis_by_directions['top'][i]
+                        if top_or_bottom[1] > dis_by_directions['bottom'][i][1]:
+                            top_or_bottom = dis_by_directions['bottom'][i]
+                else:
+                    top_or_bottom = dis_by_directions['top'][i]
+                    if top_or_bottom[1] == float('inf'):
+                        top_or_bottom = dis_by_directions['bottom'][i]
             else:
-                print(i, 'no nearest')
+                top_or_bottom = [-1, float('inf')]
+
+            if left_or_right[1] != float('inf') or top_or_bottom[1] != float('inf'):
+                if left_or_right[1] != float('inf') and top_or_bottom[1] != float(
+                    'inf'
+                ):
+                    if AXIS_MULPLICITY * axis_unit >= abs(
+                        left_or_right[1] - top_or_bottom[1]
+                    ):
+                        y_axis_bbox = subjects[left_or_right[0]]['bbox']
+                        x_axis_bbox = subjects[top_or_bottom[0]]['bbox']
+
+                        if (
+                            abs((x_axis_bbox[2] - x_axis_bbox[0]) - l_x_axis) / l_x_axis
+                            > abs((y_axis_bbox[3] - y_axis_bbox[1]) - l_y_axis)
+                            / l_y_axis
+                        ):
+                            sub_obj_map_h[left_or_right[0]].append(i)
+                        else:
+                            sub_obj_map_h[top_or_bottom[0]].append(i)
+                    else:
+                        if left_or_right[1] > top_or_bottom[1]:
+                            sub_obj_map_h[top_or_bottom[0]].append(i)
+                        else:
+                            sub_obj_map_h[left_or_right[0]].append(i)
+                else:
+                    if left_or_right[1] != float('inf'):
+                        sub_obj_map_h[left_or_right[0]].append(i)
+                    else:
+                        sub_obj_map_h[top_or_bottom[0]].append(i)
         ret = []
         for i in sub_obj_map_h.keys():
             ret.append(
                 {
-                    'sub_bbox': subjects[i]['bbox'],
-                    'obj_bboxes': [objects[j]['bbox'] for j in sub_obj_map_h[i]],
+                    'sub_bbox': {
+                        'bbox': subjects[i]['bbox'],
+                        'score': subjects[i]['score'],
+                    },
+                    'obj_bboxes': [
+                        {'score': objects[j]['score'], 'bbox': objects[j]['bbox']}
+                        for j in sub_obj_map_h[i]
+                    ],
                     'sub_idx': i,
                 }
             )
         return ret
 
     def get_imgs_v2(self, page_no: int):
-        with_captions = self.__tie_up_category_by_distance_v2(page_no, 3, 4)
+        with_captions = self.__tie_up_category_by_distance_v2(
+            page_no, 3, 4, PosRelationEnum.BOTTOM
+        )
         with_footnotes = self.__tie_up_category_by_distance_v2(
-            page_no, 3, CategoryId.ImageFootnote
+            page_no, 3, CategoryId.ImageFootnote, PosRelationEnum.ALL
         )
         ret = []
         for v in with_captions:
             record = {
-                'image_bbox': v['sub_bbox'],
-                'image_caption_bbox_list': v['obj_bboxes'],
+                'image_body': v['sub_bbox'],
+                'image_caption_list': v['obj_bboxes'],
             }
             filter_idx = v['sub_idx']
             d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
-            record['image_footnote_bbox_list'] = d['obj_bboxes']
+            record['image_footnote_list'] = d['obj_bboxes']
             ret.append(record)
         return ret
 
     def get_tables_v2(self, page_no: int) -> list:
-        with_captions = self.__tie_up_category_by_distance_v2(page_no, 5, 6)
-        with_footnotes = self.__tie_up_category_by_distance_v2(page_no, 5, 7)
+        with_captions = self.__tie_up_category_by_distance_v2(
+            page_no, 5, 6, PosRelationEnum.UP
+        )
+        with_footnotes = self.__tie_up_category_by_distance_v2(
+            page_no, 5, 7, PosRelationEnum.ALL
+        )
         ret = []
         for v in with_captions:
             record = {
-                'table_bbox': v['sub_bbox'],
-                'table_caption_bbox_list': v['obj_bboxes'],
+                'table_body': v['sub_bbox'],
+                'table_caption_list': v['obj_bboxes'],
             }
             filter_idx = v['sub_idx']
             d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
-            record['table_footnote_bbox_list'] = d['obj_bboxes']
+            record['table_footnote_list'] = d['obj_bboxes']
             ret.append(record)
         return ret
 

+ 25 - 6
magic_pdf/model/pdf_extract_kit.py

@@ -1,7 +1,8 @@
 from loguru import logger
 import os
 import time
-
+from pathlib import Path
+import shutil
 from magic_pdf.libs.Constants import *
 from magic_pdf.libs.clean_memory import clean_memory
 from magic_pdf.model.model_list import AtomicModel
@@ -37,19 +38,24 @@ except ImportError as e:
 from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
 from magic_pdf.model.pek_sub_modules.post_process import latex_rm_whitespace
 from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
-from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
+# from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
 from magic_pdf.model.ppTableModel import ppTableModel
 
 
 def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
     if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
-        table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
-    else:
+        # table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
+        logger.error("StructEqTable is under upgrade, the current version does not support it.")
+        exit(1)
+    elif table_model_type == MODEL_NAME.TABLE_MASTER:
         config = {
             "model_dir": model_path,
             "device": _device_
         }
         table_model = ppTableModel(config)
+    else:
+        logger.error("table model type not allow")
+        exit(1)
     return table_model
 
 
@@ -83,7 +89,7 @@ def doclayout_yolo_model_init(weight):
     return model
 
 
-def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None, use_dilation=True, det_db_unclip_ratio=2.4):
+def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None, use_dilation=True, det_db_unclip_ratio=1.8):
     if lang is not None:
         model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
     else:
@@ -297,6 +303,17 @@ class CustomPEKModel:
                 device=self.device
             )
 
+            home_directory = Path.home()
+            det_source = os.path.join(models_dir, table_model_dir, DETECT_MODEL_DIR)
+            rec_source = os.path.join(models_dir, table_model_dir, REC_MODEL_DIR)
+            det_dest_dir = os.path.join(home_directory, PP_DET_DIRECTORY)
+            rec_dest_dir = os.path.join(home_directory, PP_REC_DIRECTORY)
+
+            if not os.path.exists(det_dest_dir):
+                shutil.copytree(det_source, det_dest_dir)
+            if not os.path.exists(rec_dest_dir):
+                shutil.copytree(rec_source, rec_dest_dir)
+
         logger.info('DocAnalysis init done!')
 
     def __call__(self, image):
@@ -314,7 +331,7 @@ class CustomPEKModel:
         elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
             # doclayout_yolo
             layout_res = []
-            doclayout_yolo_res = self.layout_model.predict(image, imgsz=1024, conf=0.15, iou=0.45, verbose=True, device=self.device)[0]
+            doclayout_yolo_res = self.layout_model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
             for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), doclayout_yolo_res.boxes.cls.cpu()):
                 xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
                 new_item = {
@@ -472,3 +489,5 @@ class CustomPEKModel:
         logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----")
 
         return layout_res
+
+

+ 8 - 1
magic_pdf/model/pek_sub_modules/structeqtable/StructTableModel.py

@@ -1,5 +1,12 @@
-from struct_eqtable.model import StructTable
+from loguru import logger
+
+try:
+    from struct_eqtable.model import StructTable
+except ImportError:
+    logger.error("StructEqTable is under upgrade, the current version does not support it.")
 from pypandoc import convert_text
+
+
 class StructTableModel:
     def __init__(self, model_path, max_new_tokens=2048, max_time=400, device = 'cpu'):
         # init

+ 102 - 32
magic_pdf/pdf_parse_union_core_v2.py

@@ -1,3 +1,4 @@
+import copy
 import os
 import statistics
 import time
@@ -15,7 +16,7 @@ from magic_pdf.libs.convert_utils import dict_to_list
 from magic_pdf.libs.drop_reason import DropReason
 from magic_pdf.libs.hash_utils import compute_md5
 from magic_pdf.libs.local_math import float_equal
-from magic_pdf.libs.ocr_content_type import ContentType
+from magic_pdf.libs.ocr_content_type import ContentType, BlockType
 from magic_pdf.model.magic_model import MagicModel
 from magic_pdf.para.para_split_v3 import para_split
 from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker
@@ -29,7 +30,7 @@ from magic_pdf.pre_proc.ocr_detect_all_bboxes import \
     ocr_prepare_bboxes_for_layout_split_v2
 from magic_pdf.pre_proc.ocr_dict_merge import (fill_spans_in_blocks,
                                                fix_block_spans,
-                                               fix_discarded_block)
+                                               fix_discarded_block, fix_block_spans_v2)
 from magic_pdf.pre_proc.ocr_span_list_modify import (
     get_qa_need_list_v2, remove_overlaps_low_confidence_spans,
     remove_overlaps_min_spans)
@@ -173,19 +174,6 @@ def do_predict(boxes: List[List[int]], model) -> List[int]:
 
 def cal_block_index(fix_blocks, sorted_bboxes):
     for block in fix_blocks:
-        # if block['type'] in ['text', 'title', 'interline_equation']:
-        #     line_index_list = []
-        #     if len(block['lines']) == 0:
-        #         block['index'] = sorted_bboxes.index(block['bbox'])
-        #     else:
-        #         for line in block['lines']:
-        #             line['index'] = sorted_bboxes.index(line['bbox'])
-        #             line_index_list.append(line['index'])
-        #         median_value = statistics.median(line_index_list)
-        #         block['index'] = median_value
-        #
-        # elif block['type'] in ['table', 'image']:
-        #     block['index'] = sorted_bboxes.index(block['bbox'])
 
         line_index_list = []
         if len(block['lines']) == 0:
@@ -197,9 +185,11 @@ def cal_block_index(fix_blocks, sorted_bboxes):
             median_value = statistics.median(line_index_list)
             block['index'] = median_value
 
-        # 删除图表block中的虚拟line信息
-        if block['type'] in ['table', 'image']:
-            del block['lines']
+        # 删除图表body block中的虚拟line信息, 并用real_lines信息回填
+        if block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
+            block['virtual_lines'] = copy.deepcopy(block['lines'])
+            block['lines'] = copy.deepcopy(block['real_lines'])
+            del block['real_lines']
 
     return fix_blocks
 
@@ -218,13 +208,12 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
         ):  # 可能是双列结构,可以切细点
             lines = int(block_height / line_height) + 1
         else:
-            # 如果block的宽度超过0.4页面宽度,则将block分成3行
+            # 如果block的宽度超过0.4页面宽度,则将block分成3行(是一种复杂布局,图不能切的太细)
             if block_weight > page_w * 0.4:
                 line_height = (y1 - y0) / 3
                 lines = 3
-            elif block_weight > page_w * 0.25:  # 否则将block分成两行
-                line_height = (y1 - y0) / 2
-                lines = 2
+            elif block_weight > page_w * 0.25:  # (可能是三列结构,也切细点)
+                lines = int(block_height / line_height) + 1
             else:  # 判断长宽比
                 if block_height / block_weight > 1.2:  # 细长的不分
                     return [[x0, y0, x1, y1]]
@@ -250,7 +239,11 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
 def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
     page_line_list = []
     for block in fix_blocks:
-        if block['type'] in ['text', 'title', 'interline_equation']:
+        if block['type'] in [
+            BlockType.Text, BlockType.Title, BlockType.InterlineEquation,
+            BlockType.ImageCaption, BlockType.ImageFootnote,
+            BlockType.TableCaption, BlockType.TableFootnote
+        ]:
             if len(block['lines']) == 0:
                 bbox = block['bbox']
                 lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
@@ -261,8 +254,9 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
                 for line in block['lines']:
                     bbox = line['bbox']
                     page_line_list.append(bbox)
-        elif block['type'] in ['table', 'image']:
+        elif block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
             bbox = block['bbox']
+            block["real_lines"] = copy.deepcopy(block['lines'])
             lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
             block['lines'] = []
             for line in lines:
@@ -316,7 +310,11 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
 def get_line_height(blocks):
     page_line_height_list = []
     for block in blocks:
-        if block['type'] in ['text', 'title', 'interline_equation']:
+        if block['type'] in [
+            BlockType.Text, BlockType.Title,
+            BlockType.ImageCaption, BlockType.ImageFootnote,
+            BlockType.TableCaption, BlockType.TableFootnote
+        ]:
             for line in block['lines']:
                 bbox = line['bbox']
                 page_line_height_list.append(int(bbox[3] - bbox[1]))
@@ -326,6 +324,63 @@ def get_line_height(blocks):
         return 10
 
 
+def process_groups(groups, body_key, caption_key, footnote_key):
+    body_blocks = []
+    caption_blocks = []
+    footnote_blocks = []
+    for i, group in enumerate(groups):
+        group[body_key]['group_id'] = i
+        body_blocks.append(group[body_key])
+        for caption_block in group[caption_key]:
+            caption_block['group_id'] = i
+            caption_blocks.append(caption_block)
+        for footnote_block in group[footnote_key]:
+            footnote_block['group_id'] = i
+            footnote_blocks.append(footnote_block)
+    return body_blocks, caption_blocks, footnote_blocks
+
+
+def process_block_list(blocks, body_type, block_type):
+    indices = [block['index'] for block in blocks]
+    median_index = statistics.median(indices)
+
+    body_bbox = next((block['bbox'] for block in blocks if block.get('type') == body_type), [])
+
+    return {
+        'type': block_type,
+        'bbox': body_bbox,
+        'blocks': blocks,
+        'index': median_index,
+    }
+
+
+def revert_group_blocks(blocks):
+    image_groups = {}
+    table_groups = {}
+    new_blocks = []
+    for block in blocks:
+        if block['type'] in [BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote]:
+            group_id = block['group_id']
+            if group_id not in image_groups:
+                image_groups[group_id] = []
+            image_groups[group_id].append(block)
+        elif block['type'] in [BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote]:
+            group_id = block['group_id']
+            if group_id not in table_groups:
+                table_groups[group_id] = []
+            table_groups[group_id].append(block)
+        else:
+            new_blocks.append(block)
+
+    for group_id, blocks in image_groups.items():
+        new_blocks.append(process_block_list(blocks, BlockType.ImageBody, BlockType.Image))
+
+    for group_id, blocks in table_groups.items():
+        new_blocks.append(process_block_list(blocks, BlockType.TableBody, BlockType.Table))
+
+    return new_blocks
+
+
 def parse_page_core(
     page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
 ):
@@ -333,8 +388,20 @@ def parse_page_core(
     drop_reason = []
 
     """从magic_model对象中获取后面会用到的区块信息"""
-    img_blocks = magic_model.get_imgs(page_id)
-    table_blocks = magic_model.get_tables(page_id)
+    # img_blocks = magic_model.get_imgs(page_id)
+    # table_blocks = magic_model.get_tables(page_id)
+
+    img_groups = magic_model.get_imgs_v2(page_id)
+    table_groups = magic_model.get_tables_v2(page_id)
+
+    img_body_blocks, img_caption_blocks, img_footnote_blocks = process_groups(
+        img_groups, 'image_body', 'image_caption_list', 'image_footnote_list'
+    )
+
+    table_body_blocks, table_caption_blocks, table_footnote_blocks = process_groups(
+        table_groups, 'table_body', 'table_caption_list', 'table_footnote_list'
+    )
+
     discarded_blocks = magic_model.get_discarded(page_id)
     text_blocks = magic_model.get_text_blocks(page_id)
     title_blocks = magic_model.get_title_blocks(page_id)
@@ -370,8 +437,8 @@ def parse_page_core(
     interline_equation_blocks = []
     if len(interline_equation_blocks) > 0:
         all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2(
-            img_blocks,
-            table_blocks,
+            img_body_blocks, img_caption_blocks, img_footnote_blocks,
+            table_body_blocks, table_caption_blocks, table_footnote_blocks,
             discarded_blocks,
             text_blocks,
             title_blocks,
@@ -381,8 +448,8 @@ def parse_page_core(
         )
     else:
         all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2(
-            img_blocks,
-            table_blocks,
+            img_body_blocks, img_caption_blocks, img_footnote_blocks,
+            table_body_blocks, table_caption_blocks, table_footnote_blocks,
             discarded_blocks,
             text_blocks,
             title_blocks,
@@ -419,7 +486,7 @@ def parse_page_core(
     block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.5)
 
     """对block进行fix操作"""
-    fix_blocks = fix_block_spans(block_with_spans, img_blocks, table_blocks)
+    fix_blocks = fix_block_spans_v2(block_with_spans)
 
     """获取所有line并计算正文line的高度"""
     line_height = get_line_height(fix_blocks)
@@ -430,6 +497,9 @@ def parse_page_core(
     """根据line的中位数算block的序列关系"""
     fix_blocks = cal_block_index(fix_blocks, sorted_bboxes)
 
+    """将image和table的block还原回group形式参与后续流程"""
+    fix_blocks = revert_group_blocks(fix_blocks)
+
     """重排block"""
     sorted_blocks = sorted(fix_blocks, key=lambda b: b['index'])
 

+ 31 - 24
magic_pdf/pre_proc/ocr_detect_all_bboxes.py

@@ -60,29 +60,34 @@ def ocr_prepare_bboxes_for_layout_split(img_blocks, table_blocks, discarded_bloc
     return all_bboxes, all_discarded_blocks, drop_reasons
 
 
-def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_blocks, text_blocks,
-                                        title_blocks, interline_equation_blocks, page_w, page_h):
+def add_bboxes(blocks, block_type, bboxes):
+    for block in blocks:
+        x0, y0, x1, y1 = block['bbox']
+        if block_type in [
+            BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote,
+            BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote
+        ]:
+            bboxes.append([x0, y0, x1, y1, None, None, None, block_type, None, None, None, None, block["score"], block["group_id"]])
+        else:
+            bboxes.append([x0, y0, x1, y1, None, None, None, block_type, None, None, None, None, block["score"]])
+
+
+def ocr_prepare_bboxes_for_layout_split_v2(
+        img_body_blocks, img_caption_blocks, img_footnote_blocks,
+        table_body_blocks, table_caption_blocks, table_footnote_blocks,
+        discarded_blocks, text_blocks, title_blocks, interline_equation_blocks, page_w, page_h
+):
     all_bboxes = []
-    all_discarded_blocks = []
-    for image in img_blocks:
-        x0, y0, x1, y1 = image['bbox']
-        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Image, None, None, None, None, image["score"]])
 
-    for table in table_blocks:
-        x0, y0, x1, y1 = table['bbox']
-        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Table, None, None, None, None, table["score"]])
-
-    for text in text_blocks:
-        x0, y0, x1, y1 = text['bbox']
-        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Text, None, None, None, None, text["score"]])
-
-    for title in title_blocks:
-        x0, y0, x1, y1 = title['bbox']
-        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Title, None, None, None, None, title["score"]])
-
-    for interline_equation in interline_equation_blocks:
-        x0, y0, x1, y1 = interline_equation['bbox']
-        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.InterlineEquation, None, None, None, None, interline_equation["score"]])
+    add_bboxes(img_body_blocks, BlockType.ImageBody, all_bboxes)
+    add_bboxes(img_caption_blocks, BlockType.ImageCaption, all_bboxes)
+    add_bboxes(img_footnote_blocks, BlockType.ImageFootnote, all_bboxes)
+    add_bboxes(table_body_blocks, BlockType.TableBody, all_bboxes)
+    add_bboxes(table_caption_blocks, BlockType.TableCaption, all_bboxes)
+    add_bboxes(table_footnote_blocks, BlockType.TableFootnote, all_bboxes)
+    add_bboxes(text_blocks, BlockType.Text, all_bboxes)
+    add_bboxes(title_blocks, BlockType.Title, all_bboxes)
+    add_bboxes(interline_equation_blocks, BlockType.InterlineEquation, all_bboxes)
 
     '''block嵌套问题解决'''
     '''文本框与标题框重叠,优先信任文本框'''
@@ -96,12 +101,14 @@ def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_b
     '''interline_equation框被包含在文本类型框内,且interline_equation比文本区块小很多时信任文本框,这时需要舍弃公式框'''
     # 通过后续大框套小框逻辑删除
 
-    '''discarded_blocks中只保留宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的(限定footnote)'''
+    '''discarded_blocks'''
+    all_discarded_blocks = []
+    add_bboxes(discarded_blocks, BlockType.Discarded, all_discarded_blocks)
+
+    '''footnote识别:宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的'''
     footnote_blocks = []
     for discarded in discarded_blocks:
         x0, y0, x1, y1 = discarded['bbox']
-        all_discarded_blocks.append([x0, y0, x1, y1, None, None, None, BlockType.Discarded, None, None, None, None, discarded["score"]])
-        # 将footnote加入到all_bboxes中,用来计算layout
         if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2):
             footnote_blocks.append([x0, y0, x1, y1])
 

+ 27 - 1
magic_pdf/pre_proc/ocr_dict_merge.py

@@ -49,7 +49,7 @@ def merge_spans_to_line(spans):
                 continue
 
             # 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
-            if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], 0.6):
+            if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], 0.5):
                 current_line.append(span)
             else:
                 # 否则,开始新行
@@ -153,6 +153,11 @@ def fill_spans_in_blocks(blocks, spans, radio):
             'type': block_type,
             'bbox': block_bbox,
         }
+        if block_type in [
+            BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote,
+            BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote
+        ]:
+            block_dict["group_id"] = block[-1]
         block_spans = []
         for span in spans:
             span_bbox = span['bbox']
@@ -201,6 +206,27 @@ def fix_block_spans(block_with_spans, img_blocks, table_blocks):
     return fix_blocks
 
 
+def fix_block_spans_v2(block_with_spans):
+    """1、img_block和table_block因为包含caption和footnote的关系,存在block的嵌套关系
+    需要将caption和footnote的text_span放入相应img_block和table_block内的
+    caption_block和footnote_block中 2、同时需要删除block中的spans字段."""
+    fix_blocks = []
+    for block in block_with_spans:
+        block_type = block['type']
+
+        if block_type in [BlockType.Text, BlockType.Title,
+                          BlockType.ImageCaption, BlockType.ImageFootnote,
+                          BlockType.TableCaption, BlockType.TableFootnote
+                          ]:
+            block = fix_text_block(block)
+        elif block_type in [BlockType.InterlineEquation, BlockType.ImageBody, BlockType.TableBody]:
+            block = fix_interline_block(block)
+        else:
+            continue
+        fix_blocks.append(block)
+    return fix_blocks
+
+
 def fix_discarded_block(discarded_block_with_spans):
     fix_discarded_blocks = []
     for block in discarded_block_with_spans:

+ 1 - 1
magic_pdf/tools/cli.py

@@ -52,7 +52,7 @@ without method specified, auto will be used by default.""",
     help="""
     Input the languages in the pdf (if known) to improve OCR accuracy.  Optional.
     You should input "Abbreviation" with language form url:
-    https://paddlepaddle.github.io/PaddleOCR/en/ppocr/blog/multi_languages.html#5-support-languages-and-abbreviations
+    https://paddlepaddle.github.io/PaddleOCR/latest/en/ppocr/blog/multi_languages.html#5-support-languages-and-abbreviations
     """,
     default=None,
 )

+ 0 - 0
docs/en/api/schemas.rst → magic_pdf/utils/__init__.py


+ 0 - 0
docs/en/.readthedocs.yaml → next_docs/en/.readthedocs.yaml


+ 0 - 0
docs/en/Makefile → next_docs/en/Makefile


+ 0 - 0
docs/en/_static/image/logo.png → next_docs/en/_static/image/logo.png


+ 0 - 0
docs/en/api.rst → next_docs/en/api.rst


+ 0 - 0
docs/en/api/data_reader_writer.rst → next_docs/en/api/data_reader_writer.rst


+ 0 - 0
docs/en/api/dataset.rst → next_docs/en/api/dataset.rst


+ 0 - 0
next_docs/en/api/io.rst


+ 0 - 0
docs/en/api/read_api.rst → next_docs/en/api/read_api.rst


+ 0 - 0
next_docs/en/api/schemas.rst


+ 0 - 0
docs/en/api/utils.rst → next_docs/en/api/utils.rst


+ 0 - 0
docs/en/conf.py → next_docs/en/conf.py


+ 0 - 0
docs/en/index.rst → next_docs/en/index.rst


+ 0 - 0
docs/en/make.bat → next_docs/en/make.bat


+ 0 - 0
docs/requirements.txt → next_docs/requirements.txt


+ 0 - 0
docs/zh_cn/.readthedocs.yaml → next_docs/zh_cn/.readthedocs.yaml


+ 0 - 0
docs/zh_cn/Makefile → next_docs/zh_cn/Makefile


+ 0 - 0
docs/zh_cn/_static/image/logo.png → next_docs/zh_cn/_static/image/logo.png


+ 0 - 0
docs/zh_cn/conf.py → next_docs/zh_cn/conf.py


+ 0 - 0
docs/zh_cn/index.rst → next_docs/zh_cn/index.rst


+ 0 - 0
docs/zh_cn/make.bat → next_docs/zh_cn/make.bat


+ 1 - 2
projects/README.md

@@ -6,5 +6,4 @@
 - [gradio_app](./gradio_app/README.md): Build a web app based on gradio
 - [web_demo](./web_demo/README.md): MinerU online [demo](https://opendatalab.com/OpenSourceTools/Extractor/PDF/) localized deployment version
 - [web_api](./web_api/README.md): Web API Based on FastAPI
-
-
+- [multi_gpu](./multi_gpu/README.md): Multi-GPU parallel processing based on LitServe

+ 1 - 1
projects/README_zh-CN.md

@@ -6,4 +6,4 @@
 - [gradio_app](./gradio_app/README_zh-CN.md): 基于 Gradio 的 Web 应用
 - [web_demo](./web_demo/README_zh-CN.md): MinerU在线[demo](https://opendatalab.com/OpenSourceTools/Extractor/PDF/)本地化部署版本
 - [web_api](./web_api/README.md): 基于 FastAPI 的 Web API
-
+- [multi_gpu](./multi_gpu/README.md): 基于 LitServe 的多 GPU 并行处理

+ 46 - 0
projects/multi_gpu/README.md

@@ -0,0 +1,46 @@
+## 项目简介
+本项目提供基于 LitServe 的多 GPU 并行处理方案。LitServe 是一个简便且灵活的 AI 模型服务引擎,基于 FastAPI 构建。它为 FastAPI 增强了批处理、流式传输和 GPU 自动扩展等功能,无需为每个模型单独重建 FastAPI 服务器。
+
+## 环境配置
+请使用以下命令配置所需的环境:
+```bash
+pip install -U litserve python-multipart filetype
+pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com
+pip install paddlepaddle-gpu==3.0.0b1 -i https://www.paddlepaddle.org.cn/packages/stable/cu118
+```
+
+## 快速使用
+### 1. 启动服务端
+以下示例展示了如何启动服务端,支持自定义设置:
+```python
+server = ls.LitServer(
+    MinerUAPI(output_dir='/tmp'),  # 可自定义输出文件夹
+    accelerator='cuda',  # 启用 GPU 加速
+    devices='auto',  # "auto" 使用所有 GPU
+    workers_per_device=1,  # 每个 GPU 启动一个服务实例
+    timeout=False  # 设置为 False 以禁用超时
+)
+server.run(port=8000)  # 设定服务端口为 8000
+```
+
+启动服务端命令:
+```bash
+python server.py
+```
+
+### 2. 启动客户端
+以下代码展示了客户端的使用方式,可根据需求修改配置:
+```python
+files = ['demo/small_ocr.pdf']  # 替换为文件路径,支持 jpg/jpeg、png、pdf 文件
+n_jobs = np.clip(len(files), 1, 8)  # 设置并发线程数,此处最大为 8,可根据自身修改
+results = Parallel(n_jobs, prefer='threads', verbose=10)(
+    delayed(do_parse)(p) for p in files
+)
+print(results)
+```
+
+启动客户端命令:
+```bash
+python client.py
+```
+好了,你的文件会自动在多个 GPU 上并行处理!🍻🍻🍻

+ 39 - 0
projects/multi_gpu/client.py

@@ -0,0 +1,39 @@
+import base64
+import requests
+import numpy as np
+from loguru import logger
+from joblib import Parallel, delayed
+
+
+def to_b64(file_path):
+    try:
+        with open(file_path, 'rb') as f:
+            return base64.b64encode(f.read()).decode('utf-8')
+    except Exception as e:
+        raise Exception(f'File: {file_path} - Info: {e}')
+
+
+def do_parse(file_path, url='http://127.0.0.1:8000/predict', **kwargs):
+    try:
+        response = requests.post(url, json={
+            'file': to_b64(file_path),
+            'kwargs': kwargs
+        })
+
+        if response.status_code == 200:
+            output = response.json()
+            output['file_path'] = file_path
+            return output
+        else:
+            raise Exception(response.text)
+    except Exception as e:
+        logger.error(f'File: {file_path} - Info: {e}')
+
+
+if __name__ == '__main__':
+    files = ['small_ocr.pdf']
+    n_jobs = np.clip(len(files), 1, 8)
+    results = Parallel(n_jobs, prefer='threads', verbose=10)(
+        delayed(do_parse)(p) for p in files
+    )
+    print(results)

+ 74 - 0
projects/multi_gpu/server.py

@@ -0,0 +1,74 @@
+import os
+import fitz
+import torch
+import base64
+import litserve as ls
+from uuid import uuid4
+from fastapi import HTTPException
+from filetype import guess_extension
+from magic_pdf.tools.common import do_parse
+from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
+
+
+class MinerUAPI(ls.LitAPI):
+    def __init__(self, output_dir='/tmp'):
+        self.output_dir = output_dir
+
+    def setup(self, device):
+        if device.startswith('cuda'):
+            os.environ['CUDA_VISIBLE_DEVICES'] = device.split(':')[-1]
+            if torch.cuda.device_count() > 1:
+                raise RuntimeError("Remove any CUDA actions before setting 'CUDA_VISIBLE_DEVICES'.")
+
+        model_manager = ModelSingleton()
+        model_manager.get_model(True, False)
+        model_manager.get_model(False, False)
+        print(f'Model initialization complete on {device}!')
+
+    def decode_request(self, request):
+        file = request['file']
+        file = self.to_pdf(file)
+        opts = request.get('kwargs', {})
+        opts.setdefault('debug_able', False)
+        opts.setdefault('parse_method', 'auto')
+        return file, opts
+
+    def predict(self, inputs):
+        try:
+            do_parse(self.output_dir, pdf_name := str(uuid4()), inputs[0], [], **inputs[1])
+            return pdf_name
+        except Exception as e:
+            raise HTTPException(status_code=500, detail=str(e))
+        finally:
+            self.clean_memory()
+
+    def encode_response(self, response):
+        return {'output_dir': response}
+
+    def clean_memory(self):
+        import gc
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()
+            torch.cuda.ipc_collect()
+        gc.collect()
+
+    def to_pdf(self, file_base64):
+        try:
+            file_bytes = base64.b64decode(file_base64)
+            file_ext = guess_extension(file_bytes)
+            with fitz.open(stream=file_bytes, filetype=file_ext) as f:
+                if f.is_pdf: return f.tobytes()
+                return f.convert_to_pdf()
+        except Exception as e:
+            raise HTTPException(status_code=500, detail=str(e))
+
+
+if __name__ == '__main__':
+    server = ls.LitServer(
+        MinerUAPI(output_dir='/tmp'),
+        accelerator='cuda',
+        devices='auto',
+        workers_per_device=1,
+        timeout=False
+    )
+    server.run(port=8000)

BIN
projects/multi_gpu/small_ocr.pdf


Някои файлове не бяха показани, защото твърде много файлове са промени