Ver código fonte

Merge branch 'dev' into dev-table-model-update

liukaiwen 1 ano atrás
pai
commit
7d2dfc8091
100 arquivos alterados com 3333 adições e 845 exclusões
  1. 50 45
      .gitignore
  2. 3 2
      .pre-commit-config.yaml
  3. 32 16
      README.md
  4. 31 14
      README_zh-CN.md
  5. 2 15
      demo/demo.py
  6. 14 9
      demo/magic_pdf_parse_main.py
  7. 0 0
      docs/FAQ_en_us.md
  8. 0 0
      docs/FAQ_zh_cn.md
  9. 2 2
      docs/README_Ubuntu_CUDA_Acceleration_en_US.md
  10. 3 2
      docs/README_Ubuntu_CUDA_Acceleration_zh_CN.md
  11. 0 2
      docs/README_Windows_CUDA_Acceleration_en_US.md
  12. 0 2
      docs/README_Windows_CUDA_Acceleration_zh_CN.md
  13. 0 0
      docs/chemical_knowledge_introduction/introduction.pdf
  14. 0 0
      docs/chemical_knowledge_introduction/introduction.xmind
  15. 21 8
      docs/download_models.py
  16. 29 9
      docs/download_models_hf.py
  17. 3 1
      docs/how_to_download_models_en.md
  18. 4 8
      docs/how_to_download_models_zh_cn.md
  19. 0 0
      docs/images/MinerU-logo-hq.png
  20. 0 0
      docs/images/MinerU-logo.png
  21. 0 0
      docs/images/datalab_logo.png
  22. 0 0
      docs/images/flowchart_en.png
  23. 0 0
      docs/images/flowchart_zh_cn.png
  24. 0 0
      docs/images/layout_example.png
  25. 0 0
      docs/images/poly.png
  26. 0 0
      docs/images/project_panorama_en.png
  27. 0 0
      docs/images/project_panorama_zh_cn.png
  28. 0 0
      docs/images/spans_example.png
  29. 0 0
      docs/images/web_demo_1.png
  30. 0 0
      docs/output_file_en_us.md
  31. 0 0
      docs/output_file_zh_cn.md
  32. 12 3
      magic-pdf.template.json
  33. 0 0
      magic_pdf/config/__init__.py
  34. 7 0
      magic_pdf/config/enums.py
  35. 32 0
      magic_pdf/config/exceptions.py
  36. 0 0
      magic_pdf/data/__init__.py
  37. 12 0
      magic_pdf/data/data_reader_writer/__init__.py
  38. 51 0
      magic_pdf/data/data_reader_writer/base.py
  39. 59 0
      magic_pdf/data/data_reader_writer/filebase.py
  40. 137 0
      magic_pdf/data/data_reader_writer/multi_bucket_s3.py
  41. 69 0
      magic_pdf/data/data_reader_writer/s3.py
  42. 194 0
      magic_pdf/data/dataset.py
  43. 0 0
      magic_pdf/data/io/__init__.py
  44. 42 0
      magic_pdf/data/io/base.py
  45. 37 0
      magic_pdf/data/io/http.py
  46. 114 0
      magic_pdf/data/io/s3.py
  47. 95 0
      magic_pdf/data/read_api.py
  48. 15 0
      magic_pdf/data/schemas.py
  49. 32 0
      magic_pdf/data/utils.py
  50. 29 224
      magic_pdf/dict2md/ocr_mkcontent.py
  51. 13 6
      magic_pdf/libs/Constants.py
  52. 35 0
      magic_pdf/libs/boxbase.py
  53. 44 26
      magic_pdf/libs/config_reader.py
  54. 65 46
      magic_pdf/libs/draw_bbox.py
  55. 38 14
      magic_pdf/model/doc_analyze_by_custom_model.py
  56. 259 14
      magic_pdf/model/magic_model.py
  57. 899 0
      magic_pdf/model/mfr_cudagraph.py
  58. 91 46
      magic_pdf/model/pdf_extract_kit.py
  59. 2 2
      magic_pdf/model/ppTableModel.py
  60. 122 82
      magic_pdf/para/para_split_v3.py
  61. 5 2
      magic_pdf/pdf_parse_by_ocr.py
  62. 5 2
      magic_pdf/pdf_parse_by_txt.py
  63. 300 156
      magic_pdf/pdf_parse_union_core_v2.py
  64. 6 7
      magic_pdf/pipe/AbsPipe.py
  65. 8 4
      magic_pdf/pipe/OCRPipe.py
  66. 8 4
      magic_pdf/pipe/TXTPipe.py
  67. 10 5
      magic_pdf/pipe/UNIPipe.py
  68. 55 26
      magic_pdf/pre_proc/ocr_detect_all_bboxes.py
  69. 27 1
      magic_pdf/pre_proc/ocr_dict_merge.py
  70. 5 13
      magic_pdf/resources/model_config/model_configs.yaml
  71. 1 1
      magic_pdf/tools/cli.py
  72. 11 6
      magic_pdf/tools/common.py
  73. 13 5
      magic_pdf/user_api.py
  74. 0 0
      magic_pdf/utils/__init__.py
  75. 11 0
      magic_pdf/utils/annotations.py
  76. 0 0
      next_docs/en/.readthedocs.yaml
  77. 0 0
      next_docs/en/Makefile
  78. 0 0
      next_docs/en/_static/image/logo.png
  79. 9 0
      next_docs/en/api.rst
  80. 44 0
      next_docs/en/api/data_reader_writer.rst
  81. 22 0
      next_docs/en/api/dataset.rst
  82. 0 0
      next_docs/en/api/io.rst
  83. 6 0
      next_docs/en/api/read_api.rst
  84. 0 0
      next_docs/en/api/schemas.rst
  85. 1 0
      next_docs/en/api/utils.rst
  86. 0 0
      next_docs/en/conf.py
  87. 12 0
      next_docs/en/index.rst
  88. 0 0
      next_docs/en/make.bat
  89. 5 0
      next_docs/requirements.txt
  90. 0 0
      next_docs/zh_cn/.readthedocs.yaml
  91. 0 0
      next_docs/zh_cn/Makefile
  92. 0 0
      next_docs/zh_cn/_static/image/logo.png
  93. 0 0
      next_docs/zh_cn/conf.py
  94. 0 0
      next_docs/zh_cn/index.rst
  95. 0 0
      next_docs/zh_cn/make.bat
  96. 1 2
      projects/README.md
  97. 1 1
      projects/README_zh-CN.md
  98. 68 12
      projects/gradio_app/app.py
  99. BIN
      projects/gradio_app/examples/2list_1table.pdf
  100. BIN
      projects/gradio_app/examples/3list_1table.pdf

+ 50 - 45
.gitignore

