Parcourir la source

Merge pull request #7 from opendatalab/dev

Dev
Kaiwen Liu il y a 1 an
Parent
commit
6d571e2e2c
100 fichiers modifiés avec 3658 ajouts et 1027 suppressions
  1. 50 45
      .gitignore
  2. 3 2
      .pre-commit-config.yaml
  3. 16 0
      .readthedocs.yaml
  4. 33 18
      README.md
  5. 32 16
      README_zh-CN.md
  6. 2 15
      demo/demo.py
  7. 14 9
      demo/magic_pdf_parse_main.py
  8. 7 2
      docs/FAQ_en_us.md
  9. 13 3
      docs/FAQ_zh_cn.md
  10. 79 76
      docs/README_Ubuntu_CUDA_Acceleration_en_US.md
  11. 46 28
      docs/README_Ubuntu_CUDA_Acceleration_zh_CN.md
  12. 79 83
      docs/README_Windows_CUDA_Acceleration_en_US.md
  13. 38 35
      docs/README_Windows_CUDA_Acceleration_zh_CN.md
  14. 58 3
      docs/download_models.py
  15. 65 2
      docs/download_models_hf.py
  16. 9 8
      docs/how_to_download_models_en.md
  17. 10 9
      docs/how_to_download_models_zh_cn.md
  18. 13 3
      magic-pdf.template.json
  19. 0 0
      magic_pdf/config/__init__.py
  20. 7 0
      magic_pdf/config/enums.py
  21. 32 0
      magic_pdf/config/exceptions.py
  22. 0 0
      magic_pdf/data/__init__.py
  23. 12 0
      magic_pdf/data/data_reader_writer/__init__.py
  24. 51 0
      magic_pdf/data/data_reader_writer/base.py
  25. 59 0
      magic_pdf/data/data_reader_writer/filebase.py
  26. 137 0
      magic_pdf/data/data_reader_writer/multi_bucket_s3.py
  27. 69 0
      magic_pdf/data/data_reader_writer/s3.py
  28. 194 0
      magic_pdf/data/dataset.py
  29. 0 0
      magic_pdf/data/io/__init__.py
  30. 42 0
      magic_pdf/data/io/base.py
  31. 37 0
      magic_pdf/data/io/http.py
  32. 114 0
      magic_pdf/data/io/s3.py
  33. 95 0
      magic_pdf/data/read_api.py
  34. 15 0
      magic_pdf/data/schemas.py
  35. 32 0
      magic_pdf/data/utils.py
  36. 47 237
      magic_pdf/dict2md/ocr_mkcontent.py
  37. 13 6
      magic_pdf/libs/Constants.py
  38. 35 0
      magic_pdf/libs/boxbase.py
  39. 53 23
      magic_pdf/libs/config_reader.py
  40. 75 43
      magic_pdf/libs/draw_bbox.py
  41. 2 0
      magic_pdf/libs/ocr_content_type.py
  42. 71 31
      magic_pdf/model/doc_analyze_by_custom_model.py
  43. 265 22
      magic_pdf/model/magic_model.py
  44. 86 47
      magic_pdf/model/pdf_extract_kit.py
  45. 2 2
      magic_pdf/model/ppTableModel.py
  46. 296 0
      magic_pdf/para/para_split_v3.py
  47. 5 2
      magic_pdf/pdf_parse_by_ocr.py
  48. 5 2
      magic_pdf/pdf_parse_by_txt.py
  49. 311 165
      magic_pdf/pdf_parse_union_core_v2.py
  50. 6 7
      magic_pdf/pipe/AbsPipe.py
  51. 8 4
      magic_pdf/pipe/OCRPipe.py
  52. 8 4
      magic_pdf/pipe/TXTPipe.py
  53. 10 5
      magic_pdf/pipe/UNIPipe.py
  54. 56 27
      magic_pdf/pre_proc/ocr_detect_all_bboxes.py
  55. 27 2
      magic_pdf/pre_proc/ocr_dict_merge.py
  56. 5 13
      magic_pdf/resources/model_config/model_configs.yaml
  57. 1 1
      magic_pdf/tools/cli.py
  58. 11 6
      magic_pdf/tools/common.py
  59. 13 5
      magic_pdf/user_api.py
  60. 0 0
      magic_pdf/utils/__init__.py
  61. 11 0
      magic_pdf/utils/annotations.py
  62. 16 0
      next_docs/en/.readthedocs.yaml
  63. 20 0
      next_docs/en/Makefile
  64. BIN
      next_docs/en/_static/image/logo.png
  65. 9 0
      next_docs/en/api.rst
  66. 44 0
      next_docs/en/api/data_reader_writer.rst
  67. 22 0
      next_docs/en/api/dataset.rst
  68. 0 0
      next_docs/en/api/io.rst
  69. 6 0
      next_docs/en/api/read_api.rst
  70. 0 0
      next_docs/en/api/schemas.rst
  71. 1 0
      next_docs/en/api/utils.rst
  72. 122 0
      next_docs/en/conf.py
  73. 38 0
      next_docs/en/index.rst
  74. 35 0
      next_docs/en/make.bat
  75. 11 0
      next_docs/requirements.txt
  76. 16 0
      next_docs/zh_cn/.readthedocs.yaml
  77. 20 0
      next_docs/zh_cn/Makefile
  78. BIN
      next_docs/zh_cn/_static/image/logo.png
  79. 122 0
      next_docs/zh_cn/conf.py
  80. 26 0
      next_docs/zh_cn/index.rst
  81. 35 0
      next_docs/zh_cn/make.bat
  82. 1 2
      projects/README.md
  83. 1 1
      projects/README_zh-CN.md
  84. 68 12
      projects/gradio_app/app.py
  85. BIN
      projects/gradio_app/examples/2list_1table.pdf
  86. BIN
      projects/gradio_app/examples/3list_1table.pdf
  87. BIN
      projects/gradio_app/examples/academic_paper_formula.pdf
  88. BIN
      projects/gradio_app/examples/academic_paper_img_formula.pdf
  89. BIN
      projects/gradio_app/examples/academic_paper_list.pdf
  90. BIN
      projects/gradio_app/examples/complex_layout.pdf
  91. BIN
      projects/gradio_app/examples/complex_layout_para_split_list.pdf
  92. BIN
      projects/gradio_app/examples/garbled_formula.pdf
  93. BIN
      projects/gradio_app/examples/garbled_formula2.pdf
  94. BIN
      projects/gradio_app/examples/garbled_img_formula.pdf
  95. BIN
      projects/gradio_app/examples/magazine_complex_layout_images_list.pdf
  96. 46 0
      projects/multi_gpu/README.md
  97. 39 0
      projects/multi_gpu/client.py
  98. 74 0
      projects/multi_gpu/server.py
  99. BIN
      projects/multi_gpu/small_ocr.pdf
  100. 1 1
      requirements-docker.txt

+ 50 - 45
.gitignore

@@ -1,45 +1,50 @@
-*.tar
-*.tar.gz
-*.zip
-venv*/
-envs/
-slurm_logs/
-
-sync1.sh
-data_preprocess_pj1
-data-preparation1
-__pycache__
-*.log
-*.pyc
-.vscode
-debug/
-*.ipynb
-.idea
-
-# vscode history
-.history
-
-.DS_Store
-.env
-
-bad_words/
-bak/
-
-app/tests/*
-temp/
-tmp/
-tmp
-.vscode
-.vscode/
-ocr_demo
-.coveragerc
-/app/common/__init__.py
-/magic_pdf/config/__init__.py
-source.dev.env
-
-tmp
-
-projects/web/node_modules
-projects/web/dist
-
-projects/web_demo/web_demo/static/
+*.tar
+*.tar.gz
+*.zip
+venv*/
+envs/
+slurm_logs/
+
+sync1.sh
+data_preprocess_pj1
+data-preparation1
+__pycache__
+*.log
+*.pyc
+.vscode
+debug/
+*.ipynb
+.idea
+
+# vscode history
+.history
+
+.DS_Store
+.env
+
+bad_words/
+bak/
+
+app/tests/*
+temp/
+tmp/
+tmp
+.vscode
+.vscode/
+ocr_demo
+.coveragerc
+/app/common/__init__.py
+/magic_pdf/config/__init__.py
+source.dev.env
+
+tmp
+
+projects/web/node_modules
+projects/web/dist
+
+projects/web_demo/web_demo/static/
+cli_debug/
+debug_utils/
+
+# sphinx docs
+_build/

+ 3 - 2
.pre-commit-config.yaml

@@ -3,7 +3,7 @@ repos:
     rev: 5.0.4
     hooks:
       - id: flake8
-        args: ["--max-line-length=120", "--ignore=E131,E125,W503,W504,E203"]
+        args: ["--max-line-length=150", "--ignore=E131,E125,W503,W504,E203"]
   - repo: https://github.com/PyCQA/isort
     rev: 5.11.5
     hooks:
@@ -12,11 +12,12 @@ repos:
     rev: v0.32.0
     hooks:
       - id: yapf
-        args: ["--style={based_on_style: google, column_limit: 120, indent_width: 4}"]
+        args: ["--style={based_on_style: google, column_limit: 150, indent_width: 4}"]
   - repo: https://github.com/codespell-project/codespell
     rev: v2.2.1
     hooks:
       - id: codespell
+        args: ['--skip', '*.json']
   - repo: https://github.com/pre-commit/pre-commit-hooks
     rev: v4.3.0
     hooks:

+ 16 - 0
.readthedocs.yaml

@@ -0,0 +1,16 @@
+version: 2
+
+build:
+  os: ubuntu-22.04
+  tools:
+    python: "3.10"
+
+formats:
+  - epub
+
+python:
+  install:
+    - requirements: docs/zh_cn/requirements.txt
+
+sphinx:
+  configuration: docs/zh_cn/conf.py

Fichier diff supprimé car celui-ci est trop grand
+ 33 - 18
README.md


Fichier diff supprimé car celui-ci est trop grand
+ 32 - 16
README_zh-CN.md


+ 2 - 15
demo/demo.py

@@ -1,35 +1,22 @@
 import os
-import json
 
 from loguru import logger
-
 from magic_pdf.pipe.UNIPipe import UNIPipe
 from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
 
-import magic_pdf.model as model_config 
-model_config.__use_inside_model__ = True
 
 try:
     current_script_dir = os.path.dirname(os.path.abspath(__file__))
     demo_name = "demo1"
     pdf_path = os.path.join(current_script_dir, f"{demo_name}.pdf")
-    model_path = os.path.join(current_script_dir, f"{demo_name}.json")
     pdf_bytes = open(pdf_path, "rb").read()
-    # model_json = json.loads(open(model_path, "r", encoding="utf-8").read())
-    model_json = []  # model_json传空list使用内置模型解析
-    jso_useful_key = {"_pdf_type": "", "model_list": model_json}
+    jso_useful_key = {"_pdf_type": "", "model_list": []}
     local_image_dir = os.path.join(current_script_dir, 'images')
     image_dir = str(os.path.basename(local_image_dir))
     image_writer = DiskReaderWriter(local_image_dir)
     pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer)
     pipe.pipe_classify()
-    """如果没有传入有效的模型数据,则使用内置model解析"""
-    if len(model_json) == 0:
-        if model_config.__use_inside_model__:
-            pipe.pipe_analyze()
-        else:
-            logger.error("need model list input")
-            exit(1)
+    pipe.pipe_analyze()
     pipe.pipe_parse()
     md_content = pipe.pipe_mk_markdown(image_dir, drop_mode="none")
     with open(f"{demo_name}.md", "w", encoding="utf-8") as f:

+ 14 - 9
demo/magic_pdf_parse_main.py

@@ -4,13 +4,12 @@ import copy
 
 from loguru import logger
 
+from magic_pdf.libs.draw_bbox import draw_layout_bbox, draw_span_bbox
 from magic_pdf.pipe.UNIPipe import UNIPipe
 from magic_pdf.pipe.OCRPipe import OCRPipe
 from magic_pdf.pipe.TXTPipe import TXTPipe
 from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
-import magic_pdf.model as model_config
 
-model_config.__use_inside_model__ = True
 
 # todo: 设备类型选择 (?)
 
@@ -47,11 +46,20 @@ def json_md_dump(
     )
 
 
+# 可视化
+def draw_visualization_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name):
+    # 画布局框,附带排序结果
+    draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
+    # 画 span 框
+    draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
+
+
 def pdf_parse_main(
         pdf_path: str,
         parse_method: str = 'auto',
         model_json_path: str = None,
         is_json_md_dump: bool = True,
+        is_draw_visualization_bbox: bool = True,
         output_dir: str = None
 ):
     """
@@ -108,11 +116,7 @@ def pdf_parse_main(
 
         # 如果没有传入模型数据,则使用内置模型解析
         if not model_json:
-            if model_config.__use_inside_model__:
-                pipe.pipe_analyze()  # 解析
-            else:
-                logger.error("need model list input")
-                exit(1)
+            pipe.pipe_analyze()  # 解析
 
         # 执行解析
         pipe.pipe_parse()
@@ -121,10 +125,11 @@ def pdf_parse_main(
         content_list = pipe.pipe_mk_uni_format(image_path_parent, drop_mode="none")
         md_content = pipe.pipe_mk_markdown(image_path_parent, drop_mode="none")
 
-
         if is_json_md_dump:
             json_md_dump(pipe, md_writer, pdf_name, content_list, md_content)
 
+        if is_draw_visualization_bbox:
+            draw_visualization_bbox(pipe.pdf_mid_data['pdf_info'], pdf_bytes, output_path, pdf_name)
 
     except Exception as e:
         logger.exception(e)
@@ -132,5 +137,5 @@ def pdf_parse_main(
 
 # 测试
 if __name__ == '__main__':
-    pdf_path = r"C:\Users\XYTK2\Desktop\2024-2016-gb-cd-300.pdf"
+    pdf_path = r"D:\project\20240617magicpdf\Magic-PDF\demo\demo1.pdf"
     pdf_parse_main(pdf_path)

+ 7 - 2
docs/FAQ_en_us.md

@@ -11,7 +11,7 @@ pip install magic-pdf[full]
 
 ### 2. Encountering the error `pickle.UnpicklingError: invalid load key, 'v'.` during use
 
-This might be due to an incomplete download of the model file. You can try re-downloading the model file and then try again.  
+This might be due to an incomplete download of the model file. You can try re-downloading the model file and then try again.
 Reference: https://github.com/opendatalab/MinerU/issues/143
 
 ### 3. Where should the model files be downloaded and how should the `/models-dir` configuration be set?
@@ -24,7 +24,7 @@ The path for the model files is configured in "magic-pdf.json". just like:
 }
 ```
 
-This path is an absolute path, not a relative path. You can obtain the absolute path in the models directory using the "pwd" command.  
+This path is an absolute path, not a relative path. You can obtain the absolute path in the models directory using the "pwd" command.
 Reference: https://github.com/opendatalab/MinerU/issues/155#issuecomment-2230216874
 
 ### 4. Encountered the error `ImportError: libGL.so.1: cannot open shared object file: No such file or directory` in Ubuntu 22.04 on WSL2
@@ -38,17 +38,22 @@ sudo apt-get install libgl1-mesa-glx
 Reference: https://github.com/opendatalab/MinerU/issues/388
 
 ### 5. Encountered error `ModuleNotFoundError: No module named 'fairscale'`
+
 You need to uninstall the module and reinstall it:
+
 ```bash
 pip uninstall fairscale
 pip install fairscale
 ```
+
 Reference: https://github.com/opendatalab/MinerU/issues/411
 
 ### 6. On some newer devices like the H100, the text parsed during OCR using CUDA acceleration is garbled.
 
 The compatibility of cuda11 with new graphics cards is poor, and the CUDA version used by Paddle needs to be upgraded.
+
 ```bash
 pip install paddlepaddle-gpu==3.0.0b1 -i https://www.paddlepaddle.org.cn/packages/stable/cu123/
 ```
+
 Reference: https://github.com/opendatalab/MinerU/issues/558

+ 13 - 3
docs/FAQ_zh_cn.md

@@ -1,9 +1,10 @@
 # 常见问题解答
 
-### 1.在较新版本的mac上使用命令安装pip install magic-pdf[full] zsh: no matches found: magic-pdf[full]
+### 1.在较新版本的mac上使用命令安装pip install magic-pdf\[full\] zsh: no matches found: magic-pdf\[full\]
 
 在 macOS 上,默认的 shell 从 Bash 切换到了 Z shell,而 Z shell 对于某些类型的字符串匹配有特殊的处理逻辑,这可能导致no matches found错误。
 可以通过在命令行禁用globbing特性,再尝试运行安装命令
+
 ```bash
 setopt no_nomatch
 pip install magic-pdf[full]
@@ -11,41 +12,50 @@ pip install magic-pdf[full]
 
 ### 2.使用过程中遇到_pickle.UnpicklingError: invalid load key, 'v'.错误
 
-可能是由于模型文件未下载完整导致,可尝试重新下载模型文件后再试  
+可能是由于模型文件未下载完整导致,可尝试重新下载模型文件后再试
 参考:https://github.com/opendatalab/MinerU/issues/143
 
 ### 3.模型文件应该下载到哪里/models-dir的配置应该怎么填
 
 模型文件的路径输入是在"magic-pdf.json"中通过
+
 ```json
 {
   "models-dir": "/tmp/models"
 }
 ```
+
 进行配置的。
-这个路径是绝对路径而不是相对路径,绝对路径的获取可在models目录中通过命令 "pwd" 获取。  
+这个路径是绝对路径而不是相对路径,绝对路径的获取可在models目录中通过命令 "pwd" 获取。
 参考:https://github.com/opendatalab/MinerU/issues/155#issuecomment-2230216874
 
 ### 4.在WSL2的Ubuntu22.04中遇到报错`ImportError: libGL.so.1: cannot open shared object file: No such file or directory`
 
 WSL2的Ubuntu22.04中缺少`libgl`库,可通过以下命令安装`libgl`库解决:
+
 ```bash
 sudo apt-get install libgl1-mesa-glx
 ```
+
 参考:https://github.com/opendatalab/MinerU/issues/388
 
 ### 5.遇到报错 `ModuleNotFoundError : Nomodulenamed 'fairscale'`
+
 需要卸载该模块并重新安装
+
 ```bash
 pip uninstall fairscale
 pip install fairscale
 ```
+
 参考:https://github.com/opendatalab/MinerU/issues/411
 
 ### 6.在部分较新的设备如H100上,使用CUDA加速OCR时解析出的文字乱码。
 
 cuda11对新显卡的兼容性不好,需要升级paddle使用的cuda版本
+
 ```bash
 pip install paddlepaddle-gpu==3.0.0b1 -i https://www.paddlepaddle.org.cn/packages/stable/cu123/
 ```
+
 参考:https://github.com/opendatalab/MinerU/issues/558

+ 79 - 76
docs/README_Ubuntu_CUDA_Acceleration_en_US.md

@@ -1,96 +1,101 @@
-
 # Ubuntu 22.04 LTS
 
 ### 1. Check if NVIDIA Drivers Are Installed
-   ```sh
-   nvidia-smi
-   ```
-   If you see information similar to the following, it means that the NVIDIA drivers are already installed, and you can skip Step 2.
-   ```plaintext
-   +---------------------------------------------------------------------------------------+
-   | NVIDIA-SMI 537.34                 Driver Version: 537.34       CUDA Version: 12.2     |
-   |-----------------------------------------+----------------------+----------------------+
-   | GPU  Name                     TCC/WDDM  | Bus-Id        Disp.A | Volatile Uncorr. ECC |
-   | Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
-   |                                         |                      |               MIG M. |
-   |=========================================+======================+======================|
-   |   0  NVIDIA GeForce RTX 3060 Ti   WDDM  | 00000000:01:00.0  On |                  N/A |
-   |  0%   51C    P8              12W / 200W |   1489MiB /  8192MiB |      5%      Default |
-   |                                         |                      |                  N/A |
-   +-----------------------------------------+----------------------+----------------------+
-   ```
+
+```sh
+nvidia-smi
+```
+
+If you see information similar to the following, it means that the NVIDIA drivers are already installed, and you can skip Step 2.
+
+Notice:`CUDA Version` should be >= 12.1, If the displayed version number is less than 12.1, please upgrade the driver.
+
+```plaintext
++---------------------------------------------------------------------------------------+
+| NVIDIA-SMI 537.34                 Driver Version: 537.34       CUDA Version: 12.2     |
+|-----------------------------------------+----------------------+----------------------+
+| GPU  Name                     TCC/WDDM  | Bus-Id        Disp.A | Volatile Uncorr. ECC |
+| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
+|                                         |                      |               MIG M. |
+|=========================================+======================+======================|
+|   0  NVIDIA GeForce RTX 3060 Ti   WDDM  | 00000000:01:00.0  On |                  N/A |
+|  0%   51C    P8              12W / 200W |   1489MiB /  8192MiB |      5%      Default |
+|                                         |                      |                  N/A |
++-----------------------------------------+----------------------+----------------------+
+```
 
 ### 2. Install the Driver
-   If no driver is installed, use the following command:
-   ```sh
-   sudo apt-get update
-   sudo apt-get install nvidia-driver-545
-   ```
-   Install the proprietary driver and restart your computer after installation.
-   ```sh
-   reboot
-   ```
+
+If no driver is installed, use the following command:
+
+```sh
+sudo apt-get update
+sudo apt-get install nvidia-driver-545
+```
+
+Install the proprietary driver and restart your computer after installation.
+
+```sh
+reboot
+```
 
 ### 3. Install Anaconda
-   If Anaconda is already installed, skip this step.
-   ```sh
-   wget https://repo.anaconda.com/archive/Anaconda3-2024.06-1-Linux-x86_64.sh
-   bash Anaconda3-2024.06-1-Linux-x86_64.sh
-   ```
-   In the final step, enter `yes`, close the terminal, and reopen it.
+
+If Anaconda is already installed, skip this step.
+
+```sh
+wget https://repo.anaconda.com/archive/Anaconda3-2024.06-1-Linux-x86_64.sh
+bash Anaconda3-2024.06-1-Linux-x86_64.sh
+```
+
+In the final step, enter `yes`, close the terminal, and reopen it.
 
 ### 4. Create an Environment Using Conda
-   Specify Python version 3.10.
-   ```sh
-   conda create -n MinerU python=3.10
-   conda activate MinerU
-   ```
+
+Specify Python version 3.10.
+
+```sh
+conda create -n MinerU python=3.10
+conda activate MinerU
+```
 
 ### 5. Install Applications
-   ```sh
-   pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com
-   ```
+
+```sh
+pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com
+```
+
 ❗ After installation, make sure to check the version of `magic-pdf` using the following command:
-   ```sh
-   magic-pdf --version
-   ```
-   If the version number is less than 0.7.0, please report the issue.
+
+```sh
+magic-pdf --version
+```
+
+If the version number is less than 0.7.0, please report the issue.
 
 ### 6. Download Models
-   Refer to detailed instructions on [how to download model files](how_to_download_models_en.md).  
-   After downloading, move the `models` directory to an SSD with more space.
-   
-❗ After downloading the models, ensure they are complete:
-   - Check that the file sizes match the description on the website.
-   - If possible, verify the integrity using SHA256.
-
-### 7. Configuration Before First Run
-   Obtain the configuration template file `magic-pdf.template.json` from the root directory of the repository.
-   
-❗ Execute the following command to copy the configuration file to your home directory, otherwise the program will not run:
-   ```sh
-   wget https://github.com/opendatalab/MinerU/raw/master/magic-pdf.template.json
-   cp magic-pdf.template.json ~/magic-pdf.json
-   ```
-   Find the `magic-pdf.json` file in your home directory and configure `"models-dir"` to be the directory where the model weights from Step 6 were downloaded.
-   
-❗ Correctly specify the absolute path of the directory containing the model weights; otherwise, the program will fail due to missing model files.
-   ```json
-   {
-     "models-dir": "/tmp/models"
-   }
-   ```
+
+Refer to detailed instructions on [how to download model files](how_to_download_models_en.md).
+
+## 7. Understand the Location of the Configuration File
+
+After completing the [6. Download Models](#6-download-models) step, the script will automatically generate a `magic-pdf.json` file in the user directory and configure the default model path.
+You can find the `magic-pdf.json` file in your user directory.
+
+> The user directory for Linux is "/home/username".
 
 ### 8. First Run
-   Download a sample file from the repository and test it.
-   ```sh
-   wget https://github.com/opendatalab/MinerU/raw/master/demo/small_ocr.pdf
-   magic-pdf -p small_ocr.pdf
-   ```
+
+Download a sample file from the repository and test it.
+
+```sh
+wget https://github.com/opendatalab/MinerU/raw/master/demo/small_ocr.pdf
+magic-pdf -p small_ocr.pdf
+```
 
 ### 9. Test CUDA Acceleration
 
-If your graphics card has at least 8GB of VRAM, follow these steps to test CUDA acceleration:
+If your graphics card has at least **8GB** of VRAM, follow these steps to test CUDA acceleration:
 
 1. Modify the value of `"device-mode"` in the `magic-pdf.json` configuration file located in your home directory.
    ```json
@@ -105,8 +110,6 @@ If your graphics card has at least 8GB of VRAM, follow these steps to test CUDA
 
 ### 10. Enable CUDA Acceleration for OCR
 
-❗ The following operations require a graphics card with at least 16GB of VRAM; otherwise, the program may crash or experience reduced performance.
-    
 1. Download `paddlepaddle-gpu`. Installation will automatically enable OCR acceleration.
    ```sh
    python -m pip install paddlepaddle-gpu==3.0.0b1 -i https://www.paddlepaddle.org.cn/packages/stable/cu118/

+ 46 - 28
docs/README_Ubuntu_CUDA_Acceleration_zh_CN.md

@@ -1,10 +1,16 @@
 # Ubuntu 22.04 LTS
 
 ## 1. 检测是否已安装nvidia驱动
+
 ```bash
-nvidia-smi 
+nvidia-smi
 ```
+
 如果看到类似如下的信息,说明已经安装了nvidia驱动,可以跳过步骤2
+
+注意:`CUDA Version` 显示的版本号应 >= 12.1,如显示的版本号小于12.1,请升级驱动
+
+```plaintext
 ```
 +---------------------------------------------------------------------------------------+
 | NVIDIA-SMI 537.34                 Driver Version: 537.34       CUDA Version: 12.2     |
@@ -18,96 +24,108 @@ nvidia-smi
 |                                         |                      |                  N/A |
 +-----------------------------------------+----------------------+----------------------+
 ```
+
 ## 2. 安装驱动
+
 如没有驱动,则通过如下命令
+
 ```bash
 sudo apt-get update
 sudo apt-get install nvidia-driver-545
 ```
+
 安装专有驱动,安装完成后,重启电脑
+
 ```bash
 reboot
 ```
+
 ## 3. 安装anacoda
+
 如果已安装conda,可以跳过本步骤
+
 ```bash
 wget -U NoSuchBrowser/1.0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/Anaconda3-2024.06-1-Linux-x86_64.sh
 bash Anaconda3-2024.06-1-Linux-x86_64.sh
 ```
+
 最后一步输入yes,关闭终端重新打开
+
 ## 4. 使用conda 创建环境
+
 需指定python版本为3.10
+
 ```bash
 conda create -n MinerU python=3.10
 conda activate MinerU
 ```
+
 ## 5. 安装应用
+
 ```bash
 pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com -i https://pypi.tuna.tsinghua.edu.cn/simple
 ```
+
 > ❗️下载完成后,务必通过以下命令确认magic-pdf的版本是否正确
-> 
+>
 > ```bash
 > magic-pdf --version
->```
+> ```
+>
 > 如果版本号小于0.7.0,请到issue中向我们反馈
 
 ## 6. 下载模型
-详细参考 [如何下载模型文件](how_to_download_models_zh_cn.md)  
-下载后请将models目录移动到空间较大的ssd磁盘目录  
-> ❗️模型下载后请务必检查模型文件是否下载完整
-> 
-> 请检查目录下的模型文件大小与网页上描述是否一致,如果可以的话,最好通过sha256校验模型是否下载完整
-> 
-## 7. 第一次运行前的配置
-在仓库根目录可以获得 [magic-pdf.template.json](../magic-pdf.template.json) 配置模版文件
-> ❗️务必执行以下命令将配置文件拷贝到【用户目录】下,否则程序将无法运行
->  
-> linux用户目录为 "/home/用户名"
-```bash
-wget https://gitee.com/myhloli/MinerU/raw/master/magic-pdf.template.json
-cp magic-pdf.template.json ~/magic-pdf.json
-```
 
-在用户目录中找到magic-pdf.json文件并配置"models-dir"为[6. 下载模型](#6-下载模型)中下载的模型权重文件所在目录
-> ❗️务必正确配置模型权重文件所在目录的【绝对路径】,否则会因为找不到模型文件而导致程序无法运行
-> 
-```json
-{
-  "models-dir": "/tmp/models"
-}
-```
+详细参考 [如何下载模型文件](how_to_download_models_zh_cn.md)
+
+## 7. 了解配置文件存放的位置
+
+完成[6.下载模型](#6-下载模型)步骤后,脚本会自动生成用户目录下的magic-pdf.json文件,并自动配置默认模型路径。
+您可在【用户目录】下找到magic-pdf.json文件。
+
+> linux用户目录为 "/home/用户名"
 
 ## 8. 第一次运行
+
 从仓库中下载样本文件,并测试
+
 ```bash
 wget https://gitee.com/myhloli/MinerU/raw/master/demo/small_ocr.pdf
 magic-pdf -p small_ocr.pdf
 ```
+
 ## 9. 测试CUDA加速
-如果您的显卡显存大于等于8G,可以进行以下流程,测试CUDA解析加速效果
+
+如果您的显卡显存大于等于 **8GB** ,可以进行以下流程,测试CUDA解析加速效果
 
 **1.修改【用户目录】中配置文件magic-pdf.json中"device-mode"的值**
+
 ```json
 {
   "device-mode":"cuda"
 }
 ```
+
 **2.运行以下命令测试cuda加速效果**
+
 ```bash
 magic-pdf -p small_ocr.pdf
 ```
+
 > 提示:CUDA加速是否生效可以根据log中输出的各个阶段cost耗时来简单判断,通常情况下,`layout detection cost` 和 `mfr time` 应提速10倍以上。
 
 ## 10. 为ocr开启cuda加速
-> ❗️以下操作需显卡显存大于等于16G才可进行,否则会因为显存不足导致程序崩溃或运行速度下降
 
 **1.下载paddlepaddle-gpu, 安装完成后会自动开启ocr加速**
+
 ```bash
 python -m pip install paddlepaddle-gpu==3.0.0b1 -i https://www.paddlepaddle.org.cn/packages/stable/cu118/
 ```
+
 **2.运行以下命令测试ocr加速效果**
+
 ```bash
 magic-pdf -p small_ocr.pdf
 ```
+
 > 提示:CUDA加速是否生效可以根据log中输出的各个阶段cost耗时来简单判断,通常情况下,`ocr cost`应提速10倍以上。

+ 79 - 83
docs/README_Windows_CUDA_Acceleration_en_US.md

@@ -1,104 +1,100 @@
 # Windows 10/11
 
 ### 1. Install CUDA and cuDNN
+
 Required versions: CUDA 11.8 + cuDNN 8.7.0
-   - CUDA 11.8: https://developer.nvidia.com/cuda-11-8-0-download-archive
-   - cuDNN v8.7.0 (November 28th, 2022), for CUDA 11.x: https://developer.nvidia.com/rdp/cudnn-archive
-   
+
+- CUDA 11.8: https://developer.nvidia.com/cuda-11-8-0-download-archive
+- cuDNN v8.7.0 (November 28th, 2022), for CUDA 11.x: https://developer.nvidia.com/rdp/cudnn-archive
+
 ### 2. Install Anaconda
-   If Anaconda is already installed, you can skip this step.
-   
+
+If Anaconda is already installed, you can skip this step.
+
 Download link: https://repo.anaconda.com/archive/Anaconda3-2024.06-1-Windows-x86_64.exe
 
 ### 3. Create an Environment Using Conda
-   Python version must be 3.10.
-   ```
-   conda create -n MinerU python=3.10
-   conda activate MinerU
-   ```
+
+Python version must be 3.10.
+
+```
+conda create -n MinerU python=3.10
+conda activate MinerU
+```
 
 ### 4. Install Applications
-   ```
-   pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com
-   ```
-   >❗️After installation, verify the version of `magic-pdf`:
-   >  ```bash
-   >  magic-pdf --version
-   >  ```
-   > If the version number is less than 0.7.0, please report it in the issues section.
-   
+
+```
+pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com
+```
+
+> ❗️After installation, verify the version of `magic-pdf`:
+>
+> ```bash
+> magic-pdf --version
+> ```
+>
+> If the version number is less than 0.7.0, please report it in the issues section.
+
 ### 5. Download Models
-   Refer to detailed instructions on [how to download model files](how_to_download_models_en.md).  
-   After downloading, move the `models` directory to an SSD with more space.
-   
-   >❗ After downloading the models, ensure they are complete:
-   >- Check that the file sizes match the description on the website.
-   >- If possible, verify the integrity using SHA256.
-
-### 6. Configuration Before the First Run
-   Obtain the configuration template file `magic-pdf.template.json` from the repository root directory.
-    
-   >❗️Execute the following command to copy the configuration file to your user directory, or the program will not run.
-   >   
-   > In Windows, user directory is "C:\Users\username"
-   
-   ```powershell
-     (New-Object System.Net.WebClient).DownloadFile('https://github.com/opendatalab/MinerU/raw/master/magic-pdf.template.json', 'magic-pdf.template.json')
-     cp magic-pdf.template.json ~/magic-pdf.json
+
+Refer to detailed instructions on [how to download model files](how_to_download_models_en.md).
+
+### 6. Understand the Location of the Configuration File
+
+After completing the [5. Download Models](#5-download-models) step, the script will automatically generate a `magic-pdf.json` file in the user directory and configure the default model path.
+You can find the `magic-pdf.json` file in your 【user directory】 .
+
+> The user directory for Windows is "C:/Users/username".
+
+### 7. First Run
+
+Download a sample file from the repository and test it.
+
+```powershell
+  wget https://github.com/opendatalab/MinerU/raw/master/demo/small_ocr.pdf -O small_ocr.pdf
+  magic-pdf -p small_ocr.pdf
+```
+
+### 8. Test CUDA Acceleration
+
+If your graphics card has at least 8GB of VRAM, follow these steps to test CUDA-accelerated parsing performance.
+
+1. **Overwrite the installation of torch and torchvision** supporting CUDA.
+
+   ```
+   pip install --force-reinstall torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu118
    ```
 
-   Find the `magic-pdf.json` file in your user directory and configure `"models-dir"` to point to the directory where the model weights from step 5 were downloaded.
-   
-   > ❗️Ensure the absolute path of the model weights directory is correctly configured, or the program will fail to run due to not finding the model files.
-   >    
-   > In Windows, this path should include the drive letter and replace all `"\"` to `"/"`.
-   >   
-   > Example: If the models are placed in the root directory of drive D, the value for `model-dir` should be `"D:/models"`.
-   
+   > ❗️Ensure the following versions are specified in the command:
+   >
+   > ```
+   > torch==2.3.1 torchvision==0.18.1
+   > ```
+   >
+   > These are the highest versions we support. Installing higher versions without specifying them will cause the program to fail.
+
+2. **Modify the value of `"device-mode"`** in the `magic-pdf.json` configuration file located in your user directory.
+
    ```json
    {
-     "models-dir": "/tmp/models"
+     "device-mode": "cuda"
    }
    ```
 
-### 7. First Run
-   Download a sample file from the repository and test it.
-   ```powershell
-     (New-Object System.Net.WebClient).DownloadFile('https://github.com/opendatalab/MinerU/raw/master/demo/small_ocr.pdf', 'small_ocr.pdf')
-     magic-pdf -p small_ocr.pdf
-   ```
+3. **Run the following command to test CUDA acceleration**:
 
-### 8. Test CUDA Acceleration
-   If your graphics card has at least 8GB of VRAM, follow these steps to test CUDA-accelerated parsing performance.
-   1. **Overwrite the installation of torch and torchvision** supporting CUDA.
-      ```
-      pip install --force-reinstall torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu118
-      ```
-      >❗️Ensure the following versions are specified in the command:
-      >```
-      > torch==2.3.1 torchvision==0.18.1
-      >```
-      >These are the highest versions we support. Installing higher versions without specifying them will cause the program to fail.
-   2. **Modify the value of `"device-mode"`** in the `magic-pdf.json` configuration file located in your user directory.
-     
-      ```json
-      {
-        "device-mode": "cuda"
-      }
-      ```
-   3. **Run the following command to test CUDA acceleration**:
-
-      ```
-      magic-pdf -p small_ocr.pdf
-      ```
+   ```
+   magic-pdf -p small_ocr.pdf
+   ```
 
 ### 9. Enable CUDA Acceleration for OCR
-   >❗️This operation requires at least 16GB of VRAM on your graphics card, otherwise it will cause the program to crash or slow down.
-   1. **Download paddlepaddle-gpu**, which will automatically enable OCR acceleration upon installation.
-      ```
-      pip install paddlepaddle-gpu==2.6.1
-      ```
-   2. **Run the following command to test OCR acceleration**:
-      ```
-      magic-pdf -p small_ocr.pdf
-      ```
+
+1. **Download paddlepaddle-gpu**, which will automatically enable OCR acceleration upon installation.
+   ```
+   pip install paddlepaddle-gpu==2.6.1
+   ```
+2. **Run the following command to test OCR acceleration**:
+   ```
+   magic-pdf -p small_ocr.pdf
+   ```

+ 38 - 35
docs/README_Windows_CUDA_Acceleration_zh_CN.md

@@ -3,103 +3,106 @@
 ## 1. 安装cuda和cuDNN
 
 需要安装的版本 CUDA 11.8 + cuDNN 8.7.0
+
 - CUDA 11.8 https://developer.nvidia.com/cuda-11-8-0-download-archive
 - cuDNN v8.7.0 (November 28th, 2022), for CUDA 11.x https://developer.nvidia.com/rdp/cudnn-archive
 
 ## 2. 安装anaconda
+
 如果已安装conda,可以跳过本步骤
 
 下载链接:
 https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/Anaconda3-2024.06-1-Windows-x86_64.exe
 
 ## 3. 使用conda 创建环境
+
 需指定python版本为3.10
+
 ```bash
 conda create -n MinerU python=3.10
 conda activate MinerU
 ```
+
 ## 4. 安装应用
+
 ```bash
 pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com -i https://pypi.tuna.tsinghua.edu.cn/simple
 ```
+
 > ❗️下载完成后,务必通过以下命令确认magic-pdf的版本是否正确
-> 
+>
 > ```bash
 > magic-pdf --version
->```
+> ```
+>
 > 如果版本号小于0.7.0,请到issue中向我们反馈
 
 ## 5. 下载模型
-详细参考 [如何下载模型文件](how_to_download_models_zh_cn.md)  
-下载后请将models目录移动到空间较大的ssd磁盘目录  
-> ❗️模型下载后请务必检查模型文件是否下载完整
-> 
-> 请检查目录下的模型文件大小与网页上描述是否一致,如果可以的话,最好通过sha256校验模型是否下载完整
-
-## 6. 第一次运行前的配置
-在仓库根目录可以获得 [magic-pdf.template.json](../magic-pdf.template.json) 配置模版文件
-> ❗️务必执行以下命令将配置文件拷贝到【用户目录】下,否则程序将无法运行
->  
-> windows用户目录为 "C:\Users\用户名"
-```powershell
-(New-Object System.Net.WebClient).DownloadFile('https://gitee.com/myhloli/MinerU/raw/master/magic-pdf.template.json', 'magic-pdf.template.json')
-cp magic-pdf.template.json ~/magic-pdf.json
-```
 
-在用户目录中找到magic-pdf.json文件并配置"models-dir"为[5. 下载模型](#5-下载模型)中下载的模型权重文件所在目录
-> ❗️务必正确配置模型权重文件所在目录的【绝对路径】,否则会因为找不到模型文件而导致程序无法运行
-> 
-> windows系统中此路径应包含盘符,且需把路径中所有的`"\"`替换为`"/"`,否则会因为转义原因导致json文件语法错误。
-> 
-> 例如:模型放在D盘根目录的models目录,则model-dir的值应为"D:/models"
-```json
-{
-  "models-dir": "/tmp/models"
-}
-```
+详细参考 [如何下载模型文件](how_to_download_models_zh_cn.md)
+
+## 6. 了解配置文件存放的位置
+
+完成[5.下载模型](#5-下载模型)步骤后,脚本会自动生成用户目录下的magic-pdf.json文件,并自动配置默认模型路径。
+您可在【用户目录】下找到magic-pdf.json文件。
+
+> windows用户目录为 "C:/Users/用户名"
 
 ## 7. 第一次运行
+
 从仓库中下载样本文件,并测试
+
 ```powershell
-(New-Object System.Net.WebClient).DownloadFile('https://gitee.com/myhloli/MinerU/raw/master/demo/small_ocr.pdf', 'small_ocr.pdf')
-magic-pdf -p small_ocr.pdf
+ wget https://github.com/opendatalab/MinerU/raw/master/demo/small_ocr.pdf -O small_ocr.pdf
+ magic-pdf -p small_ocr.pdf
 ```
 
 ## 8. 测试CUDA加速
-如果您的显卡显存大于等于8G,可以进行以下流程,测试CUDA解析加速效果
+
+如果您的显卡显存大于等于 **8GB** ,可以进行以下流程,测试CUDA解析加速效果
 
 **1.覆盖安装支持cuda的torch和torchvision**
+
 ```bash
 pip install --force-reinstall torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu118
 ```
+
 > ❗️务必在命令中指定以下版本
+>
 > ```bash
-> torch==2.3.1 torchvision==0.18.1 
+> torch==2.3.1 torchvision==0.18.1
 > ```
+>
 > 这是我们支持的最高版本,如果不指定版本会自动安装更高版本导致程序无法运行
 
 **2.修改【用户目录】中配置文件magic-pdf.json中"device-mode"的值**
+
 ```json
 {
   "device-mode":"cuda"
 }
 ```
+
 **3.运行以下命令测试cuda加速效果**
+
 ```bash
 magic-pdf -p small_ocr.pdf
 ```
-> 提示:CUDA加速是否生效可以根据log中输出的各个阶段cost耗时来简单判断,通常情况下,`layout detection cost` 和 `mfr time` 应提速10倍以上。
+
+> 提示:CUDA加速是否生效可以根据log中输出的各个阶段的耗时来简单判断,通常情况下,`layout detection time` 和 `mfr time` 应提速10倍以上。
 
 ## 9. 为ocr开启cuda加速
-> ❗️以下操作需显卡显存大于等于16G才可进行,否则会因为显存不足导致程序崩溃或运行速度下降
 
 **1.下载paddlepaddle-gpu, 安装完成后会自动开启ocr加速**
+
 ```bash
 pip install paddlepaddle-gpu==2.6.1
 ```
+
 **2.运行以下命令测试ocr加速效果**
+
 ```bash
 magic-pdf -p small_ocr.pdf
 ```
-> 提示:CUDA加速是否生效可以根据log中输出的各个阶段cost耗时来简单判断,通常情况下,`ocr cost`应提速10倍以上。
 
+> 提示:CUDA加速是否生效可以根据log中输出的各个阶段cost耗时来简单判断,通常情况下,`ocr time`应提速10倍以上。

+ 58 - 3
docs/download_models.py

@@ -1,4 +1,59 @@
-# use modelscope sdk download models
+import json
+import os
+
+import requests
 from modelscope import snapshot_download
-model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
-print(f"model dir is: {model_dir}/models")
+
+
+def download_json(url):
+    # 下载JSON文件
+    response = requests.get(url)
+    response.raise_for_status()  # 检查请求是否成功
+    return response.json()
+
+
+def download_and_modify_json(url, local_filename, modifications):
+    if os.path.exists(local_filename):
+        data = json.load(open(local_filename))
+        config_version = data.get('config_version', '0.0.0')
+        if config_version < '1.0.0':
+            data = download_json(url)
+    else:
+        data = download_json(url)
+
+    # 修改内容
+    for key, value in modifications.items():
+        data[key] = value
+
+    # 保存修改后的内容
+    with open(local_filename, 'w', encoding='utf-8') as f:
+        json.dump(data, f, ensure_ascii=False, indent=4)
+
+
+if __name__ == '__main__':
+    mineru_patterns = [
+        "models/Layout/LayoutLMv3/*",
+        "models/Layout/YOLO/*",
+        "models/MFD/YOLO/*",
+        "models/MFR/unimernet_small/*",
+        "models/TabRec/TableMaster/*",
+        "models/TabRec/StructEqTable/*",
+    ]
+    model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns)
+    layoutreader_model_dir = snapshot_download('ppaanngggg/layoutreader')
+    model_dir = model_dir + '/models'
+    print(f'model_dir is: {model_dir}')
+    print(f'layoutreader_model_dir is: {layoutreader_model_dir}')
+
+    json_url = 'https://gitee.com/myhloli/MinerU/raw/dev/magic-pdf.template.json'
+    config_file_name = 'magic-pdf.json'
+    home_dir = os.path.expanduser('~')
+    config_file = os.path.join(home_dir, config_file_name)
+
+    json_mods = {
+        'models-dir': model_dir,
+        'layoutreader-model-dir': layoutreader_model_dir,
+    }
+
+    download_and_modify_json(json_url, config_file, json_mods)
+    print(f'The configuration file has been configured successfully, the path is: {config_file}')

+ 65 - 2
docs/download_models_hf.py

@@ -1,3 +1,66 @@
+import json
+import os
+
+import requests
 from huggingface_hub import snapshot_download
-model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
-print(f"model dir is: {model_dir}/models")
+
+
+def download_json(url):
+    # 下载JSON文件
+    response = requests.get(url)
+    response.raise_for_status()  # 检查请求是否成功
+    return response.json()
+
+
+def download_and_modify_json(url, local_filename, modifications):
+    if os.path.exists(local_filename):
+        data = json.load(open(local_filename))
+        config_version = data.get('config_version', '0.0.0')
+        if config_version < '1.0.0':
+            data = download_json(url)
+    else:
+        data = download_json(url)
+
+    # 修改内容
+    for key, value in modifications.items():
+        data[key] = value
+
+    # 保存修改后的内容
+    with open(local_filename, 'w', encoding='utf-8') as f:
+        json.dump(data, f, ensure_ascii=False, indent=4)
+
+
+if __name__ == '__main__':
+
+    mineru_patterns = [
+        "models/Layout/LayoutLMv3/*",
+        "models/Layout/YOLO/*",
+        "models/MFD/YOLO/*",
+        "models/MFR/unimernet_small/*",
+        "models/TabRec/TableMaster/*",
+        "models/TabRec/StructEqTable/*",
+    ]
+    model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns)
+
+    layoutreader_pattern = [
+        "*.json",
+        "*.safetensors",
+    ]
+    layoutreader_model_dir = snapshot_download('hantian/layoutreader', allow_patterns=layoutreader_pattern)
+
+    model_dir = model_dir + '/models'
+    print(f'model_dir is: {model_dir}')
+    print(f'layoutreader_model_dir is: {layoutreader_model_dir}')
+
+    json_url = 'https://github.com/opendatalab/MinerU/raw/dev/magic-pdf.template.json'
+    config_file_name = 'magic-pdf.json'
+    home_dir = os.path.expanduser('~')
+    config_file = os.path.join(home_dir, config_file_name)
+
+    json_mods = {
+        'models-dir': model_dir,
+        'layoutreader-model-dir': layoutreader_model_dir,
+    }
+
+    download_and_modify_json(json_url, config_file, json_mods)
+    print(f'The configuration file has been configured successfully, the path is: {config_file}')

+ 9 - 8
docs/how_to_download_models_en.md

@@ -1,29 +1,30 @@
 Model downloads are divided into initial downloads and updates to the model directory. Please refer to the corresponding documentation for instructions on how to proceed.
 
-
 # Initial download of model files
 
 ### 1. Download the Model from Hugging Face
+
 Use a Python Script to Download Model Files from Hugging Face
+
 ```bash
 pip install huggingface_hub
-wget https://github.com/opendatalab/MinerU/raw/master/docs/download_models_hf.py
+wget https://github.com/opendatalab/MinerU/raw/master/docs/download_models_hf.py -O download_models_hf.py
 python download_models_hf.py
 ```
-After the Python script finishes executing, it will output the directory where the models are downloaded.
 
-### 2. To modify the model path address in the configuration file
-
-Additionally, in `~/magic-pdf.json`, update the model directory path to the absolute path of the `models` directory output by the previous Python script. Otherwise, you will encounter an error indicating that the model cannot be loaded.
+The Python script will automatically download the model files and configure the model directory in the configuration file.
 
+The configuration file can be found in the user directory, with the filename `magic-pdf.json`.
 
 # How to update models previously downloaded
 
 ## 1. Models downloaded via Git LFS
 
->Due to feedback from some users that downloading model files using git lfs was incomplete or resulted in corrupted model files, this method is no longer recommended.
+> Due to feedback from some users that downloading model files using git lfs was incomplete or resulted in corrupted model files, this method is no longer recommended.
+
+When magic-pdf <= 0.8.1, if you have previously downloaded the model files via git lfs, you can navigate to the previous download directory and update the models using the `git pull` command.
 
-If you previously downloaded model files via git lfs, you can navigate to the previous download directory and use the `git pull` command to update the model.
+> For versions 0.9.x and later, due to the repository change and the addition of the layout sorting model in PDF-Extract-Kit 1.0, the models cannot be updated using the `git pull` command. Instead, a Python script must be used for one-click updates.
 
 ## 2. Models downloaded via Hugging Face or Model Scope
 

+ 10 - 9
docs/how_to_download_models_zh_cn.md

@@ -8,9 +8,8 @@
   <summary>方法一:从 Hugging Face 下载模型</summary>
   <p>使用python脚本 从Hugging Face下载模型文件</p>
   <pre><code>pip install huggingface_hub
-wget https://gitee.com/myhloli/MinerU/raw/master/docs/download_models_hf.py
+wget https://gitee.com/myhloli/MinerU/raw/master/docs/download_models_hf.py -O download_models_hf.py
 python download_models_hf.py</code></pre>
-  <p>python脚本执行完毕后,会输出模型下载目录</p>
 </details>
 
 ## 方法二:从 ModelScope 下载模型
@@ -19,24 +18,26 @@ python download_models_hf.py</code></pre>
 
 ```bash
 pip install modelscope
-wget https://gitee.com/myhloli/MinerU/raw/master/docs/download_models.py
+wget https://gitee.com/myhloli/MinerU/raw/master/docs/download_models.py -O download_models.py
 python download_models.py
 ```
-python脚本执行完毕后,会输出模型下载目录
 
+python脚本会自动下载模型文件并配置好配置文件中的模型目录
 
-## 下载完成后的操作:修改magic-pdf.json中的模型路径
-在`~/magic-pdf.json`里修改模型的目录指向上一步脚本输出的models目录的绝对路径,否则会报模型无法加载的错误。
-
+配置文件可以在用户目录中找到,文件名为`magic-pdf.json`
 
+> windows的用户目录为 "C:\\Users\\用户名", linux用户目录为 "/home/用户名", macOS用户目录为 "/Users/用户名"
 
 # 此前下载过模型,如何更新
 
 ## 1. 通过git lfs下载过模型
 
->由于部分用户反馈通过git lfs下载模型文件遇到下载不全和模型文件损坏情况,现已不推荐使用该方式下载。
+> 由于部分用户反馈通过git lfs下载模型文件遇到下载不全和模型文件损坏情况,现已不推荐使用该方式下载。
+
+当magic-pdf <= 0.8.1时,如此前通过 git lfs 下载过模型文件,可以进入到之前的下载目录中,通过`git pull`命令更新模型。
+
+> 0.9.x及以后版本由于PDF-Extract-Kit 1.0更换仓库和新增layout排序模型,不能通过`git pull`命令更新,需要使用python脚本一键更新。
 
-如此前通过 git lfs 下载过模型文件,可以进入到之前的下载目录中,通过`git pull`命令更新模型。
 
 ## 2. 通过 Hugging Face 或 Model Scope 下载过模型
 

+ 13 - 3
magic-pdf.template.json

@@ -4,10 +4,20 @@
         "bucket-name-2":["ak", "sk", "endpoint"]
     },
     "models-dir":"/tmp/models",
+    "layoutreader-model-dir":"/tmp/layoutreader",
     "device-mode":"cpu",
+    "layout-config": {
+        "model": "layoutlmv3"
+    },
+    "formula-config": {
+        "mfd_model": "yolo_v8_mfd",
+        "mfr_model": "unimernet_small",
+        "enable": true
+    },
     "table-config": {
-        "model": "TableMaster",
-        "is_table_recog_enable": false,
+        "model": "tablemaster",
+        "enable": false,
         "max_time": 400
-    }
+    },
+    "config_version": "1.0.0"
 }

+ 0 - 0
magic_pdf/config/__init__.py


+ 7 - 0
magic_pdf/config/enums.py

@@ -0,0 +1,7 @@
+
+import enum
+
+
+class SupportedPdfParseMethod(enum.Enum):
+    OCR = 'ocr'
+    TXT = 'txt'

+ 32 - 0
magic_pdf/config/exceptions.py

@@ -0,0 +1,32 @@
+
+class FileNotExisted(Exception):
+
+    def __init__(self, path):
+        self.path = path
+
+    def __str__(self):
+        return f'File {self.path} does not exist.'
+
+
+class InvalidConfig(Exception):
+    def __init__(self, msg):
+        self.msg = msg
+
+    def __str__(self):
+        return f'Invalid config: {self.msg}'
+
+
+class InvalidParams(Exception):
+    def __init__(self, msg):
+        self.msg = msg
+
+    def __str__(self):
+        return f'Invalid params: {self.msg}'
+
+
+class EmptyData(Exception):
+    def __init__(self, msg):
+        self.msg = msg
+
+    def __str__(self):
+        return f'Empty data: {self.msg}'

+ 0 - 0
magic_pdf/data/__init__.py


+ 12 - 0
magic_pdf/data/data_reader_writer/__init__.py

@@ -0,0 +1,12 @@
+from magic_pdf.data.data_reader_writer.filebase import \
+    FileBasedDataReader  # noqa: F401
+from magic_pdf.data.data_reader_writer.filebase import \
+    FileBasedDataWriter  # noqa: F401
+from magic_pdf.data.data_reader_writer.multi_bucket_s3 import \
+    MultiBucketS3DataReader  # noqa: F401
+from magic_pdf.data.data_reader_writer.multi_bucket_s3 import \
+    MultiBucketS3DataWriter  # noqa: F401
+from magic_pdf.data.data_reader_writer.s3 import S3DataReader  # noqa: F401
+from magic_pdf.data.data_reader_writer.s3 import S3DataWriter  # noqa: F401
+from magic_pdf.data.data_reader_writer.base import DataReader  # noqa: F401
+from magic_pdf.data.data_reader_writer.base import DataWriter  # noqa: F401

+ 51 - 0
magic_pdf/data/data_reader_writer/base.py

@@ -0,0 +1,51 @@
+
+from abc import ABC, abstractmethod
+
+
+class DataReader(ABC):
+
+    def read(self, path: str) -> bytes:
+        """Read the file.
+
+        Args:
+            path (str): file path to read
+
+        Returns:
+            bytes: the content of the file
+        """
+        return self.read_at(path)
+
+    @abstractmethod
+    def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
+        """Read the file at offset and limit.
+
+        Args:
+            path (str): the file path
+            offset (int, optional): the number of bytes skipped. Defaults to 0.
+            limit (int, optional): the length of bytes want to read. Defaults to -1.
+
+        Returns:
+            bytes: the content of the file
+        """
+        pass
+
+
+class DataWriter(ABC):
+    @abstractmethod
+    def write(self, path: str, data: bytes) -> None:
+        """Write the data to the file.
+
+        Args:
+            path (str): the target file where to write
+            data (bytes): the data want to write
+        """
+        pass
+
+    def write_string(self, path: str, data: str) -> None:
+        """Write the data to file, the data will be encoded to bytes.
+
+        Args:
+            path (str): the target file where to write
+            data (str): the data want to write
+        """
+        self.write(path, data.encode())

+ 59 - 0
magic_pdf/data/data_reader_writer/filebase.py

@@ -0,0 +1,59 @@
+import os
+
+from magic_pdf.data.data_reader_writer.base import DataReader, DataWriter
+
+
+class FileBasedDataReader(DataReader):
+    def __init__(self, parent_dir: str = ''):
+        """Initialized with parent_dir.
+
+        Args:
+            parent_dir (str, optional): the parent directory that may be used within methods. Defaults to ''.
+        """
+        self._parent_dir = parent_dir
+
+    def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
+        """Read at offset and limit.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            offset (int, optional): the number of bytes skipped. Defaults to 0.
+            limit (int, optional): the length of bytes want to read. Defaults to -1.
+
+        Returns:
+            bytes: the content of file
+        """
+        fn_path = path
+        if not os.path.isabs(fn_path) and len(self._parent_dir) > 0:
+            fn_path = os.path.join(self._parent_dir, path)
+
+        with open(fn_path, 'rb') as f:
+            f.seek(offset)
+            if limit == -1:
+                return f.read()
+            else:
+                return f.read(limit)
+
+
+class FileBasedDataWriter(DataWriter):
+    def __init__(self, parent_dir: str = '') -> None:
+        """Initialized with parent_dir.
+
+        Args:
+            parent_dir (str, optional): the parent directory that may be used within methods. Defaults to ''.
+        """
+        self._parent_dir = parent_dir
+
+    def write(self, path: str, data: bytes) -> None:
+        """Write file with data.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            data (bytes): the data want to write
+        """
+        fn_path = path
+        if not os.path.isabs(fn_path) and len(self._parent_dir) > 0:
+            fn_path = os.path.join(self._parent_dir, path)
+
+        with open(fn_path, 'wb') as f:
+            f.write(data)

+ 137 - 0
magic_pdf/data/data_reader_writer/multi_bucket_s3.py

@@ -0,0 +1,137 @@
+from magic_pdf.config.exceptions import InvalidConfig, InvalidParams
+from magic_pdf.data.data_reader_writer.base import DataReader, DataWriter
+from magic_pdf.data.io.s3 import S3Reader, S3Writer
+from magic_pdf.data.schemas import S3Config
+from magic_pdf.libs.path_utils import (parse_s3_range_params, parse_s3path,
+                                       remove_non_official_s3_args)
+
+
+class MultiS3Mixin:
+    def __init__(self, default_bucket: str, s3_configs: list[S3Config]):
+        """Initialized with multiple s3 configs.
+
+        Args:
+            default_bucket (str): the default bucket name of the relative path
+            s3_configs (list[S3Config]): list of s3 configs, the bucket_name must be unique in the list.
+
+        Raises:
+            InvalidConfig: default bucket config not in s3_configs
+            InvalidConfig: bucket name not unique in s3_configs
+            InvalidConfig: default bucket must be provided
+        """
+        if len(default_bucket) == 0:
+            raise InvalidConfig('default_bucket must be provided')
+
+        found_default_bucket_config = False
+        for conf in s3_configs:
+            if conf.bucket_name == default_bucket:
+                found_default_bucket_config = True
+                break
+
+        if not found_default_bucket_config:
+            raise InvalidConfig(
+                f'default_bucket: {default_bucket} config must be provided in s3_configs: {s3_configs}'
+            )
+
+        uniq_bucket = set([conf.bucket_name for conf in s3_configs])
+        if len(uniq_bucket) != len(s3_configs):
+            raise InvalidConfig(
+                f'the bucket_name in s3_configs: {s3_configs} must be unique'
+            )
+
+        self.default_bucket = default_bucket
+        self.s3_configs = s3_configs
+        self._s3_clients_h: dict = {}
+
+
+class MultiBucketS3DataReader(DataReader, MultiS3Mixin):
+    def read(self, path: str) -> bytes:
+        """Read the path from s3, select diffect bucket client for each request
+        based on the path, also support range read.
+
+        Args:
+            path (str): the s3 path of file, the path must be in the format of s3://bucket_name/path?offset,limit
+            for example: s3://bucket_name/path?0,100
+
+        Returns:
+            bytes: the content of s3 file
+        """
+        may_range_params = parse_s3_range_params(path)
+        if may_range_params is None or 2 != len(may_range_params):
+            byte_start, byte_len = 0, -1
+        else:
+            byte_start, byte_len = int(may_range_params[0]), int(may_range_params[1])
+        path = remove_non_official_s3_args(path)
+        return self.read_at(path, byte_start, byte_len)
+
+    def __get_s3_client(self, bucket_name: str):
+        if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
+            raise InvalidParams(
+                f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
+            )
+        if bucket_name not in self._s3_clients_h:
+            conf = next(
+                filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
+            )
+            self._s3_clients_h[bucket_name] = S3Reader(
+                bucket_name,
+                conf.access_key,
+                conf.secret_key,
+                conf.endpoint_url,
+                conf.addressing_style,
+            )
+        return self._s3_clients_h[bucket_name]
+
+    def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
+        """Read the file with offset and limit, select diffect bucket client
+        for each request based on the path.
+
+        Args:
+            path (str): the file path
+            offset (int, optional): the number of bytes skipped. Defaults to 0.
+            limit (int, optional): the number of bytes want to read. Defaults to -1 which means infinite.
+
+        Returns:
+            bytes: the file content
+        """
+        if path.startswith('s3://'):
+            bucket_name, path = parse_s3path(path)
+            s3_reader = self.__get_s3_client(bucket_name)
+        else:
+            s3_reader = self.__get_s3_client(self.default_bucket)
+        return s3_reader.read_at(path, offset, limit)
+
+
+class MultiBucketS3DataWriter(DataWriter, MultiS3Mixin):
+    def __get_s3_client(self, bucket_name: str):
+        if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
+            raise InvalidParams(
+                f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
+            )
+        if bucket_name not in self._s3_clients_h:
+            conf = next(
+                filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
+            )
+            self._s3_clients_h[bucket_name] = S3Writer(
+                bucket_name,
+                conf.access_key,
+                conf.secret_key,
+                conf.endpoint_url,
+                conf.addressing_style,
+            )
+        return self._s3_clients_h[bucket_name]
+
+    def write(self, path: str, data: bytes) -> None:
+        """Write file with data, also select diffect bucket client for each
+        request based on the path.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            data (bytes): the data want to write
+        """
+        if path.startswith('s3://'):
+            bucket_name, path = parse_s3path(path)
+            s3_writer = self.__get_s3_client(bucket_name)
+        else:
+            s3_writer = self.__get_s3_client(self.default_bucket)
+        return s3_writer.write(path, data)

+ 69 - 0
magic_pdf/data/data_reader_writer/s3.py

@@ -0,0 +1,69 @@
+from magic_pdf.data.data_reader_writer.multi_bucket_s3 import (
+    MultiBucketS3DataReader, MultiBucketS3DataWriter)
+from magic_pdf.data.schemas import S3Config
+
+
+class S3DataReader(MultiBucketS3DataReader):
+    def __init__(
+        self,
+        bucket: str,
+        ak: str,
+        sk: str,
+        endpoint_url: str,
+        addressing_style: str = 'auto',
+    ):
+        """s3 reader client.
+
+        Args:
+            bucket (str): bucket name
+            ak (str): access key
+            sk (str): secret key
+            endpoint_url (str): endpoint url of s3
+            addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
+            refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
+        """
+        super().__init__(
+            bucket,
+            [
+                S3Config(
+                    bucket_name=bucket,
+                    access_key=ak,
+                    secret_key=sk,
+                    endpoint_url=endpoint_url,
+                    addressing_style=addressing_style,
+                )
+            ],
+        )
+
+
+class S3DataWriter(MultiBucketS3DataWriter):
+    def __init__(
+        self,
+        bucket: str,
+        ak: str,
+        sk: str,
+        endpoint_url: str,
+        addressing_style: str = 'auto',
+    ):
+        """s3 writer client.
+
+        Args:
+            bucket (str): bucket name
+            ak (str): access key
+            sk (str): secret key
+            endpoint_url (str): endpoint url of s3
+            addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
+            refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
+        """
+        super().__init__(
+            bucket,
+            [
+                S3Config(
+                    bucket_name=bucket,
+                    access_key=ak,
+                    secret_key=sk,
+                    endpoint_url=endpoint_url,
+                    addressing_style=addressing_style,
+                )
+            ],
+        )

+ 194 - 0
magic_pdf/data/dataset.py

@@ -0,0 +1,194 @@
+from abc import ABC, abstractmethod
+from typing import Iterator
+
+import fitz
+
+from magic_pdf.config.enums import SupportedPdfParseMethod
+from magic_pdf.data.schemas import PageInfo
+from magic_pdf.data.utils import fitz_doc_to_image
+
+
+class PageableData(ABC):
+    @abstractmethod
+    def get_image(self) -> dict:
+        """Transform data to image."""
+        pass
+
+    @abstractmethod
+    def get_doc(self) -> fitz.Page:
+        """Get the pymudoc page."""
+        pass
+
+    @abstractmethod
+    def get_page_info(self) -> PageInfo:
+        """Get the page info of the page.
+
+        Returns:
+            PageInfo: the page info of this page
+        """
+        pass
+
+
+class Dataset(ABC):
+    @abstractmethod
+    def __len__(self) -> int:
+        """The length of the dataset."""
+        pass
+
+    @abstractmethod
+    def __iter__(self) -> Iterator[PageableData]:
+        """Yield the page data."""
+        pass
+
+    @abstractmethod
+    def supported_methods(self) -> list[SupportedPdfParseMethod]:
+        """The methods that this dataset support.
+
+        Returns:
+            list[SupportedPdfParseMethod]: The supported methods, Valid methods are: OCR, TXT
+        """
+        pass
+
+    @abstractmethod
+    def data_bits(self) -> bytes:
+        """The bits used to create this dataset."""
+        pass
+
+    @abstractmethod
+    def get_page(self, page_id: int) -> PageableData:
+        """Get the page indexed by page_id.
+
+        Args:
+            page_id (int): the index of the page
+
+        Returns:
+            PageableData: the page doc object
+        """
+        pass
+
+
+class PymuDocDataset(Dataset):
+    def __init__(self, bits: bytes):
+        """Initialize the dataset, which wraps the pymudoc documents.
+
+        Args:
+            bits (bytes): the bytes of the pdf
+        """
+        self._records = [Doc(v) for v in fitz.open('pdf', bits)]
+        self._data_bits = bits
+        self._raw_data = bits
+
+    def __len__(self) -> int:
+        """The page number of the pdf."""
+        return len(self._records)
+
+    def __iter__(self) -> Iterator[PageableData]:
+        """Yield the page doc object."""
+        return iter(self._records)
+
+    def supported_methods(self) -> list[SupportedPdfParseMethod]:
+        """The method supported by this dataset.
+
+        Returns:
+            list[SupportedPdfParseMethod]: the supported methods
+        """
+        return [SupportedPdfParseMethod.OCR, SupportedPdfParseMethod.TXT]
+
+    def data_bits(self) -> bytes:
+        """The pdf bits used to create this dataset."""
+        return self._data_bits
+
+    def get_page(self, page_id: int) -> PageableData:
+        """The page doc object.
+
+        Args:
+            page_id (int): the page doc index
+
+        Returns:
+            PageableData: the page doc object
+        """
+        return self._records[page_id]
+
+
+class ImageDataset(Dataset):
+    def __init__(self, bits: bytes):
+        """Initialize the dataset, which wraps the pymudoc documents.
+
+        Args:
+            bits (bytes): the bytes of the photo which will be converted to pdf first. then converted to pymudoc.
+        """
+        pdf_bytes = fitz.open(stream=bits).convert_to_pdf()
+        self._records = [Doc(v) for v in fitz.open('pdf', pdf_bytes)]
+        self._raw_data = bits
+        self._data_bits = pdf_bytes
+
+    def __len__(self) -> int:
+        """The length of the dataset."""
+        return len(self._records)
+
+    def __iter__(self) -> Iterator[PageableData]:
+        """Yield the page object."""
+        return iter(self._records)
+
+    def supported_methods(self):
+        """The method supported by this dataset.
+
+        Returns:
+            list[SupportedPdfParseMethod]: the supported methods
+        """
+        return [SupportedPdfParseMethod.OCR]
+
+    def data_bits(self) -> bytes:
+        """The pdf bits used to create this dataset."""
+        return self._data_bits
+
+    def get_page(self, page_id: int) -> PageableData:
+        """The page doc object.
+
+        Args:
+            page_id (int): the page doc index
+
+        Returns:
+            PageableData: the page doc object
+        """
+        return self._records[page_id]
+
+
+class Doc(PageableData):
+    """Initialized with pymudoc object."""
+    def __init__(self, doc: fitz.Page):
+        self._doc = doc
+
+    def get_image(self):
+        """Return the imge info.
+
+        Returns:
+            dict: {
+                img: np.ndarray,
+                width: int,
+                height: int
+            }
+        """
+        return fitz_doc_to_image(self._doc)
+
+    def get_doc(self) -> fitz.Page:
+        """Get the pymudoc object.
+
+        Returns:
+            fitz.Page: the pymudoc object
+        """
+        return self._doc
+
+    def get_page_info(self) -> PageInfo:
+        """Get the page info of the page.
+
+        Returns:
+            PageInfo: the page info of this page
+        """
+        page_w = self._doc.rect.width
+        page_h = self._doc.rect.height
+        return PageInfo(w=page_w, h=page_h)
+
+    def __getattr__(self, name):
+        if hasattr(self._doc, name):
+            return getattr(self._doc, name)

+ 0 - 0
magic_pdf/data/io/__init__.py


+ 42 - 0
magic_pdf/data/io/base.py

@@ -0,0 +1,42 @@
+from abc import ABC, abstractmethod
+
+
+class IOReader(ABC):
+    @abstractmethod
+    def read(self, path: str) -> bytes:
+        """Read the file.
+
+        Args:
+            path (str): file path to read
+
+        Returns:
+            bytes: the content of the file
+        """
+        pass
+
+    @abstractmethod
+    def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
+        """Read at offset and limit.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            offset (int, optional): the number of bytes skipped. Defaults to 0.
+            limit (int, optional): the length of bytes want to read. Defaults to -1.
+
+        Returns:
+            bytes: the content of file
+        """
+        pass
+
+
+class IOWriter:
+
+    @abstractmethod
+    def write(self, path: str, data: bytes) -> None:
+        """Write file with data.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            data (bytes): the data want to write
+        """
+        pass

+ 37 - 0
magic_pdf/data/io/http.py

@@ -0,0 +1,37 @@
+
+import io
+
+import requests
+
+from magic_pdf.data.io.base import IOReader, IOWriter
+
+
+class HttpReader(IOReader):
+
+    def read(self, url: str) -> bytes:
+        """Read the file.
+
+        Args:
+            path (str): file path to read
+
+        Returns:
+            bytes: the content of the file
+        """
+        return requests.get(url).content
+
+    def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
+        """Not Implemented."""
+        raise NotImplementedError
+
+
+class HttpWriter(IOWriter):
+    def write(self, url: str, data: bytes) -> None:
+        """Write file with data.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            data (bytes): the data want to write
+        """
+        files = {'file': io.BytesIO(data)}
+        response = requests.post(url, files=files)
+        assert 300 > response.status_code and response.status_code > 199

+ 114 - 0
magic_pdf/data/io/s3.py

@@ -0,0 +1,114 @@
+import boto3
+from botocore.config import Config
+
+from magic_pdf.data.io.base import IOReader, IOWriter
+
+
+class S3Reader(IOReader):
+    def __init__(
+        self,
+        bucket: str,
+        ak: str,
+        sk: str,
+        endpoint_url: str,
+        addressing_style: str = 'auto',
+    ):
+        """s3 reader client.
+
+        Args:
+            bucket (str): bucket name
+            ak (str): access key
+            sk (str): secret key
+            endpoint_url (str): endpoint url of s3
+            addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
+            refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
+        """
+        self._bucket = bucket
+        self._ak = ak
+        self._sk = sk
+        self._s3_client = boto3.client(
+            service_name='s3',
+            aws_access_key_id=ak,
+            aws_secret_access_key=sk,
+            endpoint_url=endpoint_url,
+            config=Config(
+                s3={'addressing_style': addressing_style},
+                retries={'max_attempts': 5, 'mode': 'standard'},
+            ),
+        )
+
+    def read(self, key: str) -> bytes:
+        """Read the file.
+
+        Args:
+            path (str): file path to read
+
+        Returns:
+            bytes: the content of the file
+        """
+        return self.read_at(key)
+
+    def read_at(self, key: str, offset: int = 0, limit: int = -1) -> bytes:
+        """Read at offset and limit.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            offset (int, optional): the number of bytes skipped. Defaults to 0.
+            limit (int, optional): the length of bytes want to read. Defaults to -1.
+
+        Returns:
+            bytes: the content of file
+        """
+        if limit > -1:
+            range_header = f'bytes={offset}-{offset+limit-1}'
+            res = self._s3_client.get_object(
+                Bucket=self._bucket, Key=key, Range=range_header
+            )
+        else:
+            res = self._s3_client.get_object(
+                Bucket=self._bucket, Key=key, Range=f'bytes={offset}-'
+            )
+        return res['Body'].read()
+
+
+class S3Writer(IOWriter):
+    def __init__(
+        self,
+        bucket: str,
+        ak: str,
+        sk: str,
+        endpoint_url: str,
+        addressing_style: str = 'auto',
+    ):
+        """s3 reader client.
+
+        Args:
+            bucket (str): bucket name
+            ak (str): access key
+            sk (str): secret key
+            endpoint_url (str): endpoint url of s3
+            addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
+            refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
+        """
+        self._bucket = bucket
+        self._ak = ak
+        self._sk = sk
+        self._s3_client = boto3.client(
+            service_name='s3',
+            aws_access_key_id=ak,
+            aws_secret_access_key=sk,
+            endpoint_url=endpoint_url,
+            config=Config(
+                s3={'addressing_style': addressing_style},
+                retries={'max_attempts': 5, 'mode': 'standard'},
+            ),
+        )
+
+    def write(self, key: str, data: bytes):
+        """Write file with data.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            data (bytes): the data want to write
+        """
+        self._s3_client.put_object(Bucket=self._bucket, Key=key, Body=data)

+ 95 - 0
magic_pdf/data/read_api.py

@@ -0,0 +1,95 @@
+import json
+import os
+from pathlib import Path
+
+from magic_pdf.config.exceptions import EmptyData, InvalidParams
+from magic_pdf.data.data_reader_writer import (FileBasedDataReader,
+                                               MultiBucketS3DataReader)
+from magic_pdf.data.dataset import ImageDataset, PymuDocDataset
+
+
+def read_jsonl(
+    s3_path_or_local: str, s3_client: MultiBucketS3DataReader | None = None
+) -> list[PymuDocDataset]:
+    """Read the jsonl file and return the list of PymuDocDataset.
+
+    Args:
+        s3_path_or_local (str): local file or s3 path
+        s3_client (MultiBucketS3DataReader | None, optional): s3 client that support multiple bucket. Defaults to None.
+
+    Raises:
+        InvalidParams: if s3_path_or_local is s3 path but s3_client is not provided.
+        EmptyData: if no pdf file location is provided in some line of jsonl file.
+        InvalidParams: if the file location is s3 path but s3_client is not provided
+
+    Returns:
+        list[PymuDocDataset]: each line in the jsonl file will be converted to a PymuDocDataset
+    """
+    bits_arr = []
+    if s3_path_or_local.startswith('s3://'):
+        if s3_client is None:
+            raise InvalidParams('s3_client is required when s3_path is provided')
+        jsonl_bits = s3_client.read(s3_path_or_local)
+    else:
+        jsonl_bits = FileBasedDataReader('').read(s3_path_or_local)
+    jsonl_d = [
+        json.loads(line) for line in jsonl_bits.decode().split('\n') if line.strip()
+    ]
+    for d in jsonl_d[:5]:
+        pdf_path = d.get('file_location', '') or d.get('path', '')
+        if len(pdf_path) == 0:
+            raise EmptyData('pdf file location is empty')
+        if pdf_path.startswith('s3://'):
+            if s3_client is None:
+                raise InvalidParams('s3_client is required when s3_path is provided')
+            bits_arr.append(s3_client.read(pdf_path))
+        else:
+            bits_arr.append(FileBasedDataReader('').read(pdf_path))
+    return [PymuDocDataset(bits) for bits in bits_arr]
+
+
+def read_local_pdfs(path: str) -> list[PymuDocDataset]:
+    """Read pdf from path or directory.
+
+    Args:
+        path (str): pdf file path or directory that contains pdf files
+
+    Returns:
+        list[PymuDocDataset]: each pdf file will converted to a PymuDocDataset
+    """
+    if os.path.isdir(path):
+        reader = FileBasedDataReader(path)
+        return [
+            PymuDocDataset(reader.read(doc_path.name))
+            for doc_path in Path(path).glob('*.pdf')
+        ]
+    else:
+        reader = FileBasedDataReader()
+        bits = reader.read(path)
+        return [PymuDocDataset(bits)]
+
+
+def read_local_images(path: str, suffixes: list[str]) -> list[ImageDataset]:
+    """Read images from path or directory.
+
+    Args:
+        path (str): image file path or directory that contains image files
+        suffixes (list[str]): the suffixes of the image files used to filter the files. Example: ['jpg', 'png']
+
+    Returns:
+        list[ImageDataset]: each image file will converted to a ImageDataset
+    """
+    if os.path.isdir(path):
+        imgs_bits = []
+        s_suffixes = set(suffixes)
+        reader = FileBasedDataReader(path)
+        for root, _, files in os.walk(path):
+            for file in files:
+                suffix = file.split('.')
+                if suffix[-1] in s_suffixes:
+                    imgs_bits.append(reader.read(file))
+        return [ImageDataset(bits) for bits in imgs_bits]
+    else:
+        reader = FileBasedDataReader()
+        bits = reader.read(path)
+        return [ImageDataset(bits)]

+ 15 - 0
magic_pdf/data/schemas.py

@@ -0,0 +1,15 @@
+
+from pydantic import BaseModel, Field
+
+
+class S3Config(BaseModel):
+    bucket_name: str = Field(description='s3 bucket name', min_length=1)
+    access_key: str = Field(description='s3 access key', min_length=1)
+    secret_key: str = Field(description='s3 secret key', min_length=1)
+    endpoint_url: str = Field(description='s3 endpoint url', min_length=1)
+    addressing_style: str = Field(description='s3 addressing style', default='auto', min_length=1)
+
+
+class PageInfo(BaseModel):
+    w: float = Field(description='the width of page')
+    h: float = Field(description='the height of page')

+ 32 - 0
magic_pdf/data/utils.py

@@ -0,0 +1,32 @@
+
+import fitz
+import numpy as np
+
+from magic_pdf.utils.annotations import ImportPIL
+
+
+@ImportPIL
+def fitz_doc_to_image(doc, dpi=200) -> dict:
+    """Convert fitz.Document to image, Then convert the image to numpy array.
+
+    Args:
+        doc (_type_): pymudoc page
+        dpi (int, optional): reset the dpi of dpi. Defaults to 200.
+
+    Returns:
+        dict:  {'img': numpy array, 'width': width, 'height': height }
+    """
+    from PIL import Image
+    mat = fitz.Matrix(dpi / 72, dpi / 72)
+    pm = doc.get_pixmap(matrix=mat, alpha=False)
+
+    # If the width or height exceeds 9000 after scaling, do not scale further.
+    if pm.width > 9000 or pm.height > 9000:
+        pm = doc.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
+
+    img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
+    img = np.array(img)
+
+    img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
+
+    return img_dict

+ 47 - 237
magic_pdf/dict2md/ocr_mkcontent.py

@@ -1,6 +1,5 @@
 import re
 
-import wordninja
 from loguru import logger
 
 from magic_pdf.libs.commons import join_path
@@ -8,6 +7,7 @@ from magic_pdf.libs.language import detect_lang
 from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
 from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char
 from magic_pdf.libs.ocr_content_type import BlockType, ContentType
+from magic_pdf.para.para_split_v3 import ListLineTag
 
 
 def __is_hyphen_at_line_end(line):
@@ -24,37 +24,6 @@ def __is_hyphen_at_line_end(line):
     return bool(re.search(r'[A-Za-z]+-\s*$', line))
 
 
-def split_long_words(text):
-    segments = text.split(' ')
-    for i in range(len(segments)):
-        words = re.findall(r'\w+|[^\w]', segments[i], re.UNICODE)
-        for j in range(len(words)):
-            if len(words[j]) > 10:
-                words[j] = ' '.join(wordninja.split(words[j]))
-        segments[i] = ''.join(words)
-    return ' '.join(segments)
-
-
-def ocr_mk_mm_markdown_with_para(pdf_info_list: list, img_buket_path):
-    markdown = []
-    for page_info in pdf_info_list:
-        paras_of_layout = page_info.get('para_blocks')
-        page_markdown = ocr_mk_markdown_with_para_core_v2(
-            paras_of_layout, 'mm', img_buket_path)
-        markdown.extend(page_markdown)
-    return '\n\n'.join(markdown)
-
-
-def ocr_mk_nlp_markdown_with_para(pdf_info_dict: list):
-    markdown = []
-    for page_info in pdf_info_dict:
-        paras_of_layout = page_info.get('para_blocks')
-        page_markdown = ocr_mk_markdown_with_para_core_v2(
-            paras_of_layout, 'nlp')
-        markdown.extend(page_markdown)
-    return '\n\n'.join(markdown)
-
-
 def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: list,
                                                 img_buket_path):
     markdown_with_para_and_pagination = []
@@ -67,69 +36,28 @@ def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: list,
             paras_of_layout, 'mm', img_buket_path)
         markdown_with_para_and_pagination.append({
             'page_no':
-            page_no,
+                page_no,
             'md_content':
-            '\n\n'.join(page_markdown)
+                '\n\n'.join(page_markdown)
         })
         page_no += 1
     return markdown_with_para_and_pagination
 
 
-def ocr_mk_markdown_with_para_core(paras_of_layout, mode, img_buket_path=''):
-    page_markdown = []
-    for paras in paras_of_layout:
-        for para in paras:
-            para_text = ''
-            for line in para:
-                for span in line['spans']:
-                    span_type = span.get('type')
-                    content = ''
-                    language = ''
-                    if span_type == ContentType.Text:
-                        content = span['content']
-                        language = detect_lang(content)
-                        if (language == 'en'):  # 只对英文长词进行分词处理,中文分词会丢失文本
-                            content = ocr_escape_special_markdown_char(
-                                split_long_words(content))
-                        else:
-                            content = ocr_escape_special_markdown_char(content)
-                    elif span_type == ContentType.InlineEquation:
-                        content = f"${span['content']}$"
-                    elif span_type == ContentType.InterlineEquation:
-                        content = f"\n$$\n{span['content']}\n$$\n"
-                    elif span_type in [ContentType.Image, ContentType.Table]:
-                        if mode == 'mm':
-                            content = f"\n![]({join_path(img_buket_path, span['image_path'])})\n"
-                        elif mode == 'nlp':
-                            pass
-                    if content != '':
-                        if language == 'en':  # 英文语境下 content间需要空格分隔
-                            para_text += content + ' '
-                        else:  # 中文语境下,content间不需要空格分隔
-                            para_text += content
-            if para_text.strip() == '':
-                continue
-            else:
-                page_markdown.append(para_text.strip() + '  ')
-    return page_markdown
-
-
 def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
                                       mode,
                                       img_buket_path='',
-                                      parse_type="auto",
-                                      lang=None
                                       ):
     page_markdown = []
     for para_block in paras_of_layout:
         para_text = ''
         para_type = para_block['type']
-        if para_type == BlockType.Text:
-            para_text = merge_para_with_text(para_block, parse_type=parse_type, lang=lang)
+        if para_type in [BlockType.Text, BlockType.List, BlockType.Index]:
+            para_text = merge_para_with_text(para_block)
         elif para_type == BlockType.Title:
-            para_text = f'# {merge_para_with_text(para_block, parse_type=parse_type, lang=lang)}'
+            para_text = f'# {merge_para_with_text(para_block)}'
         elif para_type == BlockType.InterlineEquation:
-            para_text = merge_para_with_text(para_block, parse_type=parse_type, lang=lang)
+            para_text = merge_para_with_text(para_block)
         elif para_type == BlockType.Image:
             if mode == 'nlp':
                 continue
@@ -142,17 +70,17 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
                                     para_text += f"\n![]({join_path(img_buket_path, span['image_path'])})  \n"
                 for block in para_block['blocks']:  # 2nd.拼image_caption
                     if block['type'] == BlockType.ImageCaption:
-                        para_text += merge_para_with_text(block, parse_type=parse_type, lang=lang)
-                for block in para_block['blocks']:  # 2nd.拼image_caption
+                        para_text += merge_para_with_text(block) + '  \n'
+                for block in para_block['blocks']:  # 3rd.拼image_footnote
                     if block['type'] == BlockType.ImageFootnote:
-                        para_text += merge_para_with_text(block, parse_type=parse_type, lang=lang)
+                        para_text += merge_para_with_text(block) + '  \n'
         elif para_type == BlockType.Table:
             if mode == 'nlp':
                 continue
             elif mode == 'mm':
                 for block in para_block['blocks']:  # 1st.拼table_caption
                     if block['type'] == BlockType.TableCaption:
-                        para_text += merge_para_with_text(block, parse_type=parse_type, lang=lang)
+                        para_text += merge_para_with_text(block) + '  \n'
                 for block in para_block['blocks']:  # 2nd.拼table_body
                     if block['type'] == BlockType.TableBody:
                         for line in block['lines']:
@@ -167,7 +95,7 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
                                         para_text += f"\n![]({join_path(img_buket_path, span['image_path'])})  \n"
                 for block in para_block['blocks']:  # 3rd.拼table_footnote
                     if block['type'] == BlockType.TableFootnote:
-                        para_text += merge_para_with_text(block, parse_type=parse_type, lang=lang)
+                        para_text += merge_para_with_text(block) + '  \n'
 
         if para_text.strip() == '':
             continue
@@ -177,22 +105,26 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
     return page_markdown
 
 
-def merge_para_with_text(para_block, parse_type="auto", lang=None):
-
-    def detect_language(text):
-        en_pattern = r'[a-zA-Z]+'
-        en_matches = re.findall(en_pattern, text)
-        en_length = sum(len(match) for match in en_matches)
-        if len(text) > 0:
-            if en_length / len(text) >= 0.5:
-                return 'en'
-            else:
-                return 'unknown'
+def detect_language(text):
+    en_pattern = r'[a-zA-Z]+'
+    en_matches = re.findall(en_pattern, text)
+    en_length = sum(len(match) for match in en_matches)
+    if len(text) > 0:
+        if en_length / len(text) >= 0.5:
+            return 'en'
         else:
-            return 'empty'
+            return 'unknown'
+    else:
+        return 'empty'
+
 
+def merge_para_with_text(para_block):
     para_text = ''
-    for line in para_block['lines']:
+    for i, line in enumerate(para_block['lines']):
+
+        if i >= 1 and line.get(ListLineTag.IS_LIST_START_LINE, False):
+            para_text += '  \n'
+
         line_text = ''
         line_lang = ''
         for span in line['spans']:
@@ -202,21 +134,11 @@ def merge_para_with_text(para_block, parse_type="auto", lang=None):
         if line_text != '':
             line_lang = detect_lang(line_text)
         for span in line['spans']:
+
             span_type = span['type']
             content = ''
             if span_type == ContentType.Text:
-                content = span['content']
-                # language = detect_lang(content)
-                language = detect_language(content)
-                # 判断是否小语种
-                if lang is not None and lang != 'en':
-                    content = ocr_escape_special_markdown_char(content)
-                else:  # 非小语种逻辑
-                    if language == 'en' and parse_type == 'ocr':  # 只对英文长词进行分词处理,中文分词会丢失文本
-                        content = ocr_escape_special_markdown_char(
-                            split_long_words(content))
-                    else:
-                        content = ocr_escape_special_markdown_char(content)
+                content = ocr_escape_special_markdown_char(span['content'])
             elif span_type == ContentType.InlineEquation:
                 content = f" ${span['content']}$ "
             elif span_type == ContentType.InterlineEquation:
@@ -237,74 +159,39 @@ def merge_para_with_text(para_block, parse_type="auto", lang=None):
     return para_text
 
 
-def para_to_standard_format(para, img_buket_path):
-    para_content = {}
-    if len(para) == 1:
-        para_content = line_to_standard_format(para[0], img_buket_path)
-    elif len(para) > 1:
-        para_text = ''
-        inline_equation_num = 0
-        for line in para:
-            for span in line['spans']:
-                language = ''
-                span_type = span.get('type')
-                content = ''
-                if span_type == ContentType.Text:
-                    content = span['content']
-                    language = detect_lang(content)
-                    if language == 'en':  # 只对英文长词进行分词处理,中文分词会丢失文本
-                        content = ocr_escape_special_markdown_char(
-                            split_long_words(content))
-                    else:
-                        content = ocr_escape_special_markdown_char(content)
-                elif span_type == ContentType.InlineEquation:
-                    content = f"${span['content']}$"
-                    inline_equation_num += 1
-                if language == 'en':  # 英文语境下 content间需要空格分隔
-                    para_text += content + ' '
-                else:  # 中文语境下,content间不需要空格分隔
-                    para_text += content
-        para_content = {
-            'type': 'text',
-            'text': para_text,
-            'inline_equation_num': inline_equation_num,
-        }
-    return para_content
-
-
-def para_to_standard_format_v2(para_block, img_buket_path, page_idx, parse_type="auto", lang=None, drop_reason=None):
+def para_to_standard_format_v2(para_block, img_buket_path, page_idx, drop_reason=None):
     para_type = para_block['type']
     para_content = {}
-    if para_type == BlockType.Text:
+    if para_type in [BlockType.Text, BlockType.List, BlockType.Index]:
         para_content = {
             'type': 'text',
-            'text': merge_para_with_text(para_block, parse_type=parse_type, lang=lang),
+            'text': merge_para_with_text(para_block),
         }
     elif para_type == BlockType.Title:
         para_content = {
             'type': 'text',
-            'text': merge_para_with_text(para_block, parse_type=parse_type, lang=lang),
+            'text': merge_para_with_text(para_block),
             'text_level': 1,
         }
     elif para_type == BlockType.InterlineEquation:
         para_content = {
             'type': 'equation',
-            'text': merge_para_with_text(para_block, parse_type=parse_type, lang=lang),
+            'text': merge_para_with_text(para_block),
             'text_format': 'latex',
         }
     elif para_type == BlockType.Image:
-        para_content = {'type': 'image'}
+        para_content = {'type': 'image', 'img_caption': [], 'img_footnote': []}
         for block in para_block['blocks']:
             if block['type'] == BlockType.ImageBody:
                 para_content['img_path'] = join_path(
                     img_buket_path,
                     block['lines'][0]['spans'][0]['image_path'])
             if block['type'] == BlockType.ImageCaption:
-                para_content['img_caption'] = merge_para_with_text(block, parse_type=parse_type, lang=lang)
+                para_content['img_caption'].append(merge_para_with_text(block))
             if block['type'] == BlockType.ImageFootnote:
-                para_content['img_footnote'] = merge_para_with_text(block, parse_type=parse_type, lang=lang)
+                para_content['img_footnote'].append(merge_para_with_text(block))
     elif para_type == BlockType.Table:
-        para_content = {'type': 'table'}
+        para_content = {'type': 'table', 'table_caption': [], 'table_footnote': []}
         for block in para_block['blocks']:
             if block['type'] == BlockType.TableBody:
                 if block["lines"][0]["spans"][0].get('latex', ''):
@@ -313,9 +200,9 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx, parse_type=
                     para_content['table_body'] = f"\n\n{block['lines'][0]['spans'][0]['html']}\n\n"
                 para_content['img_path'] = join_path(img_buket_path, block["lines"][0]["spans"][0]['image_path'])
             if block['type'] == BlockType.TableCaption:
-                para_content['table_caption'] = merge_para_with_text(block, parse_type=parse_type, lang=lang)
+                para_content['table_caption'].append(merge_para_with_text(block))
             if block['type'] == BlockType.TableFootnote:
-                para_content['table_footnote'] = merge_para_with_text(block, parse_type=parse_type, lang=lang)
+                para_content['table_footnote'].append(merge_para_with_text(block))
 
     para_content['page_idx'] = page_idx
 
@@ -325,88 +212,11 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx, parse_type=
     return para_content
 
 
-def make_standard_format_with_para(pdf_info_dict: list, img_buket_path: str):
-    content_list = []
-    for page_info in pdf_info_dict:
-        paras_of_layout = page_info.get('para_blocks')
-        if not paras_of_layout:
-            continue
-        for para_block in paras_of_layout:
-            para_content = para_to_standard_format_v2(para_block,
-                                                      img_buket_path)
-            content_list.append(para_content)
-    return content_list
-
-
-def line_to_standard_format(line, img_buket_path):
-    line_text = ''
-    inline_equation_num = 0
-    for span in line['spans']:
-        if not span.get('content'):
-            if not span.get('image_path'):
-                continue
-            else:
-                if span['type'] == ContentType.Image:
-                    content = {
-                        'type': 'image',
-                        'img_path': join_path(img_buket_path,
-                                              span['image_path']),
-                    }
-                    return content
-                elif span['type'] == ContentType.Table:
-                    content = {
-                        'type': 'table',
-                        'img_path': join_path(img_buket_path,
-                                              span['image_path']),
-                    }
-                    return content
-        else:
-            if span['type'] == ContentType.InterlineEquation:
-                interline_equation = span['content']
-                content = {
-                    'type': 'equation',
-                    'latex': f'$$\n{interline_equation}\n$$'
-                }
-                return content
-            elif span['type'] == ContentType.InlineEquation:
-                inline_equation = span['content']
-                line_text += f'${inline_equation}$'
-                inline_equation_num += 1
-            elif span['type'] == ContentType.Text:
-                text_content = ocr_escape_special_markdown_char(
-                    span['content'])  # 转义特殊符号
-                line_text += text_content
-    content = {
-        'type': 'text',
-        'text': line_text,
-        'inline_equation_num': inline_equation_num,
-    }
-    return content
-
-
-def ocr_mk_mm_standard_format(pdf_info_dict: list):
-    """content_list type         string
-    image/text/table/equation(行间的单独拿出来,行内的和text合并) latex        string
-    latex文本字段。 text         string      纯文本格式的文本数据。 md           string
-    markdown格式的文本数据。 img_path     string      s3://full/path/to/img.jpg."""
-    content_list = []
-    for page_info in pdf_info_dict:
-        blocks = page_info.get('preproc_blocks')
-        if not blocks:
-            continue
-        for block in blocks:
-            for line in block['lines']:
-                content = line_to_standard_format(line)
-                content_list.append(content)
-    return content_list
-
-
 def union_make(pdf_info_dict: list,
                make_mode: str,
                drop_mode: str,
                img_buket_path: str = '',
-               parse_type: str = "auto",
-               lang=None):
+               ):
     output_content = []
     for page_info in pdf_info_dict:
         drop_reason_flag = False
@@ -433,20 +243,20 @@ def union_make(pdf_info_dict: list,
             continue
         if make_mode == MakeMode.MM_MD:
             page_markdown = ocr_mk_markdown_with_para_core_v2(
-                paras_of_layout, 'mm', img_buket_path, parse_type=parse_type, lang=lang)
+                paras_of_layout, 'mm', img_buket_path)
             output_content.extend(page_markdown)
         elif make_mode == MakeMode.NLP_MD:
             page_markdown = ocr_mk_markdown_with_para_core_v2(
-                paras_of_layout, 'nlp', parse_type=parse_type, lang=lang)
+                paras_of_layout, 'nlp')
             output_content.extend(page_markdown)
         elif make_mode == MakeMode.STANDARD_FORMAT:
             for para_block in paras_of_layout:
                 if drop_reason_flag:
                     para_content = para_to_standard_format_v2(
-                        para_block, img_buket_path, page_idx, parse_type=parse_type, lang=lang, drop_reason=drop_reason)
+                        para_block, img_buket_path, page_idx)
                 else:
                     para_content = para_to_standard_format_v2(
-                        para_block, img_buket_path, page_idx, parse_type=parse_type, lang=lang)
+                        para_block, img_buket_path, page_idx)
                 output_content.append(para_content)
     if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
         return '\n\n'.join(output_content)

+ 13 - 6
magic_pdf/libs/Constants.py

@@ -10,18 +10,12 @@ block维度自定义字段
 # block中lines是否被删除
 LINES_DELETED = "lines_deleted"
 
-# struct eqtable
-STRUCT_EQTABLE = "struct_eqtable"
-
 # table recognition max time default value
 TABLE_MAX_TIME_VALUE = 400
 
 # pp_table_result_max_length
 TABLE_MAX_LEN = 480
 
-# pp table structure algorithm
-TABLE_MASTER = "TableMaster"
-
 # table master structure dict
 TABLE_MASTER_DICT = "table_master_structure_dict.txt"
 
@@ -38,3 +32,16 @@ REC_MODEL_DIR = "ch_PP-OCRv3_rec_infer"
 REC_CHAR_DICT = "ppocr_keys_v1.txt"
 
 
+class MODEL_NAME:
+    # pp table structure algorithm
+    TABLE_MASTER = "tablemaster"
+    # struct eqtable
+    STRUCT_EQTABLE = "struct_eqtable"
+
+    DocLayout_YOLO = "doclayout_yolo"
+
+    LAYOUTLMv3 = "layoutlmv3"
+
+    YOLO_V8_MFD = "yolo_v8_mfd"
+
+    UniMerNet_v2_Small = "unimernet_small"

+ 35 - 0
magic_pdf/libs/boxbase.py

@@ -445,3 +445,38 @@ def get_overlap_area(bbox1, bbox2):
 
     # The area of overlap area
     return (x_right - x_left) * (y_bottom - y_top)
+
+
+def calculate_vertical_projection_overlap_ratio(block1, block2):
+    """
+    Calculate the proportion of the x-axis covered by the vertical projection of two blocks.
+
+    Args:
+        block1 (tuple): Coordinates of the first block (x0, y0, x1, y1).
+        block2 (tuple): Coordinates of the second block (x0, y0, x1, y1).
+
+    Returns:
+        float: The proportion of the x-axis covered by the vertical projection of the two blocks.
+    """
+    x0_1, _, x1_1, _ = block1
+    x0_2, _, x1_2, _ = block2
+
+    # Calculate the intersection of the x-coordinates
+    x_left = max(x0_1, x0_2)
+    x_right = min(x1_1, x1_2)
+
+    if x_right < x_left:
+        return 0.0
+
+    # Length of the intersection
+    intersection_length = x_right - x_left
+
+    # Length of the x-axis projection of the first block
+    block1_length = x1_1 - x0_1
+
+    if block1_length == 0:
+        return 0.0
+
+    # Proportion of the x-axis covered by the intersection
+    # logger.info(f"intersection_length: {intersection_length}, block1_length: {block1_length}")
+    return intersection_length / block1_length

+ 53 - 23
magic_pdf/libs/config_reader.py

@@ -1,46 +1,44 @@
-"""
-根据bucket的名字返回对应的s3 AK, SK,endpoint三元组
-
-"""
+"""根据bucket的名字返回对应的s3 AK, SK,endpoint三元组."""
 
 import json
 import os
 
 from loguru import logger
 
+from magic_pdf.libs.Constants import MODEL_NAME
 from magic_pdf.libs.commons import parse_bucket_key
 
 # 定义配置文件名常量
-CONFIG_FILE_NAME = "magic-pdf.json"
+CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'magic-pdf.json')
 
 
 def read_config():
-    home_dir = os.path.expanduser("~")
-
-    config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
+    if os.path.isabs(CONFIG_FILE_NAME):
+        config_file = CONFIG_FILE_NAME
+    else:
+        home_dir = os.path.expanduser('~')
+        config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
 
     if not os.path.exists(config_file):
-        raise FileNotFoundError(f"{config_file} not found")
+        raise FileNotFoundError(f'{config_file} not found')
 
-    with open(config_file, "r", encoding="utf-8") as f:
+    with open(config_file, 'r', encoding='utf-8') as f:
         config = json.load(f)
     return config
 
 
 def get_s3_config(bucket_name: str):
-    """
-    ~/magic-pdf.json 读出来
-    """
+    """~/magic-pdf.json 读出来."""
     config = read_config()
 
-    bucket_info = config.get("bucket_info")
+    bucket_info = config.get('bucket_info')
     if bucket_name not in bucket_info:
-        access_key, secret_key, storage_endpoint = bucket_info["[default]"]
+        access_key, secret_key, storage_endpoint = bucket_info['[default]']
     else:
         access_key, secret_key, storage_endpoint = bucket_info[bucket_name]
 
     if access_key is None or secret_key is None or storage_endpoint is None:
-        raise Exception(f"ak, sk or endpoint not found in {CONFIG_FILE_NAME}")
+        raise Exception(f'ak, sk or endpoint not found in {CONFIG_FILE_NAME}')
 
     # logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}")
 
@@ -49,7 +47,7 @@ def get_s3_config(bucket_name: str):
 
 def get_s3_config_dict(path: str):
     access_key, secret_key, storage_endpoint = get_s3_config(get_bucket_name(path))
-    return {"ak": access_key, "sk": secret_key, "endpoint": storage_endpoint}
+    return {'ak': access_key, 'sk': secret_key, 'endpoint': storage_endpoint}
 
 
 def get_bucket_name(path):
@@ -59,33 +57,65 @@ def get_bucket_name(path):
 
 def get_local_models_dir():
     config = read_config()
-    models_dir = config.get("models-dir")
+    models_dir = config.get('models-dir')
     if models_dir is None:
         logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use '/tmp/models' as default")
-        return "/tmp/models"
+        return '/tmp/models'
     else:
         return models_dir
 
 
+def get_local_layoutreader_model_dir():
+    config = read_config()
+    layoutreader_model_dir = config.get('layoutreader-model-dir')
+    if layoutreader_model_dir is None or not os.path.exists(layoutreader_model_dir):
+        home_dir = os.path.expanduser('~')
+        layoutreader_at_modelscope_dir_path = os.path.join(home_dir, '.cache/modelscope/hub/ppaanngggg/layoutreader')
+        logger.warning(f"'layoutreader-model-dir' not exists, use {layoutreader_at_modelscope_dir_path} as default")
+        return layoutreader_at_modelscope_dir_path
+    else:
+        return layoutreader_model_dir
+
+
 def get_device():
     config = read_config()
-    device = config.get("device-mode")
+    device = config.get('device-mode')
     if device is None:
         logger.warning(f"'device-mode' not found in {CONFIG_FILE_NAME}, use 'cpu' as default")
-        return "cpu"
+        return 'cpu'
     else:
         return device
 
 
 def get_table_recog_config():
     config = read_config()
-    table_config = config.get("table-config")
+    table_config = config.get('table-config')
     if table_config is None:
         logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
-        return json.loads('{"is_table_recog_enable": false, "max_time": 400}')
+        return json.loads(f'{{"model": "{MODEL_NAME.TABLE_MASTER}","enable": false, "max_time": 400}}')
     else:
         return table_config
 
 
+def get_layout_config():
+    config = read_config()
+    layout_config = config.get("layout-config")
+    if layout_config is None:
+        logger.warning(f"'layout-config' not found in {CONFIG_FILE_NAME}, use '{MODEL_NAME.LAYOUTLMv3}' as default")
+        return json.loads(f'{{"model": "{MODEL_NAME.LAYOUTLMv3}"}}')
+    else:
+        return layout_config
+
+
+def get_formula_config():
+    config = read_config()
+    formula_config = config.get("formula-config")
+    if formula_config is None:
+        logger.warning(f"'formula-config' not found in {CONFIG_FILE_NAME}, use 'True' as default")
+        return json.loads(f'{{"mfd_model": "{MODEL_NAME.YOLO_V8_MFD}","mfr_model": "{MODEL_NAME.UniMerNet_v2_Small}","enable": true}}')
+    else:
+        return formula_config
+
+
 if __name__ == "__main__":
     ak, sk, endpoint = get_s3_config("llm-raw")

+ 75 - 43
magic_pdf/libs/draw_bbox.py

@@ -1,3 +1,4 @@
+from magic_pdf.data.dataset import PymuDocDataset
 from magic_pdf.libs.commons import fitz  # PyMuPDF
 from magic_pdf.libs.Constants import CROSS_PAGE
 from magic_pdf.libs.ocr_content_type import BlockType, CategoryId, ContentType
@@ -62,7 +63,7 @@ def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config, draw_bbox
                     overlay=True,
                 )  # Draw the rectangle
         page.insert_text(
-            (x1+2, y0 + 10), str(j + 1), fontsize=10, color=new_rgb
+            (x1 + 2, y0 + 10), str(j + 1), fontsize=10, color=new_rgb
         )  # Insert the index in the top left corner of the rectangle
 
 
@@ -75,6 +76,8 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
     titles_list = []
     texts_list = []
     interequations_list = []
+    lists_list = []
+    indexs_list = []
     for page in pdf_info:
 
         page_dropped_list = []
@@ -83,6 +86,8 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
         titles = []
         texts = []
         interequations = []
+        lists = []
+        indices = []
 
         for dropped_bbox in page['discarded_blocks']:
             page_dropped_list.append(dropped_bbox['bbox'])
@@ -115,6 +120,11 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
                 texts.append(bbox)
             elif block['type'] == BlockType.InterlineEquation:
                 interequations.append(bbox)
+            elif block['type'] == BlockType.List:
+                lists.append(bbox)
+            elif block['type'] == BlockType.Index:
+                indices.append(bbox)
+
         tables_list.append(tables)
         tables_body_list.append(tables_body)
         tables_caption_list.append(tables_caption)
@@ -126,42 +136,62 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
         titles_list.append(titles)
         texts_list.append(texts)
         interequations_list.append(interequations)
+        lists_list.append(lists)
+        indexs_list.append(indices)
 
     layout_bbox_list = []
 
+    table_type_order = {
+        'table_caption': 1,
+        'table_body': 2,
+        'table_footnote': 3
+    }
     for page in pdf_info:
         page_block_list = []
         for block in page['para_blocks']:
-            bbox = block['bbox']
-            page_block_list.append(bbox)
+            if block['type'] in [
+                BlockType.Text,
+                BlockType.Title,
+                BlockType.InterlineEquation,
+                BlockType.List,
+                BlockType.Index,
+            ]:
+                bbox = block['bbox']
+                page_block_list.append(bbox)
+            elif block['type'] in [BlockType.Image]:
+                for sub_block in block['blocks']:
+                    bbox = sub_block['bbox']
+                    page_block_list.append(bbox)
+            elif block['type'] in [BlockType.Table]:
+                sorted_blocks = sorted(block['blocks'], key=lambda x: table_type_order[x['type']])
+                for sub_block in sorted_blocks:
+                    bbox = sub_block['bbox']
+                    page_block_list.append(bbox)
+
         layout_bbox_list.append(page_block_list)
 
     pdf_docs = fitz.open('pdf', pdf_bytes)
 
     for i, page in enumerate(pdf_docs):
 
-        draw_bbox_without_number(i, dropped_bbox_list, page, [158, 158, 158],
-                                 True)
-        draw_bbox_without_number(i, tables_list, page, [153, 153, 0],
-                                 True)  # color !
-        draw_bbox_without_number(i, tables_body_list, page, [204, 204, 0],
-                                 True)
-        draw_bbox_without_number(i, tables_caption_list, page, [255, 255, 102],
-                                 True)
-        draw_bbox_without_number(i, tables_footnote_list, page,
-                                 [229, 255, 204], True)
-        draw_bbox_without_number(i, imgs_list, page, [51, 102, 0], True)
+        draw_bbox_without_number(i, dropped_bbox_list, page, [158, 158, 158], True)
+        # draw_bbox_without_number(i, tables_list, page, [153, 153, 0], True)  # color !
+        draw_bbox_without_number(i, tables_body_list, page, [204, 204, 0], True)
+        draw_bbox_without_number(i, tables_caption_list, page, [255, 255, 102], True)
+        draw_bbox_without_number(i, tables_footnote_list, page, [229, 255, 204], True)
+        # draw_bbox_without_number(i, imgs_list, page, [51, 102, 0], True)
         draw_bbox_without_number(i, imgs_body_list, page, [153, 255, 51], True)
-        draw_bbox_without_number(i, imgs_caption_list, page, [102, 178, 255],
-                                 True)
-        draw_bbox_without_number(i, imgs_footnote_list, page, [255, 178, 102],
-                              True),
+        draw_bbox_without_number(i, imgs_caption_list, page, [102, 178, 255], True)
+        draw_bbox_without_number(i, imgs_footnote_list, page, [255, 178, 102], True),
         draw_bbox_without_number(i, titles_list, page, [102, 102, 255], True)
         draw_bbox_without_number(i, texts_list, page, [153, 0, 76], True)
-        draw_bbox_without_number(i, interequations_list, page, [0, 255, 0],
-                                 True)
+        draw_bbox_without_number(i, interequations_list, page, [0, 255, 0], True)
+        draw_bbox_without_number(i, lists_list, page, [40, 169, 92], True)
+        draw_bbox_without_number(i, indexs_list, page, [40, 169, 92], True)
 
-        draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False, draw_bbox=False)
+        draw_bbox_with_number(
+            i, layout_bbox_list, page, [255, 0, 0], False, draw_bbox=False
+        )
 
     # Save the PDF
     pdf_docs.save(f'{out_path}/{filename}_layout.pdf')
@@ -224,6 +254,8 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
                 BlockType.Text,
                 BlockType.Title,
                 BlockType.InterlineEquation,
+                BlockType.List,
+                BlockType.Index,
             ]:
                 for line in block['lines']:
                     for span in line['spans']:
@@ -260,7 +292,7 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
     texts_list = []
     interequations_list = []
     pdf_docs = fitz.open('pdf', pdf_bytes)
-    magic_model = MagicModel(model_list, pdf_docs)
+    magic_model = MagicModel(model_list, PymuDocDataset(pdf_bytes))
     for i in range(len(model_list)):
         page_dropped_list = []
         tables_body, tables_caption, tables_footnote = [], [], []
@@ -286,8 +318,7 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
                 imgs_body.append(bbox)
             elif layout_det['category_id'] == CategoryId.ImageCaption:
                 imgs_caption.append(bbox)
-            elif layout_det[
-                'category_id'] == CategoryId.InterlineEquation_YOLO:
+            elif layout_det['category_id'] == CategoryId.InterlineEquation_YOLO:
                 interequations.append(bbox)
             elif layout_det['category_id'] == CategoryId.Abandon:
                 page_dropped_list.append(bbox)
@@ -306,18 +337,15 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
         imgs_footnote_list.append(imgs_footnote)
 
     for i, page in enumerate(pdf_docs):
-        draw_bbox_with_number(i, dropped_bbox_list, page, [158, 158, 158],
-                              True)  # color !
+        draw_bbox_with_number(
+            i, dropped_bbox_list, page, [158, 158, 158], True
+        )  # color !
         draw_bbox_with_number(i, tables_body_list, page, [204, 204, 0], True)
-        draw_bbox_with_number(i, tables_caption_list, page, [255, 255, 102],
-                              True)
-        draw_bbox_with_number(i, tables_footnote_list, page, [229, 255, 204],
-                              True)
+        draw_bbox_with_number(i, tables_caption_list, page, [255, 255, 102], True)
+        draw_bbox_with_number(i, tables_footnote_list, page, [229, 255, 204], True)
         draw_bbox_with_number(i, imgs_body_list, page, [153, 255, 51], True)
-        draw_bbox_with_number(i, imgs_caption_list, page, [102, 178, 255],
-                              True)
-        draw_bbox_with_number(i, imgs_footnote_list, page, [255, 178, 102],
-                              True)
+        draw_bbox_with_number(i, imgs_caption_list, page, [102, 178, 255], True)
+        draw_bbox_with_number(i, imgs_footnote_list, page, [255, 178, 102], True)
         draw_bbox_with_number(i, titles_list, page, [102, 102, 255], True)
         draw_bbox_with_number(i, texts_list, page, [153, 0, 76], True)
         draw_bbox_with_number(i, interequations_list, page, [0, 255, 0], True)
@@ -332,19 +360,23 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
     for page in pdf_info:
         page_line_list = []
         for block in page['preproc_blocks']:
-            if block['type'] in ['text', 'title', 'interline_equation']:
+            if block['type'] in [BlockType.Text, BlockType.Title, BlockType.InterlineEquation]:
                 for line in block['lines']:
                     bbox = line['bbox']
                     index = line['index']
                     page_line_list.append({'index': index, 'bbox': bbox})
-            if block['type'] in ['table', 'image']:
-                bbox = block['bbox']
-                index = block['index']
-                page_line_list.append({'index': index, 'bbox': bbox})
-            # for line in block['lines']:
-            #     bbox = line['bbox']
-            #     index = line['index']
-            #     page_line_list.append({'index': index, 'bbox': bbox})
+            if block['type'] in [BlockType.Image, BlockType.Table]:
+                for sub_block in block['blocks']:
+                    if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
+                        for line in sub_block['virtual_lines']:
+                            bbox = line['bbox']
+                            index = line['index']
+                            page_line_list.append({'index': index, 'bbox': bbox})
+                    elif sub_block['type'] in [BlockType.ImageCaption, BlockType.TableCaption, BlockType.ImageFootnote, BlockType.TableFootnote]:
+                        for line in sub_block['lines']:
+                            bbox = line['bbox']
+                            index = line['index']
+                            page_line_list.append({'index': index, 'bbox': bbox})
         sorted_bboxes = sorted(page_line_list, key=lambda x: x['index'])
         layout_bbox_list.append(sorted_bbox['bbox'] for sorted_bbox in sorted_bboxes)
     pdf_docs = fitz.open('pdf', pdf_bytes)

+ 2 - 0
magic_pdf/libs/ocr_content_type.py

@@ -20,6 +20,8 @@ class BlockType:
     InterlineEquation = 'interline_equation'
     Footnote = 'footnote'
     Discarded = 'discarded'
+    List = 'list'
+    Index = 'index'
 
 
 class CategoryId:

+ 71 - 31
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -4,7 +4,9 @@ import fitz
 import numpy as np
 from loguru import logger
 
-from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config
+from magic_pdf.libs.clean_memory import clean_memory
+from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config, get_layout_config, \
+    get_formula_config
 from magic_pdf.model.model_list import MODEL
 import magic_pdf.model as model_config
 
@@ -23,7 +25,7 @@ def remove_duplicates_dicts(lst):
     return unique_dicts
 
 
-def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
+def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None) -> list:
     try:
         from PIL import Image
     except ImportError:
@@ -32,18 +34,28 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
 
     images = []
     with fitz.open("pdf", pdf_bytes) as doc:
+        pdf_page_num = doc.page_count
+        end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else pdf_page_num - 1
+        if end_page_id > pdf_page_num - 1:
+            logger.warning("end_page_id is out of range, use images length")
+            end_page_id = pdf_page_num - 1
+
         for index in range(0, doc.page_count):
-            page = doc[index]
-            mat = fitz.Matrix(dpi / 72, dpi / 72)
-            pm = page.get_pixmap(matrix=mat, alpha=False)
+            if start_page_id <= index <= end_page_id:
+                page = doc[index]
+                mat = fitz.Matrix(dpi / 72, dpi / 72)
+                pm = page.get_pixmap(matrix=mat, alpha=False)
+
+                # If the width or height exceeds 9000 after scaling, do not scale further.
+                if pm.width > 9000 or pm.height > 9000:
+                    pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
 
-            # If the width or height exceeds 9000 after scaling, do not scale further.
-            if pm.width > 9000 or pm.height > 9000:
-                pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
+                img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples)
+                img = np.array(img)
+                img_dict = {"img": img, "width": pm.width, "height": pm.height}
+            else:
+                img_dict = {"img": [], "width": 0, "height": 0}
 
-            img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples)
-            img = np.array(img)
-            img_dict = {"img": img, "width": pm.width, "height": pm.height}
             images.append(img_dict)
     return images
 
@@ -57,14 +69,17 @@ class ModelSingleton:
             cls._instance = super().__new__(cls)
         return cls._instance
 
-    def get_model(self, ocr: bool, show_log: bool, lang=None):
-        key = (ocr, show_log, lang)
+    def get_model(self, ocr: bool, show_log: bool, lang=None, layout_model=None, formula_enable=None, table_enable=None):
+        key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
         if key not in self._models:
-            self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang)
+            self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang, layout_model=layout_model,
+                                                  formula_enable=formula_enable, table_enable=table_enable)
         return self._models[key]
 
 
-def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
+def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None,
+                      layout_model=None, formula_enable=None, table_enable=None):
+
     model = None
 
     if model_config.__model_mode__ == "lite":
@@ -84,14 +99,30 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
             # 从配置文件读取model-dir和device
             local_models_dir = get_local_models_dir()
             device = get_device()
+
+            layout_config = get_layout_config()
+            if layout_model is not None:
+                layout_config["model"] = layout_model
+
+            formula_config = get_formula_config()
+            if formula_enable is not None:
+                formula_config["enable"] = formula_enable
+
             table_config = get_table_recog_config()
-            model_input = {"ocr": ocr,
-                           "show_log": show_log,
-                           "models_dir": local_models_dir,
-                           "device": device,
-                           "table_config": table_config,
-                           "lang": lang,
-                           }
+            if table_enable is not None:
+                table_config["enable"] = table_enable
+
+            model_input = {
+                            "ocr": ocr,
+                            "show_log": show_log,
+                            "models_dir": local_models_dir,
+                            "device": device,
+                            "table_config": table_config,
+                            "layout_config": layout_config,
+                            "formula_config": formula_config,
+                            "lang": lang,
+            }
+
             custom_model = CustomPEKModel(**model_input)
         else:
             logger.error("Not allow model_name!")
@@ -106,19 +137,23 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
 
 
 def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
-                start_page_id=0, end_page_id=None, lang=None):
+                start_page_id=0, end_page_id=None, lang=None,
+                layout_model=None, formula_enable=None, table_enable=None):
 
-    model_manager = ModelSingleton()
-    custom_model = model_manager.get_model(ocr, show_log, lang)
+    if lang == "":
+        lang = None
 
-    images = load_images_from_pdf(pdf_bytes)
+    model_manager = ModelSingleton()
+    custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable)
 
-    # end_page_id = end_page_id if end_page_id else len(images) - 1
-    end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(images) - 1
+    with fitz.open("pdf", pdf_bytes) as doc:
+        pdf_page_num = doc.page_count
+        end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else pdf_page_num - 1
+        if end_page_id > pdf_page_num - 1:
+            logger.warning("end_page_id is out of range, use images length")
+            end_page_id = pdf_page_num - 1
 
-    if end_page_id > len(images) - 1:
-        logger.warning("end_page_id is out of range, use images length")
-        end_page_id = len(images) - 1
+    images = load_images_from_pdf(pdf_bytes, start_page_id=start_page_id, end_page_id=end_page_id)
 
     model_json = []
     doc_analyze_start = time.time()
@@ -135,6 +170,11 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
         page_dict = {"layout_dets": result, "page_info": page_info}
         model_json.append(page_dict)
 
+    gc_start = time.time()
+    clean_memory()
+    gc_time = round(time.time() - gc_start, 2)
+    logger.info(f"gc time: {gc_time}")
+
     doc_analyze_time = round(time.time() - doc_analyze_start, 2)
     doc_analyze_speed = round( (end_page_id + 1 - start_page_id) / doc_analyze_time, 2)
     logger.info(f"doc analyze time: {round(time.time() - doc_analyze_start, 2)},"

+ 265 - 22
magic_pdf/model/magic_model.py

@@ -1,5 +1,6 @@
 import json
 
+from magic_pdf.data.dataset import Dataset
 from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
                                     bbox_relative_pos, box_area, calculate_iou,
                                     calculate_overlap_area_in_bbox1_area_ratio,
@@ -9,6 +10,7 @@ from magic_pdf.libs.coordinate_transform import get_scale_ratio
 from magic_pdf.libs.local_math import float_gt
 from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
 from magic_pdf.libs.ocr_content_type import CategoryId, ContentType
+from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox
 from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
 from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
 
@@ -24,7 +26,7 @@ class MagicModel:
             need_remove_list = []
             page_no = model_page_info['page_info']['page_no']
             horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio(
-                model_page_info, self.__docs[page_no]
+                model_page_info, self.__docs.get_page(page_no)
             )
             layout_dets = model_page_info['layout_dets']
             for layout_det in layout_dets:
@@ -99,7 +101,7 @@ class MagicModel:
             for need_remove in need_remove_list:
                 layout_dets.remove(need_remove)
 
-    def __init__(self, model_list: list, docs: fitz.Document):
+    def __init__(self, model_list: list, docs: Dataset):
         self.__model_list = model_list
         self.__docs = docs
         """为所有模型数据添加bbox信息(缩放,poly->bbox)"""
@@ -119,15 +121,13 @@ class MagicModel:
         if left or right:
             l1 = bbox1[3] - bbox1[1]
             l2 = bbox2[3] - bbox2[1]
-            minL, maxL = min(l1, l2), max(l1, l2)
-            if (maxL - minL) / minL > 0.5:
-                return float('inf')
-        if bottom or top:
+        else:
             l1 = bbox1[2] - bbox1[0]
             l2 = bbox2[2] - bbox2[0]
-            minL, maxL = min(l1, l2), max(l1, l2)
-            if (maxL - minL) / minL > 0.5:
-                return float('inf')
+
+        if l2 > l1 and (l2 - l1) / l1 > 0.3:
+            return float('inf')
+
         return bbox_distance(bbox1, bbox2)
 
     def __fix_footnote(self):
@@ -215,9 +215,8 @@ class MagicModel:
         筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
         再求出筛选出的 subjects 和 object 的最短距离
         """
-        def search_overlap_between_boxes(
-            subject_idx, object_idx
-        ):
+
+        def search_overlap_between_boxes(subject_idx, object_idx):
             idxes = [subject_idx, object_idx]
             x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes]
             y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes]
@@ -245,9 +244,9 @@ class MagicModel:
             for other_object in other_objects:
                 ratio = max(
                     ratio,
-                    get_overlap_area(
-                        merged_bbox, other_object['bbox']
-                    ) * 1.0 / box_area(all_bboxes[object_idx]['bbox'])
+                    get_overlap_area(merged_bbox, other_object['bbox'])
+                    * 1.0
+                    / box_area(all_bboxes[object_idx]['bbox']),
                 )
                 if ratio >= MERGE_BOX_OVERLAP_AREA_RATIO:
                     break
@@ -365,12 +364,17 @@ class MagicModel:
                 if all_bboxes[j]['category_id'] == subject_category_id:
                     subject_idx, object_idx = j, i
 
-                if search_overlap_between_boxes(subject_idx, object_idx) >= MERGE_BOX_OVERLAP_AREA_RATIO:
+                if (
+                    search_overlap_between_boxes(subject_idx, object_idx)
+                    >= MERGE_BOX_OVERLAP_AREA_RATIO
+                ):
                     dis[i][j] = float('inf')
                     dis[j][i] = dis[i][j]
                     continue
 
-                dis[i][j] = self._bbox_distance(all_bboxes[i]['bbox'], all_bboxes[j]['bbox'])
+                dis[i][j] = self._bbox_distance(
+                    all_bboxes[subject_idx]['bbox'], all_bboxes[object_idx]['bbox']
+                )
                 dis[j][i] = dis[i][j]
 
         used = set()
@@ -461,7 +465,7 @@ class MagicModel:
 
                     if is_nearest:
                         nx0, ny0, nx1, ny1 = expand_bbbox(list(seen) + [k])
-                        n_dis = self._bbox_distance(
+                        n_dis = bbox_distance(
                             all_bboxes[i]['bbox'], [nx0, ny0, nx1, ny1]
                         )
                         if float_gt(dis[i][j], n_dis):
@@ -557,7 +561,7 @@ class MagicModel:
         # 计算已经配对的 distance 距离
         for i in subject_object_relation_map.keys():
             for j in subject_object_relation_map[i]:
-                total_subject_object_dis += self._bbox_distance(
+                total_subject_object_dis += bbox_distance(
                     all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
                 )
 
@@ -586,6 +590,245 @@ class MagicModel:
                 with_caption_subject.add(j)
         return ret, total_subject_object_dis
 
+    def __tie_up_category_by_distance_v2(
+        self, page_no, subject_category_id, object_category_id
+    ):
+
+        AXIS_MULPLICITY = 0.5
+        subjects = self.__reduct_overlap(
+            list(
+                map(
+                    lambda x: {'bbox': x['bbox'], 'score': x['score']},
+                    filter(
+                        lambda x: x['category_id'] == subject_category_id,
+                        self.__model_list[page_no]['layout_dets'],
+                    ),
+                )
+            )
+        )
+
+        objects = self.__reduct_overlap(
+            list(
+                map(
+                    lambda x: {'bbox': x['bbox'], 'score': x['score']},
+                    filter(
+                        lambda x: x['category_id'] == object_category_id,
+                        self.__model_list[page_no]['layout_dets'],
+                    ),
+                )
+            )
+        )
+        M = len(objects)
+
+        subjects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
+        objects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
+
+        sub_obj_map_h = {i: [] for i in range(len(subjects))}
+
+        dis_by_directions = {
+            'top': [[-1, float('inf')]] * M,
+            'bottom': [[-1, float('inf')]] * M,
+            'left': [[-1, float('inf')]] * M,
+            'right': [[-1, float('inf')]] * M,
+        }
+
+        for i, obj in enumerate(objects):
+            l_x_axis, l_y_axis = (
+                obj['bbox'][2] - obj['bbox'][0],
+                obj['bbox'][3] - obj['bbox'][1],
+            )
+            axis_unit = min(l_x_axis, l_y_axis)
+            for j, sub in enumerate(subjects):
+
+                bbox1, bbox2, _ = _remove_overlap_between_bbox(
+                    objects[i]['bbox'], subjects[j]['bbox']
+                )
+                left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
+                flags = [left, right, bottom, top]
+                if sum([1 if v else 0 for v in flags]) > 1:
+                    continue
+
+                if left:
+                    if dis_by_directions['left'][i][1] > bbox_distance(
+                        obj['bbox'], sub['bbox']
+                    ):
+                        dis_by_directions['left'][i] = [
+                            j,
+                            bbox_distance(obj['bbox'], sub['bbox']),
+                        ]
+                if right:
+                    if dis_by_directions['right'][i][1] > bbox_distance(
+                        obj['bbox'], sub['bbox']
+                    ):
+                        dis_by_directions['right'][i] = [
+                            j,
+                            bbox_distance(obj['bbox'], sub['bbox']),
+                        ]
+                if bottom:
+                    if dis_by_directions['bottom'][i][1] > bbox_distance(
+                        obj['bbox'], sub['bbox']
+                    ):
+                        dis_by_directions['bottom'][i] = [
+                            j,
+                            bbox_distance(obj['bbox'], sub['bbox']),
+                        ]
+                if top:
+                    if dis_by_directions['top'][i][1] > bbox_distance(
+                        obj['bbox'], sub['bbox']
+                    ):
+                        dis_by_directions['top'][i] = [
+                            j,
+                            bbox_distance(obj['bbox'], sub['bbox']),
+                        ]
+            if dis_by_directions['left'][i][1] != float('inf') or dis_by_directions[
+                'right'
+            ][i][1] != float('inf'):
+                if dis_by_directions['left'][i][1] != float(
+                    'inf'
+                ) and dis_by_directions['right'][i][1] != float('inf'):
+                    if AXIS_MULPLICITY * axis_unit >= abs(
+                        dis_by_directions['left'][i][1]
+                        - dis_by_directions['right'][i][1]
+                    ):
+                        left_sub_bbox = subjects[dis_by_directions['left'][i][0]][
+                            'bbox'
+                        ]
+                        right_sub_bbox = subjects[dis_by_directions['right'][i][0]][
+                            'bbox'
+                        ]
+
+                        left_sub_bbox_y_axis = left_sub_bbox[3] - left_sub_bbox[1]
+                        right_sub_bbox_y_axis = right_sub_bbox[3] - right_sub_bbox[1]
+
+                        if (
+                            abs(left_sub_bbox_y_axis - l_y_axis)
+                            + dis_by_directions['left'][i][0]
+                            > abs(right_sub_bbox_y_axis - l_y_axis)
+                            + dis_by_directions['right'][i][0]
+                        ):
+                            left_or_right = dis_by_directions['right'][i]
+                        else:
+                            left_or_right = dis_by_directions['left'][i]
+                    else:
+                        left_or_right = dis_by_directions['left'][i]
+                        if left_or_right[1] > dis_by_directions['right'][i][1]:
+                            left_or_right = dis_by_directions['right'][i]
+                else:
+                    left_or_right = dis_by_directions['left'][i]
+                    if left_or_right[1] == float('inf'):
+                        left_or_right = dis_by_directions['right'][i]
+            else:
+                left_or_right = [-1, float('inf')]
+
+            if dis_by_directions['top'][i][1] != float('inf') or dis_by_directions[
+                'bottom'
+            ][i][1] != float('inf'):
+                if dis_by_directions['top'][i][1] != float('inf') and dis_by_directions[
+                    'bottom'
+                ][i][1] != float('inf'):
+                    if AXIS_MULPLICITY * axis_unit >= abs(
+                        dis_by_directions['top'][i][1]
+                        - dis_by_directions['bottom'][i][1]
+                    ):
+                        top_bottom = subjects[dis_by_directions['bottom'][i][0]]['bbox']
+                        bottom_top = subjects[dis_by_directions['top'][i][0]]['bbox']
+
+                        top_bottom_x_axis = top_bottom[2] - top_bottom[0]
+                        bottom_top_x_axis = bottom_top[2] - bottom_top[0]
+                        if abs(top_bottom_x_axis - l_x_axis) + dis_by_directions['bottom'][i][1] > abs(
+                            bottom_top_x_axis - l_x_axis
+                        ) + dis_by_directions['top'][i][1]:
+                            top_or_bottom = dis_by_directions['top'][i]
+                        else:
+                            top_or_bottom = dis_by_directions['bottom'][i]
+                    else:
+                        top_or_bottom = dis_by_directions['top'][i]
+                        if top_or_bottom[1] > dis_by_directions['bottom'][i][1]:
+                            top_or_bottom = dis_by_directions['bottom'][i]
+                else:
+                    top_or_bottom = dis_by_directions['top'][i]
+                    if top_or_bottom[1] == float('inf'):
+                        top_or_bottom = dis_by_directions['bottom'][i]
+            else:
+                top_or_bottom = [-1, float('inf')]
+
+            if left_or_right[1] != float('inf') or top_or_bottom[1] != float('inf'):
+                if left_or_right[1] != float('inf') and top_or_bottom[1] != float(
+                    'inf'
+                ):
+                    if AXIS_MULPLICITY * axis_unit >= abs(
+                        left_or_right[1] - top_or_bottom[1]
+                    ):
+                        y_axis_bbox = subjects[left_or_right[0]]['bbox']
+                        x_axis_bbox = subjects[top_or_bottom[0]]['bbox']
+
+                        if (
+                            abs((x_axis_bbox[2] - x_axis_bbox[0]) - l_x_axis) / l_x_axis
+                            > abs((y_axis_bbox[3] - y_axis_bbox[1]) - l_y_axis)
+                            / l_y_axis
+                        ):
+                            sub_obj_map_h[left_or_right[0]].append(i)
+                        else:
+                            sub_obj_map_h[top_or_bottom[0]].append(i)
+                    else:
+                        if left_or_right[1] > top_or_bottom[1]:
+                            sub_obj_map_h[top_or_bottom[0]].append(i)
+                        else:
+                            sub_obj_map_h[left_or_right[0]].append(i)
+                else:
+                    if left_or_right[1] != float('inf'):
+                        sub_obj_map_h[left_or_right[0]].append(i)
+                    else:
+                        sub_obj_map_h[top_or_bottom[0]].append(i)
+        ret = []
+        for i in sub_obj_map_h.keys():
+            ret.append(
+                {
+                    'sub_bbox': {
+                        'bbox': subjects[i]['bbox'],
+                        'score': subjects[i]['score'],
+                    },
+                    'obj_bboxes': [
+                        {'score': objects[j]['score'], 'bbox': objects[j]['bbox']}
+                        for j in sub_obj_map_h[i]
+                    ],
+                    'sub_idx': i,
+                }
+            )
+        return ret
+
+    def get_imgs_v2(self, page_no: int):
+        with_captions = self.__tie_up_category_by_distance_v2(page_no, 3, 4)
+        with_footnotes = self.__tie_up_category_by_distance_v2(
+            page_no, 3, CategoryId.ImageFootnote
+        )
+        ret = []
+        for v in with_captions:
+            record = {
+                'image_body': v['sub_bbox'],
+                'image_caption_list': v['obj_bboxes'],
+            }
+            filter_idx = v['sub_idx']
+            d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
+            record['image_footnote_list'] = d['obj_bboxes']
+            ret.append(record)
+        return ret
+
+    def get_tables_v2(self, page_no: int) -> list:
+        with_captions = self.__tie_up_category_by_distance_v2(page_no, 5, 6)
+        with_footnotes = self.__tie_up_category_by_distance_v2(page_no, 5, 7)
+        ret = []
+        for v in with_captions:
+            record = {
+                'table_body': v['sub_bbox'],
+                'table_caption_list': v['obj_bboxes'],
+            }
+            filter_idx = v['sub_idx']
+            d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
+            record['table_footnote_list'] = d['obj_bboxes']
+            ret.append(record)
+        return ret
+
     def get_imgs(self, page_no: int):
         with_captions, _ = self.__tie_up_category_by_distance(page_no, 3, 4)
         with_footnotes, _ = self.__tie_up_category_by_distance(
@@ -719,10 +962,10 @@ class MagicModel:
 
     def get_page_size(self, page_no: int):  # 获取页面宽高
         # 获取当前页的page对象
-        page = self.__docs[page_no]
+        page = self.__docs.get_page(page_no).get_page_info()
         # 获取当前页的宽高
-        page_w = page.rect.width
-        page_h = page.rect.height
+        page_w = page.w
+        page_h = page.h
         return page_w, page_h
 
     def __get_blocks_by_type(

+ 86 - 47
magic_pdf/model/pdf_extract_kit.py

@@ -26,6 +26,7 @@ try:
     from unimernet.common.config import Config
     import unimernet.tasks as tasks
     from unimernet.processors import load_processor
+    from doclayout_yolo import YOLOv10
 
 except ImportError as e:
     logger.exception(e)
@@ -42,7 +43,7 @@ from magic_pdf.model.ppTableModel import ppTableModel
 
 
 def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
-    if table_model_type == STRUCT_EQTABLE:
+    if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
         table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
     else:
         config = {
@@ -83,11 +84,16 @@ def layout_model_init(weight, config_file, device):
     return model
 
 
-def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None):
+def doclayout_yolo_model_init(weight):
+    model = YOLOv10(weight)
+    return model
+
+
+def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None, use_dilation=True, det_db_unclip_ratio=1.8):
     if lang is not None:
-        model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang)
+        model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
     else:
-        model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh)
+        model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
     return model
 
 
@@ -120,19 +126,27 @@ class AtomModelSingleton:
         return cls._instance
 
     def get_atom_model(self, atom_model_name: str, **kwargs):
-        if atom_model_name not in self._models:
-            self._models[atom_model_name] = atom_model_init(model_name=atom_model_name, **kwargs)
-        return self._models[atom_model_name]
+        lang = kwargs.get("lang", None)
+        layout_model_name = kwargs.get("layout_model_name", None)
+        key = (atom_model_name, layout_model_name, lang)
+        if key not in self._models:
+            self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
+        return self._models[key]
 
 
 def atom_model_init(model_name: str, **kwargs):
 
     if model_name == AtomicModel.Layout:
-        atom_model = layout_model_init(
-            kwargs.get("layout_weights"),
-            kwargs.get("layout_config_file"),
-            kwargs.get("device")
-        )
+        if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
+            atom_model = layout_model_init(
+                kwargs.get("layout_weights"),
+                kwargs.get("layout_config_file"),
+                kwargs.get("device")
+            )
+        elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
+            atom_model = doclayout_yolo_model_init(
+                kwargs.get("doclayout_yolo_weights"),
+            )
     elif model_name == AtomicModel.MFD:
         atom_model = mfd_model_init(
             kwargs.get("mfd_weights")
@@ -151,7 +165,7 @@ def atom_model_init(model_name: str, **kwargs):
         )
     elif model_name == AtomicModel.Table:
         atom_model = table_model_init(
-            kwargs.get("table_model_type"),
+            kwargs.get("table_model_name"),
             kwargs.get("table_model_path"),
             kwargs.get("table_max_time"),
             kwargs.get("device")
@@ -199,23 +213,35 @@ class CustomPEKModel:
         with open(config_path, "r", encoding='utf-8') as f:
             self.configs = yaml.load(f, Loader=yaml.FullLoader)
         # 初始化解析配置
-        self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
-        self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
+
+        # layout config
+        self.layout_config = kwargs.get("layout_config")
+        self.layout_model_name = self.layout_config.get("model", MODEL_NAME.DocLayout_YOLO)
+
+        # formula config
+        self.formula_config = kwargs.get("formula_config")
+        self.mfd_model_name = self.formula_config.get("mfd_model", MODEL_NAME.YOLO_V8_MFD)
+        self.mfr_model_name = self.formula_config.get("mfr_model", MODEL_NAME.UniMerNet_v2_Small)
+        self.apply_formula = self.formula_config.get("enable", True)
+
         # table config
-        self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"])
-        self.apply_table = self.table_config.get("is_table_recog_enable", False)
+        self.table_config = kwargs.get("table_config")
+        self.apply_table = self.table_config.get("enable", False)
         self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
-        self.table_model_type = self.table_config.get("model", TABLE_MASTER)
+        self.table_model_name = self.table_config.get("model", MODEL_NAME.TABLE_MASTER)
+
+        # ocr config
         self.apply_ocr = ocr
         self.lang = kwargs.get("lang", None)
+
         logger.info(
-            "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}, lang: {}".format(
-                self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table, self.lang
+            "DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
+            "apply_table: {}, table_model: {}, lang: {}".format(
+                self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name, self.lang
             )
         )
-        assert self.apply_layout, "DocAnalysis must contain layout model."
         # 初始化解析方案
-        self.device = kwargs.get("device", self.configs["config"]["device"])
+        self.device = kwargs.get("device", "cpu")
         logger.info("using device: {}".format(self.device))
         models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
         logger.info("using models_dir: {}".format(models_dir))
@@ -224,17 +250,16 @@ class CustomPEKModel:
 
         # 初始化公式识别
         if self.apply_formula:
+
             # 初始化公式检测模型
-            # self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
             self.mfd_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.MFD,
-                mfd_weights=str(os.path.join(models_dir, self.configs["weights"]["mfd"]))
+                mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name]))
             )
+
             # 初始化公式解析模型
-            mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
+            mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
             mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
-            # self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
-            # self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
             self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.MFR,
                 mfr_weight_dir=mfr_weight_dir,
@@ -243,17 +268,20 @@ class CustomPEKModel:
             )
 
         # 初始化layout模型
-        # self.layout_model = Layoutlmv3_Predictor(
-        #     str(os.path.join(models_dir, self.configs['weights']['layout'])),
-        #     str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
-        #     device=self.device
-        # )
-        self.layout_model = atom_model_manager.get_atom_model(
-            atom_model_name=AtomicModel.Layout,
-            layout_weights=str(os.path.join(models_dir, self.configs['weights']['layout'])),
-            layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
-            device=self.device
-        )
+        if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
+            self.layout_model = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.Layout,
+                layout_model_name=MODEL_NAME.LAYOUTLMv3,
+                layout_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
+                layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
+                device=self.device
+            )
+        elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
+            self.layout_model = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.Layout,
+                layout_model_name=MODEL_NAME.DocLayout_YOLO,
+                doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
+            )
         # 初始化ocr
         if self.apply_ocr:
 
@@ -266,12 +294,10 @@ class CustomPEKModel:
             )
         # init table model
         if self.apply_table:
-            table_model_dir = self.configs["weights"][self.table_model_type]
-            # self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
-            #                                     max_time=self.table_max_time, _device_=self.device)
+            table_model_dir = self.configs["weights"][self.table_model_name]
             self.table_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.Table,
-                table_model_type=self.table_model_type,
+                table_model_name=self.table_model_name,
                 table_model_path=str(os.path.join(models_dir, table_model_dir)),
                 table_max_time=self.table_max_time,
                 device=self.device
@@ -288,7 +314,21 @@ class CustomPEKModel:
 
         # layout检测
         layout_start = time.time()
-        layout_res = self.layout_model(image, ignore_catids=[])
+        if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
+            # layoutlmv3
+            layout_res = self.layout_model(image, ignore_catids=[])
+        elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
+            # doclayout_yolo
+            layout_res = []
+            doclayout_yolo_res = self.layout_model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
+            for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), doclayout_yolo_res.boxes.cls.cpu()):
+                xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
+                new_item = {
+                    'category_id': int(cla.item()),
+                    'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
+                    'score': round(float(conf.item()), 3),
+                }
+                layout_res.append(new_item)
         layout_cost = round(time.time() - layout_start, 2)
         logger.info(f"layout detection time: {layout_cost}")
 
@@ -297,7 +337,7 @@ class CustomPEKModel:
         if self.apply_formula:
             # 公式检测
             mfd_start = time.time()
-            mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
+            mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
             logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
             for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
                 xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
@@ -309,7 +349,6 @@ class CustomPEKModel:
                 }
                 layout_res.append(new_item)
                 latex_filling_list.append(new_item)
-                # bbox_img = get_croped_image(pil_img, [xmin, ymin, xmax, ymax])
                 bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
                 mf_image_list.append(bbox_img)
 
@@ -346,7 +385,7 @@ class CustomPEKModel:
         if torch.cuda.is_available():
             properties = torch.cuda.get_device_properties(self.device)
             total_memory = properties.total_memory / (1024 ** 3)  # 将字节转换为 GB
-            if total_memory <= 8:
+            if total_memory <= 10:
                 gc_start = time.time()
                 clean_memory()
                 gc_time = round(time.time() - gc_start, 2)
@@ -411,7 +450,7 @@ class CustomPEKModel:
                 # logger.info("------------------table recognition processing begins-----------------")
                 latex_code = None
                 html_code = None
-                if self.table_model_type == STRUCT_EQTABLE:
+                if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
                     with torch.no_grad():
                         latex_code = self.table_model.image2latex(new_image)[0]
                 else:

+ 2 - 2
magic_pdf/model/ppTableModel.py

@@ -52,11 +52,11 @@ class ppTableModel(object):
         rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR)
         rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT)
         device = kwargs.get("device", "cpu")
-        use_gpu = True if device == "cuda" else False
+        use_gpu = True if device.startswith("cuda") else False
         config = {
             "use_gpu": use_gpu,
             "table_max_len": kwargs.get("table_max_len", TABLE_MAX_LEN),
-            "table_algorithm": TABLE_MASTER,
+            "table_algorithm": "TableMaster",
             "table_model_dir": table_model_dir,
             "table_char_dict_path": table_char_dict_path,
             "det_model_dir": det_model_dir,

+ 296 - 0
magic_pdf/para/para_split_v3.py

@@ -0,0 +1,296 @@
+import copy
+
+from loguru import logger
+
+from magic_pdf.libs.Constants import LINES_DELETED, CROSS_PAGE
+from magic_pdf.libs.ocr_content_type import BlockType, ContentType
+
+LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';')
+LIST_END_FLAG = ('.', '。', ';', ';')
+
+
+class ListLineTag:
+    IS_LIST_START_LINE = "is_list_start_line"
+    IS_LIST_END_LINE = "is_list_end_line"
+
+
+def __process_blocks(blocks):
+    # 对所有block预处理
+    # 1.通过title和interline_equation将block分组
+    # 2.bbox边界根据line信息重置
+
+    result = []
+    current_group = []
+
+    for i in range(len(blocks)):
+        current_block = blocks[i]
+
+        # 如果当前块是 text 类型
+        if current_block['type'] == 'text':
+            current_block["bbox_fs"] = copy.deepcopy(current_block["bbox"])
+            if 'lines' in current_block and len(current_block["lines"]) > 0:
+                current_block['bbox_fs'] = [min([line['bbox'][0] for line in current_block['lines']]),
+                                            min([line['bbox'][1] for line in current_block['lines']]),
+                                            max([line['bbox'][2] for line in current_block['lines']]),
+                                            max([line['bbox'][3] for line in current_block['lines']])]
+            current_group.append(current_block)
+
+        # 检查下一个块是否存在
+        if i + 1 < len(blocks):
+            next_block = blocks[i + 1]
+            # 如果下一个块不是 text 类型且是 title 或 interline_equation 类型
+            if next_block['type'] in ['title', 'interline_equation']:
+                result.append(current_group)
+                current_group = []
+
+    # 处理最后一个 group
+    if current_group:
+        result.append(current_group)
+
+    return result
+
+
+def __is_list_or_index_block(block):
+    # 一个block如果是list block 应该同时满足以下特征
+    # 1.block内有多个line 2.block 内有多个line左侧顶格写 3.block内有多个line 右侧不顶格(狗牙状)
+    # 1.block内有多个line 2.block 内有多个line左侧顶格写 3.多个line以endflag结尾
+    # 1.block内有多个line 2.block 内有多个line左侧顶格写 3.block内有多个line 左侧不顶格
+
+    # index block 是一种特殊的list block
+    # 一个block如果是index block 应该同时满足以下特征
+    # 1.block内有多个line 2.block 内有多个line两侧均顶格写 3.line的开头或者结尾均为数字
+    if len(block['lines']) >= 2:
+        first_line = block['lines'][0]
+        line_height = first_line['bbox'][3] - first_line['bbox'][1]
+        block_weight = block['bbox_fs'][2] - block['bbox_fs'][0]
+
+        left_close_num = 0
+        left_not_close_num = 0
+        right_not_close_num = 0
+        right_close_num = 0
+        lines_text_list = []
+
+        multiple_para_flag = False
+        last_line = block['lines'][-1]
+        # 如果首行左边不顶格而右边顶格,末行左边顶格而右边不顶格 (第一行可能可以右边不顶格)
+        if (first_line['bbox'][0] - block['bbox_fs'][0] > line_height / 2 and
+                # block['bbox_fs'][2] - first_line['bbox'][2] < line_height and
+                abs(last_line['bbox'][0] - block['bbox_fs'][0]) < line_height / 2 and
+                block['bbox_fs'][2] - last_line['bbox'][2] > line_height
+        ):
+            multiple_para_flag = True
+
+        for line in block['lines']:
+
+            line_text = ""
+
+            for span in line['spans']:
+                span_type = span['type']
+                if span_type == ContentType.Text:
+                    line_text += span['content'].strip()
+
+            lines_text_list.append(line_text)
+
+            # 计算line左侧顶格数量是否大于2,是否顶格用abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height/2 来判断
+            if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
+                left_close_num += 1
+            elif line['bbox'][0] - block['bbox_fs'][0] > line_height:
+                # logger.info(f"{line_text}, {block['bbox_fs']}, {line['bbox']}")
+                left_not_close_num += 1
+
+            # 计算右侧是否顶格
+            if abs(block['bbox_fs'][2] - line['bbox'][2]) < line_height:
+                right_close_num += 1
+            else:
+                # 右侧不顶格情况下是否有一段距离,拍脑袋用0.3block宽度做阈值
+                closed_area = 0.3 * block_weight
+                # closed_area = 5 * line_height
+                if block['bbox_fs'][2] - line['bbox'][2] > closed_area:
+                    right_not_close_num += 1
+
+        # 判断lines_text_list中的元素是否有超过80%都以LIST_END_FLAG结尾
+        line_end_flag = False
+        # 判断lines_text_list中的元素是否有超过80%都以数字开头或都以数字结尾
+        line_num_flag = False
+        num_start_count = 0
+        num_end_count = 0
+        flag_end_count = 0
+        if len(lines_text_list) > 0:
+            for line_text in lines_text_list:
+                if len(line_text) > 0:
+                    if line_text[-1] in LIST_END_FLAG:
+                        flag_end_count += 1
+                    if line_text[0].isdigit():
+                        num_start_count += 1
+                    if line_text[-1].isdigit():
+                        num_end_count += 1
+
+            if flag_end_count / len(lines_text_list) >= 0.8:
+                line_end_flag = True
+
+            if num_start_count / len(lines_text_list) >= 0.8 or num_end_count / len(lines_text_list) >= 0.8:
+                line_num_flag = True
+
+        # 有的目录右侧不贴边, 目前认为左边或者右边有一边全贴边,且符合数字规则极为index
+        if ((left_close_num/len(block['lines']) >= 0.8 or right_close_num/len(block['lines']) >= 0.8)
+                and line_num_flag
+        ):
+            for line in block['lines']:
+                line[ListLineTag.IS_LIST_START_LINE] = True
+            return BlockType.Index
+
+        elif left_close_num >= 2 and (
+                right_not_close_num >= 2 or line_end_flag or left_not_close_num >= 2) and not multiple_para_flag:
+            # 处理一种特殊的没有缩进的list,所有行都贴左边,通过右边的空隙判断是否是item尾
+            if left_close_num / len(block['lines']) > 0.9:
+                # 这种是每个item只有一行,且左边都贴边的短item list
+                if flag_end_count == 0 and right_close_num / len(block['lines']) < 0.5:
+                    for line in block['lines']:
+                        if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
+                            line[ListLineTag.IS_LIST_START_LINE] = True
+                # 这种是大部分line item 都有结束标识符的情况,按结束标识符区分不同item
+                elif line_end_flag:
+                    for i, line in enumerate(block['lines']):
+                        if lines_text_list[i][-1] in LIST_END_FLAG:
+                            line[ListLineTag.IS_LIST_END_LINE] = True
+                            if i + 1 < len(block['lines']):
+                                block['lines'][i+1][ListLineTag.IS_LIST_START_LINE] = True
+                # line item基本没有结束标识符,而且也没有缩进,按右侧空隙判断哪些是item end
+                else:
+                    line_start_flag = False
+                    for i, line in enumerate(block['lines']):
+                        if line_start_flag:
+                            line[ListLineTag.IS_LIST_START_LINE] = True
+                            line_start_flag = False
+                        elif abs(block['bbox_fs'][2] - line['bbox'][2]) > line_height:
+                            line[ListLineTag.IS_LIST_END_LINE] = True
+                            line_start_flag = True
+            # 一种有缩进的特殊有序list,start line 左侧不贴边且以数字开头,end line 以 IS_LIST_END_LINE 结尾且数量和start line 一致
+            elif num_start_count >= 2 and num_start_count == flag_end_count:  # 简单一点先不考虑左侧不贴边的情况
+                for i, line in enumerate(block['lines']):
+                    if lines_text_list[i][0].isdigit():
+                        line[ListLineTag.IS_LIST_START_LINE] = True
+                    if lines_text_list[i][-1] in LIST_END_FLAG:
+                        line[ListLineTag.IS_LIST_END_LINE] = True
+            else:
+                # 正常有缩进的list处理
+                for line in block['lines']:
+                    if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
+                        line[ListLineTag.IS_LIST_START_LINE] = True
+                    if abs(block['bbox_fs'][2] - line['bbox'][2]) > line_height:
+                        line[ListLineTag.IS_LIST_END_LINE] = True
+
+            return BlockType.List
+        else:
+            return BlockType.Text
+    else:
+        return BlockType.Text
+
+
+def __merge_2_text_blocks(block1, block2):
+    if len(block1['lines']) > 0:
+        first_line = block1['lines'][0]
+        line_height = first_line['bbox'][3] - first_line['bbox'][1]
+        block1_weight = block1['bbox'][2] - block1['bbox'][0]
+        block2_weight = block2['bbox'][2] - block2['bbox'][0]
+        min_block_weight = min(block1_weight, block2_weight)
+        if abs(block1['bbox_fs'][0] - first_line['bbox'][0]) < line_height / 2:
+            last_line = block2['lines'][-1]
+            if len(last_line['spans']) > 0:
+                last_span = last_line['spans'][-1]
+                line_height = last_line['bbox'][3] - last_line['bbox'][1]
+                if (abs(block2['bbox_fs'][2] - last_line['bbox'][2]) < line_height and
+                        not last_span['content'].endswith(LINE_STOP_FLAG) and
+                        # 两个block宽度差距超过2倍也不合并
+                        abs(block1_weight - block2_weight) < min_block_weight
+                ):
+                    if block1['page_num'] != block2['page_num']:
+                        for line in block1['lines']:
+                            for span in line['spans']:
+                                span[CROSS_PAGE] = True
+                    block2['lines'].extend(block1['lines'])
+                    block1['lines'] = []
+                    block1[LINES_DELETED] = True
+
+    return block1, block2
+
+
+def __merge_2_list_blocks(block1, block2):
+    if block1['page_num'] != block2['page_num']:
+        for line in block1['lines']:
+            for span in line['spans']:
+                span[CROSS_PAGE] = True
+    block2['lines'].extend(block1['lines'])
+    block1['lines'] = []
+    block1[LINES_DELETED] = True
+
+    return block1, block2
+
+
+def __is_list_group(text_blocks_group):
+    # list group的特征是一个group内的所有block都满足以下条件
+    # 1.每个block都不超过3行 2. 每个block 的左边界都比较接近(逻辑简单点先不加这个规则)
+    for block in text_blocks_group:
+        if len(block['lines']) > 3:
+            return False
+    return True
+
+
+def __para_merge_page(blocks):
+    page_text_blocks_groups = __process_blocks(blocks)
+    for text_blocks_group in page_text_blocks_groups:
+
+        if len(text_blocks_group) > 0:
+            # 需要先在合并前对所有block判断是否为list or index block
+            for block in text_blocks_group:
+                block_type = __is_list_or_index_block(block)
+                block['type'] = block_type
+                # logger.info(f"{block['type']}:{block}")
+
+        if len(text_blocks_group) > 1:
+
+            # 在合并前判断这个group 是否是一个 list group
+            is_list_group = __is_list_group(text_blocks_group)
+
+            # 倒序遍历
+            for i in range(len(text_blocks_group) - 1, -1, -1):
+                current_block = text_blocks_group[i]
+
+                # 检查是否有前一个块
+                if i - 1 >= 0:
+                    prev_block = text_blocks_group[i - 1]
+
+                    if current_block['type'] == 'text' and prev_block['type'] == 'text' and not is_list_group:
+                        __merge_2_text_blocks(current_block, prev_block)
+                    elif (
+                            (current_block['type'] == BlockType.List and prev_block['type'] == BlockType.List) or
+                            (current_block['type'] == BlockType.Index and prev_block['type'] == BlockType.Index)
+                    ):
+                        __merge_2_list_blocks(current_block, prev_block)
+
+        else:
+            continue
+
+
+def para_split(pdf_info_dict, debug_mode=False):
+    all_blocks = []
+    for page_num, page in pdf_info_dict.items():
+        blocks = copy.deepcopy(page['preproc_blocks'])
+        for block in blocks:
+            block['page_num'] = page_num
+        all_blocks.extend(blocks)
+
+    __para_merge_page(all_blocks)
+    for page_num, page in pdf_info_dict.items():
+        page['para_blocks'] = []
+        for block in all_blocks:
+            if block['page_num'] == page_num:
+                page['para_blocks'].append(block)
+
+
+if __name__ == '__main__':
+    input_blocks = []
+    # 调用函数
+    groups = __process_blocks(input_blocks)
+    for group_index, group in enumerate(groups):
+        print(f"Group {group_index}: {group}")

+ 5 - 2
magic_pdf/pdf_parse_by_ocr.py

@@ -1,3 +1,5 @@
+from magic_pdf.config.enums import SupportedPdfParseMethod
+from magic_pdf.data.dataset import PymuDocDataset
 from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
 
 
@@ -8,10 +10,11 @@ def parse_pdf_by_ocr(pdf_bytes,
                      end_page_id=None,
                      debug_mode=False,
                      ):
-    return pdf_parse_union(pdf_bytes,
+    dataset = PymuDocDataset(pdf_bytes)
+    return pdf_parse_union(dataset,
                            model_list,
                            imageWriter,
-                           "ocr",
+                           SupportedPdfParseMethod.OCR,
                            start_page_id=start_page_id,
                            end_page_id=end_page_id,
                            debug_mode=debug_mode,

+ 5 - 2
magic_pdf/pdf_parse_by_txt.py

@@ -1,3 +1,5 @@
+from magic_pdf.config.enums import SupportedPdfParseMethod
+from magic_pdf.data.dataset import PymuDocDataset
 from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
 
 
@@ -9,10 +11,11 @@ def parse_pdf_by_txt(
     end_page_id=None,
     debug_mode=False,
 ):
-    return pdf_parse_union(pdf_bytes,
+    dataset = PymuDocDataset(pdf_bytes)
+    return pdf_parse_union(dataset,
                            model_list,
                            imageWriter,
-                           "txt",
+                           SupportedPdfParseMethod.TXT,
                            start_page_id=start_page_id,
                            end_page_id=end_page_id,
                            debug_mode=debug_mode,

+ 311 - 165
magic_pdf/pdf_parse_union_core_v2.py

@@ -1,42 +1,54 @@
+import copy
+import os
 import statistics
 import time
-
-from loguru import logger
-
 from typing import List
 
 import torch
+from loguru import logger
 
+from magic_pdf.config.enums import SupportedPdfParseMethod
+from magic_pdf.data.dataset import Dataset, PageableData
 from magic_pdf.libs.clean_memory import clean_memory
 from magic_pdf.libs.commons import fitz, get_delta_time
+from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
 from magic_pdf.libs.convert_utils import dict_to_list
 from magic_pdf.libs.drop_reason import DropReason
 from magic_pdf.libs.hash_utils import compute_md5
 from magic_pdf.libs.local_math import float_equal
-from magic_pdf.libs.ocr_content_type import ContentType
+from magic_pdf.libs.ocr_content_type import ContentType, BlockType
 from magic_pdf.model.magic_model import MagicModel
+from magic_pdf.para.para_split_v3 import para_split
 from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker
-from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2
+from magic_pdf.pre_proc.construct_page_dict import \
+    ocr_construct_page_component_v2
 from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
-from magic_pdf.pre_proc.equations_replace import remove_chars_in_text_blocks, replace_equations_in_textblock, \
-    combine_chars_to_pymudict
-from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layout_split_v2
-from magic_pdf.pre_proc.ocr_dict_merge import  fill_spans_in_blocks, fix_block_spans, fix_discarded_block
-from magic_pdf.pre_proc.ocr_span_list_modify import remove_overlaps_min_spans, get_qa_need_list_v2, \
-    remove_overlaps_low_confidence_spans
-from magic_pdf.pre_proc.resolve_bbox_conflict import check_useful_block_horizontal_overlap
+from magic_pdf.pre_proc.equations_replace import (
+    combine_chars_to_pymudict, remove_chars_in_text_blocks,
+    replace_equations_in_textblock)
+from magic_pdf.pre_proc.ocr_detect_all_bboxes import \
+    ocr_prepare_bboxes_for_layout_split_v2
+from magic_pdf.pre_proc.ocr_dict_merge import (fill_spans_in_blocks,
+                                               fix_block_spans,
+                                               fix_discarded_block, fix_block_spans_v2)
+from magic_pdf.pre_proc.ocr_span_list_modify import (
+    get_qa_need_list_v2, remove_overlaps_low_confidence_spans,
+    remove_overlaps_min_spans)
+from magic_pdf.pre_proc.resolve_bbox_conflict import \
+    check_useful_block_horizontal_overlap
 
 
 def remove_horizontal_overlap_block_which_smaller(all_bboxes):
     useful_blocks = []
     for bbox in all_bboxes:
-        useful_blocks.append({
-            "bbox": bbox[:4]
-        })
-    is_useful_block_horz_overlap, smaller_bbox, bigger_bbox = check_useful_block_horizontal_overlap(useful_blocks)
+        useful_blocks.append({'bbox': bbox[:4]})
+    is_useful_block_horz_overlap, smaller_bbox, bigger_bbox = (
+        check_useful_block_horizontal_overlap(useful_blocks)
+    )
     if is_useful_block_horz_overlap:
         logger.warning(
-            f"skip this page, reason: {DropReason.USEFUL_BLOCK_HOR_OVERLAP}, smaller bbox is {smaller_bbox}, bigger bbox is {bigger_bbox}")
+            f'skip this page, reason: {DropReason.USEFUL_BLOCK_HOR_OVERLAP}, smaller bbox is {smaller_bbox}, bigger bbox is {bigger_bbox}'
+        )  # noqa: E501
         for bbox in all_bboxes.copy():
             if smaller_bbox == bbox[:4]:
                 all_bboxes.remove(bbox)
@@ -44,27 +56,27 @@ def remove_horizontal_overlap_block_which_smaller(all_bboxes):
     return is_useful_block_horz_overlap, all_bboxes
 
 
-def __replace_STX_ETX(text_str:str):
-    """ Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks.
-Drawback: This issue is only observed in English text; it has not been found in Chinese text so far.
+def __replace_STX_ETX(text_str: str):
+    """Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks.
+    Drawback: This issue is only observed in English text; it has not been found in Chinese text so far.
 
-    Args:
-        text_str (str): raw text
+        Args:
+            text_str (str): raw text
 
-    Returns:
-        _type_: replaced text
-    """
+        Returns:
+            _type_: replaced text
+    """  # noqa: E501
     if text_str:
         s = text_str.replace('\u0002', "'")
-        s = s.replace("\u0003", "'")
+        s = s.replace('\u0003', "'")
         return s
     return text_str
 
 
 def txt_spans_extract(pdf_page, inline_equations, interline_equations):
-    text_raw_blocks = pdf_page.get_text("dict", flags=fitz.TEXTFLAGS_TEXT)["blocks"]
-    char_level_text_blocks = pdf_page.get_text("rawdict", flags=fitz.TEXTFLAGS_TEXT)[
-        "blocks"
+    text_raw_blocks = pdf_page.get_text('dict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
+    char_level_text_blocks = pdf_page.get_text('rawdict', flags=fitz.TEXTFLAGS_TEXT)[
+        'blocks'
     ]
     text_blocks = combine_chars_to_pymudict(text_raw_blocks, char_level_text_blocks)
     text_blocks = replace_equations_in_textblock(
@@ -74,50 +86,63 @@ def txt_spans_extract(pdf_page, inline_equations, interline_equations):
     text_blocks = remove_chars_in_text_blocks(text_blocks)
     spans = []
     for v in text_blocks:
-        for line in v["lines"]:
-            for span in line["spans"]:
-                bbox = span["bbox"]
+        for line in v['lines']:
+            for span in line['spans']:
+                bbox = span['bbox']
                 if float_equal(bbox[0], bbox[2]) or float_equal(bbox[1], bbox[3]):
                     continue
-                if span.get('type') not in (ContentType.InlineEquation, ContentType.InterlineEquation):
+                if span.get('type') not in (
+                    ContentType.InlineEquation,
+                    ContentType.InterlineEquation,
+                ):
                     spans.append(
                         {
-                            "bbox": list(span["bbox"]),
-                            "content": __replace_STX_ETX(span["text"]),
-                            "type": ContentType.Text,
-                            "score": 1.0,
+                            'bbox': list(span['bbox']),
+                            'content': __replace_STX_ETX(span['text']),
+                            'type': ContentType.Text,
+                            'score': 1.0,
                         }
                     )
     return spans
 
 
 def replace_text_span(pymu_spans, ocr_spans):
-    return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans
+    return list(filter(lambda x: x['type'] != ContentType.Text, ocr_spans)) + pymu_spans
 
 
-def model_init(model_name: str, local_path=None):
+def model_init(model_name: str):
     from transformers import LayoutLMv3ForTokenClassification
+
     if torch.cuda.is_available():
-        device = torch.device("cuda")
+        device = torch.device('cuda')
         if torch.cuda.is_bf16_supported():
             supports_bfloat16 = True
         else:
             supports_bfloat16 = False
     else:
-        device = torch.device("cpu")
+        device = torch.device('cpu')
         supports_bfloat16 = False
 
-    if model_name == "layoutreader":
-        if local_path:
-            model = LayoutLMv3ForTokenClassification.from_pretrained(local_path)
+    if model_name == 'layoutreader':
+        # 检测modelscope的缓存目录是否存在
+        layoutreader_model_dir = get_local_layoutreader_model_dir()
+        if os.path.exists(layoutreader_model_dir):
+            model = LayoutLMv3ForTokenClassification.from_pretrained(
+                layoutreader_model_dir
+            )
         else:
-            model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader")
+            logger.warning(
+                'local layoutreader model not exists, use online model from huggingface'
+            )
+            model = LayoutLMv3ForTokenClassification.from_pretrained(
+                'hantian/layoutreader'
+            )
         # 检查设备是否支持 bfloat16
         if supports_bfloat16:
             model.bfloat16()
         model.to(device).eval()
     else:
-        logger.error("model name not allow")
+        logger.error('model name not allow')
         exit(1)
     return model
 
@@ -131,17 +156,16 @@ class ModelSingleton:
             cls._instance = super().__new__(cls)
         return cls._instance
 
-    def get_model(self, model_name: str, local_path=None):
+    def get_model(self, model_name: str):
         if model_name not in self._models:
-            if local_path:
-                self._models[model_name] = model_init(model_name=model_name, local_path=local_path)
-            else:
-                self._models[model_name] = model_init(model_name=model_name)
+            self._models[model_name] = model_init(model_name=model_name)
         return self._models[model_name]
 
 
 def do_predict(boxes: List[List[int]], model) -> List[int]:
-    from magic_pdf.model.v3.helpers import prepare_inputs, boxes2inputs, parse_logits
+    from magic_pdf.model.v3.helpers import (boxes2inputs, parse_logits,
+                                            prepare_inputs)
+
     inputs = boxes2inputs(boxes)
     inputs = prepare_inputs(inputs, model)
     logits = model(**inputs).logits.cpu().squeeze(0)
@@ -150,19 +174,6 @@ def do_predict(boxes: List[List[int]], model) -> List[int]:
 
 def cal_block_index(fix_blocks, sorted_bboxes):
     for block in fix_blocks:
-        # if block['type'] in ['text', 'title', 'interline_equation']:
-        #     line_index_list = []
-        #     if len(block['lines']) == 0:
-        #         block['index'] = sorted_bboxes.index(block['bbox'])
-        #     else:
-        #         for line in block['lines']:
-        #             line['index'] = sorted_bboxes.index(line['bbox'])
-        #             line_index_list.append(line['index'])
-        #         median_value = statistics.median(line_index_list)
-        #         block['index'] = median_value
-        #
-        # elif block['type'] in ['table', 'image']:
-        #     block['index'] = sorted_bboxes.index(block['bbox'])
 
         line_index_list = []
         if len(block['lines']) == 0:
@@ -174,9 +185,11 @@ def cal_block_index(fix_blocks, sorted_bboxes):
             median_value = statistics.median(line_index_list)
             block['index'] = median_value
 
-        # 删除图表block中的虚拟line信息
-        if block['type'] in ['table', 'image']:
-            del block['lines']
+        # 删除图表body block中的虚拟line信息, 并用real_lines信息回填
+        if block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
+            block['virtual_lines'] = copy.deepcopy(block['lines'])
+            block['lines'] = copy.deepcopy(block['real_lines'])
+            del block['real_lines']
 
     return fix_blocks
 
@@ -189,21 +202,22 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
     block_weight = x1 - x0
 
     # 如果block高度小于n行正文,则直接返回block的bbox
-    if line_height*3 < block_height:
-        if block_height > page_h*0.25 and page_w*0.5 > block_weight > page_w*0.25:  # 可能是双列结构,可以切细点
-            lines = int(block_height/line_height)+1
+    if line_height * 3 < block_height:
+        if (
+            block_height > page_h * 0.25 and page_w * 0.5 > block_weight > page_w * 0.25
+        ):  # 可能是双列结构,可以切细点
+            lines = int(block_height / line_height) + 1
         else:
-            # 如果block的宽度超过0.4页面宽度,则将block分成3行
-            if block_weight > page_w*0.4:
+            # 如果block的宽度超过0.4页面宽度,则将block分成3行(是一种复杂布局,图不能切的太细)
+            if block_weight > page_w * 0.4:
                 line_height = (y1 - y0) / 3
                 lines = 3
-            elif block_weight > page_w*0.25: # 否则将block分成两行
-                line_height = (y1 - y0) / 2
-                lines = 2
-            else: # 判断长宽比
-                if block_height/block_weight > 1.2:  # 细长的不分
+            elif block_weight > page_w * 0.25:  # (可能是三列结构,也切细点)
+                lines = int(block_height / line_height) + 1
+            else:  # 判断长宽比
+                if block_height / block_weight > 1.2:  # 细长的不分
                     return [[x0, y0, x1, y1]]
-                else: # 不细长的还是分成两行
+                else:  # 不细长的还是分成两行
                     line_height = (y1 - y0) / 2
                     lines = 2
 
@@ -225,7 +239,11 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
 def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
     page_line_list = []
     for block in fix_blocks:
-        if block['type'] in ['text', 'title', 'interline_equation']:
+        if block['type'] in [
+            BlockType.Text, BlockType.Title, BlockType.InterlineEquation,
+            BlockType.ImageCaption, BlockType.ImageFootnote,
+            BlockType.TableCaption, BlockType.TableFootnote
+        ]:
             if len(block['lines']) == 0:
                 bbox = block['bbox']
                 lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
@@ -236,8 +254,9 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
                 for line in block['lines']:
                     bbox = line['bbox']
                     page_line_list.append(bbox)
-        elif block['type'] in ['table', 'image']:
+        elif block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
             bbox = block['bbox']
+            block["real_lines"] = copy.deepcopy(block['lines'])
             lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
             block['lines'] = []
             for line in lines:
@@ -252,19 +271,23 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
     for left, top, right, bottom in page_line_list:
         if left < 0:
             logger.warning(
-                f"left < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}")
+                f'left < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
+            )  # noqa: E501
             left = 0
         if right > page_w:
             logger.warning(
-                f"right > page_w, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}")
+                f'right > page_w, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
+            )  # noqa: E501
             right = page_w
         if top < 0:
             logger.warning(
-                f"top < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}")
+                f'top < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
+            )  # noqa: E501
             top = 0
         if bottom > page_h:
             logger.warning(
-                f"bottom > page_h, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}")
+                f'bottom > page_h, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
+            )  # noqa: E501
             bottom = page_h
 
         left = round(left * x_scale)
@@ -272,11 +295,11 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
         right = round(right * x_scale)
         bottom = round(bottom * y_scale)
         assert (
-                1000 >= right >= left >= 0 and 1000 >= bottom >= top >= 0
-        ), f"Invalid box. right: {right}, left: {left}, bottom: {bottom}, top: {top}"
+            1000 >= right >= left >= 0 and 1000 >= bottom >= top >= 0
+        ), f'Invalid box. right: {right}, left: {left}, bottom: {bottom}, top: {top}'  # noqa: E126, E121
         boxes.append([left, top, right, bottom])
     model_manager = ModelSingleton()
-    model = model_manager.get_model("layoutreader")
+    model = model_manager.get_model('layoutreader')
     with torch.no_grad():
         orders = do_predict(boxes, model)
     sorted_bboxes = [page_line_list[i] for i in orders]
@@ -287,159 +310,282 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
 def get_line_height(blocks):
     page_line_height_list = []
     for block in blocks:
-        if block['type'] in ['text', 'title', 'interline_equation']:
+        if block['type'] in [
+            BlockType.Text, BlockType.Title,
+            BlockType.ImageCaption, BlockType.ImageFootnote,
+            BlockType.TableCaption, BlockType.TableFootnote
+        ]:
             for line in block['lines']:
                 bbox = line['bbox']
-                page_line_height_list.append(int(bbox[3]-bbox[1]))
+                page_line_height_list.append(int(bbox[3] - bbox[1]))
     if len(page_line_height_list) > 0:
         return statistics.median(page_line_height_list)
     else:
         return 10
 
 
-def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode):
+def process_groups(groups, body_key, caption_key, footnote_key):
+    body_blocks = []
+    caption_blocks = []
+    footnote_blocks = []
+    for i, group in enumerate(groups):
+        group[body_key]['group_id'] = i
+        body_blocks.append(group[body_key])
+        for caption_block in group[caption_key]:
+            caption_block['group_id'] = i
+            caption_blocks.append(caption_block)
+        for footnote_block in group[footnote_key]:
+            footnote_block['group_id'] = i
+            footnote_blocks.append(footnote_block)
+    return body_blocks, caption_blocks, footnote_blocks
+
+
+def process_block_list(blocks, body_type, block_type):
+    indices = [block['index'] for block in blocks]
+    median_index = statistics.median(indices)
+
+    body_bbox = next((block['bbox'] for block in blocks if block.get('type') == body_type), [])
+
+    return {
+        'type': block_type,
+        'bbox': body_bbox,
+        'blocks': blocks,
+        'index': median_index,
+    }
+
+
+def revert_group_blocks(blocks):
+    image_groups = {}
+    table_groups = {}
+    new_blocks = []
+    for block in blocks:
+        if block['type'] in [BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote]:
+            group_id = block['group_id']
+            if group_id not in image_groups:
+                image_groups[group_id] = []
+            image_groups[group_id].append(block)
+        elif block['type'] in [BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote]:
+            group_id = block['group_id']
+            if group_id not in table_groups:
+                table_groups[group_id] = []
+            table_groups[group_id].append(block)
+        else:
+            new_blocks.append(block)
+
+    for group_id, blocks in image_groups.items():
+        new_blocks.append(process_block_list(blocks, BlockType.ImageBody, BlockType.Image))
+
+    for group_id, blocks in table_groups.items():
+        new_blocks.append(process_block_list(blocks, BlockType.TableBody, BlockType.Table))
+
+    return new_blocks
+
+
+def parse_page_core(
+    page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
+):
     need_drop = False
     drop_reason = []
 
-    '''从magic_model对象中获取后面会用到的区块信息'''
-    img_blocks = magic_model.get_imgs(page_id)
-    table_blocks = magic_model.get_tables(page_id)
+    """从magic_model对象中获取后面会用到的区块信息"""
+    # img_blocks = magic_model.get_imgs(page_id)
+    # table_blocks = magic_model.get_tables(page_id)
+
+    img_groups = magic_model.get_imgs_v2(page_id)
+    table_groups = magic_model.get_tables_v2(page_id)
+
+    img_body_blocks, img_caption_blocks, img_footnote_blocks = process_groups(
+        img_groups, 'image_body', 'image_caption_list', 'image_footnote_list'
+    )
+
+    table_body_blocks, table_caption_blocks, table_footnote_blocks = process_groups(
+        table_groups, 'table_body', 'table_caption_list', 'table_footnote_list'
+    )
+
     discarded_blocks = magic_model.get_discarded(page_id)
     text_blocks = magic_model.get_text_blocks(page_id)
     title_blocks = magic_model.get_title_blocks(page_id)
-    inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations(page_id)
+    inline_equations, interline_equations, interline_equation_blocks = (
+        magic_model.get_equations(page_id)
+    )
 
     page_w, page_h = magic_model.get_page_size(page_id)
 
     spans = magic_model.get_all_spans(page_id)
 
-    '''根据parse_mode,构造spans'''
-    if parse_mode == "txt":
+    """根据parse_mode,构造spans"""
+    if parse_mode == SupportedPdfParseMethod.TXT:
         """ocr 中文本类的 span 用 pymu spans 替换!"""
-        pymu_spans = txt_spans_extract(
-            pdf_docs[page_id], inline_equations, interline_equations
-        )
+        pymu_spans = txt_spans_extract(page_doc, inline_equations, interline_equations)
         spans = replace_text_span(pymu_spans, spans)
-    elif parse_mode == "ocr":
+    elif parse_mode == SupportedPdfParseMethod.OCR:
         pass
     else:
-        raise Exception("parse_mode must be txt or ocr")
+        raise Exception('parse_mode must be txt or ocr')
 
-    '''删除重叠spans中置信度较低的那些'''
+    """删除重叠spans中置信度较低的那些"""
     spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
-    '''删除重叠spans中较小的那些'''
+    """删除重叠spans中较小的那些"""
     spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
-    '''对image和table截图'''
-    spans = ocr_cut_image_and_table(spans, pdf_docs[page_id], page_id, pdf_bytes_md5, imageWriter)
+    """对image和table截图"""
+    spans = ocr_cut_image_and_table(
+        spans, page_doc, page_id, pdf_bytes_md5, imageWriter
+    )
 
-    '''将所有区块的bbox整理到一起'''
+    """将所有区块的bbox整理到一起"""
     # interline_equation_blocks参数不够准,后面切换到interline_equations上
     interline_equation_blocks = []
     if len(interline_equation_blocks) > 0:
         all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2(
-            img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks,
-            interline_equation_blocks, page_w, page_h)
+            img_body_blocks, img_caption_blocks, img_footnote_blocks,
+            table_body_blocks, table_caption_blocks, table_footnote_blocks,
+            discarded_blocks,
+            text_blocks,
+            title_blocks,
+            interline_equation_blocks,
+            page_w,
+            page_h,
+        )
     else:
         all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2(
-            img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks,
-            interline_equations, page_w, page_h)
+            img_body_blocks, img_caption_blocks, img_footnote_blocks,
+            table_body_blocks, table_caption_blocks, table_footnote_blocks,
+            discarded_blocks,
+            text_blocks,
+            title_blocks,
+            interline_equations,
+            page_w,
+            page_h,
+        )
 
-    '''先处理不需要排版的discarded_blocks'''
-    discarded_block_with_spans, spans = fill_spans_in_blocks(all_discarded_blocks, spans, 0.4)
+    """先处理不需要排版的discarded_blocks"""
+    discarded_block_with_spans, spans = fill_spans_in_blocks(
+        all_discarded_blocks, spans, 0.4
+    )
     fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)
 
-    '''如果当前页面没有bbox则跳过'''
+    """如果当前页面没有bbox则跳过"""
     if len(all_bboxes) == 0:
-        logger.warning(f"skip this page, not found useful bbox, page_id: {page_id}")
-        return ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [],
-                                               [], [], interline_equations, fix_discarded_blocks,
-                                               need_drop, drop_reason)
+        logger.warning(f'skip this page, not found useful bbox, page_id: {page_id}')
+        return ocr_construct_page_component_v2(
+            [],
+            [],
+            page_id,
+            page_w,
+            page_h,
+            [],
+            [],
+            [],
+            interline_equations,
+            fix_discarded_blocks,
+            need_drop,
+            drop_reason,
+        )
 
-    '''将span填入blocks中'''
-    block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.3)
+    """将span填入blocks中"""
+    block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.5)
 
-    '''对block进行fix操作'''
-    fix_blocks = fix_block_spans(block_with_spans, img_blocks, table_blocks)
+    """对block进行fix操作"""
+    fix_blocks = fix_block_spans_v2(block_with_spans)
 
-    '''获取所有line并计算正文line的高度'''
+    """获取所有line并计算正文line的高度"""
     line_height = get_line_height(fix_blocks)
 
-    '''获取所有line并对line排序'''
+    """获取所有line并对line排序"""
     sorted_bboxes = sort_lines_by_model(fix_blocks, page_w, page_h, line_height)
 
-    '''根据line的中位数算block的序列关系'''
+    """根据line的中位数算block的序列关系"""
     fix_blocks = cal_block_index(fix_blocks, sorted_bboxes)
 
-    '''重排block'''
+    """将image和table的block还原回group形式参与后续流程"""
+    fix_blocks = revert_group_blocks(fix_blocks)
+
+    """重排block"""
     sorted_blocks = sorted(fix_blocks, key=lambda b: b['index'])
 
-    '''获取QA需要外置的list'''
+    """获取QA需要外置的list"""
     images, tables, interline_equations = get_qa_need_list_v2(sorted_blocks)
 
-    '''构造pdf_info_dict'''
-    page_info = ocr_construct_page_component_v2(sorted_blocks, [], page_id, page_w, page_h, [],
-                                                images, tables, interline_equations, fix_discarded_blocks,
-                                                need_drop, drop_reason)
+    """构造pdf_info_dict"""
+    page_info = ocr_construct_page_component_v2(
+        sorted_blocks,
+        [],
+        page_id,
+        page_w,
+        page_h,
+        [],
+        images,
+        tables,
+        interline_equations,
+        fix_discarded_blocks,
+        need_drop,
+        drop_reason,
+    )
     return page_info
 
 
-def pdf_parse_union(pdf_bytes,
-                    model_list,
-                    imageWriter,
-                    parse_mode,
-                    start_page_id=0,
-                    end_page_id=None,
-                    debug_mode=False,
-                    ):
-    pdf_bytes_md5 = compute_md5(pdf_bytes)
-    pdf_docs = fitz.open("pdf", pdf_bytes)
+def pdf_parse_union(
+    dataset: Dataset,
+    model_list,
+    imageWriter,
+    parse_mode,
+    start_page_id=0,
+    end_page_id=None,
+    debug_mode=False,
+):
+    pdf_bytes_md5 = compute_md5(dataset.data_bits())
 
-    '''初始化空的pdf_info_dict'''
+    """初始化空的pdf_info_dict"""
     pdf_info_dict = {}
 
-    '''用model_list和docs对象初始化magic_model'''
-    magic_model = MagicModel(model_list, pdf_docs)
+    """用model_list和docs对象初始化magic_model"""
+    magic_model = MagicModel(model_list, dataset)
 
-    '''根据输入的起始范围解析pdf'''
+    """根据输入的起始范围解析pdf"""
     # end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
-    end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(pdf_docs) - 1
+    end_page_id = (
+        end_page_id
+        if end_page_id is not None and end_page_id >= 0
+        else len(dataset) - 1
+    )
 
-    if end_page_id > len(pdf_docs) - 1:
-        logger.warning("end_page_id is out of range, use pdf_docs length")
-        end_page_id = len(pdf_docs) - 1
+    if end_page_id > len(dataset) - 1:
+        logger.warning('end_page_id is out of range, use pdf_docs length')
+        end_page_id = len(dataset) - 1
 
-    '''初始化启动时间'''
+    """初始化启动时间"""
     start_time = time.time()
 
-    for page_id, page in enumerate(pdf_docs):
-        '''debug时输出每页解析的耗时'''
+    for page_id, page in enumerate(dataset):
+        """debug时输出每页解析的耗时."""
         if debug_mode:
             time_now = time.time()
             logger.info(
-                f"page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}"
+                f'page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}'
             )
             start_time = time_now
 
-        '''解析pdf中的每一页'''
+        """解析pdf中的每一页"""
         if start_page_id <= page_id <= end_page_id:
-            page_info = parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode)
+            page_info = parse_page_core(
+                page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
+            )
         else:
-            page_w = page.rect.width
-            page_h = page.rect.height
-            page_info = ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [],
-                                                [], [], [], [],
-                                                True, "skip page")
-        pdf_info_dict[f"page_{page_id}"] = page_info
+            page_info = page.get_page_info()
+            page_w = page_info.w
+            page_h = page_info.h
+            page_info = ocr_construct_page_component_v2(
+                [], [], page_id, page_w, page_h, [], [], [], [], [], True, 'skip page'
+            )
+        pdf_info_dict[f'page_{page_id}'] = page_info
 
     """分段"""
-    # para_split(pdf_info_dict, debug_mode=debug_mode)
-    for page_num, page in pdf_info_dict.items():
-        page['para_blocks'] = page['preproc_blocks']
+    para_split(pdf_info_dict, debug_mode=debug_mode)
 
     """dict转list"""
     pdf_info_list = dict_to_list(pdf_info_dict)
     new_pdf_info_dict = {
-        "pdf_info": pdf_info_list,
+        'pdf_info': pdf_info_list,
     }
 
     clean_memory()

+ 6 - 7
magic_pdf/pipe/AbsPipe.py

@@ -17,7 +17,7 @@ class AbsPipe(ABC):
     PIP_TXT = "txt"
 
     def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
-                 start_page_id=0, end_page_id=None, lang=None):
+                 start_page_id=0, end_page_id=None, lang=None, layout_model=None, formula_enable=None, table_enable=None):
         self.pdf_bytes = pdf_bytes
         self.model_list = model_list
         self.image_writer = image_writer
@@ -26,6 +26,9 @@ class AbsPipe(ABC):
         self.start_page_id = start_page_id
         self.end_page_id = end_page_id
         self.lang = lang
+        self.layout_model = layout_model
+        self.formula_enable = formula_enable
+        self.table_enable = table_enable
     
     def get_compress_pdf_mid_data(self):
         return JsonCompressor.compress_json(self.pdf_mid_data)
@@ -95,9 +98,7 @@ class AbsPipe(ABC):
         """
         pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data)
         pdf_info_list = pdf_mid_data["pdf_info"]
-        parse_type = pdf_mid_data["_parse_type"]
-        lang = pdf_mid_data.get("_lang", None)
-        content_list = union_make(pdf_info_list, MakeMode.STANDARD_FORMAT, drop_mode, img_buket_path, parse_type, lang)
+        content_list = union_make(pdf_info_list, MakeMode.STANDARD_FORMAT, drop_mode, img_buket_path)
         return content_list
 
     @staticmethod
@@ -107,9 +108,7 @@ class AbsPipe(ABC):
         """
         pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data)
         pdf_info_list = pdf_mid_data["pdf_info"]
-        parse_type = pdf_mid_data["_parse_type"]
-        lang = pdf_mid_data.get("_lang", None)
-        md_content = union_make(pdf_info_list, md_make_mode, drop_mode, img_buket_path, parse_type, lang)
+        md_content = union_make(pdf_info_list, md_make_mode, drop_mode, img_buket_path)
         return md_content
 
 

+ 8 - 4
magic_pdf/pipe/OCRPipe.py

@@ -10,8 +10,10 @@ from magic_pdf.user_api import parse_ocr_pdf
 class OCRPipe(AbsPipe):
 
     def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
-                 start_page_id=0, end_page_id=None, lang=None):
-        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang)
+                 start_page_id=0, end_page_id=None, lang=None,
+                 layout_model=None, formula_enable=None, table_enable=None):
+        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
+                         layout_model, formula_enable, table_enable)
 
     def pipe_classify(self):
         pass
@@ -19,12 +21,14 @@ class OCRPipe(AbsPipe):
     def pipe_analyze(self):
         self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
                                       start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                      lang=self.lang)
+                                      lang=self.lang, layout_model=self.layout_model,
+                                      formula_enable=self.formula_enable, table_enable=self.table_enable)
 
     def pipe_parse(self):
         self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
                                           start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                          lang=self.lang)
+                                          lang=self.lang, layout_model=self.layout_model,
+                                          formula_enable=self.formula_enable, table_enable=self.table_enable)
 
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)

+ 8 - 4
magic_pdf/pipe/TXTPipe.py

@@ -11,8 +11,10 @@ from magic_pdf.user_api import parse_txt_pdf
 class TXTPipe(AbsPipe):
 
     def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
-                 start_page_id=0, end_page_id=None, lang=None):
-        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang)
+                 start_page_id=0, end_page_id=None, lang=None,
+                 layout_model=None, formula_enable=None, table_enable=None):
+        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
+                         layout_model, formula_enable, table_enable)
 
     def pipe_classify(self):
         pass
@@ -20,12 +22,14 @@ class TXTPipe(AbsPipe):
     def pipe_analyze(self):
         self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
                                       start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                      lang=self.lang)
+                                      lang=self.lang, layout_model=self.layout_model,
+                                      formula_enable=self.formula_enable, table_enable=self.table_enable)
 
     def pipe_parse(self):
         self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
                                           start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                          lang=self.lang)
+                                          lang=self.lang, layout_model=self.layout_model,
+                                          formula_enable=self.formula_enable, table_enable=self.table_enable)
 
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)

+ 10 - 5
magic_pdf/pipe/UNIPipe.py

@@ -14,9 +14,11 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
 class UNIPipe(AbsPipe):
 
     def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False,
-                 start_page_id=0, end_page_id=None, lang=None):
+                 start_page_id=0, end_page_id=None, lang=None,
+                 layout_model=None, formula_enable=None, table_enable=None):
         self.pdf_type = jso_useful_key["_pdf_type"]
-        super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id, lang)
+        super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id,
+                         lang, layout_model, formula_enable, table_enable)
         if len(self.model_list) == 0:
             self.input_model_is_empty = True
         else:
@@ -29,18 +31,21 @@ class UNIPipe(AbsPipe):
         if self.pdf_type == self.PIP_TXT:
             self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
                                           start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                          lang=self.lang)
+                                          lang=self.lang, layout_model=self.layout_model,
+                                          formula_enable=self.formula_enable, table_enable=self.table_enable)
         elif self.pdf_type == self.PIP_OCR:
             self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
                                           start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                          lang=self.lang)
+                                          lang=self.lang, layout_model=self.layout_model,
+                                          formula_enable=self.formula_enable, table_enable=self.table_enable)
 
     def pipe_parse(self):
         if self.pdf_type == self.PIP_TXT:
             self.pdf_mid_data = parse_union_pdf(self.pdf_bytes, self.model_list, self.image_writer,
                                                 is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty,
                                                 start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                                lang=self.lang)
+                                                lang=self.lang, layout_model=self.layout_model,
+                                                formula_enable=self.formula_enable, table_enable=self.table_enable)
         elif self.pdf_type == self.PIP_OCR:
             self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer,
                                               is_debug=self.is_debug,

+ 56 - 27
magic_pdf/pre_proc/ocr_detect_all_bboxes.py

@@ -1,7 +1,7 @@
 from loguru import logger
 
 from magic_pdf.libs.boxbase import get_minbox_if_overlap_by_ratio, calculate_overlap_area_in_bbox1_area_ratio, \
-    calculate_iou
+    calculate_iou, calculate_vertical_projection_overlap_ratio
 from magic_pdf.libs.drop_tag import DropTag
 from magic_pdf.libs.ocr_content_type import BlockType
 from magic_pdf.pre_proc.remove_bbox_overlap import remove_overlap_between_bbox_for_block
@@ -60,29 +60,34 @@ def ocr_prepare_bboxes_for_layout_split(img_blocks, table_blocks, discarded_bloc
     return all_bboxes, all_discarded_blocks, drop_reasons
 
 
-def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_blocks, text_blocks,
-                                        title_blocks, interline_equation_blocks, page_w, page_h):
-    all_bboxes = []
-    all_discarded_blocks = []
-    for image in img_blocks:
-        x0, y0, x1, y1 = image['bbox']
-        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Image, None, None, None, None, image["score"]])
+def add_bboxes(blocks, block_type, bboxes):
+    for block in blocks:
+        x0, y0, x1, y1 = block['bbox']
+        if block_type in [
+            BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote,
+            BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote
+        ]:
+            bboxes.append([x0, y0, x1, y1, None, None, None, block_type, None, None, None, None, block["score"], block["group_id"]])
+        else:
+            bboxes.append([x0, y0, x1, y1, None, None, None, block_type, None, None, None, None, block["score"]])
 
-    for table in table_blocks:
-        x0, y0, x1, y1 = table['bbox']
-        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Table, None, None, None, None, table["score"]])
 
-    for text in text_blocks:
-        x0, y0, x1, y1 = text['bbox']
-        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Text, None, None, None, None, text["score"]])
+def ocr_prepare_bboxes_for_layout_split_v2(
+        img_body_blocks, img_caption_blocks, img_footnote_blocks,
+        table_body_blocks, table_caption_blocks, table_footnote_blocks,
+        discarded_blocks, text_blocks, title_blocks, interline_equation_blocks, page_w, page_h
+):
+    all_bboxes = []
 
-    for title in title_blocks:
-        x0, y0, x1, y1 = title['bbox']
-        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Title, None, None, None, None, title["score"]])
-
-    for interline_equation in interline_equation_blocks:
-        x0, y0, x1, y1 = interline_equation['bbox']
-        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.InterlineEquation, None, None, None, None, interline_equation["score"]])
+    add_bboxes(img_body_blocks, BlockType.ImageBody, all_bboxes)
+    add_bboxes(img_caption_blocks, BlockType.ImageCaption, all_bboxes)
+    add_bboxes(img_footnote_blocks, BlockType.ImageFootnote, all_bboxes)
+    add_bboxes(table_body_blocks, BlockType.TableBody, all_bboxes)
+    add_bboxes(table_caption_blocks, BlockType.TableCaption, all_bboxes)
+    add_bboxes(table_footnote_blocks, BlockType.TableFootnote, all_bboxes)
+    add_bboxes(text_blocks, BlockType.Text, all_bboxes)
+    add_bboxes(title_blocks, BlockType.Title, all_bboxes)
+    add_bboxes(interline_equation_blocks, BlockType.InterlineEquation, all_bboxes)
 
     '''block嵌套问题解决'''
     '''文本框与标题框重叠,优先信任文本框'''
@@ -96,23 +101,47 @@ def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_b
     '''interline_equation框被包含在文本类型框内,且interline_equation比文本区块小很多时信任文本框,这时需要舍弃公式框'''
     # 通过后续大框套小框逻辑删除
 
-    '''discarded_blocks中只保留宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的(限定footnote)'''
+    '''discarded_blocks'''
+    all_discarded_blocks = []
+    add_bboxes(discarded_blocks, BlockType.Discarded, all_discarded_blocks)
+
+    '''footnote识别:宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的'''
+    footnote_blocks = []
     for discarded in discarded_blocks:
         x0, y0, x1, y1 = discarded['bbox']
-        all_discarded_blocks.append([x0, y0, x1, y1, None, None, None, BlockType.Discarded, None, None, None, None, discarded["score"]])
-        # 将footnote加入到all_bboxes中,用来计算layout
-        # if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2):
-        #     all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Footnote, None, None, None, None, discarded["score"]])
+        if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2):
+            footnote_blocks.append([x0, y0, x1, y1])
+
+    '''移除在footnote下面的任何框'''
+    need_remove_blocks = find_blocks_under_footnote(all_bboxes, footnote_blocks)
+    if len(need_remove_blocks) > 0:
+        for block in need_remove_blocks:
+            all_bboxes.remove(block)
+            all_discarded_blocks.append(block)
 
     '''经过以上处理后,还存在大框套小框的情况,则删除小框'''
     all_bboxes = remove_overlaps_min_blocks(all_bboxes)
     all_discarded_blocks = remove_overlaps_min_blocks(all_discarded_blocks)
     '''将剩余的bbox做分离处理,防止后面分layout时出错'''
-    # all_bboxes, drop_reasons = remove_overlap_between_bbox_for_block(all_bboxes)
+    all_bboxes, drop_reasons = remove_overlap_between_bbox_for_block(all_bboxes)
 
     return all_bboxes, all_discarded_blocks
 
 
+def find_blocks_under_footnote(all_bboxes, footnote_blocks):
+    need_remove_blocks = []
+    for block in all_bboxes:
+        block_x0, block_y0, block_x1, block_y1 = block[:4]
+        for footnote_bbox in footnote_blocks:
+            footnote_x0, footnote_y0, footnote_x1, footnote_y1 = footnote_bbox
+            # 如果footnote的纵向投影覆盖了block的纵向投影的80%且block的y0大于等于footnote的y1
+            if block_y0 >= footnote_y1 and calculate_vertical_projection_overlap_ratio((block_x0, block_y0, block_x1, block_y1), footnote_bbox) >= 0.8:
+                if block not in need_remove_blocks:
+                    need_remove_blocks.append(block)
+                    break
+    return need_remove_blocks
+
+
 def fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes):
     # 先提取所有text和interline block
     text_blocks = []

+ 27 - 2
magic_pdf/pre_proc/ocr_dict_merge.py

@@ -49,8 +49,7 @@ def merge_spans_to_line(spans):
                 continue
 
             # 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
-            if __is_overlaps_y_exceeds_threshold(span['bbox'],
-                                                 current_line[-1]['bbox']):
+            if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], 0.5):
                 current_line.append(span)
             else:
                 # 否则,开始新行
@@ -154,6 +153,11 @@ def fill_spans_in_blocks(blocks, spans, radio):
             'type': block_type,
             'bbox': block_bbox,
         }
+        if block_type in [
+            BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote,
+            BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote
+        ]:
+            block_dict["group_id"] = block[-1]
         block_spans = []
         for span in spans:
             span_bbox = span['bbox']
@@ -202,6 +206,27 @@ def fix_block_spans(block_with_spans, img_blocks, table_blocks):
     return fix_blocks
 
 
+def fix_block_spans_v2(block_with_spans):
+    """1、img_block和table_block因为包含caption和footnote的关系,存在block的嵌套关系
+    需要将caption和footnote的text_span放入相应img_block和table_block内的
+    caption_block和footnote_block中 2、同时需要删除block中的spans字段."""
+    fix_blocks = []
+    for block in block_with_spans:
+        block_type = block['type']
+
+        if block_type in [BlockType.Text, BlockType.Title,
+                          BlockType.ImageCaption, BlockType.ImageFootnote,
+                          BlockType.TableCaption, BlockType.TableFootnote
+                          ]:
+            block = fix_text_block(block)
+        elif block_type in [BlockType.InterlineEquation, BlockType.ImageBody, BlockType.TableBody]:
+            block = fix_interline_block(block)
+        else:
+            continue
+        fix_blocks.append(block)
+    return fix_blocks
+
+
 def fix_discarded_block(discarded_block_with_spans):
     fix_discarded_blocks = []
     for block in discarded_block_with_spans:

+ 5 - 13
magic_pdf/resources/model_config/model_configs.yaml

@@ -1,15 +1,7 @@
-config:
-  device: cpu
-  layout: True
-  formula: True
-  table_config:
-    model: TableMaster
-    is_table_recog_enable: False
-    max_time: 400
-
 weights:
-  layout: Layout/model_final.pth
-  mfd: MFD/weights.pt
-  mfr: MFR/unimernet_small
+  layoutlmv3: Layout/LayoutLMv3/model_final.pth
+  doclayout_yolo: Layout/YOLO/doclayout_yolo_ft.pt
+  yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt
+  unimernet_small: MFR/unimernet_small
   struct_eqtable: TabRec/StructEqTable
-  TableMaster: TabRec/TableMaster
+  tablemaster: TabRec/TableMaster

+ 1 - 1
magic_pdf/tools/cli.py

@@ -52,7 +52,7 @@ without method specified, auto will be used by default.""",
     help="""
     Input the languages in the pdf (if known) to improve OCR accuracy.  Optional.
     You should input "Abbreviation" with language form url:
-    https://paddlepaddle.github.io/PaddleOCR/en/ppocr/blog/multi_languages.html#5-support-languages-and-abbreviations
+    https://paddlepaddle.github.io/PaddleOCR/latest/en/ppocr/blog/multi_languages.html#5-support-languages-and-abbreviations
     """,
     default=None,
 )

+ 11 - 6
magic_pdf/tools/common.py

@@ -6,8 +6,8 @@ import click
 from loguru import logger
 
 import magic_pdf.model as model_config
-from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_span_bbox,
-                                      draw_model_bbox, draw_line_sort_bbox)
+from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_line_sort_bbox,
+                                      draw_model_bbox, draw_span_bbox)
 from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
 from magic_pdf.pipe.OCRPipe import OCRPipe
 from magic_pdf.pipe.TXTPipe import TXTPipe
@@ -46,10 +46,12 @@ def do_parse(
     start_page_id=0,
     end_page_id=None,
     lang=None,
+    layout_model=None,
+    formula_enable=None,
+    table_enable=None,
 ):
     if debug_able:
         logger.warning('debug mode is on')
-        # f_dump_content_list = True
         f_draw_model_bbox = True
         f_draw_line_sort_bbox = True
 
@@ -64,13 +66,16 @@ def do_parse(
     if parse_method == 'auto':
         jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
         pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True,
-                       start_page_id=start_page_id, end_page_id=end_page_id, lang=lang)
+                       start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
+                       layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
     elif parse_method == 'txt':
         pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True,
-                       start_page_id=start_page_id, end_page_id=end_page_id, lang=lang)
+                       start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
+                       layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
     elif parse_method == 'ocr':
         pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True,
-                       start_page_id=start_page_id, end_page_id=end_page_id, lang=lang)
+                       start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
+                       layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
     else:
         logger.error('unknown parse method')
         exit(1)

+ 13 - 5
magic_pdf/user_api.py

@@ -101,11 +101,19 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
     if pdf_info_dict is None or pdf_info_dict.get("_need_drop", False):
         logger.warning(f"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr")
         if input_model_is_empty:
-            pdf_models = doc_analyze(pdf_bytes,
-                                     ocr=True,
-                                     start_page_id=start_page_id,
-                                     end_page_id=end_page_id,
-                                     lang=lang)
+            layout_model = kwargs.get("layout_model", None)
+            formula_enable = kwargs.get("formula_enable", None)
+            table_enable = kwargs.get("table_enable", None)
+            pdf_models = doc_analyze(
+                pdf_bytes,
+                ocr=True,
+                start_page_id=start_page_id,
+                end_page_id=end_page_id,
+                lang=lang,
+                layout_model=layout_model,
+                formula_enable=formula_enable,
+                table_enable=table_enable,
+            )
         pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
         if pdf_info_dict is None:
             raise Exception("Both parse_pdf_by_txt and parse_pdf_by_ocr failed.")

+ 0 - 0
magic_pdf/utils/__init__.py


+ 11 - 0
magic_pdf/utils/annotations.py

@@ -0,0 +1,11 @@
+
+from loguru import logger
+
+
+def ImportPIL(f):
+    try:
+        import PIL  # noqa: F401
+    except ImportError:
+        logger.error('Pillow not installed, please install by pip.')
+        exit(1)
+    return f

+ 16 - 0
next_docs/en/.readthedocs.yaml

@@ -0,0 +1,16 @@
+version: 2
+
+build:
+  os: ubuntu-22.04
+  tools:
+    python: "3.10"
+
+formats:
+  - epub
+
+python:
+  install:
+    - requirements: docs/requirements.txt
+
+sphinx:
+  configuration: docs/en/conf.py

+ 20 - 0
next_docs/en/Makefile

@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS    ?=
+SPHINXBUILD   ?= sphinx-build
+SOURCEDIR     = .
+BUILDDIR      = _build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+	@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+	@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

BIN
next_docs/en/_static/image/logo.png


+ 9 - 0
next_docs/en/api.rst

@@ -0,0 +1,9 @@
+Data Api
+------------------
+
+.. toctree::
+   :maxdepth: 2
+
+   api/dataset.rst
+   api/data_reader_writer.rst
+   api/read_api.rst

+ 44 - 0
next_docs/en/api/data_reader_writer.rst

@@ -0,0 +1,44 @@
+
+Data Reader Writer
+--------------------
+
+.. autoclass:: magic_pdf.data.data_reader_writer.DataReader
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.data_reader_writer.DataWriter
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.data_reader_writer.S3DataReader
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.data_reader_writer.S3DataWriter
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.data_reader_writer.FileBasedDataReader
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.data_reader_writer.FileBasedDataWriter
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.data_reader_writer.S3DataReader
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.data_reader_writer.S3DataWriter
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.data_reader_writer.MultiBucketS3DataReader
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.data_reader_writer.MultiBucketS3DataWriter
+   :members:
+   :inherited-members:
+

+ 22 - 0
next_docs/en/api/dataset.rst

@@ -0,0 +1,22 @@
+Dataset Api
+------------------
+
+.. autoclass:: magic_pdf.data.dataset.PageableData
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.dataset.Dataset
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.dataset.ImageDataset
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.dataset.PymuDocDataset
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.dataset.Doc
+   :members:
+   :inherited-members:

+ 0 - 0
next_docs/en/api/io.rst


+ 6 - 0
next_docs/en/api/read_api.rst

@@ -0,0 +1,6 @@
+read_api Api
+------------------
+
+.. automodule:: magic_pdf.data.read_api
+   :members:
+   :inherited-members:

+ 0 - 0
next_docs/en/api/schemas.rst


+ 1 - 0
next_docs/en/api/utils.rst

@@ -0,0 +1 @@
+

+ 122 - 0
next_docs/en/conf.py

@@ -0,0 +1,122 @@
+# Configuration file for the Sphinx documentation builder.
+#
+# This file only contains a selection of the most common options. For a full
+# list see the documentation:
+# https://www.sphinx-doc.org/en/master/usage/configuration.html
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+
+import os
+import subprocess
+import sys
+
+from sphinx.ext import autodoc
+
+
+def install(package):
+    subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])
+
+
+requirements_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'requirements.txt'))
+if os.path.exists(requirements_path):
+    with open(requirements_path) as f:
+        packages = f.readlines()
+    for package in packages:
+        install(package.strip())
+
+sys.path.insert(0, os.path.abspath('../..'))
+
+# -- Project information -----------------------------------------------------
+
+project = 'MinerU'
+copyright = '2024, MinerU Contributors'
+author = 'OpenDataLab'
+
+# The full version, including alpha/beta/rc tags
+version_file = '../../magic_pdf/libs/version.py'
+with open(version_file) as f:
+    exec(compile(f.read(), version_file, 'exec'))
+__version__ = locals()['__version__']
+# The short X.Y version
+version = __version__
+# The full version, including alpha/beta/rc tags
+release = __version__
+
+# -- General configuration ---------------------------------------------------
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+    'sphinx.ext.napoleon',
+    'sphinx.ext.viewcode',
+    'sphinx.ext.intersphinx',
+    'sphinx_copybutton',
+    'sphinx.ext.autodoc',
+    'sphinx.ext.autosummary',
+    'myst_parser',
+    'sphinxarg.ext',
+]
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ['_templates']
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
+
+# Exclude the prompt "$" when copying code
+copybutton_prompt_text = r'\$ '
+copybutton_prompt_is_regexp = True
+
+language = 'en'
+
+# -- Options for HTML output -------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages.  See the documentation for
+# a list of builtin themes.
+#
+html_theme = 'sphinx_book_theme'
+html_logo = '_static/image/logo.png'
+html_theme_options = {
+    'path_to_docs': 'docs/en',
+    'repository_url': 'https://github.com/opendatalab/MinerU',
+    'use_repository_button': True,
+}
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+# html_static_path = ['_static']
+
+# Mock out external dependencies here.
+autodoc_mock_imports = [
+    'cpuinfo',
+    'torch',
+    'transformers',
+    'psutil',
+    'prometheus_client',
+    'sentencepiece',
+    'vllm.cuda_utils',
+    'vllm._C',
+    'numpy',
+    'tqdm',
+]
+
+
+class MockedClassDocumenter(autodoc.ClassDocumenter):
+    """Remove note about base class when a class is derived from object."""
+
+    def add_line(self, line: str, source: str, *lineno: int) -> None:
+        if line == '   Bases: :py:class:`object`':
+            return
+        super().add_line(line, source, *lineno)
+
+
+autodoc.ClassDocumenter = MockedClassDocumenter
+
+navigation_with_keys = False

+ 38 - 0
next_docs/en/index.rst

@@ -0,0 +1,38 @@
+.. xtuner documentation master file, created by
+   sphinx-quickstart on Tue Jan  9 16:33:06 2024.
+   You can adapt this file completely to your liking, but it should at least
+   contain the root `toctree` directive.
+
+Welcome to the MinerU Documentation
+==============================================
+
+.. figure:: ./_static/image/logo.png
+  :align: center
+  :alt: mineru
+  :class: no-scaled-link
+
+.. raw:: html
+
+   <p style="text-align:center">
+   <strong>A one-stop, open-source, high-quality data extraction tool
+   </strong>
+   </p>
+
+   <p style="text-align:center">
+   <script async defer src="https://buttons.github.io/buttons.js"></script>
+   <a class="github-button" href="https://github.com/opendatalab/MinerU" data-show-count="true" data-size="large" aria-label="Star">Star</a>
+   <a class="github-button" href="https://github.com/opendatalab/MinerU/subscription" data-icon="octicon-eye" data-size="large" aria-label="Watch">Watch</a>
+   <a class="github-button" href="https://github.com/opendatalab/MinerU/fork" data-icon="octicon-repo-forked" data-size="large" aria-label="Fork">Fork</a>
+   </p>
+
+
+API Reference
+-------------
+
+If you are looking for information on a specific function, class or
+method, this part of the documentation is for you.
+
+.. toctree::
+   :maxdepth: 2
+
+   api

+ 35 - 0
next_docs/en/make.bat

@@ -0,0 +1,35 @@
+@ECHO OFF
+
+pushd %~dp0
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+	set SPHINXBUILD=sphinx-build
+)
+set SOURCEDIR=.
+set BUILDDIR=_build
+
+%SPHINXBUILD% >NUL 2>NUL
+if errorlevel 9009 (
+	echo.
+	echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+	echo.installed, then set the SPHINXBUILD environment variable to point
+	echo.to the full path of the 'sphinx-build' executable. Alternatively you
+	echo.may add the Sphinx directory to PATH.
+	echo.
+	echo.If you don't have Sphinx installed, grab it from
+	echo.https://www.sphinx-doc.org/
+	exit /b 1
+)
+
+if "%1" == "" goto help
+
+%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+goto end
+
+:help
+%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+
+:end
+popd

+ 11 - 0
next_docs/requirements.txt

@@ -0,0 +1,11 @@
+boto3>=1.28.43
+loguru>=0.6.0
+myst-parser
+Pillow==8.4.0
+pydantic>=2.7.2,<2.8.0
+PyMuPDF>=1.24.9
+sphinx
+sphinx-argparse
+sphinx-book-theme
+sphinx-copybutton
+sphinx_rtd_theme

+ 16 - 0
next_docs/zh_cn/.readthedocs.yaml

@@ -0,0 +1,16 @@
+version: 2
+
+build:
+  os: ubuntu-22.04
+  tools:
+    python: "3.10"
+
+formats:
+  - epub
+
+python:
+  install:
+    - requirements: docs/requirements.txt
+
+sphinx:
+  configuration: docs/zh_cn/conf.py

+ 20 - 0
next_docs/zh_cn/Makefile

@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS    ?=
+SPHINXBUILD   ?= sphinx-build
+SOURCEDIR     = .
+BUILDDIR      = _build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+	@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+	@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

BIN
next_docs/zh_cn/_static/image/logo.png


+ 122 - 0
next_docs/zh_cn/conf.py

@@ -0,0 +1,122 @@
+# Configuration file for the Sphinx documentation builder.
+#
+# This file only contains a selection of the most common options. For a full
+# list see the documentation:
+# https://www.sphinx-doc.org/en/master/usage/configuration.html
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+
+import os
+import subprocess
+import sys
+
+from sphinx.ext import autodoc
+
+
+def install(package):
+    subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])
+
+
+requirements_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'requirements.txt'))
+if os.path.exists(requirements_path):
+    with open(requirements_path) as f:
+        packages = f.readlines()
+    for package in packages:
+        install(package.strip())
+
+sys.path.insert(0, os.path.abspath('../..'))
+
+# -- Project information -----------------------------------------------------
+
+project = 'MinerU'
+copyright = '2024, OpenDataLab'
+author = 'MinerU Contributors'
+
+# The full version, including alpha/beta/rc tags
+version_file = '../../magic_pdf/libs/version.py'
+with open(version_file) as f:
+    exec(compile(f.read(), version_file, 'exec'))
+__version__ = locals()['__version__']
+# The short X.Y version
+version = __version__
+# The full version, including alpha/beta/rc tags
+release = __version__
+
+# -- General configuration ---------------------------------------------------
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+    'sphinx.ext.napoleon',
+    'sphinx.ext.viewcode',
+    'sphinx.ext.intersphinx',
+    'sphinx_copybutton',
+    'sphinx.ext.autodoc',
+    'sphinx.ext.autosummary',
+    'myst_parser',
+    'sphinxarg.ext',
+]
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ['_templates']
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
+
+# Exclude the prompt "$" when copying code
+copybutton_prompt_text = r'\$ '
+copybutton_prompt_is_regexp = True
+
+language = 'zh_CN'
+
+# -- Options for HTML output -------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages.  See the documentation for
+# a list of builtin themes.
+#
+html_theme = 'sphinx_book_theme'
+html_logo = '_static/image/logo.png'
+html_theme_options = {
+    'path_to_docs': 'docs/zh_cn',
+    'repository_url': 'https://github.com/opendatalab/MinerU',
+    'use_repository_button': True,
+}
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+# html_static_path = ['_static']
+
+# Mock out external dependencies here.
+autodoc_mock_imports = [
+    'cpuinfo',
+    'torch',
+    'transformers',
+    'psutil',
+    'prometheus_client',
+    'sentencepiece',
+    'vllm.cuda_utils',
+    'vllm._C',
+    'numpy',
+    'tqdm',
+]
+
+
+class MockedClassDocumenter(autodoc.ClassDocumenter):
+    """Remove note about base class when a class is derived from object."""
+
+    def add_line(self, line: str, source: str, *lineno: int) -> None:
+        if line == '   Bases: :py:class:`object`':
+            return
+        super().add_line(line, source, *lineno)
+
+
+autodoc.ClassDocumenter = MockedClassDocumenter
+
+navigation_with_keys = False

+ 26 - 0
next_docs/zh_cn/index.rst

@@ -0,0 +1,26 @@
+.. xtuner documentation master file, created by
+   sphinx-quickstart on Tue Jan  9 16:33:06 2024.
+   You can adapt this file completely to your liking, but it should at least
+   contain the root `toctree` directive.
+
+欢迎来到 MinerU 的中文文档
+==============================================
+
+.. figure:: ./_static/image/logo.png
+  :align: center
+  :alt: mineru
+  :class: no-scaled-link
+
+.. raw:: html
+
+   <p style="text-align:center">
+   <strong> 一站式开源高质量数据提取工具
+   </strong>
+   </p>
+
+   <p style="text-align:center">
+   <script async defer src="https://buttons.github.io/buttons.js"></script>
+   <a class="github-button" href="https://github.com/opendatalab/MinerU" data-show-count="true" data-size="large" aria-label="Star">Star</a>
+   <a class="github-button" href="https://github.com/opendatalab/MinerU/subscription" data-icon="octicon-eye" data-size="large" aria-label="Watch">Watch</a>
+   <a class="github-button" href="https://github.com/opendatalab/MinerU/fork" data-icon="octicon-repo-forked" data-size="large" aria-label="Fork">Fork</a>
+   </p>

+ 35 - 0
next_docs/zh_cn/make.bat

@@ -0,0 +1,35 @@
+@ECHO OFF
+
+pushd %~dp0
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+	set SPHINXBUILD=sphinx-build
+)
+set SOURCEDIR=.
+set BUILDDIR=_build
+
+%SPHINXBUILD% >NUL 2>NUL
+if errorlevel 9009 (
+	echo.
+	echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+	echo.installed, then set the SPHINXBUILD environment variable to point
+	echo.to the full path of the 'sphinx-build' executable. Alternatively you
+	echo.may add the Sphinx directory to PATH.
+	echo.
+	echo.If you don't have Sphinx installed, grab it from
+	echo.https://www.sphinx-doc.org/
+	exit /b 1
+)
+
+if "%1" == "" goto help
+
+%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+goto end
+
+:help
+%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+
+:end
+popd

+ 1 - 2
projects/README.md

@@ -6,5 +6,4 @@
 - [gradio_app](./gradio_app/README.md): Build a web app based on gradio
 - [web_demo](./web_demo/README.md): MinerU online [demo](https://opendatalab.com/OpenSourceTools/Extractor/PDF/) localized deployment version
 - [web_api](./web_api/README.md): Web API Based on FastAPI
-
-
+- [multi_gpu](./multi_gpu/README.md): Multi-GPU parallel processing based on LitServe

+ 1 - 1
projects/README_zh-CN.md

@@ -6,4 +6,4 @@
 - [gradio_app](./gradio_app/README_zh-CN.md): 基于 Gradio 的 Web 应用
 - [web_demo](./web_demo/README_zh-CN.md): MinerU在线[demo](https://opendatalab.com/OpenSourceTools/Extractor/PDF/)本地化部署版本
 - [web_api](./web_api/README.md): 基于 FastAPI 的 Web API
-
+- [multi_gpu](./multi_gpu/README.md): 基于 LitServe 的多 GPU 并行处理

+ 68 - 12
projects/gradio_app/app.py

@@ -3,10 +3,12 @@
 import base64
 import os
 import time
+import uuid
 import zipfile
 from pathlib import Path
 import re
 
+import pymupdf
 from loguru import logger
 
 from magic_pdf.libs.hash_utils import compute_sha256
@@ -23,7 +25,7 @@ def read_fn(path):
     return disk_rw.read(os.path.basename(path), AbsReaderWriter.MODE_BIN)
 
 
-def parse_pdf(doc_path, output_dir, end_page_id, is_ocr):
+def parse_pdf(doc_path, output_dir, end_page_id, is_ocr, layout_mode, formula_enable, table_enable, language):
     os.makedirs(output_dir, exist_ok=True)
 
     try:
@@ -42,6 +44,10 @@ def parse_pdf(doc_path, output_dir, end_page_id, is_ocr):
             parse_method,
             False,
             end_page_id=end_page_id,
+            layout_model=layout_mode,
+            formula_enable=formula_enable,
+            table_enable=table_enable,
+            lang=language,
         )
         return local_md_dir, file_name
     except Exception as e:
@@ -93,9 +99,10 @@ def replace_image_with_base64(markdown_text, image_dir_path):
     return re.sub(pattern, replace, markdown_text)
 
 
-def to_markdown(file_path, end_pages, is_ocr):
+def to_markdown(file_path, end_pages, is_ocr, layout_mode, formula_enable, table_enable, language):
     # 获取识别的md文件以及压缩包文件路径
-    local_md_dir, file_name = parse_pdf(file_path, './output', end_pages - 1, is_ocr)
+    local_md_dir, file_name = parse_pdf(file_path, './output', end_pages - 1, is_ocr,
+                                        layout_mode, formula_enable, table_enable, language)
     archive_zip_path = os.path.join("./output", compute_sha256(local_md_dir) + ".zip")
     zip_archive_success = compress_directory_to_zip(local_md_dir, archive_zip_path)
     if zip_archive_success == 0:
@@ -138,24 +145,71 @@ with open("header.html", "r") as file:
     header = file.read()
 
 
+latin_lang = [
+        'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr',
+        'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl',
+        'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv',
+        'sw', 'tl', 'tr', 'uz', 'vi', 'french', 'german'
+]
+arabic_lang = ['ar', 'fa', 'ug', 'ur']
+cyrillic_lang = [
+        'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava',
+        'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
+]
+devanagari_lang = [
+        'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom',
+        'sa', 'bgc'
+]
+other_lang = ['ch', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka']
+
+all_lang = [""]
+all_lang.extend([*other_lang, *latin_lang, *arabic_lang, *cyrillic_lang, *devanagari_lang])
+
+
+def to_pdf(file_path):
+    with pymupdf.open(file_path) as f:
+        if f.is_pdf:
+            return file_path
+        else:
+            pdf_bytes = f.convert_to_pdf()
+            # 将pdfbytes 写入到uuid.pdf中
+            # 生成唯一的文件名
+            unique_filename = f"{uuid.uuid4()}.pdf"
+
+            # 构建完整的文件路径
+            tmp_file_path = os.path.join(os.path.dirname(file_path), unique_filename)
+
+            # 将字节数据写入文件
+            with open(tmp_file_path, 'wb') as tmp_pdf_file:
+                tmp_pdf_file.write(pdf_bytes)
+
+            return tmp_file_path
+
+
 if __name__ == "__main__":
     with gr.Blocks() as demo:
         gr.HTML(header)
         with gr.Row():
             with gr.Column(variant='panel', scale=5):
-                pdf_show = gr.Markdown()
+                file = gr.File(label="Please upload a PDF or image", file_types=[".pdf", ".png", ".jpeg", "jpg"])
                 max_pages = gr.Slider(1, 10, 5, step=1, label="Max convert pages")
-                with gr.Row() as bu_flow:
-                    is_ocr = gr.Checkbox(label="Force enable OCR")
+                with gr.Row():
+                    layout_mode = gr.Dropdown(["layoutlmv3", "doclayout_yolo"], label="Layout model", value="layoutlmv3")
+                    language = gr.Dropdown(all_lang, label="Language", value="")
+                with gr.Row():
+                    formula_enable = gr.Checkbox(label="Enable formula recognition", value=True)
+                    is_ocr = gr.Checkbox(label="Force enable OCR", value=False)
+                    table_enable = gr.Checkbox(label="Enable table recognition(test)", value=False)
+                with gr.Row():
                     change_bu = gr.Button("Convert")
-                    clear_bu = gr.ClearButton([pdf_show], value="Clear")
-                pdf_show = PDF(label="Please upload pdf", interactive=True, height=800)
+                    clear_bu = gr.ClearButton(value="Clear")
+                pdf_show = PDF(label="PDF preview", interactive=True, height=800)
                 with gr.Accordion("Examples:"):
                     example_root = os.path.join(os.path.dirname(__file__), "examples")
                     gr.Examples(
                         examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
                                   _.endswith("pdf")],
-                        inputs=pdf_show,
+                        inputs=pdf_show
                     )
 
             with gr.Column(variant='panel', scale=5):
@@ -166,7 +220,9 @@ if __name__ == "__main__":
                                          latex_delimiters=latex_delimiters, line_breaks=True)
                     with gr.Tab("Markdown text"):
                         md_text = gr.TextArea(lines=45, show_copy_button=True)
-        change_bu.click(fn=to_markdown, inputs=[pdf_show, max_pages, is_ocr], outputs=[md, md_text, output_file, pdf_show])
-        clear_bu.add([md, pdf_show, md_text, output_file, is_ocr])
+        file.upload(fn=to_pdf, inputs=file, outputs=pdf_show)
+        change_bu.click(fn=to_markdown, inputs=[pdf_show, max_pages, is_ocr, layout_mode, formula_enable, table_enable, language],
+                        outputs=[md, md_text, output_file, pdf_show])
+        clear_bu.add([file, md, pdf_show, md_text, output_file, is_ocr, table_enable, language])
 
-    demo.launch()
+    demo.launch(server_name="0.0.0.0")

BIN
projects/gradio_app/examples/2list_1table.pdf


BIN
projects/gradio_app/examples/3list_1table.pdf


BIN
projects/gradio_app/examples/academic_paper_formula.pdf


BIN
projects/gradio_app/examples/academic_paper_img_formula.pdf


BIN
projects/gradio_app/examples/academic_paper_list.pdf


BIN
projects/gradio_app/examples/complex_layout.pdf


BIN
projects/gradio_app/examples/complex_layout_para_split_list.pdf


BIN
projects/gradio_app/examples/garbled_formula.pdf


BIN
projects/gradio_app/examples/garbled_formula2.pdf


BIN
projects/gradio_app/examples/garbled_img_formula.pdf


BIN
projects/gradio_app/examples/magazine_complex_layout_images_list.pdf


+ 46 - 0
projects/multi_gpu/README.md

@@ -0,0 +1,46 @@
+## 项目简介
+本项目提供基于 LitServe 的多 GPU 并行处理方案。LitServe 是一个简便且灵活的 AI 模型服务引擎,基于 FastAPI 构建。它为 FastAPI 增强了批处理、流式传输和 GPU 自动扩展等功能,无需为每个模型单独重建 FastAPI 服务器。
+
+## 环境配置
+请使用以下命令配置所需的环境:
+```bash
+pip install -U litserve python-multipart filetype
+pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com
+pip install paddlepaddle-gpu==3.0.0b1 -i https://www.paddlepaddle.org.cn/packages/stable/cu118
+```
+
+## 快速使用
+### 1. 启动服务端
+以下示例展示了如何启动服务端,支持自定义设置:
+```python
+server = ls.LitServer(
+    MinerUAPI(output_dir='/tmp'),  # 可自定义输出文件夹
+    accelerator='cuda',  # 启用 GPU 加速
+    devices='auto',  # "auto" 使用所有 GPU
+    workers_per_device=1,  # 每个 GPU 启动一个服务实例
+    timeout=False  # 设置为 False 以禁用超时
+)
+server.run(port=8000)  # 设定服务端口为 8000
+```
+
+启动服务端命令:
+```bash
+python server.py
+```
+
+### 2. 启动客户端
+以下代码展示了客户端的使用方式,可根据需求修改配置:
+```python
+files = ['demo/small_ocr.pdf']  # 替换为文件路径,支持 jpg/jpeg、png、pdf 文件
+n_jobs = np.clip(len(files), 1, 8)  # 设置并发线程数,此处最大为 8,可根据自身修改
+results = Parallel(n_jobs, prefer='threads', verbose=10)(
+    delayed(do_parse)(p) for p in files
+)
+print(results)
+```
+
+启动客户端命令:
+```bash
+python client.py
+```
+好了,你的文件会自动在多个 GPU 上并行处理!🍻🍻🍻

+ 39 - 0
projects/multi_gpu/client.py

@@ -0,0 +1,39 @@
+import base64
+import requests
+import numpy as np
+from loguru import logger
+from joblib import Parallel, delayed
+
+
+def to_b64(file_path):
+    try:
+        with open(file_path, 'rb') as f:
+            return base64.b64encode(f.read()).decode('utf-8')
+    except Exception as e:
+        raise Exception(f'File: {file_path} - Info: {e}')
+
+
+def do_parse(file_path, url='http://127.0.0.1:8000/predict', **kwargs):
+    try:
+        response = requests.post(url, json={
+            'file': to_b64(file_path),
+            'kwargs': kwargs
+        })
+
+        if response.status_code == 200:
+            output = response.json()
+            output['file_path'] = file_path
+            return output
+        else:
+            raise Exception(response.text)
+    except Exception as e:
+        logger.error(f'File: {file_path} - Info: {e}')
+
+
+if __name__ == '__main__':
+    files = ['small_ocr.pdf']
+    n_jobs = np.clip(len(files), 1, 8)
+    results = Parallel(n_jobs, prefer='threads', verbose=10)(
+        delayed(do_parse)(p) for p in files
+    )
+    print(results)

+ 74 - 0
projects/multi_gpu/server.py

@@ -0,0 +1,74 @@
+import os
+import fitz
+import torch
+import base64
+import litserve as ls
+from uuid import uuid4
+from fastapi import HTTPException
+from filetype import guess_extension
+from magic_pdf.tools.common import do_parse
+from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
+
+
+class MinerUAPI(ls.LitAPI):
+    def __init__(self, output_dir='/tmp'):
+        self.output_dir = output_dir
+
+    def setup(self, device):
+        if device.startswith('cuda'):
+            os.environ['CUDA_VISIBLE_DEVICES'] = device.split(':')[-1]
+            if torch.cuda.device_count() > 1:
+                raise RuntimeError("Remove any CUDA actions before setting 'CUDA_VISIBLE_DEVICES'.")
+
+        model_manager = ModelSingleton()
+        model_manager.get_model(True, False)
+        model_manager.get_model(False, False)
+        print(f'Model initialization complete on {device}!')
+
+    def decode_request(self, request):
+        file = request['file']
+        file = self.to_pdf(file)
+        opts = request.get('kwargs', {})
+        opts.setdefault('debug_able', False)
+        opts.setdefault('parse_method', 'auto')
+        return file, opts
+
+    def predict(self, inputs):
+        try:
+            do_parse(self.output_dir, pdf_name := str(uuid4()), inputs[0], [], **inputs[1])
+            return pdf_name
+        except Exception as e:
+            raise HTTPException(status_code=500, detail=str(e))
+        finally:
+            self.clean_memory()
+
+    def encode_response(self, response):
+        return {'output_dir': response}
+
+    def clean_memory(self):
+        import gc
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()
+            torch.cuda.ipc_collect()
+        gc.collect()
+
+    def to_pdf(self, file_base64):
+        try:
+            file_bytes = base64.b64decode(file_base64)
+            file_ext = guess_extension(file_bytes)
+            with fitz.open(stream=file_bytes, filetype=file_ext) as f:
+                if f.is_pdf: return f.tobytes()
+                return f.convert_to_pdf()
+        except Exception as e:
+            raise HTTPException(status_code=500, detail=str(e))
+
+
+if __name__ == '__main__':
+    server = ls.LitServer(
+        MinerUAPI(output_dir='/tmp'),
+        accelerator='cuda',
+        devices='auto',
+        workers_per_device=1,
+        timeout=False
+    )
+    server.run(port=8000)

BIN
projects/multi_gpu/small_ocr.pdf


+ 1 - 1
requirements-docker.txt

@@ -5,7 +5,6 @@ PyMuPDF>=1.24.9
 loguru>=0.6.0
 numpy>=1.21.6,<2.0.0
 fast-langdetect==0.2.0
-wordninja>=2.0.0
 scikit-learn>=1.0.2
 pdfminer.six==20231228
 unimernet==0.2.1
@@ -15,4 +14,5 @@ paddleocr==2.7.3
 paddlepaddle==3.0.0b1
 pypandoc
 struct-eqtable==0.1.0
+doclayout-yolo==0.0.2
 detectron2

Certains fichiers n'ont pas été affichés car il y a eu trop de fichiers modifiés dans ce diff