@@ -1,45 +1,50 @@
-*.tar
-*.tar.gz
-*.zip
-venv*/
-envs/
-slurm_logs/
-
-sync1.sh
-data_preprocess_pj1
-data-preparation1
-__pycache__
-*.log
-*.pyc
-.vscode
-debug/
-*.ipynb
-.idea
-
-# vscode history
-.history
-
-.DS_Store
-.env
-
-bad_words/
-bak/
-
-app/tests/*
-temp/
-tmp/
-tmp
-.vscode
-.vscode/
-ocr_demo
-.coveragerc
-/app/common/__init__.py
-/magic_pdf/config/__init__.py
-source.dev.env
-
-tmp
-
-projects/web/node_modules
-projects/web/dist
-
-projects/web_demo/web_demo/static/
+*.tar
+*.tar.gz
+*.zip
+venv*/
+envs/
+slurm_logs/
+
+sync1.sh
+data_preprocess_pj1
+data-preparation1
+__pycache__
+*.log
+*.pyc
+.vscode
+debug/
+*.ipynb
+.idea
+
+# vscode history
+.history
+
+.DS_Store
+.env
+
+bad_words/
+bak/
+
+app/tests/*
+temp/
+tmp/
+tmp
+.vscode
+.vscode/
+ocr_demo
+.coveragerc
+/app/common/__init__.py
+/magic_pdf/config/__init__.py
+source.dev.env
+
+tmp
+
+projects/web/node_modules
+projects/web/dist
+
+projects/web_demo/web_demo/static/
+cli_debug/
+debug_utils/
+
+# sphinx docs
+_build/

+ 3 - 2
.pre-commit-config.yaml

@@ -3,7 +3,7 @@ repos:
     rev: 5.0.4
     hooks:
       - id: flake8
-        args: ["--max-line-length=120", "--ignore=E131,E125,W503,W504,E203"]
+        args: ["--max-line-length=150", "--ignore=E131,E125,W503,W504,E203"]
   - repo: https://github.com/PyCQA/isort
     rev: 5.11.5
     hooks:
@@ -12,11 +12,12 @@ repos:
     rev: v0.32.0
     hooks:
       - id: yapf
-        args: ["--style={based_on_style: google, column_limit: 120, indent_width: 4}"]
+        args: ["--style={based_on_style: google, column_limit: 150, indent_width: 4}"]
   - repo: https://github.com/codespell-project/codespell
     rev: v2.2.1
     hooks:
       - id: codespell
+        args: ['--skip', '*.json']
   - repo: https://github.com/pre-commit/pre-commit-hooks
     rev: v4.3.0
     hooks:

Diferenças do arquivo suprimidas por serem muito extensas
+ 32 - 16
README.md


Diferenças do arquivo suprimidas por serem muito extensas
+ 31 - 14
README_zh-CN.md


+ 2 - 15
demo/demo.py

@@ -1,35 +1,22 @@
 import os
-import json
 
 from loguru import logger
-
 from magic_pdf.pipe.UNIPipe import UNIPipe
 from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
 
-import magic_pdf.model as model_config 
-model_config.__use_inside_model__ = True
 
 try:
     current_script_dir = os.path.dirname(os.path.abspath(__file__))
     demo_name = "demo1"
     pdf_path = os.path.join(current_script_dir, f"{demo_name}.pdf")
-    model_path = os.path.join(current_script_dir, f"{demo_name}.json")
     pdf_bytes = open(pdf_path, "rb").read()
-    # model_json = json.loads(open(model_path, "r", encoding="utf-8").read())
-    model_json = []  # model_json传空list使用内置模型解析
-    jso_useful_key = {"_pdf_type": "", "model_list": model_json}
+    jso_useful_key = {"_pdf_type": "", "model_list": []}
     local_image_dir = os.path.join(current_script_dir, 'images')
     image_dir = str(os.path.basename(local_image_dir))
     image_writer = DiskReaderWriter(local_image_dir)
     pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer)
     pipe.pipe_classify()
-    """如果没有传入有效的模型数据,则使用内置model解析"""
-    if len(model_json) == 0:
-        if model_config.__use_inside_model__:
-            pipe.pipe_analyze()
-        else:
-            logger.error("need model list input")
-            exit(1)
+    pipe.pipe_analyze()
     pipe.pipe_parse()
     md_content = pipe.pipe_mk_markdown(image_dir, drop_mode="none")
     with open(f"{demo_name}.md", "w", encoding="utf-8") as f:

+ 14 - 9
demo/magic_pdf_parse_main.py

@@ -4,13 +4,12 @@ import copy
 
 from loguru import logger
 
+from magic_pdf.libs.draw_bbox import draw_layout_bbox, draw_span_bbox
 from magic_pdf.pipe.UNIPipe import UNIPipe
 from magic_pdf.pipe.OCRPipe import OCRPipe
 from magic_pdf.pipe.TXTPipe import TXTPipe
 from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
-import magic_pdf.model as model_config
 
-model_config.__use_inside_model__ = True
 
 # todo: 设备类型选择 (?)
 
@@ -47,11 +46,20 @@ def json_md_dump(
     )
 
 
+# 可视化
+def draw_visualization_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name):
+    # 画布局框,附带排序结果
+    draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
+    # 画 span 框
+    draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
+
+
 def pdf_parse_main(
         pdf_path: str,
         parse_method: str = 'auto',
         model_json_path: str = None,
         is_json_md_dump: bool = True,
+        is_draw_visualization_bbox: bool = True,
         output_dir: str = None
 ):
     """
@@ -108,11 +116,7 @@ def pdf_parse_main(
 
         # 如果没有传入模型数据,则使用内置模型解析
         if not model_json:
-            if model_config.__use_inside_model__:
-                pipe.pipe_analyze()  # 解析
-            else:
-                logger.error("need model list input")
-                exit(1)
+            pipe.pipe_analyze()  # 解析
 
         # 执行解析
         pipe.pipe_parse()
@@ -121,10 +125,11 @@ def pdf_parse_main(
         content_list = pipe.pipe_mk_uni_format(image_path_parent, drop_mode="none")
         md_content = pipe.pipe_mk_markdown(image_path_parent, drop_mode="none")
 
-
         if is_json_md_dump:
             json_md_dump(pipe, md_writer, pdf_name, content_list, md_content)
 
+        if is_draw_visualization_bbox:
+            draw_visualization_bbox(pipe.pdf_mid_data['pdf_info'], pdf_bytes, output_path, pdf_name)
 
     except Exception as e:
         logger.exception(e)
@@ -132,5 +137,5 @@ def pdf_parse_main(
 
 # 测试
 if __name__ == '__main__':
-    pdf_path = r"C:\Users\XYTK2\Desktop\2024-2016-gb-cd-300.pdf"
+    pdf_path = r"D:\project\20240617magicpdf\Magic-PDF\demo\demo1.pdf"
     pdf_parse_main(pdf_path)

+ 0 - 0
old_docs/FAQ_en_us.md → docs/FAQ_en_us.md


+ 0 - 0
old_docs/FAQ_zh_cn.md → docs/FAQ_zh_cn.md


+ 2 - 2
old_docs/README_Ubuntu_CUDA_Acceleration_en_US.md → docs/README_Ubuntu_CUDA_Acceleration_en_US.md

@@ -8,6 +8,8 @@ nvidia-smi
 
 If you see information similar to the following, it means that the NVIDIA drivers are already installed, and you can skip Step 2.
 
+Notice:`CUDA Version` should be >= 12.1, If the displayed version number is less than 12.1, please upgrade the driver.
+
 ```plaintext
 +---------------------------------------------------------------------------------------+
 | NVIDIA-SMI 537.34                 Driver Version: 537.34       CUDA Version: 12.2     |
@@ -95,8 +97,6 @@ magic-pdf -p small_ocr.pdf
 
 If your graphics card has at least **8GB** of VRAM, follow these steps to test CUDA acceleration:
 
-> ❗ Due to the extremely limited nature of 8GB VRAM for running this application, you need to close all other programs using VRAM to ensure that 8GB of VRAM is available when running this application.
-
 1. Modify the value of `"device-mode"` in the `magic-pdf.json` configuration file located in your home directory.
    ```json
    {

+ 3 - 2
old_docs/README_Ubuntu_CUDA_Acceleration_zh_CN.md → docs/README_Ubuntu_CUDA_Acceleration_zh_CN.md

@@ -8,6 +8,9 @@ nvidia-smi
 
 如果看到类似如下的信息,说明已经安装了nvidia驱动,可以跳过步骤2
 
+注意:`CUDA Version` 显示的版本号应 >= 12.1,如显示的版本号小于12.1,请升级驱动
+
+```plaintext
 ```
 +---------------------------------------------------------------------------------------+
 | NVIDIA-SMI 537.34                 Driver Version: 537.34       CUDA Version: 12.2     |
@@ -95,8 +98,6 @@ magic-pdf -p small_ocr.pdf
 
 如果您的显卡显存大于等于 **8GB** ,可以进行以下流程,测试CUDA解析加速效果
 
-> ❗️因8GB显存运行本应用非常极限,需要关闭所有其他正在使用显存的程序以确保本应用运行时有足额8GB显存可用。
-
 **1.修改【用户目录】中配置文件magic-pdf.json中"device-mode"的值**
 
 ```json

+ 0 - 2
old_docs/README_Windows_CUDA_Acceleration_en_US.md → docs/README_Windows_CUDA_Acceleration_en_US.md

@@ -60,8 +60,6 @@ Download a sample file from the repository and test it.
 
 If your graphics card has at least 8GB of VRAM, follow these steps to test CUDA-accelerated parsing performance.
 
-> ❗ Due to the extremely limited nature of 8GB VRAM for running this application, you need to close all other programs using VRAM to ensure that 8GB of VRAM is available when running this application.
-
 1. **Overwrite the installation of torch and torchvision** supporting CUDA.
 
    ```

+ 0 - 2
old_docs/README_Windows_CUDA_Acceleration_zh_CN.md → docs/README_Windows_CUDA_Acceleration_zh_CN.md

@@ -61,8 +61,6 @@ pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com -i h
 
 如果您的显卡显存大于等于 **8GB** ,可以进行以下流程,测试CUDA解析加速效果
 
-> ❗️因8GB显存运行本应用非常极限,需要关闭所有其他正在使用显存的程序以确保本应用运行时有足额8GB显存可用。
-
 **1.覆盖安装支持cuda的torch和torchvision**
 
 ```bash

+ 0 - 0
old_docs/chemical_knowledge_introduction/introduction.pdf → docs/chemical_knowledge_introduction/introduction.pdf


+ 0 - 0
old_docs/chemical_knowledge_introduction/introduction.xmind → docs/chemical_knowledge_introduction/introduction.xmind


+ 21 - 8
old_docs/download_models.py → docs/download_models.py

@@ -5,16 +5,21 @@ import requests
 from modelscope import snapshot_download
 
 
+def download_json(url):
+    # 下载JSON文件
+    response = requests.get(url)
+    response.raise_for_status()  # 检查请求是否成功
+    return response.json()
+
+
 def download_and_modify_json(url, local_filename, modifications):
     if os.path.exists(local_filename):
         data = json.load(open(local_filename))
+        config_version = data.get('config_version', '0.0.0')
+        if config_version < '1.0.0':
+            data = download_json(url)
     else:
-        # 下载JSON文件
-        response = requests.get(url)
-        response.raise_for_status()  # 检查请求是否成功
-
-        # 解析JSON内容
-        data = response.json()
+        data = download_json(url)
 
     # 修改内容
     for key, value in modifications.items():
@@ -26,13 +31,21 @@ def download_and_modify_json(url, local_filename, modifications):
 
 
 if __name__ == '__main__':
-    model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
+    mineru_patterns = [
+        "models/Layout/LayoutLMv3/*",
+        "models/Layout/YOLO/*",
+        "models/MFD/YOLO/*",
+        "models/MFR/unimernet_small/*",
+        "models/TabRec/TableMaster/*",
+        "models/TabRec/StructEqTable/*",
+    ]
+    model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns)
     layoutreader_model_dir = snapshot_download('ppaanngggg/layoutreader')
     model_dir = model_dir + '/models'
     print(f'model_dir is: {model_dir}')
     print(f'layoutreader_model_dir is: {layoutreader_model_dir}')
 
-    json_url = 'https://gitee.com/myhloli/MinerU/raw/master/magic-pdf.template.json'
+    json_url = 'https://gitee.com/myhloli/MinerU/raw/dev/magic-pdf.template.json'
     config_file_name = 'magic-pdf.json'
     home_dir = os.path.expanduser('~')
     config_file = os.path.join(home_dir, config_file_name)

+ 29 - 9
old_docs/download_models_hf.py → docs/download_models_hf.py

@@ -5,16 +5,21 @@ import requests
 from huggingface_hub import snapshot_download
 
 
+def download_json(url):
+    # 下载JSON文件
+    response = requests.get(url)
+    response.raise_for_status()  # 检查请求是否成功
+    return response.json()
+
+
 def download_and_modify_json(url, local_filename, modifications):
     if os.path.exists(local_filename):
         data = json.load(open(local_filename))
+        config_version = data.get('config_version', '0.0.0')
+        if config_version < '1.0.0':
+            data = download_json(url)
     else:
-        # 下载JSON文件
-        response = requests.get(url)
-        response.raise_for_status()  # 检查请求是否成功
-
-        # 解析JSON内容
-        data = response.json()
+        data = download_json(url)
 
     # 修改内容
     for key, value in modifications.items():
@@ -26,13 +31,28 @@ def download_and_modify_json(url, local_filename, modifications):
 
 
 if __name__ == '__main__':
-    model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
-    layoutreader_model_dir = snapshot_download('hantian/layoutreader')
+
+    mineru_patterns = [
+        "models/Layout/LayoutLMv3/*",
+        "models/Layout/YOLO/*",
+        "models/MFD/YOLO/*",
+        "models/MFR/unimernet_small/*",
+        "models/TabRec/TableMaster/*",
+        "models/TabRec/StructEqTable/*",
+    ]
+    model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns)
+
+    layoutreader_pattern = [
+        "*.json",
+        "*.safetensors",
+    ]
+    layoutreader_model_dir = snapshot_download('hantian/layoutreader', allow_patterns=layoutreader_pattern)
+
     model_dir = model_dir + '/models'
     print(f'model_dir is: {model_dir}')
     print(f'layoutreader_model_dir is: {layoutreader_model_dir}')
 
-    json_url = 'https://github.com/opendatalab/MinerU/raw/master/magic-pdf.template.json'
+    json_url = 'https://github.com/opendatalab/MinerU/raw/dev/magic-pdf.template.json'
     config_file_name = 'magic-pdf.json'
     home_dir = os.path.expanduser('~')
     config_file = os.path.join(home_dir, config_file_name)

+ 3 - 1
old_docs/how_to_download_models_en.md → docs/how_to_download_models_en.md

@@ -22,7 +22,9 @@ The configuration file can be found in the user directory, with the filename `ma
 
 > Due to feedback from some users that downloading model files using git lfs was incomplete or resulted in corrupted model files, this method is no longer recommended.
 
-If you previously downloaded model files via git lfs, you can navigate to the previous download directory and use the `git pull` command to update the model.
+When magic-pdf <= 0.8.1, if you have previously downloaded the model files via git lfs, you can navigate to the previous download directory and update the models using the `git pull` command.
+
+> For versions 0.9.x and later, due to the repository change and the addition of the layout sorting model in PDF-Extract-Kit 1.0, the models cannot be updated using the `git pull` command. Instead, a Python script must be used for one-click updates.
 
 ## 2. Models downloaded via Hugging Face or Model Scope
 

+ 4 - 8
old_docs/how_to_download_models_zh_cn.md → docs/how_to_download_models_zh_cn.md

@@ -34,14 +34,10 @@ python脚本会自动下载模型文件并配置好配置文件中的模型目
 
 > 由于部分用户反馈通过git lfs下载模型文件遇到下载不全和模型文件损坏情况,现已不推荐使用该方式下载。
 
-如此前通过 git lfs 下载过模型文件,可以进入到之前的下载目录中,通过`git pull`命令更新模型。
-
-> 0.9.x及以后版本由于新增layout排序模型,且该模型和此前的模型不在同一仓库,不能通过`git pull`命令更新,需要单独下载。
->
-> ```
-> from modelscope import snapshot_download
-> snapshot_download('ppaanngggg/layoutreader')
-> ```
+当magic-pdf <= 0.8.1时,如此前通过 git lfs 下载过模型文件,可以进入到之前的下载目录中,通过`git pull`命令更新模型。
+
+> 0.9.x及以后版本由于PDF-Extract-Kit 1.0更换仓库和新增layout排序模型,不能通过`git pull`命令更新,需要使用python脚本一键更新。
+
 
 ## 2. 通过 Hugging Face 或 Model Scope 下载过模型
 

+ 0 - 0
old_docs/images/MinerU-logo-hq.png → docs/images/MinerU-logo-hq.png


+ 0 - 0
old_docs/images/MinerU-logo.png → docs/images/MinerU-logo.png


+ 0 - 0
old_docs/images/datalab_logo.png → docs/images/datalab_logo.png


+ 0 - 0
old_docs/images/flowchart_en.png → docs/images/flowchart_en.png


+ 0 - 0
old_docs/images/flowchart_zh_cn.png → docs/images/flowchart_zh_cn.png


+ 0 - 0
old_docs/images/layout_example.png → docs/images/layout_example.png


+ 0 - 0
old_docs/images/poly.png → docs/images/poly.png


+ 0 - 0
old_docs/images/project_panorama_en.png → docs/images/project_panorama_en.png


+ 0 - 0
old_docs/images/project_panorama_zh_cn.png → docs/images/project_panorama_zh_cn.png


+ 0 - 0
old_docs/images/spans_example.png → docs/images/spans_example.png


+ 0 - 0
old_docs/images/web_demo_1.png → docs/images/web_demo_1.png


+ 0 - 0
old_docs/output_file_en_us.md → docs/output_file_en_us.md


+ 0 - 0
old_docs/output_file_zh_cn.md → docs/output_file_zh_cn.md


+ 12 - 3
magic-pdf.template.json

@@ -6,9 +6,18 @@
     "models-dir":"/tmp/models",
     "layoutreader-model-dir":"/tmp/layoutreader",
     "device-mode":"cpu",
+    "layout-config": {
+        "model": "layoutlmv3"
+    },
+    "formula-config": {
+        "mfd_model": "yolo_v8_mfd",
+        "mfr_model": "unimernet_small",
+        "enable": true
+    },
     "table-config": {
-        "model": "TableMaster",
-        "is_table_recog_enable": false,
+        "model": "tablemaster",
+        "enable": false,
         "max_time": 400
-    }
+    },
+    "config_version": "1.0.0"
 }

+ 0 - 0
magic_pdf/config/__init__.py


+ 7 - 0
magic_pdf/config/enums.py

@@ -0,0 +1,7 @@
+
+import enum
+
+
+class SupportedPdfParseMethod(enum.Enum):
+    OCR = 'ocr'
+    TXT = 'txt'

+ 32 - 0
magic_pdf/config/exceptions.py

@@ -0,0 +1,32 @@
+
+class FileNotExisted(Exception):
+
+    def __init__(self, path):
+        self.path = path
+
+    def __str__(self):
+        return f'File {self.path} does not exist.'
+
+
+class InvalidConfig(Exception):
+    def __init__(self, msg):
+        self.msg = msg
+
+    def __str__(self):
+        return f'Invalid config: {self.msg}'
+
+
+class InvalidParams(Exception):
+    def __init__(self, msg):
+        self.msg = msg
+
+    def __str__(self):
+        return f'Invalid params: {self.msg}'
+
+
+class EmptyData(Exception):
+    def __init__(self, msg):
+        self.msg = msg
+
+    def __str__(self):
+        return f'Empty data: {self.msg}'

+ 0 - 0
magic_pdf/data/__init__.py


+ 12 - 0
magic_pdf/data/data_reader_writer/__init__.py

@@ -0,0 +1,12 @@
+from magic_pdf.data.data_reader_writer.filebase import \
+    FileBasedDataReader  # noqa: F401
+from magic_pdf.data.data_reader_writer.filebase import \
+    FileBasedDataWriter  # noqa: F401
+from magic_pdf.data.data_reader_writer.multi_bucket_s3 import \
+    MultiBucketS3DataReader  # noqa: F401
+from magic_pdf.data.data_reader_writer.multi_bucket_s3 import \
+    MultiBucketS3DataWriter  # noqa: F401
+from magic_pdf.data.data_reader_writer.s3 import S3DataReader  # noqa: F401
+from magic_pdf.data.data_reader_writer.s3 import S3DataWriter  # noqa: F401
+from magic_pdf.data.data_reader_writer.base import DataReader  # noqa: F401
+from magic_pdf.data.data_reader_writer.base import DataWriter  # noqa: F401

+ 51 - 0
magic_pdf/data/data_reader_writer/base.py

@@ -0,0 +1,51 @@
+
+from abc import ABC, abstractmethod
+
+
+class DataReader(ABC):
+
+    def read(self, path: str) -> bytes:
+        """Read the file.
+
+        Args:
+            path (str): file path to read
+
+        Returns:
+            bytes: the content of the file
+        """
+        return self.read_at(path)
+
+    @abstractmethod
+    def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
+        """Read the file at offset and limit.
+
+        Args:
+            path (str): the file path
+            offset (int, optional): the number of bytes skipped. Defaults to 0.
+            limit (int, optional): the length of bytes want to read. Defaults to -1.
+
+        Returns:
+            bytes: the content of the file
+        """
+        pass
+
+
+class DataWriter(ABC):
+    @abstractmethod
+    def write(self, path: str, data: bytes) -> None:
+        """Write the data to the file.
+
+        Args:
+            path (str): the target file where to write
+            data (bytes): the data want to write
+        """
+        pass
+
+    def write_string(self, path: str, data: str) -> None:
+        """Write the data to file, the data will be encoded to bytes.
+
+        Args:
+            path (str): the target file where to write
+            data (str): the data want to write
+        """
+        self.write(path, data.encode())

+ 59 - 0
magic_pdf/data/data_reader_writer/filebase.py

@@ -0,0 +1,59 @@
+import os
+
+from magic_pdf.data.data_reader_writer.base import DataReader, DataWriter
+
+
+class FileBasedDataReader(DataReader):
+    def __init__(self, parent_dir: str = ''):
+        """Initialized with parent_dir.
+
+        Args:
+            parent_dir (str, optional): the parent directory that may be used within methods. Defaults to ''.
+        """
+        self._parent_dir = parent_dir
+
+    def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
+        """Read at offset and limit.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            offset (int, optional): the number of bytes skipped. Defaults to 0.
+            limit (int, optional): the length of bytes want to read. Defaults to -1.
+
+        Returns:
+            bytes: the content of file
+        """
+        fn_path = path
+        if not os.path.isabs(fn_path) and len(self._parent_dir) > 0:
+            fn_path = os.path.join(self._parent_dir, path)
+
+        with open(fn_path, 'rb') as f:
+            f.seek(offset)
+            if limit == -1:
+                return f.read()
+            else:
+                return f.read(limit)
+
+
+class FileBasedDataWriter(DataWriter):
+    def __init__(self, parent_dir: str = '') -> None:
+        """Initialized with parent_dir.
+
+        Args:
+            parent_dir (str, optional): the parent directory that may be used within methods. Defaults to ''.
+        """
+        self._parent_dir = parent_dir
+
+    def write(self, path: str, data: bytes) -> None:
+        """Write file with data.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            data (bytes): the data want to write
+        """
+        fn_path = path
+        if not os.path.isabs(fn_path) and len(self._parent_dir) > 0:
+            fn_path = os.path.join(self._parent_dir, path)
+
+        with open(fn_path, 'wb') as f:
+            f.write(data)

+ 137 - 0
magic_pdf/data/data_reader_writer/multi_bucket_s3.py

@@ -0,0 +1,137 @@
+from magic_pdf.config.exceptions import InvalidConfig, InvalidParams
+from magic_pdf.data.data_reader_writer.base import DataReader, DataWriter
+from magic_pdf.data.io.s3 import S3Reader, S3Writer
+from magic_pdf.data.schemas import S3Config
+from magic_pdf.libs.path_utils import (parse_s3_range_params, parse_s3path,
+                                       remove_non_official_s3_args)
+
+
+class MultiS3Mixin:
+    def __init__(self, default_bucket: str, s3_configs: list[S3Config]):
+        """Initialized with multiple s3 configs.
+
+        Args:
+            default_bucket (str): the default bucket name of the relative path
+            s3_configs (list[S3Config]): list of s3 configs, the bucket_name must be unique in the list.
+
+        Raises:
+            InvalidConfig: default bucket config not in s3_configs
+            InvalidConfig: bucket name not unique in s3_configs
+            InvalidConfig: default bucket must be provided
+        """
+        if len(default_bucket) == 0:
+            raise InvalidConfig('default_bucket must be provided')
+
+        found_default_bucket_config = False
+        for conf in s3_configs:
+            if conf.bucket_name == default_bucket:
+                found_default_bucket_config = True
+                break
+
+        if not found_default_bucket_config:
+            raise InvalidConfig(
+                f'default_bucket: {default_bucket} config must be provided in s3_configs: {s3_configs}'
+            )
+
+        uniq_bucket = set([conf.bucket_name for conf in s3_configs])
+        if len(uniq_bucket) != len(s3_configs):
+            raise InvalidConfig(
+                f'the bucket_name in s3_configs: {s3_configs} must be unique'
+            )
+
+        self.default_bucket = default_bucket
+        self.s3_configs = s3_configs
+        self._s3_clients_h: dict = {}
+
+
+class MultiBucketS3DataReader(DataReader, MultiS3Mixin):
+    def read(self, path: str) -> bytes:
+        """Read the path from s3, select diffect bucket client for each request
+        based on the path, also support range read.
+
+        Args:
+            path (str): the s3 path of file, the path must be in the format of s3://bucket_name/path?offset,limit
+            for example: s3://bucket_name/path?0,100
+
+        Returns:
+            bytes: the content of s3 file
+        """
+        may_range_params = parse_s3_range_params(path)
+        if may_range_params is None or 2 != len(may_range_params):
+            byte_start, byte_len = 0, -1
+        else:
+            byte_start, byte_len = int(may_range_params[0]), int(may_range_params[1])
+        path = remove_non_official_s3_args(path)
+        return self.read_at(path, byte_start, byte_len)
+
+    def __get_s3_client(self, bucket_name: str):
+        if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
+            raise InvalidParams(
+                f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
+            )
+        if bucket_name not in self._s3_clients_h:
+            conf = next(
+                filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
+            )
+            self._s3_clients_h[bucket_name] = S3Reader(
+                bucket_name,
+                conf.access_key,
+                conf.secret_key,
+                conf.endpoint_url,
+                conf.addressing_style,
+            )
+        return self._s3_clients_h[bucket_name]
+
+    def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
+        """Read the file with offset and limit, select diffect bucket client
+        for each request based on the path.
+
+        Args:
+            path (str): the file path
+            offset (int, optional): the number of bytes skipped. Defaults to 0.
+            limit (int, optional): the number of bytes want to read. Defaults to -1 which means infinite.
+
+        Returns:
+            bytes: the file content
+        """
+        if path.startswith('s3://'):
+            bucket_name, path = parse_s3path(path)
+            s3_reader = self.__get_s3_client(bucket_name)
+        else:
+            s3_reader = self.__get_s3_client(self.default_bucket)
+        return s3_reader.read_at(path, offset, limit)
+
+
+class MultiBucketS3DataWriter(DataWriter, MultiS3Mixin):
+    def __get_s3_client(self, bucket_name: str):
+        if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
+            raise InvalidParams(
+                f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
+            )
+        if bucket_name not in self._s3_clients_h:
+            conf = next(
+                filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
+            )
+            self._s3_clients_h[bucket_name] = S3Writer(
+                bucket_name,
+                conf.access_key,
+                conf.secret_key,
+                conf.endpoint_url,
+                conf.addressing_style,
+            )
+        return self._s3_clients_h[bucket_name]
+
+    def write(self, path: str, data: bytes) -> None:
+        """Write file with data, also select diffect bucket client for each
+        request based on the path.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            data (bytes): the data want to write
+        """
+        if path.startswith('s3://'):
+            bucket_name, path = parse_s3path(path)
+            s3_writer = self.__get_s3_client(bucket_name)
+        else:
+            s3_writer = self.__get_s3_client(self.default_bucket)
+        return s3_writer.write(path, data)

+ 69 - 0
magic_pdf/data/data_reader_writer/s3.py

@@ -0,0 +1,69 @@
+from magic_pdf.data.data_reader_writer.multi_bucket_s3 import (
+    MultiBucketS3DataReader, MultiBucketS3DataWriter)
+from magic_pdf.data.schemas import S3Config
+
+
+class S3DataReader(MultiBucketS3DataReader):
+    def __init__(
+        self,
+        bucket: str,
+        ak: str,
+        sk: str,
+        endpoint_url: str,
+        addressing_style: str = 'auto',
+    ):
+        """s3 reader client.
+
+        Args:
+            bucket (str): bucket name
+            ak (str): access key
+            sk (str): secret key
+            endpoint_url (str): endpoint url of s3
+            addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
+            refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
+        """
+        super().__init__(
+            bucket,
+            [
+                S3Config(
+                    bucket_name=bucket,
+                    access_key=ak,
+                    secret_key=sk,
+                    endpoint_url=endpoint_url,
+                    addressing_style=addressing_style,
+                )
+            ],
+        )
+
+
+class S3DataWriter(MultiBucketS3DataWriter):
+    def __init__(
+        self,
+        bucket: str,
+        ak: str,
+        sk: str,
+        endpoint_url: str,
+        addressing_style: str = 'auto',
+    ):
+        """s3 writer client.
+
+        Args:
+            bucket (str): bucket name
+            ak (str): access key
+            sk (str): secret key
+            endpoint_url (str): endpoint url of s3
+            addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
+            refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
+        """
+        super().__init__(
+            bucket,
+            [
+                S3Config(
+                    bucket_name=bucket,
+                    access_key=ak,
+                    secret_key=sk,
+                    endpoint_url=endpoint_url,
+                    addressing_style=addressing_style,
+                )
+            ],
+        )

+ 194 - 0
magic_pdf/data/dataset.py

@@ -0,0 +1,194 @@
+from abc import ABC, abstractmethod
+from typing import Iterator
+
+import fitz
+
+from magic_pdf.config.enums import SupportedPdfParseMethod
+from magic_pdf.data.schemas import PageInfo
+from magic_pdf.data.utils import fitz_doc_to_image
+
+
+class PageableData(ABC):
+    @abstractmethod
+    def get_image(self) -> dict:
+        """Transform data to image."""
+        pass
+
+    @abstractmethod
+    def get_doc(self) -> fitz.Page:
+        """Get the pymudoc page."""
+        pass
+
+    @abstractmethod
+    def get_page_info(self) -> PageInfo:
+        """Get the page info of the page.
+
+        Returns:
+            PageInfo: the page info of this page
+        """
+        pass
+
+
+class Dataset(ABC):
+    @abstractmethod
+    def __len__(self) -> int:
+        """The length of the dataset."""
+        pass
+
+    @abstractmethod
+    def __iter__(self) -> Iterator[PageableData]:
+        """Yield the page data."""
+        pass
+
+    @abstractmethod
+    def supported_methods(self) -> list[SupportedPdfParseMethod]:
+        """The methods that this dataset support.
+
+        Returns:
+            list[SupportedPdfParseMethod]: The supported methods, Valid methods are: OCR, TXT
+        """
+        pass
+
+    @abstractmethod
+    def data_bits(self) -> bytes:
+        """The bits used to create this dataset."""
+        pass
+
+    @abstractmethod
+    def get_page(self, page_id: int) -> PageableData:
+        """Get the page indexed by page_id.
+
+        Args:
+            page_id (int): the index of the page
+
+        Returns:
+            PageableData: the page doc object
+        """
+        pass
+
+
+class PymuDocDataset(Dataset):
+    def __init__(self, bits: bytes):
+        """Initialize the dataset, which wraps the pymudoc documents.
+
+        Args:
+            bits (bytes): the bytes of the pdf
+        """
+        self._records = [Doc(v) for v in fitz.open('pdf', bits)]
+        self._data_bits = bits
+        self._raw_data = bits
+
+    def __len__(self) -> int:
+        """The page number of the pdf."""
+        return len(self._records)
+
+    def __iter__(self) -> Iterator[PageableData]:
+        """Yield the page doc object."""
+        return iter(self._records)
+
+    def supported_methods(self) -> list[SupportedPdfParseMethod]:
+        """The method supported by this dataset.
+
+        Returns:
+            list[SupportedPdfParseMethod]: the supported methods
+        """
+        return [SupportedPdfParseMethod.OCR, SupportedPdfParseMethod.TXT]
+
+    def data_bits(self) -> bytes:
+        """The pdf bits used to create this dataset."""
+        return self._data_bits
+
+    def get_page(self, page_id: int) -> PageableData:
+        """The page doc object.
+
+        Args:
+            page_id (int): the page doc index
+
+        Returns:
+            PageableData: the page doc object
+        """
+        return self._records[page_id]
+
+
+class ImageDataset(Dataset):
+    def __init__(self, bits: bytes):
+        """Initialize the dataset, which wraps the pymudoc documents.
+
+        Args:
+            bits (bytes): the bytes of the photo which will be converted to pdf first. then converted to pymudoc.
+        """
+        pdf_bytes = fitz.open(stream=bits).convert_to_pdf()
+        self._records = [Doc(v) for v in fitz.open('pdf', pdf_bytes)]
+        self._raw_data = bits
+        self._data_bits = pdf_bytes
+
+    def __len__(self) -> int:
+        """The length of the dataset."""
+        return len(self._records)
+
+    def __iter__(self) -> Iterator[PageableData]:
+        """Yield the page object."""
+        return iter(self._records)
+
+    def supported_methods(self):
+        """The method supported by this dataset.
+
+        Returns:
+            list[SupportedPdfParseMethod]: the supported methods
+        """
+        return [SupportedPdfParseMethod.OCR]
+
+    def data_bits(self) -> bytes:
+        """The pdf bits used to create this dataset."""
+        return self._data_bits
+
+    def get_page(self, page_id: int) -> PageableData:
+        """The page doc object.
+
+        Args:
+            page_id (int): the page doc index
+
+        Returns:
+            PageableData: the page doc object
+        """
+        return self._records[page_id]
+
+
+class Doc(PageableData):
+    """Initialized with pymudoc object."""
+    def __init__(self, doc: fitz.Page):
+        self._doc = doc
+
+    def get_image(self):
+        """Return the imge info.
+
+        Returns:
+            dict: {
+                img: np.ndarray,
+                width: int,
+                height: int
+            }
+        """
+        return fitz_doc_to_image(self._doc)
+
+    def get_doc(self) -> fitz.Page:
+        """Get the pymudoc object.
+
+        Returns:
+            fitz.Page: the pymudoc object
+        """
+        return self._doc
+
+    def get_page_info(self) -> PageInfo:
+        """Get the page info of the page.
+
+        Returns:
+            PageInfo: the page info of this page
+        """
+        page_w = self._doc.rect.width
+        page_h = self._doc.rect.height
+        return PageInfo(w=page_w, h=page_h)
+
+    def __getattr__(self, name):
+        if hasattr(self._doc, name):
+            return getattr(self._doc, name)

+ 0 - 0
magic_pdf/data/io/__init__.py


+ 42 - 0
magic_pdf/data/io/base.py

@@ -0,0 +1,42 @@
+from abc import ABC, abstractmethod
+
+
+class IOReader(ABC):
+    @abstractmethod
+    def read(self, path: str) -> bytes:
+        """Read the file.
+
+        Args:
+            path (str): file path to read
+
+        Returns:
+            bytes: the content of the file
+        """
+        pass
+
+    @abstractmethod
+    def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
+        """Read at offset and limit.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            offset (int, optional): the number of bytes skipped. Defaults to 0.
+            limit (int, optional): the length of bytes want to read. Defaults to -1.
+
+        Returns:
+            bytes: the content of file
+        """
+        pass
+
+
+class IOWriter:
+
+    @abstractmethod
+    def write(self, path: str, data: bytes) -> None:
+        """Write file with data.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            data (bytes): the data want to write
+        """
+        pass

+ 37 - 0
magic_pdf/data/io/http.py

@@ -0,0 +1,37 @@
+
+import io
+
+import requests
+
+from magic_pdf.data.io.base import IOReader, IOWriter
+
+
+class HttpReader(IOReader):
+
+    def read(self, url: str) -> bytes:
+        """Read the file.
+
+        Args:
+            path (str): file path to read
+
+        Returns:
+            bytes: the content of the file
+        """
+        return requests.get(url).content
+
+    def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
+        """Not Implemented."""
+        raise NotImplementedError
+
+
+class HttpWriter(IOWriter):
+    def write(self, url: str, data: bytes) -> None:
+        """Write file with data.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            data (bytes): the data want to write
+        """
+        files = {'file': io.BytesIO(data)}
+        response = requests.post(url, files=files)
+        assert 300 > response.status_code and response.status_code > 199

+ 114 - 0
magic_pdf/data/io/s3.py

@@ -0,0 +1,114 @@
+import boto3
+from botocore.config import Config
+
+from magic_pdf.data.io.base import IOReader, IOWriter
+
+
+class S3Reader(IOReader):
+    def __init__(
+        self,
+        bucket: str,
+        ak: str,
+        sk: str,
+        endpoint_url: str,
+        addressing_style: str = 'auto',
+    ):
+        """s3 reader client.
+
+        Args:
+            bucket (str): bucket name
+            ak (str): access key
+            sk (str): secret key
+            endpoint_url (str): endpoint url of s3
+            addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
+            refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
+        """
+        self._bucket = bucket
+        self._ak = ak
+        self._sk = sk
+        self._s3_client = boto3.client(
+            service_name='s3',
+            aws_access_key_id=ak,
+            aws_secret_access_key=sk,
+            endpoint_url=endpoint_url,
+            config=Config(
+                s3={'addressing_style': addressing_style},
+                retries={'max_attempts': 5, 'mode': 'standard'},
+            ),
+        )
+
+    def read(self, key: str) -> bytes:
+        """Read the file.
+
+        Args:
+            path (str): file path to read
+
+        Returns:
+            bytes: the content of the file
+        """
+        return self.read_at(key)
+
+    def read_at(self, key: str, offset: int = 0, limit: int = -1) -> bytes:
+        """Read at offset and limit.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            offset (int, optional): the number of bytes skipped. Defaults to 0.
+            limit (int, optional): the length of bytes want to read. Defaults to -1.
+
+        Returns:
+            bytes: the content of file
+        """
+        if limit > -1:
+            range_header = f'bytes={offset}-{offset+limit-1}'
+            res = self._s3_client.get_object(
+                Bucket=self._bucket, Key=key, Range=range_header
+            )
+        else:
+            res = self._s3_client.get_object(
+                Bucket=self._bucket, Key=key, Range=f'bytes={offset}-'
+            )
+        return res['Body'].read()
+
+
+class S3Writer(IOWriter):
+    def __init__(
+        self,
+        bucket: str,
+        ak: str,
+        sk: str,
+        endpoint_url: str,
+        addressing_style: str = 'auto',
+    ):
+        """s3 reader client.
+
+        Args:
+            bucket (str): bucket name
+            ak (str): access key
+            sk (str): secret key
+            endpoint_url (str): endpoint url of s3
+            addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
+            refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
+        """
+        self._bucket = bucket
+        self._ak = ak
+        self._sk = sk
+        self._s3_client = boto3.client(
+            service_name='s3',
+            aws_access_key_id=ak,
+            aws_secret_access_key=sk,
+            endpoint_url=endpoint_url,
+            config=Config(
+                s3={'addressing_style': addressing_style},
+                retries={'max_attempts': 5, 'mode': 'standard'},
+            ),
+        )
+
+    def write(self, key: str, data: bytes):
+        """Write file with data.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            data (bytes): the data want to write
+        """
+        self._s3_client.put_object(Bucket=self._bucket, Key=key, Body=data)

+ 95 - 0
magic_pdf/data/read_api.py

@@ -0,0 +1,95 @@
+import json
+import os
+from pathlib import Path
+
+from magic_pdf.config.exceptions import EmptyData, InvalidParams
+from magic_pdf.data.data_reader_writer import (FileBasedDataReader,
+                                               MultiBucketS3DataReader)
+from magic_pdf.data.dataset import ImageDataset, PymuDocDataset
+
+
+def read_jsonl(
+    s3_path_or_local: str, s3_client: MultiBucketS3DataReader | None = None
+) -> list[PymuDocDataset]:
+    """Read the jsonl file and return the list of PymuDocDataset.
+
+    Args:
+        s3_path_or_local (str): local file or s3 path
+        s3_client (MultiBucketS3DataReader | None, optional): s3 client that support multiple bucket. Defaults to None.
+
+    Raises:
+        InvalidParams: if s3_path_or_local is s3 path but s3_client is not provided.
+        EmptyData: if no pdf file location is provided in some line of jsonl file.
+        InvalidParams: if the file location is s3 path but s3_client is not provided
+
+    Returns:
+        list[PymuDocDataset]: each line in the jsonl file will be converted to a PymuDocDataset
+    """
+    bits_arr = []
+    if s3_path_or_local.startswith('s3://'):
+        if s3_client is None:
+            raise InvalidParams('s3_client is required when s3_path is provided')
+        jsonl_bits = s3_client.read(s3_path_or_local)
+    else:
+        jsonl_bits = FileBasedDataReader('').read(s3_path_or_local)
+    jsonl_d = [
+        json.loads(line) for line in jsonl_bits.decode().split('\n') if line.strip()
+    ]
+    for d in jsonl_d[:5]:
+        pdf_path = d.get('file_location', '') or d.get('path', '')
+        if len(pdf_path) == 0:
+            raise EmptyData('pdf file location is empty')
+        if pdf_path.startswith('s3://'):
+            if s3_client is None:
+                raise InvalidParams('s3_client is required when s3_path is provided')
+            bits_arr.append(s3_client.read(pdf_path))
+        else:
+            bits_arr.append(FileBasedDataReader('').read(pdf_path))
+    return [PymuDocDataset(bits) for bits in bits_arr]
+
+
+def read_local_pdfs(path: str) -> list[PymuDocDataset]:
+    """Read pdf from path or directory.
+
+    Args:
+        path (str): pdf file path or directory that contains pdf files
+
+    Returns:
+        list[PymuDocDataset]: each pdf file will converted to a PymuDocDataset
+    """
+    if os.path.isdir(path):
+        reader = FileBasedDataReader(path)
+        return [
+            PymuDocDataset(reader.read(doc_path.name))
+            for doc_path in Path(path).glob('*.pdf')
+        ]
+    else:
+        reader = FileBasedDataReader()
+        bits = reader.read(path)
+        return [PymuDocDataset(bits)]
+
+
+def read_local_images(path: str, suffixes: list[str]) -> list[ImageDataset]:
+    """Read images from path or directory.
+
+    Args:
+        path (str): image file path or directory that contains image files
+        suffixes (list[str]): the suffixes of the image files used to filter the files. Example: ['jpg', 'png']
+
+    Returns:
+        list[ImageDataset]: each image file will converted to a ImageDataset
+    """
+    if os.path.isdir(path):
+        imgs_bits = []
+        s_suffixes = set(suffixes)
+        reader = FileBasedDataReader(path)
+        for root, _, files in os.walk(path):
+            for file in files:
+                suffix = file.split('.')
+                if suffix[-1] in s_suffixes:
+                    imgs_bits.append(reader.read(file))
+        return [ImageDataset(bits) for bits in imgs_bits]
+    else:
+        reader = FileBasedDataReader()
+        bits = reader.read(path)
+        return [ImageDataset(bits)]

+ 15 - 0
magic_pdf/data/schemas.py

@@ -0,0 +1,15 @@
+
+from pydantic import BaseModel, Field
+
+
+class S3Config(BaseModel):
+    bucket_name: str = Field(description='s3 bucket name', min_length=1)
+    access_key: str = Field(description='s3 access key', min_length=1)
+    secret_key: str = Field(description='s3 secret key', min_length=1)
+    endpoint_url: str = Field(description='s3 endpoint url', min_length=1)
+    addressing_style: str = Field(description='s3 addressing style', default='auto', min_length=1)
+
+
+class PageInfo(BaseModel):
+    w: float = Field(description='the width of page')
+    h: float = Field(description='the height of page')

+ 32 - 0
magic_pdf/data/utils.py

@@ -0,0 +1,32 @@
+
+import fitz
+import numpy as np
+
+from magic_pdf.utils.annotations import ImportPIL
+
+
+@ImportPIL
+def fitz_doc_to_image(doc, dpi=200) -> dict:
+    """Convert fitz.Document to image, Then convert the image to numpy array.
+
+    Args:
+        doc (_type_): pymudoc page
+        dpi (int, optional): reset the dpi of dpi. Defaults to 200.
+
+    Returns:
+        dict:  {'img': numpy array, 'width': width, 'height': height }
+    """
+    from PIL import Image
+    mat = fitz.Matrix(dpi / 72, dpi / 72)
+    pm = doc.get_pixmap(matrix=mat, alpha=False)
+
+    # If the width or height exceeds 9000 after scaling, do not scale further.
+    if pm.width > 9000 or pm.height > 9000:
+        pm = doc.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
+
+    img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
+    img = np.array(img)
+
+    img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
+
+    return img_dict

+ 29 - 224
magic_pdf/dict2md/ocr_mkcontent.py

@@ -1,6 +1,5 @@
 import re
 
-import wordninja
 from loguru import logger
 
 from magic_pdf.libs.commons import join_path
@@ -25,37 +24,6 @@ def __is_hyphen_at_line_end(line):
     return bool(re.search(r'[A-Za-z]+-\s*$', line))
 
 
-def split_long_words(text):
-    segments = text.split(' ')
-    for i in range(len(segments)):
-        words = re.findall(r'\w+|[^\w]', segments[i], re.UNICODE)
-        for j in range(len(words)):
-            if len(words[j]) > 10:
-                words[j] = ' '.join(wordninja.split(words[j]))
-        segments[i] = ''.join(words)
-    return ' '.join(segments)
-
-
-def ocr_mk_mm_markdown_with_para(pdf_info_list: list, img_buket_path):
-    markdown = []
-    for page_info in pdf_info_list:
-        paras_of_layout = page_info.get('para_blocks')
-        page_markdown = ocr_mk_markdown_with_para_core_v2(
-            paras_of_layout, 'mm', img_buket_path)
-        markdown.extend(page_markdown)
-    return '\n\n'.join(markdown)
-
-
-def ocr_mk_nlp_markdown_with_para(pdf_info_dict: list):
-    markdown = []
-    for page_info in pdf_info_dict:
-        paras_of_layout = page_info.get('para_blocks')
-        page_markdown = ocr_mk_markdown_with_para_core_v2(
-            paras_of_layout, 'nlp')
-        markdown.extend(page_markdown)
-    return '\n\n'.join(markdown)
-
-
 def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: list,
                                                 img_buket_path):
     markdown_with_para_and_pagination = []
@@ -68,69 +36,28 @@ def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: list,
             paras_of_layout, 'mm', img_buket_path)
         markdown_with_para_and_pagination.append({
             'page_no':
-            page_no,
+                page_no,
             'md_content':
-            '\n\n'.join(page_markdown)
+                '\n\n'.join(page_markdown)
         })
         page_no += 1
     return markdown_with_para_and_pagination
 
 
-def ocr_mk_markdown_with_para_core(paras_of_layout, mode, img_buket_path=''):
-    page_markdown = []
-    for paras in paras_of_layout:
-        for para in paras:
-            para_text = ''
-            for line in para:
-                for span in line['spans']:
-                    span_type = span.get('type')
-                    content = ''
-                    language = ''
-                    if span_type == ContentType.Text:
-                        content = span['content']
-                        language = detect_lang(content)
-                        if (language == 'en'):  # 只对英文长词进行分词处理,中文分词会丢失文本
-                            content = ocr_escape_special_markdown_char(
-                                split_long_words(content))
-                        else:
-                            content = ocr_escape_special_markdown_char(content)
-                    elif span_type == ContentType.InlineEquation:
-                        content = f"${span['content']}$"
-                    elif span_type == ContentType.InterlineEquation:
-                        content = f"\n$$\n{span['content']}\n$$\n"
-                    elif span_type in [ContentType.Image, ContentType.Table]:
-                        if mode == 'mm':
-                            content = f"\n![]({join_path(img_buket_path, span['image_path'])})\n"
-                        elif mode == 'nlp':
-                            pass
-                    if content != '':
-                        if language == 'en':  # 英文语境下 content间需要空格分隔
-                            para_text += content + ' '
-                        else:  # 中文语境下,content间不需要空格分隔
-                            para_text += content
-            if para_text.strip() == '':
-                continue
-            else:
-                page_markdown.append(para_text.strip() + '  ')
-    return page_markdown
-
-
 def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
                                       mode,
                                       img_buket_path='',
-                                      parse_type="auto",
-                                      lang=None
                                       ):
     page_markdown = []
     for para_block in paras_of_layout:
         para_text = ''
         para_type = para_block['type']
         if para_type in [BlockType.Text, BlockType.List, BlockType.Index]:
-            para_text = merge_para_with_text(para_block, parse_type=parse_type, lang=lang)
+            para_text = merge_para_with_text(para_block)
         elif para_type == BlockType.Title:
-            para_text = f'# {merge_para_with_text(para_block, parse_type=parse_type, lang=lang)}'
+            para_text = f'# {merge_para_with_text(para_block)}'
         elif para_type == BlockType.InterlineEquation:
-            para_text = merge_para_with_text(para_block, parse_type=parse_type, lang=lang)
+            para_text = merge_para_with_text(para_block)
         elif para_type == BlockType.Image:
             if mode == 'nlp':
                 continue
@@ -143,17 +70,17 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
                                     para_text += f"\n![]({join_path(img_buket_path, span['image_path'])})  \n"
                 for block in para_block['blocks']:  # 2nd.拼image_caption
                     if block['type'] == BlockType.ImageCaption:
-                        para_text += merge_para_with_text(block, parse_type=parse_type, lang=lang)
-                for block in para_block['blocks']:  # 2nd.拼image_caption
+                        para_text += merge_para_with_text(block) + '  \n'
+                for block in para_block['blocks']:  # 3rd.拼image_footnote
                     if block['type'] == BlockType.ImageFootnote:
-                        para_text += merge_para_with_text(block, parse_type=parse_type, lang=lang)
+                        para_text += merge_para_with_text(block) + '  \n'
         elif para_type == BlockType.Table:
             if mode == 'nlp':
                 continue
             elif mode == 'mm':
                 for block in para_block['blocks']:  # 1st.拼table_caption
                     if block['type'] == BlockType.TableCaption:
-                        para_text += merge_para_with_text(block, parse_type=parse_type, lang=lang)
+                        para_text += merge_para_with_text(block) + '  \n'
                 for block in para_block['blocks']:  # 2nd.拼table_body
                     if block['type'] == BlockType.TableBody:
                         for line in block['lines']:
@@ -168,7 +95,7 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
                                         para_text += f"\n![]({join_path(img_buket_path, span['image_path'])})  \n"
                 for block in para_block['blocks']:  # 3rd.拼table_footnote
                     if block['type'] == BlockType.TableFootnote:
-                        para_text += merge_para_with_text(block, parse_type=parse_type, lang=lang)
+                        para_text += merge_para_with_text(block) + '  \n'
 
         if para_text.strip() == '':
             continue
@@ -191,7 +118,7 @@ def detect_language(text):
         return 'empty'
 
 
-def merge_para_with_text(para_block, parse_type="auto", lang=None):
+def merge_para_with_text(para_block):
     para_text = ''
     for i, line in enumerate(para_block['lines']):
 
@@ -207,21 +134,11 @@ def merge_para_with_text(para_block, parse_type="auto", lang=None):
         if line_text != '':
             line_lang = detect_lang(line_text)
         for span in line['spans']:
+
             span_type = span['type']
             content = ''
             if span_type == ContentType.Text:
-                content = span['content']
-                # language = detect_lang(content)
-                language = detect_language(content)
-                # 判断是否小语种
-                if lang is not None and lang != 'en':
-                    content = ocr_escape_special_markdown_char(content)
-                else:  # 非小语种逻辑
-                    if language == 'en' and parse_type == 'ocr':  # 只对英文长词进行分词处理,中文分词会丢失文本
-                        content = ocr_escape_special_markdown_char(
-                            split_long_words(content))
-                    else:
-                        content = ocr_escape_special_markdown_char(content)
+                content = ocr_escape_special_markdown_char(span['content'])
             elif span_type == ContentType.InlineEquation:
                 content = f" ${span['content']}$ "
             elif span_type == ContentType.InterlineEquation:
@@ -242,74 +159,39 @@ def merge_para_with_text(para_block, parse_type="auto", lang=None):
     return para_text
 
 
-def para_to_standard_format(para, img_buket_path):
-    para_content = {}
-    if len(para) == 1:
-        para_content = line_to_standard_format(para[0], img_buket_path)
-    elif len(para) > 1:
-        para_text = ''
-        inline_equation_num = 0
-        for line in para:
-            for span in line['spans']:
-                language = ''
-                span_type = span.get('type')
-                content = ''
-                if span_type == ContentType.Text:
-                    content = span['content']
-                    language = detect_lang(content)
-                    if language == 'en':  # 只对英文长词进行分词处理,中文分词会丢失文本
-                        content = ocr_escape_special_markdown_char(
-                            split_long_words(content))
-                    else:
-                        content = ocr_escape_special_markdown_char(content)
-                elif span_type == ContentType.InlineEquation:
-                    content = f"${span['content']}$"
-                    inline_equation_num += 1
-                if language == 'en':  # 英文语境下 content间需要空格分隔
-                    para_text += content + ' '
-                else:  # 中文语境下,content间不需要空格分隔
-                    para_text += content
-        para_content = {
-            'type': 'text',
-            'text': para_text,
-            'inline_equation_num': inline_equation_num,
-        }
-    return para_content
-
-
-def para_to_standard_format_v2(para_block, img_buket_path, page_idx, parse_type="auto", lang=None, drop_reason=None):
+def para_to_standard_format_v2(para_block, img_buket_path, page_idx, drop_reason=None):
     para_type = para_block['type']
     para_content = {}
-    if para_type == BlockType.Text:
+    if para_type in [BlockType.Text, BlockType.List, BlockType.Index]:
         para_content = {
             'type': 'text',
-            'text': merge_para_with_text(para_block, parse_type=parse_type, lang=lang),
+            'text': merge_para_with_text(para_block),
         }
     elif para_type == BlockType.Title:
         para_content = {
             'type': 'text',
-            'text': merge_para_with_text(para_block, parse_type=parse_type, lang=lang),
+            'text': merge_para_with_text(para_block),
             'text_level': 1,
         }
     elif para_type == BlockType.InterlineEquation:
         para_content = {
             'type': 'equation',
-            'text': merge_para_with_text(para_block, parse_type=parse_type, lang=lang),
+            'text': merge_para_with_text(para_block),
             'text_format': 'latex',
         }
     elif para_type == BlockType.Image:
-        para_content = {'type': 'image'}
+        para_content = {'type': 'image', 'img_caption': [], 'img_footnote': []}
         for block in para_block['blocks']:
             if block['type'] == BlockType.ImageBody:
                 para_content['img_path'] = join_path(
                     img_buket_path,
                     block['lines'][0]['spans'][0]['image_path'])
             if block['type'] == BlockType.ImageCaption:
-                para_content['img_caption'] = merge_para_with_text(block, parse_type=parse_type, lang=lang)
+                para_content['img_caption'].append(merge_para_with_text(block))
             if block['type'] == BlockType.ImageFootnote:
-                para_content['img_footnote'] = merge_para_with_text(block, parse_type=parse_type, lang=lang)
+                para_content['img_footnote'].append(merge_para_with_text(block))
     elif para_type == BlockType.Table:
-        para_content = {'type': 'table'}
+        para_content = {'type': 'table', 'table_caption': [], 'table_footnote': []}
         for block in para_block['blocks']:
             if block['type'] == BlockType.TableBody:
                 if block["lines"][0]["spans"][0].get('latex', ''):
@@ -318,9 +200,9 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx, parse_type=
                     para_content['table_body'] = f"\n\n{block['lines'][0]['spans'][0]['html']}\n\n"
                 para_content['img_path'] = join_path(img_buket_path, block["lines"][0]["spans"][0]['image_path'])
             if block['type'] == BlockType.TableCaption:
-                para_content['table_caption'] = merge_para_with_text(block, parse_type=parse_type, lang=lang)
+                para_content['table_caption'].append(merge_para_with_text(block))
             if block['type'] == BlockType.TableFootnote:
-                para_content['table_footnote'] = merge_para_with_text(block, parse_type=parse_type, lang=lang)
+                para_content['table_footnote'].append(merge_para_with_text(block))
 
     para_content['page_idx'] = page_idx
 
@@ -330,88 +212,11 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx, parse_type=
     return para_content
 
 
-def make_standard_format_with_para(pdf_info_dict: list, img_buket_path: str):
-    content_list = []
-    for page_info in pdf_info_dict:
-        paras_of_layout = page_info.get('para_blocks')
-        if not paras_of_layout:
-            continue
-        for para_block in paras_of_layout:
-            para_content = para_to_standard_format_v2(para_block,
-                                                      img_buket_path)
-            content_list.append(para_content)
-    return content_list
-
-
-def line_to_standard_format(line, img_buket_path):
-    line_text = ''
-    inline_equation_num = 0
-    for span in line['spans']:
-        if not span.get('content'):
-            if not span.get('image_path'):
-                continue
-            else:
-                if span['type'] == ContentType.Image:
-                    content = {
-                        'type': 'image',
-                        'img_path': join_path(img_buket_path,
-                                              span['image_path']),
-                    }
-                    return content
-                elif span['type'] == ContentType.Table:
-                    content = {
-                        'type': 'table',
-                        'img_path': join_path(img_buket_path,
-                                              span['image_path']),
-                    }
-                    return content
-        else:
-            if span['type'] == ContentType.InterlineEquation:
-                interline_equation = span['content']
-                content = {
-                    'type': 'equation',
-                    'latex': f'$$\n{interline_equation}\n$$'
-                }
-                return content
-            elif span['type'] == ContentType.InlineEquation:
-                inline_equation = span['content']
-                line_text += f'${inline_equation}$'
-                inline_equation_num += 1
-            elif span['type'] == ContentType.Text:
-                text_content = ocr_escape_special_markdown_char(
-                    span['content'])  # 转义特殊符号
-                line_text += text_content
-    content = {
-        'type': 'text',
-        'text': line_text,
-        'inline_equation_num': inline_equation_num,
-    }
-    return content
-
-
-def ocr_mk_mm_standard_format(pdf_info_dict: list):
-    """content_list type         string
-    image/text/table/equation(行间的单独拿出来,行内的和text合并) latex        string
-    latex文本字段。 text         string      纯文本格式的文本数据。 md           string
-    markdown格式的文本数据。 img_path     string      s3://full/path/to/img.jpg."""
-    content_list = []
-    for page_info in pdf_info_dict:
-        blocks = page_info.get('preproc_blocks')
-        if not blocks:
-            continue
-        for block in blocks:
-            for line in block['lines']:
-                content = line_to_standard_format(line)
-                content_list.append(content)
-    return content_list
-
-
 def union_make(pdf_info_dict: list,
                make_mode: str,
                drop_mode: str,
                img_buket_path: str = '',
-               parse_type: str = "auto",
-               lang=None):
+               ):
     output_content = []
     for page_info in pdf_info_dict:
         drop_reason_flag = False
@@ -438,20 +243,20 @@ def union_make(pdf_info_dict: list,
             continue
         if make_mode == MakeMode.MM_MD:
             page_markdown = ocr_mk_markdown_with_para_core_v2(
-                paras_of_layout, 'mm', img_buket_path, parse_type=parse_type, lang=lang)
+                paras_of_layout, 'mm', img_buket_path)
             output_content.extend(page_markdown)
         elif make_mode == MakeMode.NLP_MD:
             page_markdown = ocr_mk_markdown_with_para_core_v2(
-                paras_of_layout, 'nlp', parse_type=parse_type, lang=lang)
+                paras_of_layout, 'nlp')
             output_content.extend(page_markdown)
         elif make_mode == MakeMode.STANDARD_FORMAT:
             for para_block in paras_of_layout:
                 if drop_reason_flag:
                     para_content = para_to_standard_format_v2(
-                        para_block, img_buket_path, page_idx, parse_type=parse_type, lang=lang, drop_reason=drop_reason)
+                        para_block, img_buket_path, page_idx)
                 else:
                     para_content = para_to_standard_format_v2(
-                        para_block, img_buket_path, page_idx, parse_type=parse_type, lang=lang)
+                        para_block, img_buket_path, page_idx)
                 output_content.append(para_content)
     if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
         return '\n\n'.join(output_content)

+ 13 - 6
magic_pdf/libs/Constants.py

@@ -10,18 +10,12 @@ block维度自定义字段
 # block中lines是否被删除
 LINES_DELETED = "lines_deleted"
 
-# struct eqtable
-STRUCT_EQTABLE = "struct_eqtable"
-
 # table recognition max time default value
 TABLE_MAX_TIME_VALUE = 400
 
 # pp_table_result_max_length
 TABLE_MAX_LEN = 480
 
-# pp table structure algorithm
-TABLE_MASTER = "TableMaster"
-
 # table master structure dict
 TABLE_MASTER_DICT = "table_master_structure_dict.txt"
 
@@ -44,3 +38,16 @@ PP_REC_DIRECTORY = ".paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer"
 PP_DET_DIRECTORY = ".paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer"
 
 
+class MODEL_NAME:
+    # pp table structure algorithm
+    TABLE_MASTER = "tablemaster"
+    # struct eqtable
+    STRUCT_EQTABLE = "struct_eqtable"
+
+    DocLayout_YOLO = "doclayout_yolo"
+
+    LAYOUTLMv3 = "layoutlmv3"
+
+    YOLO_V8_MFD = "yolo_v8_mfd"
+
+    UniMerNet_v2_Small = "unimernet_small"

+ 35 - 0
magic_pdf/libs/boxbase.py

@@ -445,3 +445,38 @@ def get_overlap_area(bbox1, bbox2):
 
     # The area of overlap area
     return (x_right - x_left) * (y_bottom - y_top)
+
+
+def calculate_vertical_projection_overlap_ratio(block1, block2):
+    """
+    Calculate the proportion of the x-axis covered by the vertical projection of two blocks.
+
+    Args:
+        block1 (tuple): Coordinates of the first block (x0, y0, x1, y1).
+        block2 (tuple): Coordinates of the second block (x0, y0, x1, y1).
+
+    Returns:
+        float: The proportion of the x-axis covered by the vertical projection of the two blocks.
+    """
+    x0_1, _, x1_1, _ = block1
+    x0_2, _, x1_2, _ = block2
+
+    # Calculate the intersection of the x-coordinates
+    x_left = max(x0_1, x0_2)
+    x_right = min(x1_1, x1_2)
+
+    if x_right < x_left:
+        return 0.0
+
+    # Length of the intersection
+    intersection_length = x_right - x_left
+
+    # Length of the x-axis projection of the first block
+    block1_length = x1_1 - x0_1
+
+    if block1_length == 0:
+        return 0.0
+
+    # Proportion of the x-axis covered by the intersection
+    # logger.info(f"intersection_length: {intersection_length}, block1_length: {block1_length}")
+    return intersection_length / block1_length

+ 44 - 26
magic_pdf/libs/config_reader.py

@@ -1,46 +1,44 @@
-"""
-根据bucket的名字返回对应的s3 AK, SK,endpoint三元组
-
-"""
+"""根据bucket的名字返回对应的s3 AK, SK,endpoint三元组."""
 
 import json
 import os
 
 from loguru import logger
 
+from magic_pdf.libs.Constants import MODEL_NAME
 from magic_pdf.libs.commons import parse_bucket_key
 
 # 定义配置文件名常量
-CONFIG_FILE_NAME = "magic-pdf.json"
+CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'magic-pdf.json')
 
 
 def read_config():
-    home_dir = os.path.expanduser("~")
-
-    config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
+    if os.path.isabs(CONFIG_FILE_NAME):
+        config_file = CONFIG_FILE_NAME
+    else:
+        home_dir = os.path.expanduser('~')
+        config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
 
     if not os.path.exists(config_file):
-        raise FileNotFoundError(f"{config_file} not found")
+        raise FileNotFoundError(f'{config_file} not found')
 
-    with open(config_file, "r", encoding="utf-8") as f:
+    with open(config_file, 'r', encoding='utf-8') as f:
         config = json.load(f)
     return config
 
 
 def get_s3_config(bucket_name: str):
-    """
-    ~/magic-pdf.json 读出来
-    """
+    """~/magic-pdf.json 读出来."""
     config = read_config()
 
-    bucket_info = config.get("bucket_info")
+    bucket_info = config.get('bucket_info')
     if bucket_name not in bucket_info:
-        access_key, secret_key, storage_endpoint = bucket_info["[default]"]
+        access_key, secret_key, storage_endpoint = bucket_info['[default]']
     else:
         access_key, secret_key, storage_endpoint = bucket_info[bucket_name]
 
     if access_key is None or secret_key is None or storage_endpoint is None:
-        raise Exception(f"ak, sk or endpoint not found in {CONFIG_FILE_NAME}")
+        raise Exception(f'ak, sk or endpoint not found in {CONFIG_FILE_NAME}')
 
     # logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}")
 
@@ -49,7 +47,7 @@ def get_s3_config(bucket_name: str):
 
 def get_s3_config_dict(path: str):
     access_key, secret_key, storage_endpoint = get_s3_config(get_bucket_name(path))
-    return {"ak": access_key, "sk": secret_key, "endpoint": storage_endpoint}
+    return {'ak': access_key, 'sk': secret_key, 'endpoint': storage_endpoint}
 
 
 def get_bucket_name(path):
@@ -59,20 +57,20 @@ def get_bucket_name(path):
 
 def get_local_models_dir():
     config = read_config()
-    models_dir = config.get("models-dir")
+    models_dir = config.get('models-dir')
     if models_dir is None:
         logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use '/tmp/models' as default")
-        return "/tmp/models"
+        return '/tmp/models'
     else:
         return models_dir
 
 
 def get_local_layoutreader_model_dir():
     config = read_config()
-    layoutreader_model_dir = config.get("layoutreader-model-dir")
+    layoutreader_model_dir = config.get('layoutreader-model-dir')
     if layoutreader_model_dir is None or not os.path.exists(layoutreader_model_dir):
-        home_dir = os.path.expanduser("~")
-        layoutreader_at_modelscope_dir_path = os.path.join(home_dir, ".cache/modelscope/hub/ppaanngggg/layoutreader")
+        home_dir = os.path.expanduser('~')
+        layoutreader_at_modelscope_dir_path = os.path.join(home_dir, '.cache/modelscope/hub/ppaanngggg/layoutreader')
         logger.warning(f"'layoutreader-model-dir' not exists, use {layoutreader_at_modelscope_dir_path} as default")
         return layoutreader_at_modelscope_dir_path
     else:
@@ -81,23 +79,43 @@ def get_local_layoutreader_model_dir():
 
 def get_device():
     config = read_config()
-    device = config.get("device-mode")
+    device = config.get('device-mode')
     if device is None:
         logger.warning(f"'device-mode' not found in {CONFIG_FILE_NAME}, use 'cpu' as default")
-        return "cpu"
+        return 'cpu'
     else:
         return device
 
 
 def get_table_recog_config():
     config = read_config()
-    table_config = config.get("table-config")
+    table_config = config.get('table-config')
     if table_config is None:
         logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
-        return json.loads('{"is_table_recog_enable": false, "max_time": 400}')
+        return json.loads(f'{{"model": "{MODEL_NAME.TABLE_MASTER}","enable": false, "max_time": 400}}')
     else:
         return table_config
 
 
+def get_layout_config():
+    config = read_config()
+    layout_config = config.get("layout-config")
+    if layout_config is None:
+        logger.warning(f"'layout-config' not found in {CONFIG_FILE_NAME}, use '{MODEL_NAME.LAYOUTLMv3}' as default")
+        return json.loads(f'{{"model": "{MODEL_NAME.LAYOUTLMv3}"}}')
+    else:
+        return layout_config
+
+
+def get_formula_config():
+    config = read_config()
+    formula_config = config.get("formula-config")
+    if formula_config is None:
+        logger.warning(f"'formula-config' not found in {CONFIG_FILE_NAME}, use 'True' as default")
+        return json.loads(f'{{"mfd_model": "{MODEL_NAME.YOLO_V8_MFD}","mfr_model": "{MODEL_NAME.UniMerNet_v2_Small}","enable": true}}')
+    else:
+        return formula_config
+
+
 if __name__ == "__main__":
     ak, sk, endpoint = get_s3_config("llm-raw")

+ 65 - 46
magic_pdf/libs/draw_bbox.py

@@ -1,3 +1,4 @@
+from magic_pdf.data.dataset import PymuDocDataset
 from magic_pdf.libs.commons import fitz  # PyMuPDF
 from magic_pdf.libs.Constants import CROSS_PAGE
 from magic_pdf.libs.ocr_content_type import BlockType, CategoryId, ContentType
@@ -62,7 +63,7 @@ def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config, draw_bbox
                     overlay=True,
                 )  # Draw the rectangle
         page.insert_text(
-            (x1+2, y0 + 10), str(j + 1), fontsize=10, color=new_rgb
+            (x1 + 2, y0 + 10), str(j + 1), fontsize=10, color=new_rgb
         )  # Insert the index in the top left corner of the rectangle
 
 
@@ -86,7 +87,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
         texts = []
         interequations = []
         lists = []
-        indexs = []
+        indices = []
 
         for dropped_bbox in page['discarded_blocks']:
             page_dropped_list.append(dropped_bbox['bbox'])
@@ -122,7 +123,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
             elif block['type'] == BlockType.List:
                 lists.append(bbox)
             elif block['type'] == BlockType.Index:
-                indexs.append(bbox)
+                indices.append(bbox)
 
         tables_list.append(tables)
         tables_body_list.append(tables_body)
@@ -136,45 +137,61 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
         texts_list.append(texts)
         interequations_list.append(interequations)
         lists_list.append(lists)
-        indexs_list.append(indexs)
+        indexs_list.append(indices)
 
     layout_bbox_list = []
 
+    table_type_order = {
+        'table_caption': 1,
+        'table_body': 2,
+        'table_footnote': 3
+    }
     for page in pdf_info:
         page_block_list = []
         for block in page['para_blocks']:
-            bbox = block['bbox']
-            page_block_list.append(bbox)
+            if block['type'] in [
+                BlockType.Text,
+                BlockType.Title,
+                BlockType.InterlineEquation,
+                BlockType.List,
+                BlockType.Index,
+            ]:
+                bbox = block['bbox']
+                page_block_list.append(bbox)
+            elif block['type'] in [BlockType.Image]:
+                for sub_block in block['blocks']:
+                    bbox = sub_block['bbox']
+                    page_block_list.append(bbox)
+            elif block['type'] in [BlockType.Table]:
+                sorted_blocks = sorted(block['blocks'], key=lambda x: table_type_order[x['type']])
+                for sub_block in sorted_blocks:
+                    bbox = sub_block['bbox']
+                    page_block_list.append(bbox)
+
         layout_bbox_list.append(page_block_list)
 
     pdf_docs = fitz.open('pdf', pdf_bytes)
 
     for i, page in enumerate(pdf_docs):
 
-        draw_bbox_without_number(i, dropped_bbox_list, page, [158, 158, 158],
-                                 True)
-        draw_bbox_without_number(i, tables_list, page, [153, 153, 0],
-                                 True)  # color !
-        draw_bbox_without_number(i, tables_body_list, page, [204, 204, 0],
-                                 True)
-        draw_bbox_without_number(i, tables_caption_list, page, [255, 255, 102],
-                                 True)
-        draw_bbox_without_number(i, tables_footnote_list, page,
-                                 [229, 255, 204], True)
-        draw_bbox_without_number(i, imgs_list, page, [51, 102, 0], True)
+        draw_bbox_without_number(i, dropped_bbox_list, page, [158, 158, 158], True)
+        # draw_bbox_without_number(i, tables_list, page, [153, 153, 0], True)  # color !
+        draw_bbox_without_number(i, tables_body_list, page, [204, 204, 0], True)
+        draw_bbox_without_number(i, tables_caption_list, page, [255, 255, 102], True)
+        draw_bbox_without_number(i, tables_footnote_list, page, [229, 255, 204], True)
+        # draw_bbox_without_number(i, imgs_list, page, [51, 102, 0], True)
         draw_bbox_without_number(i, imgs_body_list, page, [153, 255, 51], True)
-        draw_bbox_without_number(i, imgs_caption_list, page, [102, 178, 255],
-                                 True)
-        draw_bbox_without_number(i, imgs_footnote_list, page, [255, 178, 102],
-                              True),
+        draw_bbox_without_number(i, imgs_caption_list, page, [102, 178, 255], True)
+        draw_bbox_without_number(i, imgs_footnote_list, page, [255, 178, 102], True),
         draw_bbox_without_number(i, titles_list, page, [102, 102, 255], True)
         draw_bbox_without_number(i, texts_list, page, [153, 0, 76], True)
-        draw_bbox_without_number(i, interequations_list, page, [0, 255, 0],
-                                 True)
+        draw_bbox_without_number(i, interequations_list, page, [0, 255, 0], True)
         draw_bbox_without_number(i, lists_list, page, [40, 169, 92], True)
         draw_bbox_without_number(i, indexs_list, page, [40, 169, 92], True)
 
-        draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False, draw_bbox=False)
+        draw_bbox_with_number(
+            i, layout_bbox_list, page, [255, 0, 0], False, draw_bbox=False
+        )
 
     # Save the PDF
     pdf_docs.save(f'{out_path}/{filename}_layout.pdf')
@@ -237,6 +254,8 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
                 BlockType.Text,
                 BlockType.Title,
                 BlockType.InterlineEquation,
+                BlockType.List,
+                BlockType.Index,
             ]:
                 for line in block['lines']:
                     for span in line['spans']:
@@ -273,7 +292,7 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
     texts_list = []
     interequations_list = []
     pdf_docs = fitz.open('pdf', pdf_bytes)
-    magic_model = MagicModel(model_list, pdf_docs)
+    magic_model = MagicModel(model_list, PymuDocDataset(pdf_bytes))
     for i in range(len(model_list)):
         page_dropped_list = []
         tables_body, tables_caption, tables_footnote = [], [], []
@@ -299,8 +318,7 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
                 imgs_body.append(bbox)
             elif layout_det['category_id'] == CategoryId.ImageCaption:
                 imgs_caption.append(bbox)
-            elif layout_det[
-                'category_id'] == CategoryId.InterlineEquation_YOLO:
+            elif layout_det['category_id'] == CategoryId.InterlineEquation_YOLO:
                 interequations.append(bbox)
             elif layout_det['category_id'] == CategoryId.Abandon:
                 page_dropped_list.append(bbox)
@@ -319,18 +337,15 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
         imgs_footnote_list.append(imgs_footnote)
 
     for i, page in enumerate(pdf_docs):
-        draw_bbox_with_number(i, dropped_bbox_list, page, [158, 158, 158],
-                              True)  # color !
+        draw_bbox_with_number(
+            i, dropped_bbox_list, page, [158, 158, 158], True
+        )  # color !
         draw_bbox_with_number(i, tables_body_list, page, [204, 204, 0], True)
-        draw_bbox_with_number(i, tables_caption_list, page, [255, 255, 102],
-                              True)
-        draw_bbox_with_number(i, tables_footnote_list, page, [229, 255, 204],
-                              True)
+        draw_bbox_with_number(i, tables_caption_list, page, [255, 255, 102], True)
+        draw_bbox_with_number(i, tables_footnote_list, page, [229, 255, 204], True)
         draw_bbox_with_number(i, imgs_body_list, page, [153, 255, 51], True)
-        draw_bbox_with_number(i, imgs_caption_list, page, [102, 178, 255],
-                              True)
-        draw_bbox_with_number(i, imgs_footnote_list, page, [255, 178, 102],
-                              True)
+        draw_bbox_with_number(i, imgs_caption_list, page, [102, 178, 255], True)
+        draw_bbox_with_number(i, imgs_footnote_list, page, [255, 178, 102], True)
         draw_bbox_with_number(i, titles_list, page, [102, 102, 255], True)
         draw_bbox_with_number(i, texts_list, page, [153, 0, 76], True)
         draw_bbox_with_number(i, interequations_list, page, [0, 255, 0], True)
@@ -345,19 +360,23 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
     for page in pdf_info:
         page_line_list = []
         for block in page['preproc_blocks']:
-            if block['type'] in ['text', 'title', 'interline_equation']:
+            if block['type'] in [BlockType.Text, BlockType.Title, BlockType.InterlineEquation]:
                 for line in block['lines']:
                     bbox = line['bbox']
                     index = line['index']
                     page_line_list.append({'index': index, 'bbox': bbox})
-            if block['type'] in ['table', 'image']:
-                bbox = block['bbox']
-                index = block['index']
-                page_line_list.append({'index': index, 'bbox': bbox})
-            # for line in block['lines']:
-            #     bbox = line['bbox']
-            #     index = line['index']
-            #     page_line_list.append({'index': index, 'bbox': bbox})
+            if block['type'] in [BlockType.Image, BlockType.Table]:
+                for sub_block in block['blocks']:
+                    if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
+                        for line in sub_block['virtual_lines']:
+                            bbox = line['bbox']
+                            index = line['index']
+                            page_line_list.append({'index': index, 'bbox': bbox})
+                    elif sub_block['type'] in [BlockType.ImageCaption, BlockType.TableCaption, BlockType.ImageFootnote, BlockType.TableFootnote]:
+                        for line in sub_block['lines']:
+                            bbox = line['bbox']
+                            index = line['index']
+                            page_line_list.append({'index': index, 'bbox': bbox})
         sorted_bboxes = sorted(page_line_list, key=lambda x: x['index'])
         layout_bbox_list.append(sorted_bbox['bbox'] for sorted_bbox in sorted_bboxes)
     pdf_docs = fitz.open('pdf', pdf_bytes)

+ 38 - 14
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -5,7 +5,8 @@ import numpy as np
 from loguru import logger
 
 from magic_pdf.libs.clean_memory import clean_memory
-from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config
+from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config, get_layout_config, \
+    get_formula_config
 from magic_pdf.model.model_list import MODEL
 import magic_pdf.model as model_config
 
@@ -68,14 +69,17 @@ class ModelSingleton:
             cls._instance = super().__new__(cls)
         return cls._instance
 
-    def get_model(self, ocr: bool, show_log: bool, lang=None):
-        key = (ocr, show_log, lang)
+    def get_model(self, ocr: bool, show_log: bool, lang=None, layout_model=None, formula_enable=None, table_enable=None):
+        key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
         if key not in self._models:
-            self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang)
+            self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang, layout_model=layout_model,
+                                                  formula_enable=formula_enable, table_enable=table_enable)
         return self._models[key]
 
 
-def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
+def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None,
+                      layout_model=None, formula_enable=None, table_enable=None):
+
     model = None
 
     if model_config.__model_mode__ == "lite":
@@ -95,14 +99,30 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
             # 从配置文件读取model-dir和device
             local_models_dir = get_local_models_dir()
             device = get_device()
+
+            layout_config = get_layout_config()
+            if layout_model is not None:
+                layout_config["model"] = layout_model
+
+            formula_config = get_formula_config()
+            if formula_enable is not None:
+                formula_config["enable"] = formula_enable
+
             table_config = get_table_recog_config()
-            model_input = {"ocr": ocr,
-                           "show_log": show_log,
-                           "models_dir": local_models_dir,
-                           "device": device,
-                           "table_config": table_config,
-                           "lang": lang,
-                           }
+            if table_enable is not None:
+                table_config["enable"] = table_enable
+
+            model_input = {
+                            "ocr": ocr,
+                            "show_log": show_log,
+                            "models_dir": local_models_dir,
+                            "device": device,
+                            "table_config": table_config,
+                            "layout_config": layout_config,
+                            "formula_config": formula_config,
+                            "lang": lang,
+            }
+
             custom_model = CustomPEKModel(**model_input)
         else:
             logger.error("Not allow model_name!")
@@ -117,10 +137,14 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
 
 
 def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
-                start_page_id=0, end_page_id=None, lang=None):
+                start_page_id=0, end_page_id=None, lang=None,
+                layout_model=None, formula_enable=None, table_enable=None):
+
+    if lang == "":
+        lang = None
 
     model_manager = ModelSingleton()
-    custom_model = model_manager.get_model(ocr, show_log, lang)
+    custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable)
 
     with fitz.open("pdf", pdf_bytes) as doc:
         pdf_page_num = doc.page_count

+ 259 - 14
magic_pdf/model/magic_model.py

@@ -1,5 +1,6 @@
 import json
 
+from magic_pdf.data.dataset import Dataset
 from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
                                     bbox_relative_pos, box_area, calculate_iou,
                                     calculate_overlap_area_in_bbox1_area_ratio,
@@ -9,6 +10,7 @@ from magic_pdf.libs.coordinate_transform import get_scale_ratio
 from magic_pdf.libs.local_math import float_gt
 from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
 from magic_pdf.libs.ocr_content_type import CategoryId, ContentType
+from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox
 from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
 from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
 
@@ -24,7 +26,7 @@ class MagicModel:
             need_remove_list = []
             page_no = model_page_info['page_info']['page_no']
             horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio(
-                model_page_info, self.__docs[page_no]
+                model_page_info, self.__docs.get_page(page_no)
             )
             layout_dets = model_page_info['layout_dets']
             for layout_det in layout_dets:
@@ -99,7 +101,7 @@ class MagicModel:
             for need_remove in need_remove_list:
                 layout_dets.remove(need_remove)
 
-    def __init__(self, model_list: list, docs: fitz.Document):
+    def __init__(self, model_list: list, docs: Dataset):
         self.__model_list = model_list
         self.__docs = docs
         """为所有模型数据添加bbox信息(缩放,poly->bbox)"""
@@ -123,7 +125,7 @@ class MagicModel:
             l1 = bbox1[2] - bbox1[0]
             l2 = bbox2[2] - bbox2[0]
 
-        if l2 > l1 and (l2 - l1) / l1 > 0.5:
+        if l2 > l1 and (l2 - l1) / l1 > 0.3:
             return float('inf')
 
         return bbox_distance(bbox1, bbox2)
@@ -213,9 +215,8 @@ class MagicModel:
         筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
         再求出筛选出的 subjects 和 object 的最短距离
         """
-        def search_overlap_between_boxes(
-            subject_idx, object_idx
-        ):
+
+        def search_overlap_between_boxes(subject_idx, object_idx):
             idxes = [subject_idx, object_idx]
             x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes]
             y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes]
@@ -243,9 +244,9 @@ class MagicModel:
             for other_object in other_objects:
                 ratio = max(
                     ratio,
-                    get_overlap_area(
-                        merged_bbox, other_object['bbox']
-                    ) * 1.0 / box_area(all_bboxes[object_idx]['bbox'])
+                    get_overlap_area(merged_bbox, other_object['bbox'])
+                    * 1.0
+                    / box_area(all_bboxes[object_idx]['bbox']),
                 )
                 if ratio >= MERGE_BOX_OVERLAP_AREA_RATIO:
                     break
@@ -363,12 +364,17 @@ class MagicModel:
                 if all_bboxes[j]['category_id'] == subject_category_id:
                     subject_idx, object_idx = j, i
 
-                if search_overlap_between_boxes(subject_idx, object_idx) >= MERGE_BOX_OVERLAP_AREA_RATIO:
+                if (
+                    search_overlap_between_boxes(subject_idx, object_idx)
+                    >= MERGE_BOX_OVERLAP_AREA_RATIO
+                ):
                     dis[i][j] = float('inf')
                     dis[j][i] = dis[i][j]
                     continue
 
-                dis[i][j] = self._bbox_distance(all_bboxes[subject_idx]['bbox'], all_bboxes[object_idx]['bbox'])
+                dis[i][j] = self._bbox_distance(
+                    all_bboxes[subject_idx]['bbox'], all_bboxes[object_idx]['bbox']
+                )
                 dis[j][i] = dis[i][j]
 
         used = set()
@@ -584,6 +590,245 @@ class MagicModel:
                 with_caption_subject.add(j)
         return ret, total_subject_object_dis
 
+    def __tie_up_category_by_distance_v2(
+        self, page_no, subject_category_id, object_category_id
+    ):
+
+        AXIS_MULPLICITY = 0.5
+        subjects = self.__reduct_overlap(
+            list(
+                map(
+                    lambda x: {'bbox': x['bbox'], 'score': x['score']},
+                    filter(
+                        lambda x: x['category_id'] == subject_category_id,
+                        self.__model_list[page_no]['layout_dets'],
+                    ),
+                )
+            )
+        )
+
+        objects = self.__reduct_overlap(
+            list(
+                map(
+                    lambda x: {'bbox': x['bbox'], 'score': x['score']},
+                    filter(
+                        lambda x: x['category_id'] == object_category_id,
+                        self.__model_list[page_no]['layout_dets'],
+                    ),
+                )
+            )
+        )
+        M = len(objects)
+
+        subjects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
+        objects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
+
+        sub_obj_map_h = {i: [] for i in range(len(subjects))}
+
+        dis_by_directions = {
+            'top': [[-1, float('inf')]] * M,
+            'bottom': [[-1, float('inf')]] * M,
+            'left': [[-1, float('inf')]] * M,
+            'right': [[-1, float('inf')]] * M,
+        }
+
+        for i, obj in enumerate(objects):
+            l_x_axis, l_y_axis = (
+                obj['bbox'][2] - obj['bbox'][0],
+                obj['bbox'][3] - obj['bbox'][1],
+            )
+            axis_unit = min(l_x_axis, l_y_axis)
+            for j, sub in enumerate(subjects):
+
+                bbox1, bbox2, _ = _remove_overlap_between_bbox(
+                    objects[i]['bbox'], subjects[j]['bbox']
+                )
+                left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
+                flags = [left, right, bottom, top]
+                if sum([1 if v else 0 for v in flags]) > 1:
+                    continue
+
+                if left:
+                    if dis_by_directions['left'][i][1] > bbox_distance(
+                        obj['bbox'], sub['bbox']
+                    ):
+                        dis_by_directions['left'][i] = [
+                            j,
+                            bbox_distance(obj['bbox'], sub['bbox']),
+                        ]
+                if right:
+                    if dis_by_directions['right'][i][1] > bbox_distance(
+                        obj['bbox'], sub['bbox']
+                    ):
+                        dis_by_directions['right'][i] = [
+                            j,
+                            bbox_distance(obj['bbox'], sub['bbox']),
+                        ]
+                if bottom:
+                    if dis_by_directions['bottom'][i][1] > bbox_distance(
+                        obj['bbox'], sub['bbox']
+                    ):
+                        dis_by_directions['bottom'][i] = [
+                            j,
+                            bbox_distance(obj['bbox'], sub['bbox']),
+                        ]
+                if top:
+                    if dis_by_directions['top'][i][1] > bbox_distance(
+                        obj['bbox'], sub['bbox']
+                    ):
+                        dis_by_directions['top'][i] = [
+                            j,
+                            bbox_distance(obj['bbox'], sub['bbox']),
+                        ]
+            if dis_by_directions['left'][i][1] != float('inf') or dis_by_directions[
+                'right'
+            ][i][1] != float('inf'):
+                if dis_by_directions['left'][i][1] != float(
+                    'inf'
+                ) and dis_by_directions['right'][i][1] != float('inf'):
+                    if AXIS_MULPLICITY * axis_unit >= abs(
+                        dis_by_directions['left'][i][1]
+                        - dis_by_directions['right'][i][1]
+                    ):
+                        left_sub_bbox = subjects[dis_by_directions['left'][i][0]][
+                            'bbox'
+                        ]
+                        right_sub_bbox = subjects[dis_by_directions['right'][i][0]][
+                            'bbox'
+                        ]
+
+                        left_sub_bbox_y_axis = left_sub_bbox[3] - left_sub_bbox[1]
+                        right_sub_bbox_y_axis = right_sub_bbox[3] - right_sub_bbox[1]
+
+                        if (
+                            abs(left_sub_bbox_y_axis - l_y_axis)
+                            + dis_by_directions['left'][i][0]
+                            > abs(right_sub_bbox_y_axis - l_y_axis)
+                            + dis_by_directions['right'][i][0]
+                        ):
+                            left_or_right = dis_by_directions['right'][i]
+                        else:
+                            left_or_right = dis_by_directions['left'][i]
+                    else:
+                        left_or_right = dis_by_directions['left'][i]
+                        if left_or_right[1] > dis_by_directions['right'][i][1]:
+                            left_or_right = dis_by_directions['right'][i]
+                else:
+                    left_or_right = dis_by_directions['left'][i]
+                    if left_or_right[1] == float('inf'):
+                        left_or_right = dis_by_directions['right'][i]
+            else:
+                left_or_right = [-1, float('inf')]
+
+            if dis_by_directions['top'][i][1] != float('inf') or dis_by_directions[
+                'bottom'
+            ][i][1] != float('inf'):
+                if dis_by_directions['top'][i][1] != float('inf') and dis_by_directions[
+                    'bottom'
+                ][i][1] != float('inf'):
+                    if AXIS_MULPLICITY * axis_unit >= abs(
+                        dis_by_directions['top'][i][1]
+                        - dis_by_directions['bottom'][i][1]
+                    ):
+                        top_bottom = subjects[dis_by_directions['bottom'][i][0]]['bbox']
+                        bottom_top = subjects[dis_by_directions['top'][i][0]]['bbox']
+
+                        top_bottom_x_axis = top_bottom[2] - top_bottom[0]
+                        bottom_top_x_axis = bottom_top[2] - bottom_top[0]
+                        if abs(top_bottom_x_axis - l_x_axis) + dis_by_directions['bottom'][i][1] > abs(
+                            bottom_top_x_axis - l_x_axis
+                        ) + dis_by_directions['top'][i][1]:
+                            top_or_bottom = dis_by_directions['top'][i]
+                        else:
+                            top_or_bottom = dis_by_directions['bottom'][i]
+                    else:
+                        top_or_bottom = dis_by_directions['top'][i]
+                        if top_or_bottom[1] > dis_by_directions['bottom'][i][1]:
+                            top_or_bottom = dis_by_directions['bottom'][i]
+                else:
+                    top_or_bottom = dis_by_directions['top'][i]
+                    if top_or_bottom[1] == float('inf'):
+                        top_or_bottom = dis_by_directions['bottom'][i]
+            else:
+                top_or_bottom = [-1, float('inf')]
+
+            if left_or_right[1] != float('inf') or top_or_bottom[1] != float('inf'):
+                if left_or_right[1] != float('inf') and top_or_bottom[1] != float(
+                    'inf'
+                ):
+                    if AXIS_MULPLICITY * axis_unit >= abs(
+                        left_or_right[1] - top_or_bottom[1]
+                    ):
+                        y_axis_bbox = subjects[left_or_right[0]]['bbox']
+                        x_axis_bbox = subjects[top_or_bottom[0]]['bbox']
+
+                        if (
+                            abs((x_axis_bbox[2] - x_axis_bbox[0]) - l_x_axis) / l_x_axis
+                            > abs((y_axis_bbox[3] - y_axis_bbox[1]) - l_y_axis)
+                            / l_y_axis
+                        ):
+                            sub_obj_map_h[left_or_right[0]].append(i)
+                        else:
+                            sub_obj_map_h[top_or_bottom[0]].append(i)
+                    else:
+                        if left_or_right[1] > top_or_bottom[1]:
+                            sub_obj_map_h[top_or_bottom[0]].append(i)
+                        else:
+                            sub_obj_map_h[left_or_right[0]].append(i)
+                else:
+                    if left_or_right[1] != float('inf'):
+                        sub_obj_map_h[left_or_right[0]].append(i)
+                    else:
+                        sub_obj_map_h[top_or_bottom[0]].append(i)
+        ret = []
+        for i in sub_obj_map_h.keys():
+            ret.append(
+                {
+                    'sub_bbox': {
+                        'bbox': subjects[i]['bbox'],
+                        'score': subjects[i]['score'],
+                    },
+                    'obj_bboxes': [
+                        {'score': objects[j]['score'], 'bbox': objects[j]['bbox']}
+                        for j in sub_obj_map_h[i]
+                    ],
+                    'sub_idx': i,
+                }
+            )
+        return ret
+
+    def get_imgs_v2(self, page_no: int):
+        with_captions = self.__tie_up_category_by_distance_v2(page_no, 3, 4)
+        with_footnotes = self.__tie_up_category_by_distance_v2(
+            page_no, 3, CategoryId.ImageFootnote
+        )
+        ret = []
+        for v in with_captions:
+            record = {
+                'image_body': v['sub_bbox'],
+                'image_caption_list': v['obj_bboxes'],
+            }
+            filter_idx = v['sub_idx']
+            d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
+            record['image_footnote_list'] = d['obj_bboxes']
+            ret.append(record)
+        return ret
+
+    def get_tables_v2(self, page_no: int) -> list:
+        with_captions = self.__tie_up_category_by_distance_v2(page_no, 5, 6)
+        with_footnotes = self.__tie_up_category_by_distance_v2(page_no, 5, 7)
+        ret = []
+        for v in with_captions:
+            record = {
+                'table_body': v['sub_bbox'],
+                'table_caption_list': v['obj_bboxes'],
+            }
+            filter_idx = v['sub_idx']
+            d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
+            record['table_footnote_list'] = d['obj_bboxes']
+            ret.append(record)
+        return ret
+
     def get_imgs(self, page_no: int):
         with_captions, _ = self.__tie_up_category_by_distance(page_no, 3, 4)
         with_footnotes, _ = self.__tie_up_category_by_distance(
@@ -717,10 +962,10 @@ class MagicModel:
 
     def get_page_size(self, page_no: int):  # 获取页面宽高
         # 获取当前页的page对象
-        page = self.__docs[page_no]
+        page = self.__docs.get_page(page_no).get_page_info()
         # 获取当前页的宽高
-        page_w = page.rect.width
-        page_h = page.rect.height
+        page_w = page.w
+        page_h = page.h
         return page_w, page_h
 
     def __get_blocks_by_type(

+ 899 - 0
magic_pdf/model/mfr_cudagraph.py

@@ -0,0 +1,899 @@
+from typing import Optional, Tuple, Union
+import torch
+from torch import nn
+import os
+from unimernet.common.config import Config
+import unimernet.tasks as tasks
+import argparse
+from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
+from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
+
+class PatchedMBartLearnedPositionalEmbedding(nn.Module):
+
+    def __init__(self, origin: nn.Module):
+        super().__init__()
+        self.offset = origin.offset
+        self.embedding = nn.Embedding(origin.num_embeddings, origin.embedding_dim)
+        self.embedding.weight.data = origin.weight.data
+
+    def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
+        """`input_ids' shape is expected to be [bsz x seqlen]."""
+
+        bsz, seq_len = input_ids.shape[:2]
+        positions = torch.arange(0, seq_len, dtype=torch.long, device=self.embedding.weight.device
+        )
+        positions += past_key_values_length
+        positions = positions.expand(bsz, -1)
+
+        return self.embedding(positions + self.offset)
+
+
+class PatchedMBartDecoder(nn.Module):
+    def __init__(self, origin: nn.Module, kvlen: torch.LongTensor):
+        super().__init__()
+        self.origin = origin
+        self.kvlen = kvlen
+
+        self.config = origin.config
+        self.embed_tokens = origin.embed_tokens
+        self.embed_scale = origin.embed_scale
+        self._use_flash_attention_2 = origin._use_flash_attention_2
+        self.embed_positions = origin.embed_positions
+        self.counting_context_weight = getattr(origin, 'counting_context_weight', None)
+        self.layernorm_embedding = origin.layernorm_embedding
+        self.layers = origin.layers
+        self.layer_norm = origin.layer_norm
+
+        self.patched_embed_positions = PatchedMBartLearnedPositionalEmbedding(self.embed_positions)
+
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        count_pred: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        cross_attn_head_mask: Optional[torch.Tensor] = None,
+        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+        run_origin = False
+        if past_key_values is None:
+            run_origin = True
+        elif past_key_values[0][0].size(-2) < attention_mask.size(-1):
+            run_origin = True
+
+        if run_origin:
+            return self.origin(
+                input_ids=input_ids,
+                attention_mask=attention_mask,
+                count_pred=count_pred,
+                encoder_hidden_states=encoder_hidden_states,
+                encoder_attention_mask=encoder_attention_mask,
+                head_mask=head_mask,
+                cross_attn_head_mask=cross_attn_head_mask,
+                past_key_values=past_key_values,
+                inputs_embeds=inputs_embeds,
+                use_cache=use_cache,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                return_dict=return_dict,
+            )
+
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # retrieve input_ids and inputs_embeds
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+        elif input_ids is not None:
+            input = input_ids
+            input_shape = input.size()
+            input_ids = input_ids.view(-1, input_shape[-1])
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+            input = inputs_embeds[:, :, -1]
+        else:
+            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+        # past_key_values_length
+        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+        if self._use_flash_attention_2:
+            # 2d mask is passed through the layers
+            attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+        else:
+            # 4d mask is passed through the layers
+            attention_mask = _prepare_4d_causal_attention_mask(
+                attention_mask, input_shape, inputs_embeds, past_key_values_length
+            )
+
+        # expand encoder attention mask
+        if encoder_hidden_states is not None and encoder_attention_mask is not None:
+            if self._use_flash_attention_2:
+                encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+            else:
+                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+                encoder_attention_mask = _prepare_4d_attention_mask(
+                    encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+                )
+
+        # embed positions
+        positions = self.patched_embed_positions(input, self.kvlen)
+
+        hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
+
+        # TODO: add counting context weight to hidden_states
+        if count_pred is not None:
+            count_context_weight = self.counting_context_weight(count_pred)
+            hidden_states = hidden_states + 0.5 * count_context_weight.unsqueeze(1)
+        hidden_states = self.layernorm_embedding(hidden_states)
+
+        # decoder layers
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+        next_decoder_cache = () if use_cache else None
+
+        # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
+        for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
+            if attn_mask is not None:
+                if attn_mask.size()[0] != len(self.layers):
+                    raise ValueError(
+                        f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+                        f" {attn_mask.size()[0]}."
+                    )
+        for idx, decoder_layer in enumerate(self.layers):
+            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+
+            past_key_value = past_key_values[idx] if past_key_values is not None else None
+            layer_outputs = decoder_layer(
+                hidden_states,
+                attention_mask=attention_mask,
+                encoder_hidden_states=encoder_hidden_states,
+                encoder_attention_mask=encoder_attention_mask,
+                layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+                cross_attn_layer_head_mask=(
+                    cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
+                ),
+                past_key_value=past_key_value,
+                output_attentions=output_attentions,
+                use_cache=use_cache,
+            )
+            hidden_states = layer_outputs[0]
+
+            if use_cache:
+                next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
+
+            if output_attentions:
+                all_self_attns += (layer_outputs[1],)
+
+                if encoder_hidden_states is not None:
+                    all_cross_attentions += (layer_outputs[2],)
+
+        hidden_states = self.layer_norm(hidden_states)
+
+        # add hidden states from the last decoder layer
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        next_cache = next_decoder_cache if use_cache else None
+        if not return_dict:
+            return tuple(
+                v
+                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
+                if v is not None
+            )
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=next_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attns,
+            cross_attentions=all_cross_attentions,
+        )
+
+
+class PatchedMBartAttention(nn.Module):
+
+    def __init__(self, origin: nn.Module, kvlen: torch.LongTensor):
+        super().__init__()
+        self.embed_dim = origin.embed_dim
+        self.num_heads = origin.num_heads
+        self.dropout = origin.dropout
+        self.head_dim = origin.head_dim
+        self.config = origin.config
+
+        self.scaling = origin.scaling
+        self.is_decoder = origin.is_decoder
+        self.is_causal = origin.is_causal
+
+        self.k_proj = origin.k_proj
+        self.v_proj = origin.v_proj
+        self.q_proj = origin.q_proj
+        self.out_proj = origin.out_proj
+        self.kvlen = kvlen
+
+    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        key_value_states: Optional[torch.Tensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        layer_head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        """Input shape: Batch x Time x Channel"""
+
+        # if key_value_states are provided this layer is used as a cross-attention layer
+        # for the decoder
+        is_cross_attention = key_value_states is not None
+
+        bsz, tgt_len, _ = hidden_states.size()
+
+        # get query proj
+        query_states = self.q_proj(hidden_states) * self.scaling
+        # get key, value proj
+        # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+        # is checking that the `sequence_length` of the `past_key_value` is the same as
+        # the provided `key_value_states` to support prefix tuning
+        if (
+            is_cross_attention
+            and past_key_value is not None
+            and past_key_value[0].shape[2] == key_value_states.shape[1]
+        ):
+            # reuse k,v, cross_attentions
+            key_states = past_key_value[0]
+            value_states = past_key_value[1]
+        elif is_cross_attention:
+            # cross_attentions
+            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+            value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+        elif past_key_value is not None:
+            # reuse k, v, self_attention
+            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+            if past_key_value[0].size(-2) < attention_mask.size(-1):
+                key_states = torch.cat([past_key_value[0], key_states], dim=2)
+                value_states = torch.cat([past_key_value[1], value_states], dim=2)
+            else:
+                past_key_value[0][:, :, self.kvlen[None]] = key_states
+                past_key_value[1][:, :, self.kvlen[None]] = value_states
+                key_states = past_key_value[0]
+                value_states = past_key_value[1]
+        else:
+            # self_attention
+            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+        if self.is_decoder:
+            past_key_value = (key_states, value_states)
+
+        proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+        key_states = key_states.reshape(*proj_shape)
+        value_states = value_states.reshape(*proj_shape)
+
+        src_len = key_states.size(1)
+        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+            raise ValueError(
+                f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+                f" {attn_weights.size()}"
+            )
+
+        if attention_mask is not None:
+            if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+                raise ValueError(
+                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+                )
+            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+        if layer_head_mask is not None:
+            if layer_head_mask.size() != (self.num_heads,):
+                raise ValueError(
+                    f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+                    f" {layer_head_mask.size()}"
+                )
+            attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+        if output_attentions:
+            # this operation is a bit awkward, but it's required to
+            # make sure that attn_weights keeps its gradient.
+            # In order to do so, attn_weights have to be reshaped
+            # twice and have to be reused in the following
+            attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+            attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+        else:
+            attn_weights_reshaped = None
+
+        attn_probs = attn_weights
+
+        attn_output = torch.bmm(attn_probs, value_states)
+
+        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
+                f" {attn_output.size()}"
+            )
+
+        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+        attn_output = attn_output.transpose(1, 2)
+
+        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+        # partitioned across GPUs when using tensor-parallelism.
+        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+        # attn_output = self.out_proj(attn_output)
+        attn_output = self.out_proj(attn_output)
+
+        return attn_output, attn_weights_reshaped, past_key_value
+
+
+class PatchedMBartSqueezeAttention(nn.Module):
+
+    def __init__(self, origin: nn.Module, kvlen: torch.LongTensor):
+        super().__init__()
+        self.embed_dim = origin.embed_dim
+        self.num_heads = origin.num_heads
+        self.dropout = origin.dropout
+        self.head_dim = origin.head_dim
+        self.squeeze_head_dim=origin.squeeze_head_dim
+        self.config = origin.config
+
+        self.scaling = origin.scaling
+        self.is_decoder = origin.is_decoder
+        self.scaling = origin.scaling
+
+        self.q_proj = origin.q_proj
+        self.k_proj = origin.k_proj
+        self.v_proj = origin.v_proj
+        self.out_proj = origin.out_proj
+        self.kvlen = kvlen
+
+    def _shape_qk(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+        return tensor.view(bsz, seq_len, self.num_heads, self.squeeze_head_dim).transpose(1, 2).contiguous()
+
+    def _shape_v(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        key_value_states: Optional[torch.Tensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        layer_head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        """Input shape: Batch x Time x Channel"""
+
+        # if key_value_states are provided this layer is used as a cross-attention layer
+        # for the decoder
+        is_cross_attention = key_value_states is not None
+
+        bsz, tgt_len, _ = hidden_states.size()
+
+        # get query proj
+        query_states = self.q_proj(hidden_states) * self.scaling
+        # get key, value proj
+        # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+        # is checking that the `sequence_length` of the `past_key_value` is the same as
+        # the provided `key_value_states` to support prefix tuning
+        if (
+            is_cross_attention
+            and past_key_value is not None
+            and past_key_value[0].shape[2] == key_value_states.shape[1]
+        ):
+            # reuse k,v, cross_attentions
+            key_states = past_key_value[0]
+            value_states = past_key_value[1]
+        elif is_cross_attention:
+            # cross_attentions
+            key_states = self._shape_qk(self.k_proj(key_value_states), -1, bsz)
+            value_states = self._shape_v(self.v_proj(key_value_states), -1, bsz)
+        elif past_key_value is not None:
+            # reuse k, v, self_attention
+            key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz)
+
+            if past_key_value[0].size(-2) < attention_mask.size(-1):
+                key_states = torch.cat([past_key_value[0], key_states], dim=2)
+                value_states = torch.cat([past_key_value[1], value_states], dim=2)
+            else:
+                past_key_value[0][:, :, self.kvlen[None]] = key_states
+                past_key_value[1][:, :, self.kvlen[None]] = value_states
+                key_states = past_key_value[0]
+                value_states = past_key_value[1]
+        else:
+            # self_attention
+            key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz)
+
+        if self.is_decoder:
+            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+            # Further calls to cross_attention layer can then reuse all cross-attention
+            # key/value_states (first "if" case)
+            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+            # all previous decoder key/value_states. Further calls to uni-directional self-attention
+            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+            # if encoder bi-directional self-attention `past_key_value` is always `None`
+            past_key_value = (key_states, value_states)
+
+        proj_shape = (bsz * self.num_heads, -1, self.squeeze_head_dim)
+        value_shape = (bsz * self.num_heads, -1, self.head_dim)
+        query_states = self._shape_qk(query_states, tgt_len, bsz).view(*proj_shape)
+        key_states = key_states.reshape(*proj_shape)
+        value_states = value_states.reshape(*value_shape)
+
+        src_len = key_states.size(1)
+        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+            raise ValueError(
+                f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+                f" {attn_weights.size()}"
+            )
+
+        if attention_mask is not None:
+            if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+                raise ValueError(
+                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+                )
+            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+        if layer_head_mask is not None:
+            if layer_head_mask.size() != (self.num_heads,):
+                raise ValueError(
+                    f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+                    f" {layer_head_mask.size()}"
+                )
+            attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+        if output_attentions:
+            # this operation is a bit awkward, but it's required to
+            # make sure that attn_weights keeps its gradient.
+            # In order to do so, attn_weights have to be reshaped
+            # twice and have to be reused in the following
+            attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+            attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+        else:
+            attn_weights_reshaped = None
+
+        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+        attn_output = torch.bmm(attn_probs, value_states)
+
+        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
+                f" {attn_output.size()}"
+            )
+
+        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+        attn_output = attn_output.transpose(1, 2)
+
+        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+        # partitioned across GPUs when using tensor-parallelism.
+        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+        attn_output = self.out_proj(attn_output)
+
+        return attn_output, attn_weights_reshaped, past_key_value
+
+def patch_model(model: nn.Module, kvlen: torch.LongTensor):
+    for name, child in model.named_children():
+        cls_name = type(child).__name__
+        if cls_name == 'MBartAttention':
+            patched_child = PatchedMBartAttention(child, kvlen)
+            model.register_module(name, patched_child)
+        elif cls_name == 'MBartSqueezeAttention':
+            patched_child = PatchedMBartSqueezeAttention(child, kvlen)
+            model.register_module(name, patched_child)
+        else:
+            patch_model(child, kvlen)
+
+    cls_name = type(model).__name__
+    if cls_name == 'CustomMBartDecoder':
+        model = PatchedMBartDecoder(model, kvlen)
+    return model
+
+
+def next_power_of_2(n: int):
+    """Return the smallest power of 2 greater than or equal to n."""
+    n -= 1
+    n |= n >> 1
+    n |= n >> 2
+    n |= n >> 4
+    n |= n >> 8
+    n |= n >> 16
+    n |= n >> 32
+    n += 1
+    return n
+
+
+def get_graph_key(batch_size: int, kvlens: int):
+    batch_size = next_power_of_2(batch_size)
+    kvlens = next_power_of_2(kvlens)
+
+    batch_size = max(8, batch_size)
+    kvlens = max(32, kvlens)
+
+    return batch_size, kvlens
+
+
+class GraphRunnerImpl:
+
+    def __init__(self, model: nn.Module, graph: torch.cuda.CUDAGraph, input_buffers: dict, output_buffers: dict):
+        self.model = model
+        self.graph = graph
+        self.input_buffers = input_buffers
+        self.output_buffers = output_buffers
+
+    @staticmethod
+    def extract_input_buffers(input_buffers: dict, batch_size: int, kvlens: int):
+        input_ids = input_buffers['input_ids'][:batch_size]
+        attention_mask = input_buffers['attention_mask'][:batch_size, :kvlens]
+        encoder_hidden_states = input_buffers['encoder_hidden_states'][:batch_size]
+        kvlen=input_buffers['kvlen']
+
+        past_key_values = []
+        for past_key_value in input_buffers['past_key_values']:
+            k0 = past_key_value[0][:batch_size, :, :kvlens]
+            v0 = past_key_value[1][:batch_size, :, :kvlens]
+            k1 = past_key_value[2][:batch_size]
+            v1 = past_key_value[3][:batch_size]
+            past_key_values.append((k0, v0, k1, v1))
+
+        input_buffers = dict(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            past_key_values=past_key_values,
+            kvlen=kvlen,
+        )
+        return input_buffers
+
+    @staticmethod
+    def fill_input_buffers(
+        input_buffer: dict,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+        ):
+        batch_size = input_ids.size(0)
+        kvlens = attention_mask.size(1)
+
+        input_buffer['input_ids'][:batch_size] = input_ids
+
+        if input_buffer['attention_mask'].data_ptr() != attention_mask.data_ptr():
+            input_buffer['attention_mask'].fill_(0)
+        input_buffer['attention_mask'][:batch_size, :kvlens] = attention_mask
+        input_buffer['encoder_hidden_states'][:batch_size] = encoder_hidden_states
+
+        if past_key_values is not None:
+            for buf_kv, kv in zip(input_buffer['past_key_values'], past_key_values):
+                idx = 0
+                if buf_kv[idx].data_ptr() != kv[idx].data_ptr():
+                    buf_kv[idx].fill_(0)
+                    buf_kv[idx][:batch_size, :, :kvlens-1] = kv[idx]
+                idx = 1
+                if buf_kv[idx].data_ptr() != kv[idx].data_ptr():
+                    buf_kv[idx].fill_(0)
+                    buf_kv[idx][:batch_size, :, :kvlens-1] = kv[idx]
+
+                idx = 2
+                if buf_kv[idx].data_ptr() != kv[idx].data_ptr():
+                    buf_kv[idx].fill_(0)
+                    buf_kv[idx][:batch_size] = kv[idx]
+                idx = 3
+                if buf_kv[idx].data_ptr() != kv[idx].data_ptr():
+                    buf_kv[idx].fill_(0)
+                    buf_kv[idx][:batch_size] = kv[idx]
+
+        input_buffer['kvlen'].fill_(kvlens - 1)
+
+    @classmethod
+    @torch.inference_mode()
+    def capture(cls,
+                model: nn.Module,
+                input_buffers: dict,
+                pool,
+                warmup: bool = False,
+                input_ids: torch.LongTensor = None,
+                attention_mask: Optional[torch.Tensor] = None,
+                count_pred: Optional[torch.FloatTensor] = None,
+                encoder_hidden_states: Optional[torch.FloatTensor] = None,
+                encoder_attention_mask: Optional[torch.LongTensor] = None,
+                head_mask: Optional[torch.Tensor] = None,
+                cross_attn_head_mask: Optional[torch.Tensor] = None,
+                past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+                inputs_embeds: Optional[torch.FloatTensor] = None,
+                use_cache: Optional[bool] = None,
+                output_attentions: Optional[bool] = None,
+                output_hidden_states: Optional[bool] = None,
+                return_dict: Optional[bool] = None,):
+        batch_size = input_ids.size(0)
+        kvlens = attention_mask.size(1)
+
+        graph_key = get_graph_key(batch_size, kvlens)
+        batch_size = graph_key[0]
+        kvlens = graph_key[1]
+
+        input_buffers = cls.extract_input_buffers(input_buffers,
+                                                  batch_size=batch_size,
+                                                  kvlens=kvlens)
+        cls.fill_input_buffers(input_buffers,
+                               input_ids,
+                               attention_mask,
+                               encoder_hidden_states,
+                               past_key_values)
+
+        input_ids = input_buffers['input_ids']
+        attention_mask = input_buffers['attention_mask']
+        encoder_hidden_states = input_buffers['encoder_hidden_states']
+        past_key_values = input_buffers['past_key_values']
+
+        if warmup:
+            # warmup
+            model(
+                input_ids=input_ids,
+                attention_mask=attention_mask,
+                count_pred=count_pred,
+                encoder_hidden_states=encoder_hidden_states,
+                encoder_attention_mask=encoder_attention_mask,
+                head_mask=head_mask,
+                cross_attn_head_mask=cross_attn_head_mask,
+                past_key_values=past_key_values,
+                inputs_embeds=inputs_embeds,
+                use_cache=use_cache,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                return_dict=return_dict)
+
+        graph = torch.cuda.CUDAGraph()
+        with torch.cuda.graph(graph,
+                              pool=pool):
+            outputs = model(
+                input_ids=input_ids,
+                attention_mask=attention_mask,
+                count_pred=count_pred,
+                encoder_hidden_states=encoder_hidden_states,
+                encoder_attention_mask=encoder_attention_mask,
+                head_mask=head_mask,
+                cross_attn_head_mask=cross_attn_head_mask,
+                past_key_values=past_key_values,
+                inputs_embeds=inputs_embeds,
+                use_cache=use_cache,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                return_dict=return_dict)
+
+        output_buffers = dict(
+            last_hidden_state=outputs['last_hidden_state'],
+            past_key_values=outputs['past_key_values'],
+            )
+
+        return GraphRunnerImpl(model, graph, input_buffers, output_buffers)
+
+    def __call__(self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        count_pred: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        cross_attn_head_mask: Optional[torch.Tensor] = None,
+        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        ):
+        batch_size = input_ids.size(0)
+        kvlens = attention_mask.size(1)
+        self.fill_input_buffers(self.input_buffers,
+                               input_ids,
+                               attention_mask,
+                               encoder_hidden_states,
+                               past_key_values)
+
+        self.graph.replay()
+
+        last_hidden_state = self.output_buffers['last_hidden_state'][:batch_size]
+
+        past_key_values = []
+        for past_key_value in self.output_buffers['past_key_values']:
+            k0 = past_key_value[0][:batch_size, :, :kvlens]
+            v0 = past_key_value[1][:batch_size, :, :kvlens]
+            k1 = past_key_value[2][:batch_size]
+            v1 = past_key_value[3][:batch_size]
+            past_key_values.append((k0, v0, k1, v1))
+
+        outputs = BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=last_hidden_state,
+            past_key_values=past_key_values,
+        )
+        return outputs
+
+class GraphRunner(nn.Module):
+
+    def __init__(self, model: nn.Module, max_batchs: int, max_kvlens: int, dtype:torch.dtype = torch.float16, device: torch.device = 'cuda'):
+        super().__init__()
+
+        self.kvlen = torch.tensor(0, dtype=torch.long, device=device)
+        model = patch_model(model.to(dtype), self.kvlen)
+        self.model = model
+        self.max_batchs = max_batchs
+        self.max_kvlens = max_kvlens
+        self.device = device
+
+        self.input_buffers = None
+
+        self.impl_map = dict()
+        self.graph_pool_handle = torch.cuda.graph_pool_handle()
+        self.warmuped = False
+
+    def create_buffers(self, encoder_kvlens: int, dtype: torch.dtype):
+        max_batchs = self.max_batchs
+        max_kvlens = self.max_kvlens
+        device = self.device
+        config = self.model.config
+
+        d_model = config.d_model
+        decoder_layers = config.decoder_layers
+        num_heads = config.decoder_attention_heads
+
+        head_dim = d_model // num_heads
+        self_attn = self.model.layers[0].self_attn
+        qk_head_dim = getattr(self_attn, 'squeeze_head_dim', head_dim)
+
+        input_ids = torch.ones((max_batchs, 1), dtype=torch.int64, device=device)
+        attention_mask = torch.zeros((max_batchs, max_kvlens), dtype=torch.int64, device=device)
+        encoder_hidden_states = torch.zeros((max_batchs, encoder_kvlens, d_model), dtype=dtype, device=device)
+
+        past_key_values = []
+        for _ in range(decoder_layers):
+            k0 = torch.zeros((max_batchs, num_heads, max_kvlens, qk_head_dim), dtype=dtype, device=device)
+            v0 = torch.zeros((max_batchs, num_heads, max_kvlens, head_dim), dtype=dtype, device=device)
+            k1 = torch.zeros((max_batchs, num_heads, encoder_kvlens, qk_head_dim), dtype=dtype, device=device)
+            v1 = torch.zeros((max_batchs, num_heads, encoder_kvlens, head_dim), dtype=dtype, device=device)
+
+            past_key_values.append((k0, v0, k1, v1))
+
+        self.input_buffers = dict(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            past_key_values=past_key_values,
+            kvlen=self.kvlen
+        )
+
+    @torch.inference_mode()
+    def forward(self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        count_pred: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        cross_attn_head_mask: Optional[torch.Tensor] = None,
+        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        ):
+        batch_size, qlens = input_ids.size()
+        kvlens = attention_mask.size(1)
+
+        eager_mode = False
+
+        if qlens != 1:
+            eager_mode = True
+
+        if past_key_values is None:
+            eager_mode = True
+        else:
+            for past_key_value in past_key_values:
+                if past_key_value is None:
+                    eager_mode = True
+                    break
+
+        if batch_size >= self.max_batchs or kvlens >= self.max_kvlens:
+            eager_mode = True
+
+        if eager_mode:
+            return self.model(
+                input_ids=input_ids,
+                attention_mask=attention_mask,
+                count_pred=count_pred,
+                encoder_hidden_states=encoder_hidden_states,
+                encoder_attention_mask=encoder_attention_mask,
+                head_mask=head_mask,
+                cross_attn_head_mask=cross_attn_head_mask,
+                past_key_values=past_key_values,
+                inputs_embeds=inputs_embeds,
+                use_cache=use_cache,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                return_dict=return_dict,)
+
+        # create buffer if not exists.
+        if self.input_buffers is None:
+            encoder_kvlens = encoder_hidden_states.size(1)
+            self.create_buffers(encoder_kvlens=encoder_kvlens, dtype=encoder_hidden_states.dtype)
+
+        graph_key = get_graph_key(batch_size, kvlens)
+        if graph_key not in self.impl_map:
+            warmup = False
+            if not self.warmuped:
+                warmup = True
+                self.warmuped = True
+            impl = GraphRunnerImpl.capture(
+                self.model,
+                self.input_buffers,
+                self.graph_pool_handle,
+                warmup=warmup,
+                input_ids=input_ids,
+                attention_mask=attention_mask,
+                count_pred=count_pred,
+                encoder_hidden_states=encoder_hidden_states,
+                encoder_attention_mask=encoder_attention_mask,
+                head_mask=head_mask,
+                cross_attn_head_mask=cross_attn_head_mask,
+                past_key_values=past_key_values,
+                inputs_embeds=inputs_embeds,
+                use_cache=use_cache,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                return_dict=return_dict,
+            )
+            self.impl_map[graph_key] = impl
+        impl = self.impl_map[graph_key]
+
+        ret = impl(
+                input_ids=input_ids,
+                attention_mask=attention_mask,
+                count_pred=count_pred,
+                encoder_hidden_states=encoder_hidden_states,
+                encoder_attention_mask=encoder_attention_mask,
+                head_mask=head_mask,
+                cross_attn_head_mask=cross_attn_head_mask,
+                past_key_values=past_key_values,
+                inputs_embeds=inputs_embeds,
+                use_cache=use_cache,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                return_dict=return_dict,
+        )
+        return ret

+ 91 - 46
magic_pdf/model/pdf_extract_kit.py

@@ -6,6 +6,7 @@ import shutil
 from magic_pdf.libs.Constants import *
 from magic_pdf.libs.clean_memory import clean_memory
 from magic_pdf.model.model_list import AtomicModel
+from .mfr_cudagraph import GraphRunner
 
 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
 os.environ['YOLO_VERBOSE'] = 'False'  # disable yolo logger
@@ -26,6 +27,7 @@ try:
     from unimernet.common.config import Config
     import unimernet.tasks as tasks
     from unimernet.processors import load_processor
+    from doclayout_yolo import YOLOv10
 
 except ImportError as e:
     logger.exception(e)
@@ -42,7 +44,7 @@ from magic_pdf.model.ppTableModel import ppTableModel
 
 
 def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
-    if table_model_type == STRUCT_EQTABLE:
+    if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
         table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
     else:
         config = {
@@ -68,6 +70,11 @@ def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
     model = task.build_model(cfg)
     model.to(_device_)
     model.eval()
+    model = model.to(_device_)
+    if 'cuda' in _device_:
+        decoder_runner = GraphRunner(model.model.model.decoder.model.decoder, max_batchs=128, max_kvlens=256,
+                                     device=_device_)
+        model.model.model.decoder.model.decoder = decoder_runner
     vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
     mfr_transform = transforms.Compose([vis_processor, ])
     return [model, mfr_transform]
@@ -78,11 +85,16 @@ def layout_model_init(weight, config_file, device):
     return model
 
 
-def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None):
+def doclayout_yolo_model_init(weight):
+    model = YOLOv10(weight)
+    return model
+
+
+def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None, use_dilation=True, det_db_unclip_ratio=1.8):
     if lang is not None:
-        model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang)
+        model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
     else:
-        model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh)
+        model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
     return model
 
 
@@ -115,19 +127,27 @@ class AtomModelSingleton:
         return cls._instance
 
     def get_atom_model(self, atom_model_name: str, **kwargs):
-        if atom_model_name not in self._models:
-            self._models[atom_model_name] = atom_model_init(model_name=atom_model_name, **kwargs)
-        return self._models[atom_model_name]
+        lang = kwargs.get("lang", None)
+        layout_model_name = kwargs.get("layout_model_name", None)
+        key = (atom_model_name, layout_model_name, lang)
+        if key not in self._models:
+            self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
+        return self._models[key]
 
 
 def atom_model_init(model_name: str, **kwargs):
 
     if model_name == AtomicModel.Layout:
-        atom_model = layout_model_init(
-            kwargs.get("layout_weights"),
-            kwargs.get("layout_config_file"),
-            kwargs.get("device")
-        )
+        if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
+            atom_model = layout_model_init(
+                kwargs.get("layout_weights"),
+                kwargs.get("layout_config_file"),
+                kwargs.get("device")
+            )
+        elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
+            atom_model = doclayout_yolo_model_init(
+                kwargs.get("doclayout_yolo_weights"),
+            )
     elif model_name == AtomicModel.MFD:
         atom_model = mfd_model_init(
             kwargs.get("mfd_weights")
@@ -146,7 +166,7 @@ def atom_model_init(model_name: str, **kwargs):
         )
     elif model_name == AtomicModel.Table:
         atom_model = table_model_init(
-            kwargs.get("table_model_type"),
+            kwargs.get("table_model_name"),
             kwargs.get("table_model_path"),
             kwargs.get("table_max_time"),
             kwargs.get("device")
@@ -194,23 +214,35 @@ class CustomPEKModel:
         with open(config_path, "r", encoding='utf-8') as f:
             self.configs = yaml.load(f, Loader=yaml.FullLoader)
         # 初始化解析配置
-        self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
-        self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
+
+        # layout config
+        self.layout_config = kwargs.get("layout_config")
+        self.layout_model_name = self.layout_config.get("model", MODEL_NAME.DocLayout_YOLO)
+
+        # formula config
+        self.formula_config = kwargs.get("formula_config")
+        self.mfd_model_name = self.formula_config.get("mfd_model", MODEL_NAME.YOLO_V8_MFD)
+        self.mfr_model_name = self.formula_config.get("mfr_model", MODEL_NAME.UniMerNet_v2_Small)
+        self.apply_formula = self.formula_config.get("enable", True)
+
         # table config
-        self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"])
-        self.apply_table = self.table_config.get("is_table_recog_enable", False)
+        self.table_config = kwargs.get("table_config")
+        self.apply_table = self.table_config.get("enable", False)
         self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
-        self.table_model_type = self.table_config.get("model", TABLE_MASTER)
+        self.table_model_name = self.table_config.get("model", MODEL_NAME.TABLE_MASTER)
+
+        # ocr config
         self.apply_ocr = ocr
         self.lang = kwargs.get("lang", None)
+
         logger.info(
-            "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}, lang: {}".format(
-                self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table, self.lang
+            "DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
+            "apply_table: {}, table_model: {}, lang: {}".format(
+                self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name, self.lang
             )
         )
-        assert self.apply_layout, "DocAnalysis must contain layout model."
         # 初始化解析方案
-        self.device = kwargs.get("device", self.configs["config"]["device"])
+        self.device = kwargs.get("device", "cpu")
         logger.info("using device: {}".format(self.device))
         models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
         logger.info("using models_dir: {}".format(models_dir))
@@ -219,17 +251,16 @@ class CustomPEKModel:
 
         # 初始化公式识别
         if self.apply_formula:
+
             # 初始化公式检测模型
-            # self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
             self.mfd_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.MFD,
-                mfd_weights=str(os.path.join(models_dir, self.configs["weights"]["mfd"]))
+                mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name]))
             )
+
             # 初始化公式解析模型
-            mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
+            mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
             mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
-            # self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
-            # self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
             self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.MFR,
                 mfr_weight_dir=mfr_weight_dir,
@@ -238,17 +269,20 @@ class CustomPEKModel:
             )
 
         # 初始化layout模型
-        # self.layout_model = Layoutlmv3_Predictor(
-        #     str(os.path.join(models_dir, self.configs['weights']['layout'])),
-        #     str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
-        #     device=self.device
-        # )
-        self.layout_model = atom_model_manager.get_atom_model(
-            atom_model_name=AtomicModel.Layout,
-            layout_weights=str(os.path.join(models_dir, self.configs['weights']['layout'])),
-            layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
-            device=self.device
-        )
+        if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
+            self.layout_model = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.Layout,
+                layout_model_name=MODEL_NAME.LAYOUTLMv3,
+                layout_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
+                layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
+                device=self.device
+            )
+        elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
+            self.layout_model = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.Layout,
+                layout_model_name=MODEL_NAME.DocLayout_YOLO,
+                doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
+            )
         # 初始化ocr
         if self.apply_ocr:
 
@@ -261,12 +295,10 @@ class CustomPEKModel:
             )
         # init table model
         if self.apply_table:
-            table_model_dir = self.configs["weights"][self.table_model_type]
-            # self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
-            #                                     max_time=self.table_max_time, _device_=self.device)
+            table_model_dir = self.configs["weights"][self.table_model_name]
             self.table_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.Table,
-                table_model_type=self.table_model_type,
+                table_model_name=self.table_model_name,
                 table_model_path=str(os.path.join(models_dir, table_model_dir)),
                 table_max_time=self.table_max_time,
                 device=self.device
@@ -294,7 +326,21 @@ class CustomPEKModel:
 
         # layout检测
         layout_start = time.time()
-        layout_res = self.layout_model(image, ignore_catids=[])
+        if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
+            # layoutlmv3
+            layout_res = self.layout_model(image, ignore_catids=[])
+        elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
+            # doclayout_yolo
+            layout_res = []
+            doclayout_yolo_res = self.layout_model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
+            for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), doclayout_yolo_res.boxes.cls.cpu()):
+                xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
+                new_item = {
+                    'category_id': int(cla.item()),
+                    'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
+                    'score': round(float(conf.item()), 3),
+                }
+                layout_res.append(new_item)
         layout_cost = round(time.time() - layout_start, 2)
         logger.info(f"layout detection time: {layout_cost}")
 
@@ -303,7 +349,7 @@ class CustomPEKModel:
         if self.apply_formula:
             # 公式检测
             mfd_start = time.time()
-            mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
+            mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
             logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
             for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
                 xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
@@ -315,7 +361,6 @@ class CustomPEKModel:
                 }
                 layout_res.append(new_item)
                 latex_filling_list.append(new_item)
-                # bbox_img = get_croped_image(pil_img, [xmin, ymin, xmax, ymax])
                 bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
                 mf_image_list.append(bbox_img)
 
@@ -417,7 +462,7 @@ class CustomPEKModel:
                 # logger.info("------------------table recognition processing begins-----------------")
                 latex_code = None
                 html_code = None
-                if self.table_model_type == STRUCT_EQTABLE:
+                if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
                     with torch.no_grad():
                         latex_code = self.table_model.image2latex(new_image)[0]
                 else:

+ 2 - 2
magic_pdf/model/ppTableModel.py

@@ -52,11 +52,11 @@ class ppTableModel(object):
         rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR)
         rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT)
         device = kwargs.get("device", "cpu")
-        use_gpu = True if device == "cuda" else False
+        use_gpu = True if device.startswith("cuda") else False
         config = {
             "use_gpu": use_gpu,
             "table_max_len": kwargs.get("table_max_len", TABLE_MAX_LEN),
-            "table_algorithm": TABLE_MASTER,
+            "table_algorithm": "TableMaster",
             "table_model_dir": table_model_dir,
             "table_char_dict_path": table_char_dict_path,
             "det_model_dir": det_model_dir,

Diferenças do arquivo suprimidas por serem muito extensas
+ 122 - 82
magic_pdf/para/para_split_v3.py


+ 5 - 2
magic_pdf/pdf_parse_by_ocr.py

@@ -1,3 +1,5 @@
+from magic_pdf.config.enums import SupportedPdfParseMethod
+from magic_pdf.data.dataset import PymuDocDataset
 from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
 
 
@@ -8,10 +10,11 @@ def parse_pdf_by_ocr(pdf_bytes,
                      end_page_id=None,
                      debug_mode=False,
                      ):
-    return pdf_parse_union(pdf_bytes,
+    dataset = PymuDocDataset(pdf_bytes)
+    return pdf_parse_union(dataset,
                            model_list,
                            imageWriter,
-                           "ocr",
+                           SupportedPdfParseMethod.OCR,
                            start_page_id=start_page_id,
                            end_page_id=end_page_id,
                            debug_mode=debug_mode,

+ 5 - 2
magic_pdf/pdf_parse_by_txt.py

@@ -1,3 +1,5 @@
+from magic_pdf.config.enums import SupportedPdfParseMethod
+from magic_pdf.data.dataset import PymuDocDataset
 from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
 
 
@@ -9,10 +11,11 @@ def parse_pdf_by_txt(
     end_page_id=None,
     debug_mode=False,
 ):
-    return pdf_parse_union(pdf_bytes,
+    dataset = PymuDocDataset(pdf_bytes)
+    return pdf_parse_union(dataset,
                            model_list,
                            imageWriter,
-                           "txt",
+                           SupportedPdfParseMethod.TXT,
                            start_page_id=start_page_id,
                            end_page_id=end_page_id,
                            debug_mode=debug_mode,

+ 300 - 156
magic_pdf/pdf_parse_union_core_v2.py

@@ -1,13 +1,14 @@
+import copy
 import os
 import statistics
 import time
-
-from loguru import logger
-
 from typing import List
 
 import torch
+from loguru import logger
 
+from magic_pdf.config.enums import SupportedPdfParseMethod
+from magic_pdf.data.dataset import Dataset, PageableData
 from magic_pdf.libs.clean_memory import clean_memory
 from magic_pdf.libs.commons import fitz, get_delta_time
 from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
@@ -15,31 +16,39 @@ from magic_pdf.libs.convert_utils import dict_to_list
 from magic_pdf.libs.drop_reason import DropReason
 from magic_pdf.libs.hash_utils import compute_md5
 from magic_pdf.libs.local_math import float_equal
-from magic_pdf.libs.ocr_content_type import ContentType
+from magic_pdf.libs.ocr_content_type import ContentType, BlockType
 from magic_pdf.model.magic_model import MagicModel
 from magic_pdf.para.para_split_v3 import para_split
 from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker
-from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2
+from magic_pdf.pre_proc.construct_page_dict import \
+    ocr_construct_page_component_v2
 from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
-from magic_pdf.pre_proc.equations_replace import remove_chars_in_text_blocks, replace_equations_in_textblock, \
-    combine_chars_to_pymudict
-from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layout_split_v2
-from magic_pdf.pre_proc.ocr_dict_merge import  fill_spans_in_blocks, fix_block_spans, fix_discarded_block
-from magic_pdf.pre_proc.ocr_span_list_modify import remove_overlaps_min_spans, get_qa_need_list_v2, \
-    remove_overlaps_low_confidence_spans
-from magic_pdf.pre_proc.resolve_bbox_conflict import check_useful_block_horizontal_overlap
+from magic_pdf.pre_proc.equations_replace import (
+    combine_chars_to_pymudict, remove_chars_in_text_blocks,
+    replace_equations_in_textblock)
+from magic_pdf.pre_proc.ocr_detect_all_bboxes import \
+    ocr_prepare_bboxes_for_layout_split_v2
+from magic_pdf.pre_proc.ocr_dict_merge import (fill_spans_in_blocks,
+                                               fix_block_spans,
+                                               fix_discarded_block, fix_block_spans_v2)
+from magic_pdf.pre_proc.ocr_span_list_modify import (
+    get_qa_need_list_v2, remove_overlaps_low_confidence_spans,
+    remove_overlaps_min_spans)
+from magic_pdf.pre_proc.resolve_bbox_conflict import \
+    check_useful_block_horizontal_overlap
 
 
 def remove_horizontal_overlap_block_which_smaller(all_bboxes):
     useful_blocks = []
     for bbox in all_bboxes:
-        useful_blocks.append({
-            "bbox": bbox[:4]
-        })
-    is_useful_block_horz_overlap, smaller_bbox, bigger_bbox = check_useful_block_horizontal_overlap(useful_blocks)
+        useful_blocks.append({'bbox': bbox[:4]})
+    is_useful_block_horz_overlap, smaller_bbox, bigger_bbox = (
+        check_useful_block_horizontal_overlap(useful_blocks)
+    )
     if is_useful_block_horz_overlap:
         logger.warning(
-            f"skip this page, reason: {DropReason.USEFUL_BLOCK_HOR_OVERLAP}, smaller bbox is {smaller_bbox}, bigger bbox is {bigger_bbox}")
+            f'skip this page, reason: {DropReason.USEFUL_BLOCK_HOR_OVERLAP}, smaller bbox is {smaller_bbox}, bigger bbox is {bigger_bbox}'
+        )  # noqa: E501
         for bbox in all_bboxes.copy():
             if smaller_bbox == bbox[:4]:
                 all_bboxes.remove(bbox)
@@ -47,27 +56,27 @@ def remove_horizontal_overlap_block_which_smaller(all_bboxes):
     return is_useful_block_horz_overlap, all_bboxes
 
 
-def __replace_STX_ETX(text_str:str):
-    """ Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks.
-Drawback: This issue is only observed in English text; it has not been found in Chinese text so far.
+def __replace_STX_ETX(text_str: str):
+    """Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks.
+    Drawback: This issue is only observed in English text; it has not been found in Chinese text so far.
 
-    Args:
-        text_str (str): raw text
+        Args:
+            text_str (str): raw text
 
-    Returns:
-        _type_: replaced text
-    """
+        Returns:
+            _type_: replaced text
+    """  # noqa: E501
     if text_str:
         s = text_str.replace('\u0002', "'")
-        s = s.replace("\u0003", "'")
+        s = s.replace('\u0003', "'")
         return s
     return text_str
 
 
 def txt_spans_extract(pdf_page, inline_equations, interline_equations):
-    text_raw_blocks = pdf_page.get_text("dict", flags=fitz.TEXTFLAGS_TEXT)["blocks"]
-    char_level_text_blocks = pdf_page.get_text("rawdict", flags=fitz.TEXTFLAGS_TEXT)[
-        "blocks"
+    text_raw_blocks = pdf_page.get_text('dict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
+    char_level_text_blocks = pdf_page.get_text('rawdict', flags=fitz.TEXTFLAGS_TEXT)[
+        'blocks'
     ]
     text_blocks = combine_chars_to_pymudict(text_raw_blocks, char_level_text_blocks)
     text_blocks = replace_equations_in_textblock(
@@ -77,54 +86,63 @@ def txt_spans_extract(pdf_page, inline_equations, interline_equations):
     text_blocks = remove_chars_in_text_blocks(text_blocks)
     spans = []
     for v in text_blocks:
-        for line in v["lines"]:
-            for span in line["spans"]:
-                bbox = span["bbox"]
+        for line in v['lines']:
+            for span in line['spans']:
+                bbox = span['bbox']
                 if float_equal(bbox[0], bbox[2]) or float_equal(bbox[1], bbox[3]):
                     continue
-                if span.get('type') not in (ContentType.InlineEquation, ContentType.InterlineEquation):
+                if span.get('type') not in (
+                    ContentType.InlineEquation,
+                    ContentType.InterlineEquation,
+                ):
                     spans.append(
                         {
-                            "bbox": list(span["bbox"]),
-                            "content": __replace_STX_ETX(span["text"]),
-                            "type": ContentType.Text,
-                            "score": 1.0,
+                            'bbox': list(span['bbox']),
+                            'content': __replace_STX_ETX(span['text']),
+                            'type': ContentType.Text,
+                            'score': 1.0,
                         }
                     )
     return spans
 
 
 def replace_text_span(pymu_spans, ocr_spans):
-    return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans
+    return list(filter(lambda x: x['type'] != ContentType.Text, ocr_spans)) + pymu_spans
 
 
 def model_init(model_name: str):
     from transformers import LayoutLMv3ForTokenClassification
+
     if torch.cuda.is_available():
-        device = torch.device("cuda")
+        device = torch.device('cuda')
         if torch.cuda.is_bf16_supported():
             supports_bfloat16 = True
         else:
             supports_bfloat16 = False
     else:
-        device = torch.device("cpu")
+        device = torch.device('cpu')
         supports_bfloat16 = False
 
-    if model_name == "layoutreader":
+    if model_name == 'layoutreader':
         # 检测modelscope的缓存目录是否存在
         layoutreader_model_dir = get_local_layoutreader_model_dir()
         if os.path.exists(layoutreader_model_dir):
-            model = LayoutLMv3ForTokenClassification.from_pretrained(layoutreader_model_dir)
+            model = LayoutLMv3ForTokenClassification.from_pretrained(
+                layoutreader_model_dir
+            )
         else:
             logger.warning(
-                f"local layoutreader model not exists, use online model from huggingface")
-            model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader")
+                'local layoutreader model not exists, use online model from huggingface'
+            )
+            model = LayoutLMv3ForTokenClassification.from_pretrained(
+                'hantian/layoutreader'
+            )
         # 检查设备是否支持 bfloat16
         if supports_bfloat16:
             model.bfloat16()
         model.to(device).eval()
     else:
-        logger.error("model name not allow")
+        logger.error('model name not allow')
         exit(1)
     return model
 
@@ -145,7 +163,9 @@ class ModelSingleton:
 
 
 def do_predict(boxes: List[List[int]], model) -> List[int]:
-    from magic_pdf.model.v3.helpers import prepare_inputs, boxes2inputs, parse_logits
+    from magic_pdf.model.v3.helpers import (boxes2inputs, parse_logits,
+                                            prepare_inputs)
+
     inputs = boxes2inputs(boxes)
     inputs = prepare_inputs(inputs, model)
     logits = model(**inputs).logits.cpu().squeeze(0)
@@ -154,19 +174,6 @@ def do_predict(boxes: List[List[int]], model) -> List[int]:
 
 def cal_block_index(fix_blocks, sorted_bboxes):
     for block in fix_blocks:
-        # if block['type'] in ['text', 'title', 'interline_equation']:
-        #     line_index_list = []
-        #     if len(block['lines']) == 0:
-        #         block['index'] = sorted_bboxes.index(block['bbox'])
-        #     else:
-        #         for line in block['lines']:
-        #             line['index'] = sorted_bboxes.index(line['bbox'])
-        #             line_index_list.append(line['index'])
-        #         median_value = statistics.median(line_index_list)
-        #         block['index'] = median_value
-        #
-        # elif block['type'] in ['table', 'image']:
-        #     block['index'] = sorted_bboxes.index(block['bbox'])
 
         line_index_list = []
         if len(block['lines']) == 0:
@@ -178,9 +185,11 @@ def cal_block_index(fix_blocks, sorted_bboxes):
             median_value = statistics.median(line_index_list)
             block['index'] = median_value
 
-        # 删除图表block中的虚拟line信息
-        if block['type'] in ['table', 'image']:
-            del block['lines']
+        # 删除图表body block中的虚拟line信息, 并用real_lines信息回填
+        if block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
+            block['virtual_lines'] = copy.deepcopy(block['lines'])
+            block['lines'] = copy.deepcopy(block['real_lines'])
+            del block['real_lines']
 
     return fix_blocks
 
@@ -193,21 +202,22 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
     block_weight = x1 - x0
 
     # 如果block高度小于n行正文,则直接返回block的bbox
-    if line_height*3 < block_height:
-        if block_height > page_h*0.25 and page_w*0.5 > block_weight > page_w*0.25:  # 可能是双列结构,可以切细点
-            lines = int(block_height/line_height)+1
+    if line_height * 3 < block_height:
+        if (
+            block_height > page_h * 0.25 and page_w * 0.5 > block_weight > page_w * 0.25
+        ):  # 可能是双列结构,可以切细点
+            lines = int(block_height / line_height) + 1
         else:
-            # 如果block的宽度超过0.4页面宽度,则将block分成3行
-            if block_weight > page_w*0.4:
+            # 如果block的宽度超过0.4页面宽度,则将block分成3行(是一种复杂布局,图不能切的太细)
+            if block_weight > page_w * 0.4:
                 line_height = (y1 - y0) / 3
                 lines = 3
-            elif block_weight > page_w*0.25: # 否则将block分成两行
-                line_height = (y1 - y0) / 2
-                lines = 2
-            else: # 判断长宽比
-                if block_height/block_weight > 1.2:  # 细长的不分
+            elif block_weight > page_w * 0.25:  # (可能是三列结构,也切细点)
+                lines = int(block_height / line_height) + 1
+            else:  # 判断长宽比
+                if block_height / block_weight > 1.2:  # 细长的不分
                     return [[x0, y0, x1, y1]]
-                else: # 不细长的还是分成两行
+                else:  # 不细长的还是分成两行
                     line_height = (y1 - y0) / 2
                     lines = 2
 
@@ -229,7 +239,11 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
 def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
     page_line_list = []
     for block in fix_blocks:
-        if block['type'] in ['text', 'title', 'interline_equation']:
+        if block['type'] in [
+            BlockType.Text, BlockType.Title, BlockType.InterlineEquation,
+            BlockType.ImageCaption, BlockType.ImageFootnote,
+            BlockType.TableCaption, BlockType.TableFootnote
+        ]:
             if len(block['lines']) == 0:
                 bbox = block['bbox']
                 lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
@@ -240,8 +254,9 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
                 for line in block['lines']:
                     bbox = line['bbox']
                     page_line_list.append(bbox)
-        elif block['type'] in ['table', 'image']:
+        elif block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
             bbox = block['bbox']
+            block["real_lines"] = copy.deepcopy(block['lines'])
             lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
             block['lines'] = []
             for line in lines:
@@ -256,19 +271,23 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
     for left, top, right, bottom in page_line_list:
         if left < 0:
             logger.warning(
-                f"left < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}")
+                f'left < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
+            )  # noqa: E501
             left = 0
         if right > page_w:
             logger.warning(
-                f"right > page_w, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}")
+                f'right > page_w, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
+            )  # noqa: E501
             right = page_w
         if top < 0:
             logger.warning(
-                f"top < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}")
+                f'top < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
+            )  # noqa: E501
             top = 0
         if bottom > page_h:
             logger.warning(
-                f"bottom > page_h, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}")
+                f'bottom > page_h, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
+            )  # noqa: E501
             bottom = page_h
 
         left = round(left * x_scale)
@@ -276,11 +295,11 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
         right = round(right * x_scale)
         bottom = round(bottom * y_scale)
         assert (
-                1000 >= right >= left >= 0 and 1000 >= bottom >= top >= 0
-        ), f"Invalid box. right: {right}, left: {left}, bottom: {bottom}, top: {top}"
+            1000 >= right >= left >= 0 and 1000 >= bottom >= top >= 0
+        ), f'Invalid box. right: {right}, left: {left}, bottom: {bottom}, top: {top}'  # noqa: E126, E121
         boxes.append([left, top, right, bottom])
     model_manager = ModelSingleton()
-    model = model_manager.get_model("layoutreader")
+    model = model_manager.get_model('layoutreader')
     with torch.no_grad():
         orders = do_predict(boxes, model)
     sorted_bboxes = [page_line_list[i] for i in orders]
@@ -291,149 +310,274 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
 def get_line_height(blocks):
     page_line_height_list = []
     for block in blocks:
-        if block['type'] in ['text', 'title', 'interline_equation']:
+        if block['type'] in [
+            BlockType.Text, BlockType.Title,
+            BlockType.ImageCaption, BlockType.ImageFootnote,
+            BlockType.TableCaption, BlockType.TableFootnote
+        ]:
             for line in block['lines']:
                 bbox = line['bbox']
-                page_line_height_list.append(int(bbox[3]-bbox[1]))
+                page_line_height_list.append(int(bbox[3] - bbox[1]))
     if len(page_line_height_list) > 0:
         return statistics.median(page_line_height_list)
     else:
         return 10
 
 
-def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode):
+def process_groups(groups, body_key, caption_key, footnote_key):
+    body_blocks = []
+    caption_blocks = []
+    footnote_blocks = []
+    for i, group in enumerate(groups):
+        group[body_key]['group_id'] = i
+        body_blocks.append(group[body_key])
+        for caption_block in group[caption_key]:
+            caption_block['group_id'] = i
+            caption_blocks.append(caption_block)
+        for footnote_block in group[footnote_key]:
+            footnote_block['group_id'] = i
+            footnote_blocks.append(footnote_block)
+    return body_blocks, caption_blocks, footnote_blocks
+
+
+def process_block_list(blocks, body_type, block_type):
+    indices = [block['index'] for block in blocks]
+    median_index = statistics.median(indices)
+
+    body_bbox = next((block['bbox'] for block in blocks if block.get('type') == body_type), [])
+
+    return {
+        'type': block_type,
+        'bbox': body_bbox,
+        'blocks': blocks,
+        'index': median_index,
+    }
+
+
+def revert_group_blocks(blocks):
+    image_groups = {}
+    table_groups = {}
+    new_blocks = []
+    for block in blocks:
+        if block['type'] in [BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote]:
+            group_id = block['group_id']
+            if group_id not in image_groups:
+                image_groups[group_id] = []
+            image_groups[group_id].append(block)
+        elif block['type'] in [BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote]:
+            group_id = block['group_id']
+            if group_id not in table_groups:
+                table_groups[group_id] = []
+            table_groups[group_id].append(block)
+        else:
+            new_blocks.append(block)
+
+    for group_id, blocks in image_groups.items():
+        new_blocks.append(process_block_list(blocks, BlockType.ImageBody, BlockType.Image))
+
+    for group_id, blocks in table_groups.items():
+        new_blocks.append(process_block_list(blocks, BlockType.TableBody, BlockType.Table))
+
+    return new_blocks
+
+
+def parse_page_core(
+    page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
+):
     need_drop = False
     drop_reason = []
 
-    '''从magic_model对象中获取后面会用到的区块信息'''
-    img_blocks = magic_model.get_imgs(page_id)
-    table_blocks = magic_model.get_tables(page_id)
+    """从magic_model对象中获取后面会用到的区块信息"""
+    # img_blocks = magic_model.get_imgs(page_id)
+    # table_blocks = magic_model.get_tables(page_id)
+
+    img_groups = magic_model.get_imgs_v2(page_id)
+    table_groups = magic_model.get_tables_v2(page_id)
+
+    img_body_blocks, img_caption_blocks, img_footnote_blocks = process_groups(
+        img_groups, 'image_body', 'image_caption_list', 'image_footnote_list'
+    )
+
+    table_body_blocks, table_caption_blocks, table_footnote_blocks = process_groups(
+        table_groups, 'table_body', 'table_caption_list', 'table_footnote_list'
+    )
+
     discarded_blocks = magic_model.get_discarded(page_id)
     text_blocks = magic_model.get_text_blocks(page_id)
     title_blocks = magic_model.get_title_blocks(page_id)
-    inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations(page_id)
+    inline_equations, interline_equations, interline_equation_blocks = (
+        magic_model.get_equations(page_id)
+    )
 
     page_w, page_h = magic_model.get_page_size(page_id)
 
     spans = magic_model.get_all_spans(page_id)
 
-    '''根据parse_mode,构造spans'''
-    if parse_mode == "txt":
+    """根据parse_mode,构造spans"""
+    if parse_mode == SupportedPdfParseMethod.TXT:
         """ocr 中文本类的 span 用 pymu spans 替换!"""
-        pymu_spans = txt_spans_extract(
-            pdf_docs[page_id], inline_equations, interline_equations
-        )
+        pymu_spans = txt_spans_extract(page_doc, inline_equations, interline_equations)
         spans = replace_text_span(pymu_spans, spans)
-    elif parse_mode == "ocr":
+    elif parse_mode == SupportedPdfParseMethod.OCR:
         pass
     else:
-        raise Exception("parse_mode must be txt or ocr")
+        raise Exception('parse_mode must be txt or ocr')
 
-    '''删除重叠spans中置信度较低的那些'''
+    """删除重叠spans中置信度较低的那些"""
     spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
-    '''删除重叠spans中较小的那些'''
+    """删除重叠spans中较小的那些"""
     spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
-    '''对image和table截图'''
-    spans = ocr_cut_image_and_table(spans, pdf_docs[page_id], page_id, pdf_bytes_md5, imageWriter)
+    """对image和table截图"""
+    spans = ocr_cut_image_and_table(
+        spans, page_doc, page_id, pdf_bytes_md5, imageWriter
+    )
 
-    '''将所有区块的bbox整理到一起'''
+    """将所有区块的bbox整理到一起"""
     # interline_equation_blocks参数不够准,后面切换到interline_equations上
     interline_equation_blocks = []
     if len(interline_equation_blocks) > 0:
         all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2(
-            img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks,
-            interline_equation_blocks, page_w, page_h)
+            img_body_blocks, img_caption_blocks, img_footnote_blocks,
+            table_body_blocks, table_caption_blocks, table_footnote_blocks,
+            discarded_blocks,
+            text_blocks,
+            title_blocks,
+            interline_equation_blocks,
+            page_w,
+            page_h,
+        )
     else:
         all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2(
-            img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks,
-            interline_equations, page_w, page_h)
+            img_body_blocks, img_caption_blocks, img_footnote_blocks,
+            table_body_blocks, table_caption_blocks, table_footnote_blocks,
+            discarded_blocks,
+            text_blocks,
+            title_blocks,
+            interline_equations,
+            page_w,
+            page_h,
+        )
 
-    '''先处理不需要排版的discarded_blocks'''
-    discarded_block_with_spans, spans = fill_spans_in_blocks(all_discarded_blocks, spans, 0.4)
+    """先处理不需要排版的discarded_blocks"""
+    discarded_block_with_spans, spans = fill_spans_in_blocks(
+        all_discarded_blocks, spans, 0.4
+    )
     fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)
 
-    '''如果当前页面没有bbox则跳过'''
+    """如果当前页面没有bbox则跳过"""
     if len(all_bboxes) == 0:
-        logger.warning(f"skip this page, not found useful bbox, page_id: {page_id}")
-        return ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [],
-                                               [], [], interline_equations, fix_discarded_blocks,
-                                               need_drop, drop_reason)
+        logger.warning(f'skip this page, not found useful bbox, page_id: {page_id}')
+        return ocr_construct_page_component_v2(
+            [],
+            [],
+            page_id,
+            page_w,
+            page_h,
+            [],
+            [],
+            [],
+            interline_equations,
+            fix_discarded_blocks,
+            need_drop,
+            drop_reason,
+        )
 
-    '''将span填入blocks中'''
-    block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.3)
+    """将span填入blocks中"""
+    block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.5)
 
-    '''对block进行fix操作'''
-    fix_blocks = fix_block_spans(block_with_spans, img_blocks, table_blocks)
+    """对block进行fix操作"""
+    fix_blocks = fix_block_spans_v2(block_with_spans)
 
-    '''获取所有line并计算正文line的高度'''
+    """获取所有line并计算正文line的高度"""
     line_height = get_line_height(fix_blocks)
 
-    '''获取所有line并对line排序'''
+    """获取所有line并对line排序"""
     sorted_bboxes = sort_lines_by_model(fix_blocks, page_w, page_h, line_height)
 
-    '''根据line的中位数算block的序列关系'''
+    """根据line的中位数算block的序列关系"""
     fix_blocks = cal_block_index(fix_blocks, sorted_bboxes)
 
-    '''重排block'''
+    """将image和table的block还原回group形式参与后续流程"""
+    fix_blocks = revert_group_blocks(fix_blocks)
+
+    """重排block"""
     sorted_blocks = sorted(fix_blocks, key=lambda b: b['index'])
 
-    '''获取QA需要外置的list'''
+    """获取QA需要外置的list"""
     images, tables, interline_equations = get_qa_need_list_v2(sorted_blocks)
 
-    '''构造pdf_info_dict'''
-    page_info = ocr_construct_page_component_v2(sorted_blocks, [], page_id, page_w, page_h, [],
-                                                images, tables, interline_equations, fix_discarded_blocks,
-                                                need_drop, drop_reason)
+    """构造pdf_info_dict"""
+    page_info = ocr_construct_page_component_v2(
+        sorted_blocks,
+        [],
+        page_id,
+        page_w,
+        page_h,
+        [],
+        images,
+        tables,
+        interline_equations,
+        fix_discarded_blocks,
+        need_drop,
+        drop_reason,
+    )
     return page_info
 
 
-def pdf_parse_union(pdf_bytes,
-                    model_list,
-                    imageWriter,
-                    parse_mode,
-                    start_page_id=0,
-                    end_page_id=None,
-                    debug_mode=False,
-                    ):
-    pdf_bytes_md5 = compute_md5(pdf_bytes)
-    pdf_docs = fitz.open("pdf", pdf_bytes)
+def pdf_parse_union(
+    dataset: Dataset,
+    model_list,
+    imageWriter,
+    parse_mode,
+    start_page_id=0,
+    end_page_id=None,
+    debug_mode=False,
+):
+    pdf_bytes_md5 = compute_md5(dataset.data_bits())
 
-    '''初始化空的pdf_info_dict'''
+    """初始化空的pdf_info_dict"""
     pdf_info_dict = {}
 
-    '''用model_list和docs对象初始化magic_model'''
-    magic_model = MagicModel(model_list, pdf_docs)
+    """用model_list和docs对象初始化magic_model"""
+    magic_model = MagicModel(model_list, dataset)
 
-    '''根据输入的起始范围解析pdf'''
+    """根据输入的起始范围解析pdf"""
     # end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
-    end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(pdf_docs) - 1
+    end_page_id = (
+        end_page_id
+        if end_page_id is not None and end_page_id >= 0
+        else len(dataset) - 1
+    )
 
-    if end_page_id > len(pdf_docs) - 1:
-        logger.warning("end_page_id is out of range, use pdf_docs length")
-        end_page_id = len(pdf_docs) - 1
+    if end_page_id > len(dataset) - 1:
+        logger.warning('end_page_id is out of range, use pdf_docs length')
+        end_page_id = len(dataset) - 1
 
-    '''初始化启动时间'''
+    """初始化启动时间"""
     start_time = time.time()
 
-    for page_id, page in enumerate(pdf_docs):
-        '''debug时输出每页解析的耗时'''
+    for page_id, page in enumerate(dataset):
+        """debug时输出每页解析的耗时."""
         if debug_mode:
             time_now = time.time()
             logger.info(
-                f"page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}"
+                f'page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}'
             )
             start_time = time_now
 
-        '''解析pdf中的每一页'''
+        """解析pdf中的每一页"""
         if start_page_id <= page_id <= end_page_id:
-            page_info = parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode)
+            page_info = parse_page_core(
+                page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
+            )
         else:
-            page_w = page.rect.width
-            page_h = page.rect.height
-            page_info = ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [],
-                                                [], [], [], [],
-                                                True, "skip page")
-        pdf_info_dict[f"page_{page_id}"] = page_info
+            page_info = page.get_page_info()
+            page_w = page_info.w
+            page_h = page_info.h
+            page_info = ocr_construct_page_component_v2(
+                [], [], page_id, page_w, page_h, [], [], [], [], [], True, 'skip page'
+            )
+        pdf_info_dict[f'page_{page_id}'] = page_info
 
     """分段"""
     para_split(pdf_info_dict, debug_mode=debug_mode)
@@ -441,7 +585,7 @@ def pdf_parse_union(pdf_bytes,
     """dict转list"""
     pdf_info_list = dict_to_list(pdf_info_dict)
     new_pdf_info_dict = {
-        "pdf_info": pdf_info_list,
+        'pdf_info': pdf_info_list,
     }
 
     clean_memory()

+ 6 - 7
magic_pdf/pipe/AbsPipe.py

@@ -17,7 +17,7 @@ class AbsPipe(ABC):
     PIP_TXT = "txt"
 
     def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
-                 start_page_id=0, end_page_id=None, lang=None):
+                 start_page_id=0, end_page_id=None, lang=None, layout_model=None, formula_enable=None, table_enable=None):
         self.pdf_bytes = pdf_bytes
         self.model_list = model_list
         self.image_writer = image_writer
@@ -26,6 +26,9 @@ class AbsPipe(ABC):
         self.start_page_id = start_page_id
         self.end_page_id = end_page_id
         self.lang = lang
+        self.layout_model = layout_model
+        self.formula_enable = formula_enable
+        self.table_enable = table_enable
     
     def get_compress_pdf_mid_data(self):
         return JsonCompressor.compress_json(self.pdf_mid_data)
@@ -95,9 +98,7 @@ class AbsPipe(ABC):
         """
         pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data)
         pdf_info_list = pdf_mid_data["pdf_info"]
-        parse_type = pdf_mid_data["_parse_type"]
-        lang = pdf_mid_data.get("_lang", None)
-        content_list = union_make(pdf_info_list, MakeMode.STANDARD_FORMAT, drop_mode, img_buket_path, parse_type, lang)
+        content_list = union_make(pdf_info_list, MakeMode.STANDARD_FORMAT, drop_mode, img_buket_path)
         return content_list
 
     @staticmethod
@@ -107,9 +108,7 @@ class AbsPipe(ABC):
         """
         pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data)
         pdf_info_list = pdf_mid_data["pdf_info"]
-        parse_type = pdf_mid_data["_parse_type"]
-        lang = pdf_mid_data.get("_lang", None)
-        md_content = union_make(pdf_info_list, md_make_mode, drop_mode, img_buket_path, parse_type, lang)
+        md_content = union_make(pdf_info_list, md_make_mode, drop_mode, img_buket_path)
         return md_content
 
 

+ 8 - 4
magic_pdf/pipe/OCRPipe.py

@@ -10,8 +10,10 @@ from magic_pdf.user_api import parse_ocr_pdf
 class OCRPipe(AbsPipe):
 
     def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
-                 start_page_id=0, end_page_id=None, lang=None):
-        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang)
+                 start_page_id=0, end_page_id=None, lang=None,
+                 layout_model=None, formula_enable=None, table_enable=None):
+        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
+                         layout_model, formula_enable, table_enable)
 
     def pipe_classify(self):
         pass
@@ -19,12 +21,14 @@ class OCRPipe(AbsPipe):
     def pipe_analyze(self):
         self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
                                       start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                      lang=self.lang)
+                                      lang=self.lang, layout_model=self.layout_model,
+                                      formula_enable=self.formula_enable, table_enable=self.table_enable)
 
     def pipe_parse(self):
         self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
                                           start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                          lang=self.lang)
+                                          lang=self.lang, layout_model=self.layout_model,
+                                          formula_enable=self.formula_enable, table_enable=self.table_enable)
 
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)

+ 8 - 4
magic_pdf/pipe/TXTPipe.py

@@ -11,8 +11,10 @@ from magic_pdf.user_api import parse_txt_pdf
 class TXTPipe(AbsPipe):
 
     def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
-                 start_page_id=0, end_page_id=None, lang=None):
-        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang)
+                 start_page_id=0, end_page_id=None, lang=None,
+                 layout_model=None, formula_enable=None, table_enable=None):
+        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
+                         layout_model, formula_enable, table_enable)
 
     def pipe_classify(self):
         pass
@@ -20,12 +22,14 @@ class TXTPipe(AbsPipe):
     def pipe_analyze(self):
         self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
                                       start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                      lang=self.lang)
+                                      lang=self.lang, layout_model=self.layout_model,
+                                      formula_enable=self.formula_enable, table_enable=self.table_enable)
 
     def pipe_parse(self):
         self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
                                           start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                          lang=self.lang)
+                                          lang=self.lang, layout_model=self.layout_model,
+                                          formula_enable=self.formula_enable, table_enable=self.table_enable)
 
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)

+ 10 - 5
magic_pdf/pipe/UNIPipe.py

@@ -14,9 +14,11 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
 class UNIPipe(AbsPipe):
 
     def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False,
-                 start_page_id=0, end_page_id=None, lang=None):
+                 start_page_id=0, end_page_id=None, lang=None,
+                 layout_model=None, formula_enable=None, table_enable=None):
         self.pdf_type = jso_useful_key["_pdf_type"]
-        super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id, lang)
+        super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id,
+                         lang, layout_model, formula_enable, table_enable)
         if len(self.model_list) == 0:
             self.input_model_is_empty = True
         else:
@@ -29,18 +31,21 @@ class UNIPipe(AbsPipe):
         if self.pdf_type == self.PIP_TXT:
             self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
                                           start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                          lang=self.lang)
+                                          lang=self.lang, layout_model=self.layout_model,
+                                          formula_enable=self.formula_enable, table_enable=self.table_enable)
         elif self.pdf_type == self.PIP_OCR:
             self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
                                           start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                          lang=self.lang)
+                                          lang=self.lang, layout_model=self.layout_model,
+                                          formula_enable=self.formula_enable, table_enable=self.table_enable)
 
     def pipe_parse(self):
         if self.pdf_type == self.PIP_TXT:
             self.pdf_mid_data = parse_union_pdf(self.pdf_bytes, self.model_list, self.image_writer,
                                                 is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty,
                                                 start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                                lang=self.lang)
+                                                lang=self.lang, layout_model=self.layout_model,
+                                                formula_enable=self.formula_enable, table_enable=self.table_enable)
         elif self.pdf_type == self.PIP_OCR:
             self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer,
                                               is_debug=self.is_debug,

+ 55 - 26
magic_pdf/pre_proc/ocr_detect_all_bboxes.py

@@ -1,7 +1,7 @@
 from loguru import logger
 
 from magic_pdf.libs.boxbase import get_minbox_if_overlap_by_ratio, calculate_overlap_area_in_bbox1_area_ratio, \
-    calculate_iou
+    calculate_iou, calculate_vertical_projection_overlap_ratio
 from magic_pdf.libs.drop_tag import DropTag
 from magic_pdf.libs.ocr_content_type import BlockType
 from magic_pdf.pre_proc.remove_bbox_overlap import remove_overlap_between_bbox_for_block
@@ -60,29 +60,34 @@ def ocr_prepare_bboxes_for_layout_split(img_blocks, table_blocks, discarded_bloc
     return all_bboxes, all_discarded_blocks, drop_reasons
 
 
-def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_blocks, text_blocks,
-                                        title_blocks, interline_equation_blocks, page_w, page_h):
-    all_bboxes = []
-    all_discarded_blocks = []
-    for image in img_blocks:
-        x0, y0, x1, y1 = image['bbox']
-        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Image, None, None, None, None, image["score"]])
-
-    for table in table_blocks:
-        x0, y0, x1, y1 = table['bbox']
-        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Table, None, None, None, None, table["score"]])
+def add_bboxes(blocks, block_type, bboxes):
+    for block in blocks:
+        x0, y0, x1, y1 = block['bbox']
+        if block_type in [
+            BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote,
+            BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote
+        ]:
+            bboxes.append([x0, y0, x1, y1, None, None, None, block_type, None, None, None, None, block["score"], block["group_id"]])
+        else:
+            bboxes.append([x0, y0, x1, y1, None, None, None, block_type, None, None, None, None, block["score"]])
 
-    for text in text_blocks:
-        x0, y0, x1, y1 = text['bbox']
-        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Text, None, None, None, None, text["score"]])
 
-    for title in title_blocks:
-        x0, y0, x1, y1 = title['bbox']
-        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Title, None, None, None, None, title["score"]])
+def ocr_prepare_bboxes_for_layout_split_v2(
+        img_body_blocks, img_caption_blocks, img_footnote_blocks,
+        table_body_blocks, table_caption_blocks, table_footnote_blocks,
+        discarded_blocks, text_blocks, title_blocks, interline_equation_blocks, page_w, page_h
+):
+    all_bboxes = []
 
-    for interline_equation in interline_equation_blocks:
-        x0, y0, x1, y1 = interline_equation['bbox']
-        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.InterlineEquation, None, None, None, None, interline_equation["score"]])
+    add_bboxes(img_body_blocks, BlockType.ImageBody, all_bboxes)
+    add_bboxes(img_caption_blocks, BlockType.ImageCaption, all_bboxes)
+    add_bboxes(img_footnote_blocks, BlockType.ImageFootnote, all_bboxes)
+    add_bboxes(table_body_blocks, BlockType.TableBody, all_bboxes)
+    add_bboxes(table_caption_blocks, BlockType.TableCaption, all_bboxes)
+    add_bboxes(table_footnote_blocks, BlockType.TableFootnote, all_bboxes)
+    add_bboxes(text_blocks, BlockType.Text, all_bboxes)
+    add_bboxes(title_blocks, BlockType.Title, all_bboxes)
+    add_bboxes(interline_equation_blocks, BlockType.InterlineEquation, all_bboxes)
 
     '''block嵌套问题解决'''
     '''文本框与标题框重叠,优先信任文本框'''
@@ -96,13 +101,23 @@ def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_b
     '''interline_equation框被包含在文本类型框内,且interline_equation比文本区块小很多时信任文本框,这时需要舍弃公式框'''
     # 通过后续大框套小框逻辑删除
 
-    '''discarded_blocks中只保留宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的(限定footnote)'''
+    '''discarded_blocks'''
+    all_discarded_blocks = []
+    add_bboxes(discarded_blocks, BlockType.Discarded, all_discarded_blocks)
+
+    '''footnote识别:宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的'''
+    footnote_blocks = []
     for discarded in discarded_blocks:
         x0, y0, x1, y1 = discarded['bbox']
-        all_discarded_blocks.append([x0, y0, x1, y1, None, None, None, BlockType.Discarded, None, None, None, None, discarded["score"]])
-        # 将footnote加入到all_bboxes中,用来计算layout
-        # if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2):
-        #     all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Footnote, None, None, None, None, discarded["score"]])
+        if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2):
+            footnote_blocks.append([x0, y0, x1, y1])
+
+    '''移除在footnote下面的任何框'''
+    need_remove_blocks = find_blocks_under_footnote(all_bboxes, footnote_blocks)
+    if len(need_remove_blocks) > 0:
+        for block in need_remove_blocks:
+            all_bboxes.remove(block)
+            all_discarded_blocks.append(block)
 
     '''经过以上处理后,还存在大框套小框的情况,则删除小框'''
     all_bboxes = remove_overlaps_min_blocks(all_bboxes)
@@ -113,6 +128,20 @@ def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_b
     return all_bboxes, all_discarded_blocks
 
 
+def find_blocks_under_footnote(all_bboxes, footnote_blocks):
+    need_remove_blocks = []
+    for block in all_bboxes:
+        block_x0, block_y0, block_x1, block_y1 = block[:4]
+        for footnote_bbox in footnote_blocks:
+            footnote_x0, footnote_y0, footnote_x1, footnote_y1 = footnote_bbox
+            # 如果footnote的纵向投影覆盖了block的纵向投影的80%且block的y0大于等于footnote的y1
+            if block_y0 >= footnote_y1 and calculate_vertical_projection_overlap_ratio((block_x0, block_y0, block_x1, block_y1), footnote_bbox) >= 0.8:
+                if block not in need_remove_blocks:
+                    need_remove_blocks.append(block)
+                    break
+    return need_remove_blocks
+
+
 def fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes):
     # 先提取所有text和interline block
     text_blocks = []

+ 27 - 1
magic_pdf/pre_proc/ocr_dict_merge.py

@@ -49,7 +49,7 @@ def merge_spans_to_line(spans):
                 continue
 
             # 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
-            if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], 0.6):
+            if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], 0.5):
                 current_line.append(span)
             else:
                 # 否则,开始新行
@@ -153,6 +153,11 @@ def fill_spans_in_blocks(blocks, spans, radio):
             'type': block_type,
             'bbox': block_bbox,
         }
+        if block_type in [
+            BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote,
+            BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote
+        ]:
+            block_dict["group_id"] = block[-1]
         block_spans = []
         for span in spans:
             span_bbox = span['bbox']
@@ -201,6 +206,27 @@ def fix_block_spans(block_with_spans, img_blocks, table_blocks):
     return fix_blocks
 
 
+def fix_block_spans_v2(block_with_spans):
+    """1、img_block和table_block因为包含caption和footnote的关系,存在block的嵌套关系
+    需要将caption和footnote的text_span放入相应img_block和table_block内的
+    caption_block和footnote_block中 2、同时需要删除block中的spans字段."""
+    fix_blocks = []
+    for block in block_with_spans:
+        block_type = block['type']
+
+        if block_type in [BlockType.Text, BlockType.Title,
+                          BlockType.ImageCaption, BlockType.ImageFootnote,
+                          BlockType.TableCaption, BlockType.TableFootnote
+                          ]:
+            block = fix_text_block(block)
+        elif block_type in [BlockType.InterlineEquation, BlockType.ImageBody, BlockType.TableBody]:
+            block = fix_interline_block(block)
+        else:
+            continue
+        fix_blocks.append(block)
+    return fix_blocks
+
+
 def fix_discarded_block(discarded_block_with_spans):
     fix_discarded_blocks = []
     for block in discarded_block_with_spans:

+ 5 - 13
magic_pdf/resources/model_config/model_configs.yaml

@@ -1,15 +1,7 @@
-config:
-  device: cpu
-  layout: True
-  formula: True
-  table_config:
-    model: TableMaster
-    is_table_recog_enable: False
-    max_time: 400
-
 weights:
-  layout: Layout/model_final.pth
-  mfd: MFD/weights.pt
-  mfr: MFR/unimernet_small
+  layoutlmv3: Layout/LayoutLMv3/model_final.pth
+  doclayout_yolo: Layout/YOLO/doclayout_yolo_ft.pt
+  yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt
+  unimernet_small: MFR/unimernet_small
   struct_eqtable: TabRec/StructEqTable
-  TableMaster: TabRec/TableMaster
+  tablemaster: TabRec/TableMaster

+ 1 - 1
magic_pdf/tools/cli.py

@@ -52,7 +52,7 @@ without method specified, auto will be used by default.""",
     help="""
     Input the languages in the pdf (if known) to improve OCR accuracy.  Optional.
     You should input "Abbreviation" with language form url:
-    https://paddlepaddle.github.io/PaddleOCR/en/ppocr/blog/multi_languages.html#5-support-languages-and-abbreviations
+    https://paddlepaddle.github.io/PaddleOCR/latest/en/ppocr/blog/multi_languages.html#5-support-languages-and-abbreviations
     """,
     default=None,
 )

+ 11 - 6
magic_pdf/tools/common.py

@@ -6,8 +6,8 @@ import click
 from loguru import logger
 
 import magic_pdf.model as model_config
-from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_span_bbox,
-                                      draw_model_bbox, draw_line_sort_bbox)
+from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_line_sort_bbox,
+                                      draw_model_bbox, draw_span_bbox)
 from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
 from magic_pdf.pipe.OCRPipe import OCRPipe
 from magic_pdf.pipe.TXTPipe import TXTPipe
@@ -46,10 +46,12 @@ def do_parse(
     start_page_id=0,
     end_page_id=None,
     lang=None,
+    layout_model=None,
+    formula_enable=None,
+    table_enable=None,
 ):
     if debug_able:
         logger.warning('debug mode is on')
-        # f_dump_content_list = True
         f_draw_model_bbox = True
         f_draw_line_sort_bbox = True
 
@@ -64,13 +66,16 @@ def do_parse(
     if parse_method == 'auto':
         jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
         pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True,
-                       start_page_id=start_page_id, end_page_id=end_page_id, lang=lang)
+                       start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
+                       layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
     elif parse_method == 'txt':
         pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True,
-                       start_page_id=start_page_id, end_page_id=end_page_id, lang=lang)
+                       start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
+                       layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
     elif parse_method == 'ocr':
         pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True,
-                       start_page_id=start_page_id, end_page_id=end_page_id, lang=lang)
+                       start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
+                       layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
     else:
         logger.error('unknown parse method')
         exit(1)

+ 13 - 5
magic_pdf/user_api.py

@@ -101,11 +101,19 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
     if pdf_info_dict is None or pdf_info_dict.get("_need_drop", False):
         logger.warning(f"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr")
         if input_model_is_empty:
-            pdf_models = doc_analyze(pdf_bytes,
-                                     ocr=True,
-                                     start_page_id=start_page_id,
-                                     end_page_id=end_page_id,
-                                     lang=lang)
+            layout_model = kwargs.get("layout_model", None)
+            formula_enable = kwargs.get("formula_enable", None)
+            table_enable = kwargs.get("table_enable", None)
+            pdf_models = doc_analyze(
+                pdf_bytes,
+                ocr=True,
+                start_page_id=start_page_id,
+                end_page_id=end_page_id,
+                lang=lang,
+                layout_model=layout_model,
+                formula_enable=formula_enable,
+                table_enable=table_enable,
+            )
         pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
         if pdf_info_dict is None:
             raise Exception("Both parse_pdf_by_txt and parse_pdf_by_ocr failed.")

+ 0 - 0
magic_pdf/utils/__init__.py


+ 11 - 0
magic_pdf/utils/annotations.py

@@ -0,0 +1,11 @@
+
+from loguru import logger
+
+
+def ImportPIL(f):
+    try:
+        import PIL  # noqa: F401
+    except ImportError:
+        logger.error('Pillow not installed, please install by pip.')
+        exit(1)
+    return f

+ 0 - 0
docs/en/.readthedocs.yaml → next_docs/en/.readthedocs.yaml


+ 0 - 0
docs/en/Makefile → next_docs/en/Makefile


+ 0 - 0
docs/en/_static/image/logo.png → next_docs/en/_static/image/logo.png


+ 9 - 0
next_docs/en/api.rst

@@ -0,0 +1,9 @@
+Data Api
+------------------
+
+.. toctree::
+   :maxdepth: 2
+
+   api/dataset.rst
+   api/data_reader_writer.rst
+   api/read_api.rst

+ 44 - 0
next_docs/en/api/data_reader_writer.rst

@@ -0,0 +1,44 @@
+
+Data Reader Writer
+--------------------
+
+.. autoclass:: magic_pdf.data.data_reader_writer.DataReader
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.data_reader_writer.DataWriter
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.data_reader_writer.S3DataReader
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.data_reader_writer.S3DataWriter
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.data_reader_writer.FileBasedDataReader
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.data_reader_writer.FileBasedDataWriter
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.data_reader_writer.S3DataReader
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.data_reader_writer.S3DataWriter
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.data_reader_writer.MultiBucketS3DataReader
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.data_reader_writer.MultiBucketS3DataWriter
+   :members:
+   :inherited-members:
+

+ 22 - 0
next_docs/en/api/dataset.rst

@@ -0,0 +1,22 @@
+Dataset Api
+------------------
+
+.. autoclass:: magic_pdf.data.dataset.PageableData
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.dataset.Dataset
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.dataset.ImageDataset
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.dataset.PymuDocDataset
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.dataset.Doc
+   :members:
+   :inherited-members:

+ 0 - 0
next_docs/en/api/io.rst


+ 6 - 0
next_docs/en/api/read_api.rst

@@ -0,0 +1,6 @@
+read_api Api
+------------------
+
+.. automodule:: magic_pdf.data.read_api
+   :members:
+   :inherited-members:

+ 0 - 0
next_docs/en/api/schemas.rst


+ 1 - 0
next_docs/en/api/utils.rst

@@ -0,0 +1 @@
+

+ 0 - 0
docs/en/conf.py → next_docs/en/conf.py


+ 12 - 0
docs/en/index.rst → next_docs/en/index.rst

@@ -24,3 +24,15 @@ Welcome to the MinerU Documentation
    <a class="github-button" href="https://github.com/opendatalab/MinerU/subscription" data-icon="octicon-eye" data-size="large" aria-label="Watch">Watch</a>
    <a class="github-button" href="https://github.com/opendatalab/MinerU/fork" data-icon="octicon-repo-forked" data-size="large" aria-label="Fork">Fork</a>
    </p>
+
+
+API Reference
+-------------
+
+If you are looking for information on a specific function, class or
+method, this part of the documentation is for you.
+
+.. toctree::
+   :maxdepth: 2
+
+   api

+ 0 - 0
docs/en/make.bat → next_docs/en/make.bat


+ 5 - 0
docs/requirements.txt → next_docs/requirements.txt

@@ -1,4 +1,9 @@
+boto3>=1.28.43
+loguru>=0.6.0
 myst-parser
+Pillow==8.4.0
+pydantic>=2.7.2,<2.8.0
+PyMuPDF>=1.24.9
 sphinx
 sphinx-argparse
 sphinx-book-theme

+ 0 - 0
docs/zh_cn/.readthedocs.yaml → next_docs/zh_cn/.readthedocs.yaml


+ 0 - 0
docs/zh_cn/Makefile → next_docs/zh_cn/Makefile


+ 0 - 0
docs/zh_cn/_static/image/logo.png → next_docs/zh_cn/_static/image/logo.png


+ 0 - 0
docs/zh_cn/conf.py → next_docs/zh_cn/conf.py


+ 0 - 0
docs/zh_cn/index.rst → next_docs/zh_cn/index.rst


+ 0 - 0
docs/zh_cn/make.bat → next_docs/zh_cn/make.bat


+ 1 - 2
projects/README.md

@@ -6,5 +6,4 @@
 - [gradio_app](./gradio_app/README.md): Build a web app based on gradio
 - [web_demo](./web_demo/README.md): MinerU online [demo](https://opendatalab.com/OpenSourceTools/Extractor/PDF/) localized deployment version
 - [web_api](./web_api/README.md): Web API Based on FastAPI
-
-
+- [multi_gpu](./multi_gpu/README.md): Multi-GPU parallel processing based on LitServe

+ 1 - 1
projects/README_zh-CN.md

@@ -6,4 +6,4 @@
 - [gradio_app](./gradio_app/README_zh-CN.md): 基于 Gradio 的 Web 应用
 - [web_demo](./web_demo/README_zh-CN.md): MinerU在线[demo](https://opendatalab.com/OpenSourceTools/Extractor/PDF/)本地化部署版本
 - [web_api](./web_api/README.md): 基于 FastAPI 的 Web API
-
+- [multi_gpu](./multi_gpu/README.md): 基于 LitServe 的多 GPU 并行处理

+ 68 - 12
projects/gradio_app/app.py

@@ -3,10 +3,12 @@
 import base64
 import os
 import time
+import uuid
 import zipfile
 from pathlib import Path
 import re
 
+import pymupdf
 from loguru import logger
 
 from magic_pdf.libs.hash_utils import compute_sha256
@@ -23,7 +25,7 @@ def read_fn(path):
     return disk_rw.read(os.path.basename(path), AbsReaderWriter.MODE_BIN)
 
 
-def parse_pdf(doc_path, output_dir, end_page_id, is_ocr):
+def parse_pdf(doc_path, output_dir, end_page_id, is_ocr, layout_mode, formula_enable, table_enable, language):
     os.makedirs(output_dir, exist_ok=True)
 
     try:
@@ -42,6 +44,10 @@ def parse_pdf(doc_path, output_dir, end_page_id, is_ocr):
             parse_method,
             False,
             end_page_id=end_page_id,
+            layout_model=layout_mode,
+            formula_enable=formula_enable,
+            table_enable=table_enable,
+            lang=language,
         )
         return local_md_dir, file_name
     except Exception as e:
@@ -93,9 +99,10 @@ def replace_image_with_base64(markdown_text, image_dir_path):
     return re.sub(pattern, replace, markdown_text)
 
 
-def to_markdown(file_path, end_pages, is_ocr):
+def to_markdown(file_path, end_pages, is_ocr, layout_mode, formula_enable, table_enable, language):
     # 获取识别的md文件以及压缩包文件路径
-    local_md_dir, file_name = parse_pdf(file_path, './output', end_pages - 1, is_ocr)
+    local_md_dir, file_name = parse_pdf(file_path, './output', end_pages - 1, is_ocr,
+                                        layout_mode, formula_enable, table_enable, language)
     archive_zip_path = os.path.join("./output", compute_sha256(local_md_dir) + ".zip")
     zip_archive_success = compress_directory_to_zip(local_md_dir, archive_zip_path)
     if zip_archive_success == 0:
@@ -138,24 +145,71 @@ with open("header.html", "r") as file:
     header = file.read()
 
 
+latin_lang = [
+        'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr',
+        'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl',
+        'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv',
+        'sw', 'tl', 'tr', 'uz', 'vi', 'french', 'german'
+]
+arabic_lang = ['ar', 'fa', 'ug', 'ur']
+cyrillic_lang = [
+        'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava',
+        'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
+]
+devanagari_lang = [
+        'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom',
+        'sa', 'bgc'
+]
+other_lang = ['ch', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka']
+
+all_lang = [""]
+all_lang.extend([*other_lang, *latin_lang, *arabic_lang, *cyrillic_lang, *devanagari_lang])
+
+
+def to_pdf(file_path):
+    with pymupdf.open(file_path) as f:
+        if f.is_pdf:
+            return file_path
+        else:
+            pdf_bytes = f.convert_to_pdf()
+            # 将pdfbytes 写入到uuid.pdf中
+            # 生成唯一的文件名
+            unique_filename = f"{uuid.uuid4()}.pdf"
+
+            # 构建完整的文件路径
+            tmp_file_path = os.path.join(os.path.dirname(file_path), unique_filename)
+
+            # 将字节数据写入文件
+            with open(tmp_file_path, 'wb') as tmp_pdf_file:
+                tmp_pdf_file.write(pdf_bytes)
+
+            return tmp_file_path
+
+
 if __name__ == "__main__":
     with gr.Blocks() as demo:
         gr.HTML(header)
         with gr.Row():
             with gr.Column(variant='panel', scale=5):
-                pdf_show = gr.Markdown()
+                file = gr.File(label="Please upload a PDF or image", file_types=[".pdf", ".png", ".jpeg", "jpg"])
                 max_pages = gr.Slider(1, 10, 5, step=1, label="Max convert pages")
-                with gr.Row() as bu_flow:
-                    is_ocr = gr.Checkbox(label="Force enable OCR")
+                with gr.Row():
+                    layout_mode = gr.Dropdown(["layoutlmv3", "doclayout_yolo"], label="Layout model", value="layoutlmv3")
+                    language = gr.Dropdown(all_lang, label="Language", value="")
+                with gr.Row():
+                    formula_enable = gr.Checkbox(label="Enable formula recognition", value=True)
+                    is_ocr = gr.Checkbox(label="Force enable OCR", value=False)
+                    table_enable = gr.Checkbox(label="Enable table recognition(test)", value=False)
+                with gr.Row():
                     change_bu = gr.Button("Convert")
-                    clear_bu = gr.ClearButton([pdf_show], value="Clear")
-                pdf_show = PDF(label="Please upload pdf", interactive=True, height=800)
+                    clear_bu = gr.ClearButton(value="Clear")
+                pdf_show = PDF(label="PDF preview", interactive=True, height=800)
                 with gr.Accordion("Examples:"):
                     example_root = os.path.join(os.path.dirname(__file__), "examples")
                     gr.Examples(
                         examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
                                   _.endswith("pdf")],
-                        inputs=pdf_show,
+                        inputs=pdf_show
                     )
 
             with gr.Column(variant='panel', scale=5):
@@ -166,7 +220,9 @@ if __name__ == "__main__":
                                          latex_delimiters=latex_delimiters, line_breaks=True)
                     with gr.Tab("Markdown text"):
                         md_text = gr.TextArea(lines=45, show_copy_button=True)
-        change_bu.click(fn=to_markdown, inputs=[pdf_show, max_pages, is_ocr], outputs=[md, md_text, output_file, pdf_show])
-        clear_bu.add([md, pdf_show, md_text, output_file, is_ocr])
+        file.upload(fn=to_pdf, inputs=file, outputs=pdf_show)
+        change_bu.click(fn=to_markdown, inputs=[pdf_show, max_pages, is_ocr, layout_mode, formula_enable, table_enable, language],
+                        outputs=[md, md_text, output_file, pdf_show])
+        clear_bu.add([file, md, pdf_show, md_text, output_file, is_ocr, table_enable, language])
 
-    demo.launch()
+    demo.launch(server_name="0.0.0.0")

BIN
projects/gradio_app/examples/2list_1table.pdf


BIN
projects/gradio_app/examples/3list_1table.pdf


Alguns arquivos não foram mostrados porque muitos arquivos mudaram nesse diff