瀏覽代碼

Merge pull request #838 from opendatalab/release-0.9.0

Release 0.9.0
Xiaomeng Zhao 1 年之前
父節點
當前提交
3a42ebbf57
共有 100 個文件被更改,包括 4250 次插入829 次删除
  1. 1 1
      .github/workflows/cla.yml
  2. 9 11
      .github/workflows/cli.yml
  3. 55 0
      .github/workflows/daily.yml
  4. 61 0
      .github/workflows/huigui.yml
  5. 0 22
      .github/workflows/update_base.yml
  6. 50 39
      .gitignore
  7. 3 2
      .pre-commit-config.yaml
  8. 16 0
      .readthedocs.yaml
  9. 1 1
      Dockerfile
  10. 1 0
      LICENSE.md
  11. 0 0
      README.md
  12. 15 6
      README_ja-JP.md
  13. 35 16
      README_zh-CN.md
  14. 2 15
      demo/demo.py
  15. 14 9
      demo/magic_pdf_parse_main.py
  16. 7 2
      docs/FAQ_en_us.md
  17. 13 3
      docs/FAQ_zh_cn.md
  18. 76 54
      docs/README_Ubuntu_CUDA_Acceleration_en_US.md
  19. 45 7
      docs/README_Ubuntu_CUDA_Acceleration_zh_CN.md
  20. 78 56
      docs/README_Windows_CUDA_Acceleration_en_US.md
  21. 35 10
      docs/README_Windows_CUDA_Acceleration_zh_CN.md
  22. 34 14
      docs/download_models.py
  23. 41 16
      docs/download_models_hf.py
  24. 8 4
      docs/how_to_download_models_en.md
  25. 7 10
      docs/how_to_download_models_zh_cn.md
  26. 二進制
      docs/images/web_demo_1.png
  27. 3 0
      docs/output_file_en_us.md
  28. 8 5
      docs/output_file_zh_cn.md
  29. 12 3
      magic-pdf.template.json
  30. 0 0
      magic_pdf/config/__init__.py
  31. 7 0
      magic_pdf/config/enums.py
  32. 32 0
      magic_pdf/config/exceptions.py
  33. 0 0
      magic_pdf/data/__init__.py
  34. 12 0
      magic_pdf/data/data_reader_writer/__init__.py
  35. 51 0
      magic_pdf/data/data_reader_writer/base.py
  36. 59 0
      magic_pdf/data/data_reader_writer/filebase.py
  37. 137 0
      magic_pdf/data/data_reader_writer/multi_bucket_s3.py
  38. 69 0
      magic_pdf/data/data_reader_writer/s3.py
  39. 194 0
      magic_pdf/data/dataset.py
  40. 0 0
      magic_pdf/data/io/__init__.py
  41. 42 0
      magic_pdf/data/io/base.py
  42. 37 0
      magic_pdf/data/io/http.py
  43. 114 0
      magic_pdf/data/io/s3.py
  44. 95 0
      magic_pdf/data/read_api.py
  45. 15 0
      magic_pdf/data/schemas.py
  46. 32 0
      magic_pdf/data/utils.py
  47. 74 234
      magic_pdf/dict2md/ocr_mkcontent.py
  48. 21 8
      magic_pdf/libs/Constants.py
  49. 1 0
      magic_pdf/libs/MakeContentConfig.py
  50. 二進制
      magic_pdf/libs/__pycache__/__init__.cpython-312.pyc
  51. 二進制
      magic_pdf/libs/__pycache__/version.cpython-312.pyc
  52. 35 0
      magic_pdf/libs/boxbase.py
  53. 10 0
      magic_pdf/libs/clean_memory.py
  54. 53 23
      magic_pdf/libs/config_reader.py
  55. 150 65
      magic_pdf/libs/draw_bbox.py
  56. 2 0
      magic_pdf/libs/ocr_content_type.py
  57. 1 1
      magic_pdf/libs/version.py
  58. 77 32
      magic_pdf/model/doc_analyze_by_custom_model.py
  59. 331 15
      magic_pdf/model/magic_model.py
  60. 164 80
      magic_pdf/model/pdf_extract_kit.py
  61. 8 1
      magic_pdf/model/pek_sub_modules/structeqtable/StructTableModel.py
  62. 2 2
      magic_pdf/model/ppTableModel.py
  63. 5 2
      magic_pdf/model/pp_structure_v2.py
  64. 0 0
      magic_pdf/model/v3/__init__.py
  65. 125 0
      magic_pdf/model/v3/helpers.py
  66. 296 0
      magic_pdf/para/para_split_v3.py
  67. 6 3
      magic_pdf/pdf_parse_by_ocr.py
  68. 6 3
      magic_pdf/pdf_parse_by_txt.py
  69. 644 0
      magic_pdf/pdf_parse_union_core_v2.py
  70. 5 1
      magic_pdf/pipe/AbsPipe.py
  71. 10 4
      magic_pdf/pipe/OCRPipe.py
  72. 10 4
      magic_pdf/pipe/TXTPipe.py
  73. 16 7
      magic_pdf/pipe/UNIPipe.py
  74. 83 1
      magic_pdf/pre_proc/ocr_detect_all_bboxes.py
  75. 27 2
      magic_pdf/pre_proc/ocr_dict_merge.py
  76. 7 7
      magic_pdf/resources/model_config/UniMERNet/demo.yaml
  77. 5 13
      magic_pdf/resources/model_config/model_configs.yaml
  78. 14 1
      magic_pdf/tools/cli.py
  79. 18 8
      magic_pdf/tools/common.py
  80. 25 6
      magic_pdf/user_api.py
  81. 0 0
      magic_pdf/utils/__init__.py
  82. 11 0
      magic_pdf/utils/annotations.py
  83. 16 0
      next_docs/en/.readthedocs.yaml
  84. 20 0
      next_docs/en/Makefile
  85. 二進制
      next_docs/en/_static/image/logo.png
  86. 9 0
      next_docs/en/api.rst
  87. 44 0
      next_docs/en/api/data_reader_writer.rst
  88. 22 0
      next_docs/en/api/dataset.rst
  89. 0 0
      next_docs/en/api/io.rst
  90. 6 0
      next_docs/en/api/read_api.rst
  91. 0 0
      next_docs/en/api/schemas.rst
  92. 1 0
      next_docs/en/api/utils.rst
  93. 122 0
      next_docs/en/conf.py
  94. 38 0
      next_docs/en/index.rst
  95. 35 0
      next_docs/en/make.bat
  96. 11 0
      next_docs/requirements.txt
  97. 16 0
      next_docs/zh_cn/.readthedocs.yaml
  98. 20 0
      next_docs/zh_cn/Makefile
  99. 二進制
      next_docs/zh_cn/_static/image/logo.png
  100. 122 0
      next_docs/zh_cn/conf.py

+ 1 - 1
.github/workflows/cla.yml

@@ -29,7 +29,7 @@ jobs:
           path-to-document: 'https://github.com/opendatalab/MinerU/blob/master/MinerU_CLA.md' # e.g. a CLA or a DCO document
           # branch should not be protected
           branch: 'master'
-          allowlist: myhloli,dt-yy,Focusshang,renpengli01,icecraft,drunkpig,wangbinDL,qiangqiang199,GDDGCZ518,papayalove,conghui,quyuan
+          allowlist: myhloli,dt-yy,Focusshang,renpengli01,icecraft,drunkpig,wangbinDL,qiangqiang199,GDDGCZ518,papayalove,conghui,quyuan,LollipopsAndWine
 
          # the followings are the optional inputs - If the optional inputs are not given, then default values will be taken
           #remote-organization-name: enter the remote organization name where the signatures should be stored (Default is storing the signatures in the same repository)

+ 9 - 11
.github/workflows/cli.yml

@@ -10,7 +10,6 @@ on:
     paths-ignore:
       - "cmds/**"
       - "**.md"
-      - "**.yml"
   pull_request:
     branches:
       - "master"
@@ -18,12 +17,11 @@ on:
     paths-ignore:
       - "cmds/**"
       - "**.md"
-      - "**.yml"
   workflow_dispatch:
 jobs:
   cli-test:
     runs-on: pdf
-    timeout-minutes: 120
+    timeout-minutes: 240
     strategy:
       fail-fast: true
 
@@ -33,16 +31,16 @@ jobs:
       with:
         fetch-depth: 2
 
-    - name: install
+    - name: install&test
       run: |
-        echo $GITHUB_WORKSPACE && sh tests/retry_env.sh
-    - name: unit test
-      run: |        
-        cd $GITHUB_WORKSPACE && export PYTHONPATH=. && coverage run -m  pytest  tests/test_unit.py --cov=magic_pdf/ --cov-report term-missing --cov-report html
+        source activate mineru
+        conda env list
+        pip show coverage
+        # cd $GITHUB_WORKSPACE && sh tests/retry_env.sh
+        cd $GITHUB_WORKSPACE && python tests/clean_coverage.py      
+        cd $GITHUB_WORKSPACE && coverage run -m pytest tests/unittest/ --cov=magic_pdf/  --cov-report html --cov-report term-missing
         cd $GITHUB_WORKSPACE && python tests/get_coverage.py
-    - name: cli test
-      run: |
-        cd $GITHUB_WORKSPACE &&  pytest -s -v tests/test_cli/test_cli_sdk.py
+        cd $GITHUB_WORKSPACE && pytest -m P0 -s -v tests/test_cli/test_cli_sdk.py
 
   notify_to_feishu:
     if: ${{ always() && !cancelled() && contains(needs.*.result, 'failure') && (github.ref_name == 'master') }}

+ 55 - 0
.github/workflows/daily.yml

@@ -0,0 +1,55 @@
+# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
+# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
+
+name: mineru
+on:
+  schedule:
+    - cron: '0 22 * * *'  # 每天晚上 10 点执行
+jobs:
+  cli-test:
+    runs-on: pdf
+    timeout-minutes: 240
+    strategy:
+      fail-fast: true
+
+    steps:
+    - name: PDF cli
+      uses: actions/checkout@v3
+      with:
+        fetch-depth: 2
+
+    - name: install&test
+      run: |
+        source activate mineru
+        conda env list
+        pip show coverage
+        # cd $GITHUB_WORKSPACE && sh tests/retry_env.sh
+        cd $GITHUB_WORKSPACE && python tests/clean_coverage.py      
+        cd $GITHUB_WORKSPACE && coverage run -m pytest tests/unittest/ --cov=magic_pdf/  --cov-report html --cov-report term-missing
+        cd $GITHUB_WORKSPACE && python tests/get_coverage.py
+        cd $GITHUB_WORKSPACE && pytest -s -v tests/test_cli/test_cli_sdk.py
+
+  notify_to_feishu:
+    if: ${{ always() && !cancelled() && contains(needs.*.result, 'failure') && (github.ref_name == 'master') }}
+    needs: cli-test
+    runs-on: pdf
+    steps:
+    - name: get_actor
+      run: |
+          metion_list="dt-yy"
+          echo $GITHUB_ACTOR
+          if [[ $GITHUB_ACTOR == "drunkpig" ]]; then
+            metion_list="xuchao"
+          elif [[ $GITHUB_ACTOR == "myhloli" ]]; then
+            metion_list="zhaoxiaomeng"
+          elif [[ $GITHUB_ACTOR == "icecraft" ]]; then
+            metion_list="xurui1"
+          fi
+          echo $metion_list
+          echo "METIONS=$metion_list" >> "$GITHUB_ENV"
+          echo ${{ env.METIONS }}
+
+    - name: notify
+      run: |
+        echo ${{ secrets.USER_ID }}
+        curl -X POST -H "Content-Type: application/json" -d '{"msg_type":"post","content":{"post":{"zh_cn":{"title":"'${{ github.repository }}' GitHubAction Failed","content":[[{"tag":"text","text":""},{"tag":"a","text":"Please click here for details ","href":"https://github.com/'${{ github.repository }}'/actions/runs/'${GITHUB_RUN_ID}'"},{"tag":"at","user_id":"'${{ secrets.USER_ID }}'"}]]}}}}'  ${{ secrets.WEBHOOK_URL }}

+ 61 - 0
.github/workflows/huigui.yml

@@ -0,0 +1,61 @@
+# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
+# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
+
+name: mineru
+on:
+  push:
+    branches:
+      - "master"
+      - "dev"
+    paths-ignore:
+      - "cmds/**"
+      - "**.md"
+  workflow_dispatch:
+jobs:
+  cli-test:
+    runs-on: pdf
+    timeout-minutes: 240
+    strategy:
+      fail-fast: true
+
+    steps:
+    - name: PDF cli
+      uses: actions/checkout@v3
+      with:
+        fetch-depth: 2
+
+    - name: install&test
+      run: |
+        source activate mineru
+        conda env list
+        pip show coverage
+        # cd $GITHUB_WORKSPACE && sh tests/retry_env.sh
+        cd $GITHUB_WORKSPACE && python tests/clean_coverage.py      
+        cd $GITHUB_WORKSPACE && coverage run -m pytest tests/unittest/ --cov=magic_pdf/  --cov-report html --cov-report term-missing
+        cd $GITHUB_WORKSPACE && python tests/get_coverage.py
+        cd $GITHUB_WORKSPACE && pytest -s -v tests/test_cli/test_cli_sdk.py
+
+  notify_to_feishu:
+    if: ${{ always() && !cancelled() && contains(needs.*.result, 'failure') && (github.ref_name == 'master') }}
+    needs: cli-test
+    runs-on: pdf
+    steps:
+    - name: get_actor
+      run: |
+          metion_list="dt-yy"
+          echo $GITHUB_ACTOR
+          if [[ $GITHUB_ACTOR == "drunkpig" ]]; then
+            metion_list="xuchao"
+          elif [[ $GITHUB_ACTOR == "myhloli" ]]; then
+            metion_list="zhaoxiaomeng"
+          elif [[ $GITHUB_ACTOR == "icecraft" ]]; then
+            metion_list="xurui1"
+          fi
+          echo $metion_list
+          echo "METIONS=$metion_list" >> "$GITHUB_ENV"
+          echo ${{ env.METIONS }}
+
+    - name: notify
+      run: |
+        echo ${{ secrets.USER_ID }}
+        curl -X POST -H "Content-Type: application/json" -d '{"msg_type":"post","content":{"post":{"zh_cn":{"title":"'${{ github.repository }}' GitHubAction Failed","content":[[{"tag":"text","text":""},{"tag":"a","text":"Please click here for details ","href":"https://github.com/'${{ github.repository }}'/actions/runs/'${GITHUB_RUN_ID}'"},{"tag":"at","user_id":"'${{ secrets.USER_ID }}'"}]]}}}}'  ${{ secrets.WEBHOOK_URL }}

+ 0 - 22
.github/workflows/update_base.yml

@@ -1,22 +0,0 @@
-# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
-# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
-
-name: update-base
-on:
-  push:
-    tags:
-      - '*released'
-  workflow_dispatch:
-jobs:
-  pdf-test:
-    runs-on: pdf
-    timeout-minutes: 40
-
-
-    steps:
-    - name: update-base
-      uses: actions/checkout@v3
-    - name: start-update
-      run: |
-        echo "start test"
-  

+ 50 - 39
.gitignore

@@ -1,39 +1,50 @@
-*.tar
-*.tar.gz
-venv*/
-envs/
-slurm_logs/
-
-sync1.sh
-data_preprocess_pj1
-data-preparation1
-__pycache__
-*.log
-*.pyc
-.vscode
-debug/
-*.ipynb
-.idea
-
-# vscode history
-.history
-
-.DS_Store
-.env
-
-bad_words/
-bak/
-
-app/tests/*
-temp/
-tmp/
-tmp
-.vscode
-.vscode/
-ocr_demo
-
-/app/common/__init__.py
-/magic_pdf/config/__init__.py
-source.dev.env
-
-tmp
+*.tar
+*.tar.gz
+*.zip
+venv*/
+envs/
+slurm_logs/
+
+sync1.sh
+data_preprocess_pj1
+data-preparation1
+__pycache__
+*.log
+*.pyc
+.vscode
+debug/
+*.ipynb
+.idea
+
+# vscode history
+.history
+
+.DS_Store
+.env
+
+bad_words/
+bak/
+
+app/tests/*
+temp/
+tmp/
+tmp
+.vscode
+.vscode/
+ocr_demo
+.coveragerc
+/app/common/__init__.py
+/magic_pdf/config/__init__.py
+source.dev.env
+
+tmp
+
+projects/web/node_modules
+projects/web/dist
+
+projects/web_demo/web_demo/static/
+cli_debug/
+debug_utils/
+
+# sphinx docs
+_build/

+ 3 - 2
.pre-commit-config.yaml

@@ -3,7 +3,7 @@ repos:
     rev: 5.0.4
     hooks:
       - id: flake8
-        args: ["--max-line-length=120", "--ignore=E131,E125,W503,W504,E203"]
+        args: ["--max-line-length=150", "--ignore=E131,E125,W503,W504,E203"]
   - repo: https://github.com/PyCQA/isort
     rev: 5.11.5
     hooks:
@@ -12,11 +12,12 @@ repos:
     rev: v0.32.0
     hooks:
       - id: yapf
-        args: ["--style={based_on_style: google, column_limit: 120, indent_width: 4}"]
+        args: ["--style={based_on_style: google, column_limit: 150, indent_width: 4}"]
   - repo: https://github.com/codespell-project/codespell
     rev: v2.2.1
     hooks:
       - id: codespell
+        args: ['--skip', '*.json']
   - repo: https://github.com/pre-commit/pre-commit-hooks
     rev: v4.3.0
     hooks:

+ 16 - 0
.readthedocs.yaml

@@ -0,0 +1,16 @@
+version: 2
+
+build:
+  os: ubuntu-22.04
+  tools:
+    python: "3.10"
+
+formats:
+  - epub
+
+python:
+  install:
+    - requirements: docs/zh_cn/requirements.txt
+
+sphinx:
+  configuration: docs/zh_cn/conf.py

+ 1 - 1
Dockerfile

@@ -31,7 +31,7 @@ RUN python3 -m venv /opt/mineru_venv
 RUN /bin/bash -c "source /opt/mineru_venv/bin/activate && \
     pip3 install --upgrade pip && \
     wget https://gitee.com/myhloli/MinerU/raw/master/requirements-docker.txt && \
-    pip3 install -r requirements-docker.txt --extra-index-url https://wheels.myhloli.com -i https://pypi.tuna.tsinghua.edu.cn/simple && \
+    pip3 install -r requirements-docker.txt --extra-index-url https://wheels.myhloli.com -i https://mirrors.aliyun.com/pypi/simple && \
     pip3 install paddlepaddle-gpu==3.0.0b1 -i https://www.paddlepaddle.org.cn/packages/stable/cu118/"
 
 # Copy the configuration file template and install magic-pdf latest

+ 1 - 0
LICENSE.md

@@ -659,3 +659,4 @@ specific requirements.
 if any, to sign a "copyright disclaimer" for the program, if necessary.
 For more information on this, and how to apply and follow the GNU AGPL, see
 <https://www.gnu.org/licenses/>.
+

文件差異過大導致無法顯示
+ 0 - 0
README.md


+ 15 - 6
README_ja-JP.md

@@ -290,14 +290,23 @@ https://github.com/opendatalab/MinerU/assets/11393164/20438a02-ce6c-4af8-9dde-d7
 # 引用
 
 ```bibtex
-@misc{2024mineru,
-    title={MinerU: A One-stop, Open-source, High-quality Data Extraction Tool},
-    author={MinerU Contributors},
-    howpublished = {\url{https://github.com/opendatalab/MinerU}},
-    year={2024}
+@misc{wang2024mineruopensourcesolutionprecise,
+      title={MinerU: An Open-Source Solution for Precise Document Content Extraction}, 
+      author={Bin Wang and Chao Xu and Xiaomeng Zhao and Linke Ouyang and Fan Wu and Zhiyuan Zhao and Rui Xu and Kaiwen Liu and Yuan Qu and Fukai Shang and Bo Zhang and Liqun Wei and Zhihao Sui and Wei Li and Botian Shi and Yu Qiao and Dahua Lin and Conghui He},
+      year={2024},
+      eprint={2409.18839},
+      archivePrefix={arXiv},
+      primaryClass={cs.CV},
+      url={https://arxiv.org/abs/2409.18839}, 
 }
-```
 
+@article{he2024opendatalab,
+  title={Opendatalab: Empowering general artificial intelligence with open datasets},
+  author={He, Conghui and Li, Wei and Jin, Zhenjiang and Xu, Chao and Wang, Bin and Lin, Dahua},
+  journal={arXiv preprint arXiv:2407.13773},
+  year={2024}
+}
+```
 
 # スター履歴
 

文件差異過大導致無法顯示
+ 35 - 16
README_zh-CN.md


+ 2 - 15
demo/demo.py

@@ -1,35 +1,22 @@
 import os
-import json
 
 from loguru import logger
-
 from magic_pdf.pipe.UNIPipe import UNIPipe
 from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
 
-import magic_pdf.model as model_config 
-model_config.__use_inside_model__ = True
 
 try:
     current_script_dir = os.path.dirname(os.path.abspath(__file__))
     demo_name = "demo1"
     pdf_path = os.path.join(current_script_dir, f"{demo_name}.pdf")
-    model_path = os.path.join(current_script_dir, f"{demo_name}.json")
     pdf_bytes = open(pdf_path, "rb").read()
-    # model_json = json.loads(open(model_path, "r", encoding="utf-8").read())
-    model_json = []  # model_json传空list使用内置模型解析
-    jso_useful_key = {"_pdf_type": "", "model_list": model_json}
+    jso_useful_key = {"_pdf_type": "", "model_list": []}
     local_image_dir = os.path.join(current_script_dir, 'images')
     image_dir = str(os.path.basename(local_image_dir))
     image_writer = DiskReaderWriter(local_image_dir)
     pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer)
     pipe.pipe_classify()
-    """如果没有传入有效的模型数据,则使用内置model解析"""
-    if len(model_json) == 0:
-        if model_config.__use_inside_model__:
-            pipe.pipe_analyze()
-        else:
-            logger.error("need model list input")
-            exit(1)
+    pipe.pipe_analyze()
     pipe.pipe_parse()
     md_content = pipe.pipe_mk_markdown(image_dir, drop_mode="none")
     with open(f"{demo_name}.md", "w", encoding="utf-8") as f:

+ 14 - 9
demo/magic_pdf_parse_main.py

@@ -4,13 +4,12 @@ import copy
 
 from loguru import logger
 
+from magic_pdf.libs.draw_bbox import draw_layout_bbox, draw_span_bbox
 from magic_pdf.pipe.UNIPipe import UNIPipe
 from magic_pdf.pipe.OCRPipe import OCRPipe
 from magic_pdf.pipe.TXTPipe import TXTPipe
 from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
-import magic_pdf.model as model_config
 
-model_config.__use_inside_model__ = True
 
 # todo: 设备类型选择 (?)
 
@@ -47,11 +46,20 @@ def json_md_dump(
     )
 
 
+# 可视化
+def draw_visualization_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name):
+    # 画布局框,附带排序结果
+    draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
+    # 画 span 框
+    draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
+
+
 def pdf_parse_main(
         pdf_path: str,
         parse_method: str = 'auto',
         model_json_path: str = None,
         is_json_md_dump: bool = True,
+        is_draw_visualization_bbox: bool = True,
         output_dir: str = None
 ):
     """
@@ -108,11 +116,7 @@ def pdf_parse_main(
 
         # 如果没有传入模型数据,则使用内置模型解析
         if not model_json:
-            if model_config.__use_inside_model__:
-                pipe.pipe_analyze()  # 解析
-            else:
-                logger.error("need model list input")
-                exit(1)
+            pipe.pipe_analyze()  # 解析
 
         # 执行解析
         pipe.pipe_parse()
@@ -121,10 +125,11 @@ def pdf_parse_main(
         content_list = pipe.pipe_mk_uni_format(image_path_parent, drop_mode="none")
         md_content = pipe.pipe_mk_markdown(image_path_parent, drop_mode="none")
 
-
         if is_json_md_dump:
             json_md_dump(pipe, md_writer, pdf_name, content_list, md_content)
 
+        if is_draw_visualization_bbox:
+            draw_visualization_bbox(pipe.pdf_mid_data['pdf_info'], pdf_bytes, output_path, pdf_name)
 
     except Exception as e:
         logger.exception(e)
@@ -132,5 +137,5 @@ def pdf_parse_main(
 
 # 测试
 if __name__ == '__main__':
-    pdf_path = r"C:\Users\XYTK2\Desktop\2024-2016-gb-cd-300.pdf"
+    pdf_path = r"D:\project\20240617magicpdf\Magic-PDF\demo\demo1.pdf"
     pdf_parse_main(pdf_path)

+ 7 - 2
docs/FAQ_en_us.md

@@ -11,7 +11,7 @@ pip install magic-pdf[full]
 
 ### 2. Encountering the error `pickle.UnpicklingError: invalid load key, 'v'.` during use
 
-This might be due to an incomplete download of the model file. You can try re-downloading the model file and then try again.  
+This might be due to an incomplete download of the model file. You can try re-downloading the model file and then try again.
 Reference: https://github.com/opendatalab/MinerU/issues/143
 
 ### 3. Where should the model files be downloaded and how should the `/models-dir` configuration be set?
@@ -24,7 +24,7 @@ The path for the model files is configured in "magic-pdf.json". just like:
 }
 ```
 
-This path is an absolute path, not a relative path. You can obtain the absolute path in the models directory using the "pwd" command.  
+This path is an absolute path, not a relative path. You can obtain the absolute path in the models directory using the "pwd" command.
 Reference: https://github.com/opendatalab/MinerU/issues/155#issuecomment-2230216874
 
 ### 4. Encountered the error `ImportError: libGL.so.1: cannot open shared object file: No such file or directory` in Ubuntu 22.04 on WSL2
@@ -38,17 +38,22 @@ sudo apt-get install libgl1-mesa-glx
 Reference: https://github.com/opendatalab/MinerU/issues/388
 
 ### 5. Encountered error `ModuleNotFoundError: No module named 'fairscale'`
+
 You need to uninstall the module and reinstall it:
+
 ```bash
 pip uninstall fairscale
 pip install fairscale
 ```
+
 Reference: https://github.com/opendatalab/MinerU/issues/411
 
 ### 6. On some newer devices like the H100, the text parsed during OCR using CUDA acceleration is garbled.
 
 The compatibility of cuda11 with new graphics cards is poor, and the CUDA version used by Paddle needs to be upgraded.
+
 ```bash
 pip install paddlepaddle-gpu==3.0.0b1 -i https://www.paddlepaddle.org.cn/packages/stable/cu123/
 ```
+
 Reference: https://github.com/opendatalab/MinerU/issues/558

+ 13 - 3
docs/FAQ_zh_cn.md

@@ -1,9 +1,10 @@
 # 常见问题解答
 
-### 1.在较新版本的mac上使用命令安装pip install magic-pdf[full] zsh: no matches found: magic-pdf[full]
+### 1.在较新版本的mac上使用命令安装pip install magic-pdf\[full\] zsh: no matches found: magic-pdf\[full\]
 
 在 macOS 上,默认的 shell 从 Bash 切换到了 Z shell,而 Z shell 对于某些类型的字符串匹配有特殊的处理逻辑,这可能导致no matches found错误。
 可以通过在命令行禁用globbing特性,再尝试运行安装命令
+
 ```bash
 setopt no_nomatch
 pip install magic-pdf[full]
@@ -11,41 +12,50 @@ pip install magic-pdf[full]
 
 ### 2.使用过程中遇到_pickle.UnpicklingError: invalid load key, 'v'.错误
 
-可能是由于模型文件未下载完整导致,可尝试重新下载模型文件后再试  
+可能是由于模型文件未下载完整导致,可尝试重新下载模型文件后再试
 参考:https://github.com/opendatalab/MinerU/issues/143
 
 ### 3.模型文件应该下载到哪里/models-dir的配置应该怎么填
 
 模型文件的路径输入是在"magic-pdf.json"中通过
+
 ```json
 {
   "models-dir": "/tmp/models"
 }
 ```
+
 进行配置的。
-这个路径是绝对路径而不是相对路径,绝对路径的获取可在models目录中通过命令 "pwd" 获取。  
+这个路径是绝对路径而不是相对路径,绝对路径的获取可在models目录中通过命令 "pwd" 获取。
 参考:https://github.com/opendatalab/MinerU/issues/155#issuecomment-2230216874
 
 ### 4.在WSL2的Ubuntu22.04中遇到报错`ImportError: libGL.so.1: cannot open shared object file: No such file or directory`
 
 WSL2的Ubuntu22.04中缺少`libgl`库,可通过以下命令安装`libgl`库解决:
+
 ```bash
 sudo apt-get install libgl1-mesa-glx
 ```
+
 参考:https://github.com/opendatalab/MinerU/issues/388
 
 ### 5.遇到报错 `ModuleNotFoundError : Nomodulenamed 'fairscale'`
+
 需要卸载该模块并重新安装
+
 ```bash
 pip uninstall fairscale
 pip install fairscale
 ```
+
 参考:https://github.com/opendatalab/MinerU/issues/411
 
 ### 6.在部分较新的设备如H100上,使用CUDA加速OCR时解析出的文字乱码。
 
 cuda11对新显卡的兼容性不好,需要升级paddle使用的cuda版本
+
 ```bash
 pip install paddlepaddle-gpu==3.0.0b1 -i https://www.paddlepaddle.org.cn/packages/stable/cu123/
 ```
+
 参考:https://github.com/opendatalab/MinerU/issues/558

+ 76 - 54
docs/README_Ubuntu_CUDA_Acceleration_en_US.md

@@ -1,80 +1,104 @@
-
 # Ubuntu 22.04 LTS
 
 ### 1. Check if NVIDIA Drivers Are Installed
-   ```sh
-   nvidia-smi
-   ```
-   If you see information similar to the following, it means that the NVIDIA drivers are already installed, and you can skip Step 2.
-   ```plaintext
-   +---------------------------------------------------------------------------------------+
-   | NVIDIA-SMI 537.34                 Driver Version: 537.34       CUDA Version: 12.2     |
-   |-----------------------------------------+----------------------+----------------------+
-   | GPU  Name                     TCC/WDDM  | Bus-Id        Disp.A | Volatile Uncorr. ECC |
-   | Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
-   |                                         |                      |               MIG M. |
-   |=========================================+======================+======================|
-   |   0  NVIDIA GeForce RTX 3060 Ti   WDDM  | 00000000:01:00.0  On |                  N/A |
-   |  0%   51C    P8              12W / 200W |   1489MiB /  8192MiB |      5%      Default |
-   |                                         |                      |                  N/A |
-   +-----------------------------------------+----------------------+----------------------+
-   ```
+
+```sh
+nvidia-smi
+```
+
+If you see information similar to the following, it means that the NVIDIA drivers are already installed, and you can skip Step 2.
+
+Notice:`CUDA Version` should be >= 12.1, If the displayed version number is less than 12.1, please upgrade the driver.
+
+```plaintext
++---------------------------------------------------------------------------------------+
+| NVIDIA-SMI 537.34                 Driver Version: 537.34       CUDA Version: 12.2     |
+|-----------------------------------------+----------------------+----------------------+
+| GPU  Name                     TCC/WDDM  | Bus-Id        Disp.A | Volatile Uncorr. ECC |
+| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
+|                                         |                      |               MIG M. |
+|=========================================+======================+======================|
+|   0  NVIDIA GeForce RTX 3060 Ti   WDDM  | 00000000:01:00.0  On |                  N/A |
+|  0%   51C    P8              12W / 200W |   1489MiB /  8192MiB |      5%      Default |
+|                                         |                      |                  N/A |
++-----------------------------------------+----------------------+----------------------+
+```
 
 ### 2. Install the Driver
-   If no driver is installed, use the following command:
-   ```sh
-   sudo apt-get update
-   sudo apt-get install nvidia-driver-545
-   ```
-   Install the proprietary driver and restart your computer after installation.
-   ```sh
-   reboot
-   ```
+
+If no driver is installed, use the following command:
+
+```sh
+sudo apt-get update
+sudo apt-get install nvidia-driver-545
+```
+
+Install the proprietary driver and restart your computer after installation.
+
+```sh
+reboot
+```
 
 ### 3. Install Anaconda
-   If Anaconda is already installed, skip this step.
-   ```sh
-   wget https://repo.anaconda.com/archive/Anaconda3-2024.06-1-Linux-x86_64.sh
-   bash Anaconda3-2024.06-1-Linux-x86_64.sh
-   ```
-   In the final step, enter `yes`, close the terminal, and reopen it.
+
+If Anaconda is already installed, skip this step.
+
+```sh
+wget https://repo.anaconda.com/archive/Anaconda3-2024.06-1-Linux-x86_64.sh
+bash Anaconda3-2024.06-1-Linux-x86_64.sh
+```
+
+In the final step, enter `yes`, close the terminal, and reopen it.
 
 ### 4. Create an Environment Using Conda
-   Specify Python version 3.10.
-   ```sh
-   conda create -n MinerU python=3.10
-   conda activate MinerU
-   ```
+
+Specify Python version 3.10.
+
+```sh
+conda create -n MinerU python=3.10
+conda activate MinerU
+```
 
 ### 5. Install Applications
-   ```sh
-   pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com
-   ```
+
+```sh
+pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com
+```
+
 ❗ After installation, make sure to check the version of `magic-pdf` using the following command:
-   ```sh
-   magic-pdf --version
-   ```
-   If the version number is less than 0.7.0, please report the issue.
+
+```sh
+magic-pdf --version
+```
+
+If the version number is less than 0.7.0, please report the issue.
 
 ### 6. Download Models
-   Refer to detailed instructions on [how to download model files](how_to_download_models_en.md).
+
+
+Refer to detailed instructions on [how to download model files](how_to_download_models_en.md).
+
 
 ## 7. Understand the Location of the Configuration File
 
 After completing the [6. Download Models](#6-download-models) step, the script will automatically generate a `magic-pdf.json` file in the user directory and configure the default model path.
 You can find the `magic-pdf.json` file in your user directory.
+
 > The user directory for Linux is "/home/username".
 
+
 ### 8. First Run
-   Download a sample file from the repository and test it.
-   ```sh
-   wget https://github.com/opendatalab/MinerU/raw/master/demo/small_ocr.pdf
-   magic-pdf -p small_ocr.pdf
-   ```
+
+Download a sample file from the repository and test it.
+
+```sh
+wget https://github.com/opendatalab/MinerU/raw/master/demo/small_ocr.pdf
+magic-pdf -p small_ocr.pdf
+```
 
 ### 9. Test CUDA Acceleration
 
-If your graphics card has at least 8GB of VRAM, follow these steps to test CUDA acceleration:
+If your graphics card has at least **8GB** of VRAM, follow these steps to test CUDA acceleration:
 
 1. Modify the value of `"device-mode"` in the `magic-pdf.json` configuration file located in your home directory.
    ```json
@@ -89,8 +113,6 @@ If your graphics card has at least 8GB of VRAM, follow these steps to test CUDA
 
 ### 10. Enable CUDA Acceleration for OCR
 
-❗ The following operations require a graphics card with at least 16GB of VRAM; otherwise, the program may crash or experience reduced performance.
-    
 1. Download `paddlepaddle-gpu`. Installation will automatically enable OCR acceleration.
    ```sh
    python -m pip install paddlepaddle-gpu==3.0.0b1 -i https://www.paddlepaddle.org.cn/packages/stable/cu118/

+ 45 - 7
docs/README_Ubuntu_CUDA_Acceleration_zh_CN.md

@@ -1,10 +1,16 @@
 # Ubuntu 22.04 LTS
 
 ## 1. 检测是否已安装nvidia驱动
+
 ```bash
-nvidia-smi 
+nvidia-smi
 ```
+
 如果看到类似如下的信息,说明已经安装了nvidia驱动,可以跳过步骤2
+
+注意:`CUDA Version` 显示的版本号应 >= 12.1,如显示的版本号小于12.1,请升级驱动
+
+```plaintext
 ```
 +---------------------------------------------------------------------------------------+
 | NVIDIA-SMI 537.34                 Driver Version: 537.34       CUDA Version: 12.2     |
@@ -18,78 +24,110 @@ nvidia-smi
 |                                         |                      |                  N/A |
 +-----------------------------------------+----------------------+----------------------+
 ```
+
 ## 2. 安装驱动
+
 如没有驱动,则通过如下命令
+
 ```bash
 sudo apt-get update
 sudo apt-get install nvidia-driver-545
 ```
+
 安装专有驱动,安装完成后,重启电脑
+
 ```bash
 reboot
 ```
+
 ## 3. 安装anacoda
+
 如果已安装conda,可以跳过本步骤
+
 ```bash
 wget -U NoSuchBrowser/1.0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/Anaconda3-2024.06-1-Linux-x86_64.sh
 bash Anaconda3-2024.06-1-Linux-x86_64.sh
 ```
+
 最后一步输入yes,关闭终端重新打开
+
 ## 4. 使用conda 创建环境
+
 需指定python版本为3.10
+
 ```bash
 conda create -n MinerU python=3.10
 conda activate MinerU
 ```
+
 ## 5. 安装应用
+
 ```bash
-pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com -i https://pypi.tuna.tsinghua.edu.cn/simple
+pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com -i https://mirrors.aliyun.com/pypi/simple
 ```
+
 > ❗️下载完成后,务必通过以下命令确认magic-pdf的版本是否正确
-> 
+>
 > ```bash
 > magic-pdf --version
->```
+> ```
+>
 > 如果版本号小于0.7.0,请到issue中向我们反馈
 
 ## 6. 下载模型
+
+
 详细参考 [如何下载模型文件](how_to_download_models_zh_cn.md)
 
 ## 7. 了解配置文件存放的位置
+
 完成[6.下载模型](#6-下载模型)步骤后,脚本会自动生成用户目录下的magic-pdf.json文件,并自动配置默认模型路径。
-您可在【用户目录】下找到magic-pdf.json文件。 
+您可在【用户目录】下找到magic-pdf.json文件。
+
+
 > linux用户目录为 "/home/用户名"
 
 ## 8. 第一次运行
+
 从仓库中下载样本文件,并测试
+
 ```bash
 wget https://gitee.com/myhloli/MinerU/raw/master/demo/small_ocr.pdf
 magic-pdf -p small_ocr.pdf
 ```
+
 ## 9. 测试CUDA加速
-如果您的显卡显存大于等于8G,可以进行以下流程,测试CUDA解析加速效果
+
+如果您的显卡显存大于等于 **8GB** ,可以进行以下流程,测试CUDA解析加速效果
 
 **1.修改【用户目录】中配置文件magic-pdf.json中"device-mode"的值**
+
 ```json
 {
   "device-mode":"cuda"
 }
 ```
+
 **2.运行以下命令测试cuda加速效果**
+
 ```bash
 magic-pdf -p small_ocr.pdf
 ```
+
 > 提示:CUDA加速是否生效可以根据log中输出的各个阶段cost耗时来简单判断,通常情况下,`layout detection cost` 和 `mfr time` 应提速10倍以上。
 
 ## 10. 为ocr开启cuda加速
-> ❗️以下操作需显卡显存大于等于16G才可进行,否则会因为显存不足导致程序崩溃或运行速度下降
 
 **1.下载paddlepaddle-gpu, 安装完成后会自动开启ocr加速**
+
 ```bash
 python -m pip install paddlepaddle-gpu==3.0.0b1 -i https://www.paddlepaddle.org.cn/packages/stable/cu118/
 ```
+
 **2.运行以下命令测试ocr加速效果**
+
 ```bash
 magic-pdf -p small_ocr.pdf
 ```
+
 > 提示:CUDA加速是否生效可以根据log中输出的各个阶段cost耗时来简单判断,通常情况下,`ocr cost`应提速10倍以上。

+ 78 - 56
docs/README_Windows_CUDA_Acceleration_en_US.md

@@ -1,79 +1,101 @@
 # Windows 10/11
 
 ### 1. Install CUDA and cuDNN
+
 Required versions: CUDA 11.8 + cuDNN 8.7.0
-   - CUDA 11.8: https://developer.nvidia.com/cuda-11-8-0-download-archive
-   - cuDNN v8.7.0 (November 28th, 2022), for CUDA 11.x: https://developer.nvidia.com/rdp/cudnn-archive
-   
+
+- CUDA 11.8: https://developer.nvidia.com/cuda-11-8-0-download-archive
+- cuDNN v8.7.0 (November 28th, 2022), for CUDA 11.x: https://developer.nvidia.com/rdp/cudnn-archive
+
 ### 2. Install Anaconda
-   If Anaconda is already installed, you can skip this step.
-   
+
+If Anaconda is already installed, you can skip this step.
+
 Download link: https://repo.anaconda.com/archive/Anaconda3-2024.06-1-Windows-x86_64.exe
 
 ### 3. Create an Environment Using Conda
-   Python version must be 3.10.
-   ```
-   conda create -n MinerU python=3.10
-   conda activate MinerU
-   ```
+
+Python version must be 3.10.
+
+```
+conda create -n MinerU python=3.10
+conda activate MinerU
+```
 
 ### 4. Install Applications
-   ```
-   pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com
-   ```
-   >❗️After installation, verify the version of `magic-pdf`:
-   >  ```bash
-   >  magic-pdf --version
-   >  ```
-   > If the version number is less than 0.7.0, please report it in the issues section.
-   
+
+```
+pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com
+```
+
+> ❗️After installation, verify the version of `magic-pdf`:
+>
+> ```bash
+> magic-pdf --version
+> ```
+>
+> If the version number is less than 0.7.0, please report it in the issues section.
+
 ### 5. Download Models
-   Refer to detailed instructions on [how to download model files](how_to_download_models_en.md).
+
+Refer to detailed instructions on [how to download model files](how_to_download_models_en.md).
 
 ### 6. Understand the Location of the Configuration File
 
 After completing the [5. Download Models](#5-download-models) step, the script will automatically generate a `magic-pdf.json` file in the user directory and configure the default model path.
 You can find the `magic-pdf.json` file in your 【user directory】 .
+
 > The user directory for Windows is "C:/Users/username".
 
 ### 7. First Run
-   Download a sample file from the repository and test it.
-   ```powershell
-     (New-Object System.Net.WebClient).DownloadFile('https://github.com/opendatalab/MinerU/raw/master/demo/small_ocr.pdf', 'small_ocr.pdf')
-     magic-pdf -p small_ocr.pdf
-   ```
+
+Download a sample file from the repository and test it.
+
+```powershell
+  wget https://github.com/opendatalab/MinerU/raw/master/demo/small_ocr.pdf -O small_ocr.pdf
+  magic-pdf -p small_ocr.pdf
+```
 
 ### 8. Test CUDA Acceleration
-   If your graphics card has at least 8GB of VRAM, follow these steps to test CUDA-accelerated parsing performance.
-   1. **Overwrite the installation of torch and torchvision** supporting CUDA.
-      ```
-      pip install --force-reinstall torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu118
-      ```
-      >❗️Ensure the following versions are specified in the command:
-      >```
-      > torch==2.3.1 torchvision==0.18.1
-      >```
-      >These are the highest versions we support. Installing higher versions without specifying them will cause the program to fail.
-   2. **Modify the value of `"device-mode"`** in the `magic-pdf.json` configuration file located in your user directory.
-     
-      ```json
-      {
-        "device-mode": "cuda"
-      }
-      ```
-   3. **Run the following command to test CUDA acceleration**:
-
-      ```
-      magic-pdf -p small_ocr.pdf
-      ```
+
+If your graphics card has at least 8GB of VRAM, follow these steps to test CUDA-accelerated parsing performance.
+
+1. **Overwrite the installation of torch and torchvision** supporting CUDA.
+
+   ```
+   pip install --force-reinstall torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu118
+   ```
+
+   > ❗️Ensure the following versions are specified in the command:
+   >
+   > ```
+   > torch==2.3.1 torchvision==0.18.1
+   > ```
+   >
+   > These are the highest versions we support. Installing higher versions without specifying them will cause the program to fail.
+
+2. **Modify the value of `"device-mode"`** in the `magic-pdf.json` configuration file located in your user directory.
+
+   ```json
+   {
+     "device-mode": "cuda"
+   }
+   ```
+
+
+3. **Run the following command to test CUDA acceleration**:
+
+   ```
+   magic-pdf -p small_ocr.pdf
+   ```
 
 ### 9. Enable CUDA Acceleration for OCR
-   >❗️This operation requires at least 16GB of VRAM on your graphics card, otherwise it will cause the program to crash or slow down.
-   1. **Download paddlepaddle-gpu**, which will automatically enable OCR acceleration upon installation.
-      ```
-      pip install paddlepaddle-gpu==2.6.1
-      ```
-   2. **Run the following command to test OCR acceleration**:
-      ```
-      magic-pdf -p small_ocr.pdf
-      ```
+
+1. **Download paddlepaddle-gpu**, which will automatically enable OCR acceleration upon installation.
+   ```
+   pip install paddlepaddle-gpu==2.6.1
+   ```
+2. **Run the following command to test OCR acceleration**:
+   ```
+   magic-pdf -p small_ocr.pdf
+   ```

+ 35 - 10
docs/README_Windows_CUDA_Acceleration_zh_CN.md

@@ -3,82 +3,107 @@
 ## 1. 安装cuda和cuDNN
 
 需要安装的版本 CUDA 11.8 + cuDNN 8.7.0
+
 - CUDA 11.8 https://developer.nvidia.com/cuda-11-8-0-download-archive
 - cuDNN v8.7.0 (November 28th, 2022), for CUDA 11.x https://developer.nvidia.com/rdp/cudnn-archive
 
 ## 2. 安装anaconda
+
 如果已安装conda,可以跳过本步骤
 
 下载链接:
 https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/Anaconda3-2024.06-1-Windows-x86_64.exe
 
 ## 3. 使用conda 创建环境
+
 需指定python版本为3.10
+
 ```bash
 conda create -n MinerU python=3.10
 conda activate MinerU
 ```
+
 ## 4. 安装应用
+
 ```bash
-pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com -i https://pypi.tuna.tsinghua.edu.cn/simple
+pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com -i https://mirrors.aliyun.com/pypi/simple
 ```
+
 > ❗️下载完成后,务必通过以下命令确认magic-pdf的版本是否正确
-> 
+>
 > ```bash
 > magic-pdf --version
->```
+> ```
+>
 > 如果版本号小于0.7.0,请到issue中向我们反馈
 
 ## 5. 下载模型
+
 详细参考 [如何下载模型文件](how_to_download_models_zh_cn.md)
 
 ## 6. 了解配置文件存放的位置
+
 完成[5.下载模型](#5-下载模型)步骤后,脚本会自动生成用户目录下的magic-pdf.json文件,并自动配置默认模型路径。
 您可在【用户目录】下找到magic-pdf.json文件。
+
+
 > windows用户目录为 "C:/Users/用户名"
 
 ## 7. 第一次运行
+
 从仓库中下载样本文件,并测试
+
 ```powershell
-(New-Object System.Net.WebClient).DownloadFile('https://gitee.com/myhloli/MinerU/raw/master/demo/small_ocr.pdf', 'small_ocr.pdf')
-magic-pdf -p small_ocr.pdf
+ wget https://github.com/opendatalab/MinerU/raw/master/demo/small_ocr.pdf -O small_ocr.pdf
+ magic-pdf -p small_ocr.pdf
 ```
 
 ## 8. 测试CUDA加速
-如果您的显卡显存大于等于8G,可以进行以下流程,测试CUDA解析加速效果
+
+如果您的显卡显存大于等于 **8GB** ,可以进行以下流程,测试CUDA解析加速效果
 
 **1.覆盖安装支持cuda的torch和torchvision**
+
 ```bash
 pip install --force-reinstall torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu118
 ```
+
 > ❗️务必在命令中指定以下版本
+>
 > ```bash
-> torch==2.3.1 torchvision==0.18.1 
+> torch==2.3.1 torchvision==0.18.1
 > ```
+>
 > 这是我们支持的最高版本,如果不指定版本会自动安装更高版本导致程序无法运行
 
 **2.修改【用户目录】中配置文件magic-pdf.json中"device-mode"的值**
+
 ```json
 {
   "device-mode":"cuda"
 }
 ```
+
 **3.运行以下命令测试cuda加速效果**
+
 ```bash
 magic-pdf -p small_ocr.pdf
 ```
-> 提示:CUDA加速是否生效可以根据log中输出的各个阶段cost耗时来简单判断,通常情况下,`layout detection cost` 和 `mfr time` 应提速10倍以上。
+
+> 提示:CUDA加速是否生效可以根据log中输出的各个阶段的耗时来简单判断,通常情况下,`layout detection time` 和 `mfr time` 应提速10倍以上。
 
 ## 9. 为ocr开启cuda加速
-> ❗️以下操作需显卡显存大于等于16G才可进行,否则会因为显存不足导致程序崩溃或运行速度下降
 
 **1.下载paddlepaddle-gpu, 安装完成后会自动开启ocr加速**
+
 ```bash
 pip install paddlepaddle-gpu==2.6.1
 ```
+
 **2.运行以下命令测试ocr加速效果**
+
 ```bash
 magic-pdf -p small_ocr.pdf
 ```
-> 提示:CUDA加速是否生效可以根据log中输出的各个阶段cost耗时来简单判断,通常情况下,`ocr cost`应提速10倍以上。
 
+> 提示:CUDA加速是否生效可以根据log中输出的各个阶段cost耗时来简单判断,通常情况下,`ocr time`应提速10倍以上。

+ 34 - 14
docs/download_models.py

@@ -1,19 +1,27 @@
+
+import json
 import os
+
 import requests
-import json
 from modelscope import snapshot_download
 
 
+def download_json(url):
+    # 下载JSON文件
+    response = requests.get(url)
+    response.raise_for_status()  # 检查请求是否成功
+    return response.json()
+
+
 def download_and_modify_json(url, local_filename, modifications):
     if os.path.exists(local_filename):
         data = json.load(open(local_filename))
+        config_version = data.get('config_version', '0.0.0')
+        if config_version < '1.0.0':
+            data = download_json(url)
     else:
-        # 下载JSON文件
-        response = requests.get(url)
-        response.raise_for_status()  # 检查请求是否成功
+        data = download_json(url)
 
-        # 解析JSON内容
-        data = response.json()
 
     # 修改内容
     for key, value in modifications.items():
@@ -25,15 +33,25 @@ def download_and_modify_json(url, local_filename, modifications):
 
 
 if __name__ == '__main__':
-    model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
+
+    mineru_patterns = [
+        "models/Layout/LayoutLMv3/*",
+        "models/Layout/YOLO/*",
+        "models/MFD/YOLO/*",
+        "models/MFR/unimernet_small/*",
+        "models/TabRec/TableMaster/*",
+        "models/TabRec/StructEqTable/*",
+    ]
+    model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns)
     layoutreader_model_dir = snapshot_download('ppaanngggg/layoutreader')
-    model_dir = model_dir + "/models"
-    print(f"model_dir is: {model_dir}")
-    print(f"layoutreader_model_dir is: {layoutreader_model_dir}")
+    model_dir = model_dir + '/models'
+    print(f'model_dir is: {model_dir}')
+    print(f'layoutreader_model_dir is: {layoutreader_model_dir}')
+
+    json_url = 'https://gitee.com/myhloli/MinerU/raw/dev/magic-pdf.template.json'
+    config_file_name = 'magic-pdf.json'
+    home_dir = os.path.expanduser('~')
 
-    json_url = 'https://gitee.com/myhloli/MinerU/raw/master/magic-pdf.template.json'
-    config_file_name = "magic-pdf.json"
-    home_dir = os.path.expanduser("~")
     config_file = os.path.join(home_dir, config_file_name)
 
     json_mods = {
@@ -42,4 +60,6 @@ if __name__ == '__main__':
     }
 
     download_and_modify_json(json_url, config_file, json_mods)
-    print(f"The configuration file has been configured successfully, the path is: {config_file}")
+
+    print(f'The configuration file has been configured successfully, the path is: {config_file}')
+

+ 41 - 16
docs/download_models_hf.py

@@ -1,19 +1,26 @@
+import json
 import os
+
 import requests
-import json
 from huggingface_hub import snapshot_download
 
 
+def download_json(url):
+    # 下载JSON文件
+    response = requests.get(url)
+    response.raise_for_status()  # 检查请求是否成功
+    return response.json()
+
+
 def download_and_modify_json(url, local_filename, modifications):
     if os.path.exists(local_filename):
         data = json.load(open(local_filename))
+        config_version = data.get('config_version', '0.0.0')
+        if config_version < '1.0.0':
+            data = download_json(url)
     else:
-        # 下载JSON文件
-        response = requests.get(url)
-        response.raise_for_status()  # 检查请求是否成功
+        data = download_json(url)
 
-        # 解析JSON内容
-        data = response.json()
 
     # 修改内容
     for key, value in modifications.items():
@@ -25,15 +32,31 @@ def download_and_modify_json(url, local_filename, modifications):
 
 
 if __name__ == '__main__':
-    model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
-    layoutreader_model_dir = snapshot_download('hantian/layoutreader')
-    model_dir = model_dir + "/models"
-    print(f"model_dir is: {model_dir}")
-    print(f"layoutreader_model_dir is: {layoutreader_model_dir}")
-
-    json_url = 'https://github.com/opendatalab/MinerU/raw/master/magic-pdf.template.json'
-    config_file_name = "magic-pdf.json"
-    home_dir = os.path.expanduser("~")
+
+    mineru_patterns = [
+        "models/Layout/LayoutLMv3/*",
+        "models/Layout/YOLO/*",
+        "models/MFD/YOLO/*",
+        "models/MFR/unimernet_small/*",
+        "models/TabRec/TableMaster/*",
+        "models/TabRec/StructEqTable/*",
+    ]
+    model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns)
+
+    layoutreader_pattern = [
+        "*.json",
+        "*.safetensors",
+    ]
+    layoutreader_model_dir = snapshot_download('hantian/layoutreader', allow_patterns=layoutreader_pattern)
+
+    model_dir = model_dir + '/models'
+    print(f'model_dir is: {model_dir}')
+    print(f'layoutreader_model_dir is: {layoutreader_model_dir}')
+
+    json_url = 'https://github.com/opendatalab/MinerU/raw/dev/magic-pdf.template.json'
+    config_file_name = 'magic-pdf.json'
+    home_dir = os.path.expanduser('~')
+
     config_file = os.path.join(home_dir, config_file_name)
 
     json_mods = {
@@ -42,4 +65,6 @@ if __name__ == '__main__':
     }
 
     download_and_modify_json(json_url, config_file, json_mods)
-    print(f"The configuration file has been configured successfully, the path is: {config_file}")
+
+    print(f'The configuration file has been configured successfully, the path is: {config_file}')
+

+ 8 - 4
docs/how_to_download_models_en.md

@@ -3,7 +3,8 @@ Model downloads are divided into initial downloads and updates to the model dire
 
 # Initial download of model files
 
-### 1. Download the Model from Hugging Face
+### Download the Model from Hugging Face
+
 Use a Python Script to Download Model Files from Hugging Face
 ```bash
 pip install huggingface_hub
@@ -14,14 +15,17 @@ The Python script will automatically download the model files and configure the
 
 The configuration file can be found in the user directory, with the filename `magic-pdf.json`.
 
+
 # How to update models previously downloaded
 
 ## 1. Models downloaded via Git LFS
 
->Due to feedback from some users that downloading model files using git lfs was incomplete or resulted in corrupted model files, this method is no longer recommended.
+> Due to feedback from some users that downloading model files using git lfs was incomplete or resulted in corrupted model files, this method is no longer recommended.
+
+When magic-pdf <= 0.8.1, if you have previously downloaded the model files via git lfs, you can navigate to the previous download directory and update the models using the `git pull` command.
 
-If you previously downloaded model files via git lfs, you can navigate to the previous download directory and use the `git pull` command to update the model.
+> For versions 0.9.x and later, due to the repository change and the addition of the layout sorting model in PDF-Extract-Kit 1.0, the models cannot be updated using the `git pull` command. Instead, a Python script must be used for one-click updates.
 
 ## 2. Models downloaded via Hugging Face or Model Scope
 
-If you previously downloaded models via Hugging Face or Model Scope, you can rerun the Python script used for the initial download. This will automatically update the model directory to the latest version.
+If you previously downloaded models via Hugging Face or Model Scope, you can rerun the Python script used for the initial download. This will automatically update the model directory to the latest version.

+ 7 - 10
docs/how_to_download_models_zh_cn.md

@@ -10,7 +10,7 @@
   <pre><code>pip install huggingface_hub
 wget https://gitee.com/myhloli/MinerU/raw/master/docs/download_models_hf.py -O download_models_hf.py
 python download_models_hf.py</code></pre>
-  <p>python脚本执行完毕后,会输出模型下载目录</p>
+  <p>python脚本会自动下载模型文件并配置好配置文件中的模型目录</p>
 </details>
 
 ## 方法二:从 ModelScope 下载模型
@@ -25,6 +25,7 @@ python download_models.py
 python脚本会自动下载模型文件并配置好配置文件中的模型目录
 
 配置文件可以在用户目录中找到,文件名为`magic-pdf.json`
+
 > windows的用户目录为 "C:\\Users\\用户名", linux用户目录为 "/home/用户名", macOS用户目录为 "/Users/用户名"
 
 
@@ -32,17 +33,13 @@ python脚本会自动下载模型文件并配置好配置文件中的模型目
 
 ## 1. 通过git lfs下载过模型
 
->由于部分用户反馈通过git lfs下载模型文件遇到下载不全和模型文件损坏情况,现已不推荐使用该方式下载。
+> 由于部分用户反馈通过git lfs下载模型文件遇到下载不全和模型文件损坏情况,现已不推荐使用该方式下载。
+
+当magic-pdf <= 0.8.1时,如此前通过 git lfs 下载过模型文件,可以进入到之前的下载目录中,通过`git pull`命令更新模型。
 
-如此前通过 git lfs 下载过模型文件,可以进入到之前的下载目录中,通过`git pull`命令更新模型。
+> 0.9.x及以后版本由于PDF-Extract-Kit 1.0更换仓库和新增layout排序模型,不能通过`git pull`命令更新,需要使用python脚本一键更新
 
-> 0.9.x及以后版本由于新增layout排序模型,且该模型和此前的模型不在同一仓库,不能通过`git pull`命令更新,需要单独下载。
-> 
->``` 
->from modelscope import snapshot_download
->snapshot_download('ppaanngggg/layoutreader')
->```
 
 ## 2. 通过 Hugging Face 或 Model Scope 下载过模型
 
-如此前通过 HuggingFace 或 Model Scope 下载过模型,可以重复执行此前的模型下载python脚本,将会自动将模型目录更新到最新版本。
+如此前通过 HuggingFace 或 Model Scope 下载过模型,可以重复执行此前的模型下载python脚本,将会自动将模型目录更新到最新版本。

二進制
docs/images/web_demo_1.png


+ 3 - 0
docs/output_file_en_us.md

@@ -175,11 +175,14 @@ Detailed explanation of second-level block types
 | :----------------- | :--------------------- |
 | image_body         | Main body of the image |
 | image_caption      | Image description text |
+| image_footnote     | Image footnote         |
 | table_body         | Main body of the table |
 | table_caption      | Table description text |
 | table_footnote     | Table footnote         |
 | text               | Text block             |
 | title              | Title block            |
+| index              | Index block            |
+| list               | List block             |
 | interline_equation | Block formula          |
 
 <br>

+ 8 - 5
docs/output_file_zh_cn.md

@@ -174,12 +174,15 @@ poly 坐标的格式 \[x0, y0, x1, y1, x2, y2, x3, y3\], 分别表示左上、
 | :----------------- | :------------- |
 | image_body         | 图像的本体     |
 | image_caption      | 图像的描述文本 |
-| table_body         | 表格本体       |
+| image_footnote     | 图像的脚注   |
+| table_body         | 表格本体    |
 | table_caption      | 表格的描述文本 |
-| table_footnote     | 表格的脚注     |
-| text               | 文本块         |
-| title              | 标题块         |
-| interline_equation | 行间公式块     |
+| table_footnote     | 表格的脚注   |
+| text               | 文本块     |
+| title              | 标题块     |
+| index              | 目录块     |
+| list               | 列表块     |
+| interline_equation | 行间公式块   |
 
 <br>
 

+ 12 - 3
magic-pdf.template.json

@@ -6,9 +6,18 @@
     "models-dir":"/tmp/models",
     "layoutreader-model-dir":"/tmp/layoutreader",
     "device-mode":"cpu",
+    "layout-config": {
+        "model": "layoutlmv3"
+    },
+    "formula-config": {
+        "mfd_model": "yolo_v8_mfd",
+        "mfr_model": "unimernet_small",
+        "enable": true
+    },
     "table-config": {
-        "model": "TableMaster",
-        "is_table_recog_enable": false,
+        "model": "tablemaster",
+        "enable": false,
         "max_time": 400
-    }
+    },
+    "config_version": "1.0.0"
 }

+ 0 - 0
magic_pdf/config/__init__.py


+ 7 - 0
magic_pdf/config/enums.py

@@ -0,0 +1,7 @@
+
+import enum
+
+
+class SupportedPdfParseMethod(enum.Enum):
+    OCR = 'ocr'
+    TXT = 'txt'

+ 32 - 0
magic_pdf/config/exceptions.py

@@ -0,0 +1,32 @@
+
+class FileNotExisted(Exception):
+
+    def __init__(self, path):
+        self.path = path
+
+    def __str__(self):
+        return f'File {self.path} does not exist.'
+
+
+class InvalidConfig(Exception):
+    def __init__(self, msg):
+        self.msg = msg
+
+    def __str__(self):
+        return f'Invalid config: {self.msg}'
+
+
+class InvalidParams(Exception):
+    def __init__(self, msg):
+        self.msg = msg
+
+    def __str__(self):
+        return f'Invalid params: {self.msg}'
+
+
+class EmptyData(Exception):
+    def __init__(self, msg):
+        self.msg = msg
+
+    def __str__(self):
+        return f'Empty data: {self.msg}'

+ 0 - 0
magic_pdf/data/__init__.py


+ 12 - 0
magic_pdf/data/data_reader_writer/__init__.py

@@ -0,0 +1,12 @@
+from magic_pdf.data.data_reader_writer.filebase import \
+    FileBasedDataReader  # noqa: F401
+from magic_pdf.data.data_reader_writer.filebase import \
+    FileBasedDataWriter  # noqa: F401
+from magic_pdf.data.data_reader_writer.multi_bucket_s3 import \
+    MultiBucketS3DataReader  # noqa: F401
+from magic_pdf.data.data_reader_writer.multi_bucket_s3 import \
+    MultiBucketS3DataWriter  # noqa: F401
+from magic_pdf.data.data_reader_writer.s3 import S3DataReader  # noqa: F401
+from magic_pdf.data.data_reader_writer.s3 import S3DataWriter  # noqa: F401
+from magic_pdf.data.data_reader_writer.base import DataReader  # noqa: F401
+from magic_pdf.data.data_reader_writer.base import DataWriter  # noqa: F401

+ 51 - 0
magic_pdf/data/data_reader_writer/base.py

@@ -0,0 +1,51 @@
+
+from abc import ABC, abstractmethod
+
+
+class DataReader(ABC):
+
+    def read(self, path: str) -> bytes:
+        """Read the file.
+
+        Args:
+            path (str): file path to read
+
+        Returns:
+            bytes: the content of the file
+        """
+        return self.read_at(path)
+
+    @abstractmethod
+    def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
+        """Read the file at offset and limit.
+
+        Args:
+            path (str): the file path
+            offset (int, optional): the number of bytes skipped. Defaults to 0.
+            limit (int, optional): the length of bytes want to read. Defaults to -1.
+
+        Returns:
+            bytes: the content of the file
+        """
+        pass
+
+
+class DataWriter(ABC):
+    @abstractmethod
+    def write(self, path: str, data: bytes) -> None:
+        """Write the data to the file.
+
+        Args:
+            path (str): the target file where to write
+            data (bytes): the data want to write
+        """
+        pass
+
+    def write_string(self, path: str, data: str) -> None:
+        """Write the data to file, the data will be encoded to bytes.
+
+        Args:
+            path (str): the target file where to write
+            data (str): the data want to write
+        """
+        self.write(path, data.encode())

+ 59 - 0
magic_pdf/data/data_reader_writer/filebase.py

@@ -0,0 +1,59 @@
+import os
+
+from magic_pdf.data.data_reader_writer.base import DataReader, DataWriter
+
+
+class FileBasedDataReader(DataReader):
+    def __init__(self, parent_dir: str = ''):
+        """Initialized with parent_dir.
+
+        Args:
+            parent_dir (str, optional): the parent directory that may be used within methods. Defaults to ''.
+        """
+        self._parent_dir = parent_dir
+
+    def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
+        """Read at offset and limit.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            offset (int, optional): the number of bytes skipped. Defaults to 0.
+            limit (int, optional): the length of bytes want to read. Defaults to -1.
+
+        Returns:
+            bytes: the content of file
+        """
+        fn_path = path
+        if not os.path.isabs(fn_path) and len(self._parent_dir) > 0:
+            fn_path = os.path.join(self._parent_dir, path)
+
+        with open(fn_path, 'rb') as f:
+            f.seek(offset)
+            if limit == -1:
+                return f.read()
+            else:
+                return f.read(limit)
+
+
+class FileBasedDataWriter(DataWriter):
+    def __init__(self, parent_dir: str = '') -> None:
+        """Initialized with parent_dir.
+
+        Args:
+            parent_dir (str, optional): the parent directory that may be used within methods. Defaults to ''.
+        """
+        self._parent_dir = parent_dir
+
+    def write(self, path: str, data: bytes) -> None:
+        """Write file with data.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            data (bytes): the data want to write
+        """
+        fn_path = path
+        if not os.path.isabs(fn_path) and len(self._parent_dir) > 0:
+            fn_path = os.path.join(self._parent_dir, path)
+
+        with open(fn_path, 'wb') as f:
+            f.write(data)

+ 137 - 0
magic_pdf/data/data_reader_writer/multi_bucket_s3.py

@@ -0,0 +1,137 @@
+from magic_pdf.config.exceptions import InvalidConfig, InvalidParams
+from magic_pdf.data.data_reader_writer.base import DataReader, DataWriter
+from magic_pdf.data.io.s3 import S3Reader, S3Writer
+from magic_pdf.data.schemas import S3Config
+from magic_pdf.libs.path_utils import (parse_s3_range_params, parse_s3path,
+                                       remove_non_official_s3_args)
+
+
+class MultiS3Mixin:
+    def __init__(self, default_bucket: str, s3_configs: list[S3Config]):
+        """Initialized with multiple s3 configs.
+
+        Args:
+            default_bucket (str): the default bucket name of the relative path
+            s3_configs (list[S3Config]): list of s3 configs, the bucket_name must be unique in the list.
+
+        Raises:
+            InvalidConfig: default bucket config not in s3_configs
+            InvalidConfig: bucket name not unique in s3_configs
+            InvalidConfig: default bucket must be provided
+        """
+        if len(default_bucket) == 0:
+            raise InvalidConfig('default_bucket must be provided')
+
+        found_default_bucket_config = False
+        for conf in s3_configs:
+            if conf.bucket_name == default_bucket:
+                found_default_bucket_config = True
+                break
+
+        if not found_default_bucket_config:
+            raise InvalidConfig(
+                f'default_bucket: {default_bucket} config must be provided in s3_configs: {s3_configs}'
+            )
+
+        uniq_bucket = set([conf.bucket_name for conf in s3_configs])
+        if len(uniq_bucket) != len(s3_configs):
+            raise InvalidConfig(
+                f'the bucket_name in s3_configs: {s3_configs} must be unique'
+            )
+
+        self.default_bucket = default_bucket
+        self.s3_configs = s3_configs
+        self._s3_clients_h: dict = {}
+
+
+class MultiBucketS3DataReader(DataReader, MultiS3Mixin):
+    def read(self, path: str) -> bytes:
+        """Read the path from s3, select diffect bucket client for each request
+        based on the path, also support range read.
+
+        Args:
+            path (str): the s3 path of file, the path must be in the format of s3://bucket_name/path?offset,limit
+            for example: s3://bucket_name/path?0,100
+
+        Returns:
+            bytes: the content of s3 file
+        """
+        may_range_params = parse_s3_range_params(path)
+        if may_range_params is None or 2 != len(may_range_params):
+            byte_start, byte_len = 0, -1
+        else:
+            byte_start, byte_len = int(may_range_params[0]), int(may_range_params[1])
+        path = remove_non_official_s3_args(path)
+        return self.read_at(path, byte_start, byte_len)
+
+    def __get_s3_client(self, bucket_name: str):
+        if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
+            raise InvalidParams(
+                f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
+            )
+        if bucket_name not in self._s3_clients_h:
+            conf = next(
+                filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
+            )
+            self._s3_clients_h[bucket_name] = S3Reader(
+                bucket_name,
+                conf.access_key,
+                conf.secret_key,
+                conf.endpoint_url,
+                conf.addressing_style,
+            )
+        return self._s3_clients_h[bucket_name]
+
+    def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
+        """Read the file with offset and limit, select diffect bucket client
+        for each request based on the path.
+
+        Args:
+            path (str): the file path
+            offset (int, optional): the number of bytes skipped. Defaults to 0.
+            limit (int, optional): the number of bytes want to read. Defaults to -1 which means infinite.
+
+        Returns:
+            bytes: the file content
+        """
+        if path.startswith('s3://'):
+            bucket_name, path = parse_s3path(path)
+            s3_reader = self.__get_s3_client(bucket_name)
+        else:
+            s3_reader = self.__get_s3_client(self.default_bucket)
+        return s3_reader.read_at(path, offset, limit)
+
+
+class MultiBucketS3DataWriter(DataWriter, MultiS3Mixin):
+    def __get_s3_client(self, bucket_name: str):
+        if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
+            raise InvalidParams(
+                f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
+            )
+        if bucket_name not in self._s3_clients_h:
+            conf = next(
+                filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
+            )
+            self._s3_clients_h[bucket_name] = S3Writer(
+                bucket_name,
+                conf.access_key,
+                conf.secret_key,
+                conf.endpoint_url,
+                conf.addressing_style,
+            )
+        return self._s3_clients_h[bucket_name]
+
+    def write(self, path: str, data: bytes) -> None:
+        """Write file with data, also select diffect bucket client for each
+        request based on the path.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            data (bytes): the data want to write
+        """
+        if path.startswith('s3://'):
+            bucket_name, path = parse_s3path(path)
+            s3_writer = self.__get_s3_client(bucket_name)
+        else:
+            s3_writer = self.__get_s3_client(self.default_bucket)
+        return s3_writer.write(path, data)

+ 69 - 0
magic_pdf/data/data_reader_writer/s3.py

@@ -0,0 +1,69 @@
+from magic_pdf.data.data_reader_writer.multi_bucket_s3 import (
+    MultiBucketS3DataReader, MultiBucketS3DataWriter)
+from magic_pdf.data.schemas import S3Config
+
+
+class S3DataReader(MultiBucketS3DataReader):
+    def __init__(
+        self,
+        bucket: str,
+        ak: str,
+        sk: str,
+        endpoint_url: str,
+        addressing_style: str = 'auto',
+    ):
+        """s3 reader client.
+
+        Args:
+            bucket (str): bucket name
+            ak (str): access key
+            sk (str): secret key
+            endpoint_url (str): endpoint url of s3
+            addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
+            refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
+        """
+        super().__init__(
+            bucket,
+            [
+                S3Config(
+                    bucket_name=bucket,
+                    access_key=ak,
+                    secret_key=sk,
+                    endpoint_url=endpoint_url,
+                    addressing_style=addressing_style,
+                )
+            ],
+        )
+
+
+class S3DataWriter(MultiBucketS3DataWriter):
+    def __init__(
+        self,
+        bucket: str,
+        ak: str,
+        sk: str,
+        endpoint_url: str,
+        addressing_style: str = 'auto',
+    ):
+        """s3 writer client.
+
+        Args:
+            bucket (str): bucket name
+            ak (str): access key
+            sk (str): secret key
+            endpoint_url (str): endpoint url of s3
+            addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
+            refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
+        """
+        super().__init__(
+            bucket,
+            [
+                S3Config(
+                    bucket_name=bucket,
+                    access_key=ak,
+                    secret_key=sk,
+                    endpoint_url=endpoint_url,
+                    addressing_style=addressing_style,
+                )
+            ],
+        )

+ 194 - 0
magic_pdf/data/dataset.py

@@ -0,0 +1,194 @@
+from abc import ABC, abstractmethod
+from typing import Iterator
+
+import fitz
+
+from magic_pdf.config.enums import SupportedPdfParseMethod
+from magic_pdf.data.schemas import PageInfo
+from magic_pdf.data.utils import fitz_doc_to_image
+
+
+class PageableData(ABC):
+    @abstractmethod
+    def get_image(self) -> dict:
+        """Transform data to image."""
+        pass
+
+    @abstractmethod
+    def get_doc(self) -> fitz.Page:
+        """Get the pymudoc page."""
+        pass
+
+    @abstractmethod
+    def get_page_info(self) -> PageInfo:
+        """Get the page info of the page.
+
+        Returns:
+            PageInfo: the page info of this page
+        """
+        pass
+
+
+class Dataset(ABC):
+    @abstractmethod
+    def __len__(self) -> int:
+        """The length of the dataset."""
+        pass
+
+    @abstractmethod
+    def __iter__(self) -> Iterator[PageableData]:
+        """Yield the page data."""
+        pass
+
+    @abstractmethod
+    def supported_methods(self) -> list[SupportedPdfParseMethod]:
+        """The methods that this dataset support.
+
+        Returns:
+            list[SupportedPdfParseMethod]: The supported methods, Valid methods are: OCR, TXT
+        """
+        pass
+
+    @abstractmethod
+    def data_bits(self) -> bytes:
+        """The bits used to create this dataset."""
+        pass
+
+    @abstractmethod
+    def get_page(self, page_id: int) -> PageableData:
+        """Get the page indexed by page_id.
+
+        Args:
+            page_id (int): the index of the page
+
+        Returns:
+            PageableData: the page doc object
+        """
+        pass
+
+
+class PymuDocDataset(Dataset):
+    def __init__(self, bits: bytes):
+        """Initialize the dataset, which wraps the pymudoc documents.
+
+        Args:
+            bits (bytes): the bytes of the pdf
+        """
+        self._records = [Doc(v) for v in fitz.open('pdf', bits)]
+        self._data_bits = bits
+        self._raw_data = bits
+
+    def __len__(self) -> int:
+        """The page number of the pdf."""
+        return len(self._records)
+
+    def __iter__(self) -> Iterator[PageableData]:
+        """Yield the page doc object."""
+        return iter(self._records)
+
+    def supported_methods(self) -> list[SupportedPdfParseMethod]:
+        """The method supported by this dataset.
+
+        Returns:
+            list[SupportedPdfParseMethod]: the supported methods
+        """
+        return [SupportedPdfParseMethod.OCR, SupportedPdfParseMethod.TXT]
+
+    def data_bits(self) -> bytes:
+        """The pdf bits used to create this dataset."""
+        return self._data_bits
+
+    def get_page(self, page_id: int) -> PageableData:
+        """The page doc object.
+
+        Args:
+            page_id (int): the page doc index
+
+        Returns:
+            PageableData: the page doc object
+        """
+        return self._records[page_id]
+
+
+class ImageDataset(Dataset):
+    def __init__(self, bits: bytes):
+        """Initialize the dataset, which wraps the pymudoc documents.
+
+        Args:
+            bits (bytes): the bytes of the photo which will be converted to pdf first. then converted to pymudoc.
+        """
+        pdf_bytes = fitz.open(stream=bits).convert_to_pdf()
+        self._records = [Doc(v) for v in fitz.open('pdf', pdf_bytes)]
+        self._raw_data = bits
+        self._data_bits = pdf_bytes
+
+    def __len__(self) -> int:
+        """The length of the dataset."""
+        return len(self._records)
+
+    def __iter__(self) -> Iterator[PageableData]:
+        """Yield the page object."""
+        return iter(self._records)
+
+    def supported_methods(self):
+        """The method supported by this dataset.
+
+        Returns:
+            list[SupportedPdfParseMethod]: the supported methods
+        """
+        return [SupportedPdfParseMethod.OCR]
+
+    def data_bits(self) -> bytes:
+        """The pdf bits used to create this dataset."""
+        return self._data_bits
+
+    def get_page(self, page_id: int) -> PageableData:
+        """The page doc object.
+
+        Args:
+            page_id (int): the page doc index
+
+        Returns:
+            PageableData: the page doc object
+        """
+        return self._records[page_id]
+
+
+class Doc(PageableData):
+    """Initialized with pymudoc object."""
+    def __init__(self, doc: fitz.Page):
+        self._doc = doc
+
+    def get_image(self):
+        """Return the imge info.
+
+        Returns:
+            dict: {
+                img: np.ndarray,
+                width: int,
+                height: int
+            }
+        """
+        return fitz_doc_to_image(self._doc)
+
+    def get_doc(self) -> fitz.Page:
+        """Get the pymudoc object.
+
+        Returns:
+            fitz.Page: the pymudoc object
+        """
+        return self._doc
+
+    def get_page_info(self) -> PageInfo:
+        """Get the page info of the page.
+
+        Returns:
+            PageInfo: the page info of this page
+        """
+        page_w = self._doc.rect.width
+        page_h = self._doc.rect.height
+        return PageInfo(w=page_w, h=page_h)
+
+    def __getattr__(self, name):
+        if hasattr(self._doc, name):
+            return getattr(self._doc, name)

+ 0 - 0
magic_pdf/data/io/__init__.py


+ 42 - 0
magic_pdf/data/io/base.py

@@ -0,0 +1,42 @@
+from abc import ABC, abstractmethod
+
+
+class IOReader(ABC):
+    @abstractmethod
+    def read(self, path: str) -> bytes:
+        """Read the file.
+
+        Args:
+            path (str): file path to read
+
+        Returns:
+            bytes: the content of the file
+        """
+        pass
+
+    @abstractmethod
+    def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
+        """Read at offset and limit.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            offset (int, optional): the number of bytes skipped. Defaults to 0.
+            limit (int, optional): the length of bytes want to read. Defaults to -1.
+
+        Returns:
+            bytes: the content of file
+        """
+        pass
+
+
+class IOWriter:
+
+    @abstractmethod
+    def write(self, path: str, data: bytes) -> None:
+        """Write file with data.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            data (bytes): the data want to write
+        """
+        pass

+ 37 - 0
magic_pdf/data/io/http.py

@@ -0,0 +1,37 @@
+
+import io
+
+import requests
+
+from magic_pdf.data.io.base import IOReader, IOWriter
+
+
+class HttpReader(IOReader):
+
+    def read(self, url: str) -> bytes:
+        """Read the file.
+
+        Args:
+            path (str): file path to read
+
+        Returns:
+            bytes: the content of the file
+        """
+        return requests.get(url).content
+
+    def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
+        """Not Implemented."""
+        raise NotImplementedError
+
+
+class HttpWriter(IOWriter):
+    def write(self, url: str, data: bytes) -> None:
+        """Write file with data.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            data (bytes): the data want to write
+        """
+        files = {'file': io.BytesIO(data)}
+        response = requests.post(url, files=files)
+        assert 300 > response.status_code and response.status_code > 199

+ 114 - 0
magic_pdf/data/io/s3.py

@@ -0,0 +1,114 @@
+import boto3
+from botocore.config import Config
+
+from magic_pdf.data.io.base import IOReader, IOWriter
+
+
+class S3Reader(IOReader):
+    def __init__(
+        self,
+        bucket: str,
+        ak: str,
+        sk: str,
+        endpoint_url: str,
+        addressing_style: str = 'auto',
+    ):
+        """s3 reader client.
+
+        Args:
+            bucket (str): bucket name
+            ak (str): access key
+            sk (str): secret key
+            endpoint_url (str): endpoint url of s3
+            addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
+            refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
+        """
+        self._bucket = bucket
+        self._ak = ak
+        self._sk = sk
+        self._s3_client = boto3.client(
+            service_name='s3',
+            aws_access_key_id=ak,
+            aws_secret_access_key=sk,
+            endpoint_url=endpoint_url,
+            config=Config(
+                s3={'addressing_style': addressing_style},
+                retries={'max_attempts': 5, 'mode': 'standard'},
+            ),
+        )
+
+    def read(self, key: str) -> bytes:
+        """Read the file.
+
+        Args:
+            path (str): file path to read
+
+        Returns:
+            bytes: the content of the file
+        """
+        return self.read_at(key)
+
+    def read_at(self, key: str, offset: int = 0, limit: int = -1) -> bytes:
+        """Read at offset and limit.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            offset (int, optional): the number of bytes skipped. Defaults to 0.
+            limit (int, optional): the length of bytes want to read. Defaults to -1.
+
+        Returns:
+            bytes: the content of file
+        """
+        if limit > -1:
+            range_header = f'bytes={offset}-{offset+limit-1}'
+            res = self._s3_client.get_object(
+                Bucket=self._bucket, Key=key, Range=range_header
+            )
+        else:
+            res = self._s3_client.get_object(
+                Bucket=self._bucket, Key=key, Range=f'bytes={offset}-'
+            )
+        return res['Body'].read()
+
+
+class S3Writer(IOWriter):
+    def __init__(
+        self,
+        bucket: str,
+        ak: str,
+        sk: str,
+        endpoint_url: str,
+        addressing_style: str = 'auto',
+    ):
+        """s3 reader client.
+
+        Args:
+            bucket (str): bucket name
+            ak (str): access key
+            sk (str): secret key
+            endpoint_url (str): endpoint url of s3
+            addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
+            refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
+        """
+        self._bucket = bucket
+        self._ak = ak
+        self._sk = sk
+        self._s3_client = boto3.client(
+            service_name='s3',
+            aws_access_key_id=ak,
+            aws_secret_access_key=sk,
+            endpoint_url=endpoint_url,
+            config=Config(
+                s3={'addressing_style': addressing_style},
+                retries={'max_attempts': 5, 'mode': 'standard'},
+            ),
+        )
+
+    def write(self, key: str, data: bytes):
+        """Write file with data.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            data (bytes): the data want to write
+        """
+        self._s3_client.put_object(Bucket=self._bucket, Key=key, Body=data)

+ 95 - 0
magic_pdf/data/read_api.py

@@ -0,0 +1,95 @@
+import json
+import os
+from pathlib import Path
+
+from magic_pdf.config.exceptions import EmptyData, InvalidParams
+from magic_pdf.data.data_reader_writer import (FileBasedDataReader,
+                                               MultiBucketS3DataReader)
+from magic_pdf.data.dataset import ImageDataset, PymuDocDataset
+
+
+def read_jsonl(
+    s3_path_or_local: str, s3_client: MultiBucketS3DataReader | None = None
+) -> list[PymuDocDataset]:
+    """Read the jsonl file and return the list of PymuDocDataset.
+
+    Args:
+        s3_path_or_local (str): local file or s3 path
+        s3_client (MultiBucketS3DataReader | None, optional): s3 client that support multiple bucket. Defaults to None.
+
+    Raises:
+        InvalidParams: if s3_path_or_local is s3 path but s3_client is not provided.
+        EmptyData: if no pdf file location is provided in some line of jsonl file.
+        InvalidParams: if the file location is s3 path but s3_client is not provided
+
+    Returns:
+        list[PymuDocDataset]: each line in the jsonl file will be converted to a PymuDocDataset
+    """
+    bits_arr = []
+    if s3_path_or_local.startswith('s3://'):
+        if s3_client is None:
+            raise InvalidParams('s3_client is required when s3_path is provided')
+        jsonl_bits = s3_client.read(s3_path_or_local)
+    else:
+        jsonl_bits = FileBasedDataReader('').read(s3_path_or_local)
+    jsonl_d = [
+        json.loads(line) for line in jsonl_bits.decode().split('\n') if line.strip()
+    ]
+    for d in jsonl_d[:5]:
+        pdf_path = d.get('file_location', '') or d.get('path', '')
+        if len(pdf_path) == 0:
+            raise EmptyData('pdf file location is empty')
+        if pdf_path.startswith('s3://'):
+            if s3_client is None:
+                raise InvalidParams('s3_client is required when s3_path is provided')
+            bits_arr.append(s3_client.read(pdf_path))
+        else:
+            bits_arr.append(FileBasedDataReader('').read(pdf_path))
+    return [PymuDocDataset(bits) for bits in bits_arr]
+
+
+def read_local_pdfs(path: str) -> list[PymuDocDataset]:
+    """Read pdf from path or directory.
+
+    Args:
+        path (str): pdf file path or directory that contains pdf files
+
+    Returns:
+        list[PymuDocDataset]: each pdf file will converted to a PymuDocDataset
+    """
+    if os.path.isdir(path):
+        reader = FileBasedDataReader(path)
+        return [
+            PymuDocDataset(reader.read(doc_path.name))
+            for doc_path in Path(path).glob('*.pdf')
+        ]
+    else:
+        reader = FileBasedDataReader()
+        bits = reader.read(path)
+        return [PymuDocDataset(bits)]
+
+
+def read_local_images(path: str, suffixes: list[str]) -> list[ImageDataset]:
+    """Read images from path or directory.
+
+    Args:
+        path (str): image file path or directory that contains image files
+        suffixes (list[str]): the suffixes of the image files used to filter the files. Example: ['jpg', 'png']
+
+    Returns:
+        list[ImageDataset]: each image file will converted to a ImageDataset
+    """
+    if os.path.isdir(path):
+        imgs_bits = []
+        s_suffixes = set(suffixes)
+        reader = FileBasedDataReader(path)
+        for root, _, files in os.walk(path):
+            for file in files:
+                suffix = file.split('.')
+                if suffix[-1] in s_suffixes:
+                    imgs_bits.append(reader.read(file))
+        return [ImageDataset(bits) for bits in imgs_bits]
+    else:
+        reader = FileBasedDataReader()
+        bits = reader.read(path)
+        return [ImageDataset(bits)]

+ 15 - 0
magic_pdf/data/schemas.py

@@ -0,0 +1,15 @@
+
+from pydantic import BaseModel, Field
+
+
+class S3Config(BaseModel):
+    bucket_name: str = Field(description='s3 bucket name', min_length=1)
+    access_key: str = Field(description='s3 access key', min_length=1)
+    secret_key: str = Field(description='s3 secret key', min_length=1)
+    endpoint_url: str = Field(description='s3 endpoint url', min_length=1)
+    addressing_style: str = Field(description='s3 addressing style', default='auto', min_length=1)
+
+
+class PageInfo(BaseModel):
+    w: float = Field(description='the width of page')
+    h: float = Field(description='the height of page')

+ 32 - 0
magic_pdf/data/utils.py

@@ -0,0 +1,32 @@
+
+import fitz
+import numpy as np
+
+from magic_pdf.utils.annotations import ImportPIL
+
+
+@ImportPIL
+def fitz_doc_to_image(doc, dpi=200) -> dict:
+    """Convert fitz.Document to image, Then convert the image to numpy array.
+
+    Args:
+        doc (_type_): pymudoc page
+        dpi (int, optional): reset the dpi of dpi. Defaults to 200.
+
+    Returns:
+        dict:  {'img': numpy array, 'width': width, 'height': height }
+    """
+    from PIL import Image
+    mat = fitz.Matrix(dpi / 72, dpi / 72)
+    pm = doc.get_pixmap(matrix=mat, alpha=False)
+
+    # If the width or height exceeds 9000 after scaling, do not scale further.
+    if pm.width > 9000 or pm.height > 9000:
+        pm = doc.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
+
+    img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
+    img = np.array(img)
+
+    img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
+
+    return img_dict

+ 74 - 234
magic_pdf/dict2md/ocr_mkcontent.py

@@ -1,6 +1,5 @@
 import re
 
-import wordninja
 from loguru import logger
 
 from magic_pdf.libs.commons import join_path
@@ -8,6 +7,7 @@ from magic_pdf.libs.language import detect_lang
 from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
 from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char
 from magic_pdf.libs.ocr_content_type import BlockType, ContentType
+from magic_pdf.para.para_split_v3 import ListLineTag
 
 
 def __is_hyphen_at_line_end(line):
@@ -24,37 +24,6 @@ def __is_hyphen_at_line_end(line):
     return bool(re.search(r'[A-Za-z]+-\s*$', line))
 
 
-def split_long_words(text):
-    segments = text.split(' ')
-    for i in range(len(segments)):
-        words = re.findall(r'\w+|[^\w]', segments[i], re.UNICODE)
-        for j in range(len(words)):
-            if len(words[j]) > 10:
-                words[j] = ' '.join(wordninja.split(words[j]))
-        segments[i] = ''.join(words)
-    return ' '.join(segments)
-
-
-def ocr_mk_mm_markdown_with_para(pdf_info_list: list, img_buket_path):
-    markdown = []
-    for page_info in pdf_info_list:
-        paras_of_layout = page_info.get('para_blocks')
-        page_markdown = ocr_mk_markdown_with_para_core_v2(
-            paras_of_layout, 'mm', img_buket_path)
-        markdown.extend(page_markdown)
-    return '\n\n'.join(markdown)
-
-
-def ocr_mk_nlp_markdown_with_para(pdf_info_dict: list):
-    markdown = []
-    for page_info in pdf_info_dict:
-        paras_of_layout = page_info.get('para_blocks')
-        page_markdown = ocr_mk_markdown_with_para_core_v2(
-            paras_of_layout, 'nlp')
-        markdown.extend(page_markdown)
-    return '\n\n'.join(markdown)
-
-
 def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: list,
                                                 img_buket_path):
     markdown_with_para_and_pagination = []
@@ -67,61 +36,23 @@ def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: list,
             paras_of_layout, 'mm', img_buket_path)
         markdown_with_para_and_pagination.append({
             'page_no':
-            page_no,
+                page_no,
             'md_content':
-            '\n\n'.join(page_markdown)
+                '\n\n'.join(page_markdown)
         })
         page_no += 1
     return markdown_with_para_and_pagination
 
 
-def ocr_mk_markdown_with_para_core(paras_of_layout, mode, img_buket_path=''):
-    page_markdown = []
-    for paras in paras_of_layout:
-        for para in paras:
-            para_text = ''
-            for line in para:
-                for span in line['spans']:
-                    span_type = span.get('type')
-                    content = ''
-                    language = ''
-                    if span_type == ContentType.Text:
-                        content = span['content']
-                        language = detect_lang(content)
-                        if (language == 'en'):  # 只对英文长词进行分词处理,中文分词会丢失文本
-                            content = ocr_escape_special_markdown_char(
-                                split_long_words(content))
-                        else:
-                            content = ocr_escape_special_markdown_char(content)
-                    elif span_type == ContentType.InlineEquation:
-                        content = f"${span['content']}$"
-                    elif span_type == ContentType.InterlineEquation:
-                        content = f"\n$$\n{span['content']}\n$$\n"
-                    elif span_type in [ContentType.Image, ContentType.Table]:
-                        if mode == 'mm':
-                            content = f"\n![]({join_path(img_buket_path, span['image_path'])})\n"
-                        elif mode == 'nlp':
-                            pass
-                    if content != '':
-                        if language == 'en':  # 英文语境下 content间需要空格分隔
-                            para_text += content + ' '
-                        else:  # 中文语境下,content间不需要空格分隔
-                            para_text += content
-            if para_text.strip() == '':
-                continue
-            else:
-                page_markdown.append(para_text.strip() + '  ')
-    return page_markdown
-
-
 def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
                                       mode,
-                                      img_buket_path=''):
+                                      img_buket_path='',
+                                      ):
     page_markdown = []
     for para_block in paras_of_layout:
         para_text = ''
         para_type = para_block['type']
-        if para_type == BlockType.Text:
+        if para_type in [BlockType.Text, BlockType.List, BlockType.Index]:
             para_text = merge_para_with_text(para_block)
         elif para_type == BlockType.Title:
             para_text = f'# {merge_para_with_text(para_block)}'
@@ -136,20 +67,21 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
                         for line in block['lines']:
                             for span in line['spans']:
                                 if span['type'] == ContentType.Image:
-                                    para_text += f"\n![]({join_path(img_buket_path, span['image_path'])})  \n"
+                                    if span.get('image_path', ''):
+                                        para_text += f"\n![]({join_path(img_buket_path, span['image_path'])})  \n"
                 for block in para_block['blocks']:  # 2nd.拼image_caption
                     if block['type'] == BlockType.ImageCaption:
-                        para_text += merge_para_with_text(block)
-                for block in para_block['blocks']:  # 2nd.拼image_caption
+                        para_text += merge_para_with_text(block) + '  \n'
+                for block in para_block['blocks']:  # 3rd.拼image_footnote
                     if block['type'] == BlockType.ImageFootnote:
-                        para_text += merge_para_with_text(block)
+                        para_text += merge_para_with_text(block) + '  \n'
         elif para_type == BlockType.Table:
             if mode == 'nlp':
                 continue
             elif mode == 'mm':
                 for block in para_block['blocks']:  # 1st.拼table_caption
                     if block['type'] == BlockType.TableCaption:
-                        para_text += merge_para_with_text(block)
+                        para_text += merge_para_with_text(block) + '  \n'
                 for block in para_block['blocks']:  # 2nd.拼table_body
                     if block['type'] == BlockType.TableBody:
                         for line in block['lines']:
@@ -160,11 +92,11 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
                                         para_text += f"\n\n$\n {span['latex']}\n$\n\n"
                                     elif span.get('html', ''):
                                         para_text += f"\n\n{span['html']}\n\n"
-                                    else:
+                                    elif span.get('image_path', ''):
                                         para_text += f"\n![]({join_path(img_buket_path, span['image_path'])})  \n"
                 for block in para_block['blocks']:  # 3rd.拼table_footnote
                     if block['type'] == BlockType.TableFootnote:
-                        para_text += merge_para_with_text(block)
+                        para_text += merge_para_with_text(block) + '  \n'
 
         if para_text.strip() == '':
             continue
@@ -174,22 +106,26 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
     return page_markdown
 
 
-def merge_para_with_text(para_block):
-
-    def detect_language(text):
-        en_pattern = r'[a-zA-Z]+'
-        en_matches = re.findall(en_pattern, text)
-        en_length = sum(len(match) for match in en_matches)
-        if len(text) > 0:
-            if en_length / len(text) >= 0.5:
-                return 'en'
-            else:
-                return 'unknown'
+def detect_language(text):
+    en_pattern = r'[a-zA-Z]+'
+    en_matches = re.findall(en_pattern, text)
+    en_length = sum(len(match) for match in en_matches)
+    if len(text) > 0:
+        if en_length / len(text) >= 0.5:
+            return 'en'
         else:
-            return 'empty'
+            return 'unknown'
+    else:
+        return 'empty'
 
+
+def merge_para_with_text(para_block):
     para_text = ''
-    for line in para_block['lines']:
+    for i, line in enumerate(para_block['lines']):
+
+        if i >= 1 and line.get(ListLineTag.IS_LIST_START_LINE, False):
+            para_text += '  \n'
+
         line_text = ''
         line_lang = ''
         for span in line['spans']:
@@ -199,17 +135,11 @@ def merge_para_with_text(para_block):
         if line_text != '':
             line_lang = detect_lang(line_text)
         for span in line['spans']:
+
             span_type = span['type']
             content = ''
             if span_type == ContentType.Text:
-                content = span['content']
-                # language = detect_lang(content)
-                language = detect_language(content)
-                if language == 'en':  # 只对英文长词进行分词处理,中文分词会丢失文本
-                    content = ocr_escape_special_markdown_char(
-                        split_long_words(content))
-                else:
-                    content = ocr_escape_special_markdown_char(content)
+                content = ocr_escape_special_markdown_char(span['content'])
             elif span_type == ContentType.InlineEquation:
                 content = f" ${span['content']}$ "
             elif span_type == ContentType.InterlineEquation:
@@ -230,177 +160,83 @@ def merge_para_with_text(para_block):
     return para_text
 
 
-def para_to_standard_format(para, img_buket_path):
-    para_content = {}
-    if len(para) == 1:
-        para_content = line_to_standard_format(para[0], img_buket_path)
-    elif len(para) > 1:
-        para_text = ''
-        inline_equation_num = 0
-        for line in para:
-            for span in line['spans']:
-                language = ''
-                span_type = span.get('type')
-                content = ''
-                if span_type == ContentType.Text:
-                    content = span['content']
-                    language = detect_lang(content)
-                    if language == 'en':  # 只对英文长词进行分词处理,中文分词会丢失文本
-                        content = ocr_escape_special_markdown_char(
-                            split_long_words(content))
-                    else:
-                        content = ocr_escape_special_markdown_char(content)
-                elif span_type == ContentType.InlineEquation:
-                    content = f"${span['content']}$"
-                    inline_equation_num += 1
-                if language == 'en':  # 英文语境下 content间需要空格分隔
-                    para_text += content + ' '
-                else:  # 中文语境下,content间不需要空格分隔
-                    para_text += content
-        para_content = {
-            'type': 'text',
-            'text': para_text,
-            'inline_equation_num': inline_equation_num,
-        }
-    return para_content
-
-
-def para_to_standard_format_v2(para_block, img_buket_path, page_idx):
+def para_to_standard_format_v2(para_block, img_buket_path, page_idx, drop_reason=None):
     para_type = para_block['type']
-    if para_type == BlockType.Text:
+    para_content = {}
+    if para_type in [BlockType.Text, BlockType.List, BlockType.Index]:
         para_content = {
             'type': 'text',
             'text': merge_para_with_text(para_block),
-            'page_idx': page_idx,
         }
     elif para_type == BlockType.Title:
         para_content = {
             'type': 'text',
             'text': merge_para_with_text(para_block),
             'text_level': 1,
-            'page_idx': page_idx,
         }
     elif para_type == BlockType.InterlineEquation:
         para_content = {
             'type': 'equation',
             'text': merge_para_with_text(para_block),
             'text_format': 'latex',
-            'page_idx': page_idx,
         }
     elif para_type == BlockType.Image:
-        para_content = {'type': 'image', 'page_idx': page_idx}
+        para_content = {'type': 'image', 'img_path': '', 'img_caption': [], 'img_footnote': []}
         for block in para_block['blocks']:
             if block['type'] == BlockType.ImageBody:
-                para_content['img_path'] = join_path(
-                    img_buket_path,
-                    block['lines'][0]['spans'][0]['image_path'])
+                for line in block['lines']:
+                    for span in line['spans']:
+                        if span['type'] == ContentType.Image:
+                            if span.get('image_path', ''):
+                                para_content['img_path'] = join_path(img_buket_path, span['image_path'])
             if block['type'] == BlockType.ImageCaption:
-                para_content['img_caption'] = merge_para_with_text(block)
+                para_content['img_caption'].append(merge_para_with_text(block))
             if block['type'] == BlockType.ImageFootnote:
-                para_content['img_footnote'] = merge_para_with_text(block)
+                para_content['img_footnote'].append(merge_para_with_text(block))
     elif para_type == BlockType.Table:
-        para_content = {'type': 'table', 'page_idx': page_idx}
+        para_content = {'type': 'table', 'img_path': '', 'table_caption': [], 'table_footnote': []}
         for block in para_block['blocks']:
             if block['type'] == BlockType.TableBody:
-                if block["lines"][0]["spans"][0].get('latex', ''):
-                    para_content['table_body'] = f"\n\n$\n {block['lines'][0]['spans'][0]['latex']}\n$\n\n"
-                elif block["lines"][0]["spans"][0].get('html', ''):
-                    para_content['table_body'] = f"\n\n{block['lines'][0]['spans'][0]['html']}\n\n"
-                para_content['img_path'] = join_path(img_buket_path, block["lines"][0]["spans"][0]['image_path'])
-            if block['type'] == BlockType.TableCaption:
-                para_content['table_caption'] = merge_para_with_text(block)
-            if block['type'] == BlockType.TableFootnote:
-                para_content['table_footnote'] = merge_para_with_text(block)
-
-    return para_content
+                for line in block['lines']:
+                    for span in line['spans']:
+                        if span['type'] == ContentType.Table:
 
+                            if span.get('latex', ''):
+                                para_content['table_body'] = f"\n\n$\n {span['latex']}\n$\n\n"
+                            elif span.get('html', ''):
+                                para_content['table_body'] = f"\n\n{span['html']}\n\n"
 
-def make_standard_format_with_para(pdf_info_dict: list, img_buket_path: str):
-    content_list = []
-    for page_info in pdf_info_dict:
-        paras_of_layout = page_info.get('para_blocks')
-        if not paras_of_layout:
-            continue
-        for para_block in paras_of_layout:
-            para_content = para_to_standard_format_v2(para_block,
-                                                      img_buket_path)
-            content_list.append(para_content)
-    return content_list
+                            if span.get('image_path', ''):
+                                para_content['img_path'] = join_path(img_buket_path, span['image_path'])
 
+            if block['type'] == BlockType.TableCaption:
+                para_content['table_caption'].append(merge_para_with_text(block))
+            if block['type'] == BlockType.TableFootnote:
+                para_content['table_footnote'].append(merge_para_with_text(block))
 
-def line_to_standard_format(line, img_buket_path):
-    line_text = ''
-    inline_equation_num = 0
-    for span in line['spans']:
-        if not span.get('content'):
-            if not span.get('image_path'):
-                continue
-            else:
-                if span['type'] == ContentType.Image:
-                    content = {
-                        'type': 'image',
-                        'img_path': join_path(img_buket_path,
-                                              span['image_path']),
-                    }
-                    return content
-                elif span['type'] == ContentType.Table:
-                    content = {
-                        'type': 'table',
-                        'img_path': join_path(img_buket_path,
-                                              span['image_path']),
-                    }
-                    return content
-        else:
-            if span['type'] == ContentType.InterlineEquation:
-                interline_equation = span['content']
-                content = {
-                    'type': 'equation',
-                    'latex': f'$$\n{interline_equation}\n$$'
-                }
-                return content
-            elif span['type'] == ContentType.InlineEquation:
-                inline_equation = span['content']
-                line_text += f'${inline_equation}$'
-                inline_equation_num += 1
-            elif span['type'] == ContentType.Text:
-                text_content = ocr_escape_special_markdown_char(
-                    span['content'])  # 转义特殊符号
-                line_text += text_content
-    content = {
-        'type': 'text',
-        'text': line_text,
-        'inline_equation_num': inline_equation_num,
-    }
-    return content
+    para_content['page_idx'] = page_idx
 
+    if drop_reason is not None:
+        para_content['drop_reason'] = drop_reason
 
-def ocr_mk_mm_standard_format(pdf_info_dict: list):
-    """content_list type         string
-    image/text/table/equation(行间的单独拿出来,行内的和text合并) latex        string
-    latex文本字段。 text         string      纯文本格式的文本数据。 md           string
-    markdown格式的文本数据。 img_path     string      s3://full/path/to/img.jpg."""
-    content_list = []
-    for page_info in pdf_info_dict:
-        blocks = page_info.get('preproc_blocks')
-        if not blocks:
-            continue
-        for block in blocks:
-            for line in block['lines']:
-                content = line_to_standard_format(line)
-                content_list.append(content)
-    return content_list
+    return para_content
 
 
 def union_make(pdf_info_dict: list,
                make_mode: str,
                drop_mode: str,
-               img_buket_path: str = ''):
+               img_buket_path: str = '',
+               ):
     output_content = []
     for page_info in pdf_info_dict:
+        drop_reason_flag = False
+        drop_reason = None
         if page_info.get('need_drop', False):
             drop_reason = page_info.get('drop_reason')
             if drop_mode == DropMode.NONE:
                 pass
+            elif drop_mode == DropMode.NONE_WITH_REASON:
+                drop_reason_flag = True
             elif drop_mode == DropMode.WHOLE_PDF:
                 raise Exception((f'drop_mode is {DropMode.WHOLE_PDF} ,'
                                  f'drop_reason is {drop_reason}'))
@@ -425,8 +261,12 @@ def union_make(pdf_info_dict: list,
             output_content.extend(page_markdown)
         elif make_mode == MakeMode.STANDARD_FORMAT:
             for para_block in paras_of_layout:
-                para_content = para_to_standard_format_v2(
-                    para_block, img_buket_path, page_idx)
+                if drop_reason_flag:
+                    para_content = para_to_standard_format_v2(
+                        para_block, img_buket_path, page_idx)
+                else:
+                    para_content = para_to_standard_format_v2(
+                        para_block, img_buket_path, page_idx)
                 output_content.append(para_content)
     if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
         return '\n\n'.join(output_content)

+ 21 - 8
magic_pdf/libs/Constants.py

@@ -10,18 +10,12 @@ block维度自定义字段
 # block中lines是否被删除
 LINES_DELETED = "lines_deleted"
 
-# struct eqtable
-STRUCT_EQTABLE = "struct_eqtable"
-
 # table recognition max time default value
 TABLE_MAX_TIME_VALUE = 400
 
 # pp_table_result_max_length
 TABLE_MAX_LEN = 480
 
-# pp table structure algorithm
-TABLE_MASTER = "TableMaster"
-
 # table master structure dict
 TABLE_MASTER_DICT = "table_master_structure_dict.txt"
 
@@ -29,12 +23,31 @@ TABLE_MASTER_DICT = "table_master_structure_dict.txt"
 TABLE_MASTER_DIR = "table_structure_tablemaster_infer/"
 
 # pp detect model dir
-DETECT_MODEL_DIR = "ch_PP-OCRv3_det_infer"
+DETECT_MODEL_DIR = "ch_PP-OCRv4_det_infer"
 
 # pp rec model dir
-REC_MODEL_DIR = "ch_PP-OCRv3_rec_infer"
+REC_MODEL_DIR = "ch_PP-OCRv4_rec_infer"
 
 # pp rec char dict path
 REC_CHAR_DICT = "ppocr_keys_v1.txt"
 
+# pp rec copy rec directory
+PP_REC_DIRECTORY = ".paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer"
+
+# pp rec copy det directory
+PP_DET_DIRECTORY = ".paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer"
+
+
+class MODEL_NAME:
+    # pp table structure algorithm
+    TABLE_MASTER = "tablemaster"
+    # struct eqtable
+    STRUCT_EQTABLE = "struct_eqtable"
+
+    DocLayout_YOLO = "doclayout_yolo"
+
+    LAYOUTLMv3 = "layoutlmv3"
+
+    YOLO_V8_MFD = "yolo_v8_mfd"
 
+    UniMerNet_v2_Small = "unimernet_small"

+ 1 - 0
magic_pdf/libs/MakeContentConfig.py

@@ -8,3 +8,4 @@ class DropMode:
     WHOLE_PDF = "whole_pdf"
     SINGLE_PAGE = "single_page"
     NONE = "none"
+    NONE_WITH_REASON = "none_with_reason"

二進制
magic_pdf/libs/__pycache__/__init__.cpython-312.pyc


二進制
magic_pdf/libs/__pycache__/version.cpython-312.pyc


+ 35 - 0
magic_pdf/libs/boxbase.py

@@ -445,3 +445,38 @@ def get_overlap_area(bbox1, bbox2):
 
     # The area of overlap area
     return (x_right - x_left) * (y_bottom - y_top)
+
+
+def calculate_vertical_projection_overlap_ratio(block1, block2):
+    """
+    Calculate the proportion of the x-axis covered by the vertical projection of two blocks.
+
+    Args:
+        block1 (tuple): Coordinates of the first block (x0, y0, x1, y1).
+        block2 (tuple): Coordinates of the second block (x0, y0, x1, y1).
+
+    Returns:
+        float: The proportion of the x-axis covered by the vertical projection of the two blocks.
+    """
+    x0_1, _, x1_1, _ = block1
+    x0_2, _, x1_2, _ = block2
+
+    # Calculate the intersection of the x-coordinates
+    x_left = max(x0_1, x0_2)
+    x_right = min(x1_1, x1_2)
+
+    if x_right < x_left:
+        return 0.0
+
+    # Length of the intersection
+    intersection_length = x_right - x_left
+
+    # Length of the x-axis projection of the first block
+    block1_length = x1_1 - x0_1
+
+    if block1_length == 0:
+        return 0.0
+
+    # Proportion of the x-axis covered by the intersection
+    # logger.info(f"intersection_length: {intersection_length}, block1_length: {block1_length}")
+    return intersection_length / block1_length

+ 10 - 0
magic_pdf/libs/clean_memory.py

@@ -0,0 +1,10 @@
+# Copyright (c) Opendatalab. All rights reserved.
+import torch
+import gc
+
+
+def clean_memory():
+    if torch.cuda.is_available():
+        torch.cuda.empty_cache()
+        torch.cuda.ipc_collect()
+    gc.collect()

+ 53 - 23
magic_pdf/libs/config_reader.py

@@ -1,46 +1,44 @@
-"""
-根据bucket的名字返回对应的s3 AK, SK,endpoint三元组
-
-"""
+"""根据bucket的名字返回对应的s3 AK, SK,endpoint三元组."""
 
 import json
 import os
 
 from loguru import logger
 
+from magic_pdf.libs.Constants import MODEL_NAME
 from magic_pdf.libs.commons import parse_bucket_key
 
 # 定义配置文件名常量
-CONFIG_FILE_NAME = "magic-pdf.json"
+CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'magic-pdf.json')
 
 
 def read_config():
-    home_dir = os.path.expanduser("~")
-
-    config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
+    if os.path.isabs(CONFIG_FILE_NAME):
+        config_file = CONFIG_FILE_NAME
+    else:
+        home_dir = os.path.expanduser('~')
+        config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
 
     if not os.path.exists(config_file):
-        raise FileNotFoundError(f"{config_file} not found")
+        raise FileNotFoundError(f'{config_file} not found')
 
-    with open(config_file, "r", encoding="utf-8") as f:
+    with open(config_file, 'r', encoding='utf-8') as f:
         config = json.load(f)
     return config
 
 
 def get_s3_config(bucket_name: str):
-    """
-    ~/magic-pdf.json 读出来
-    """
+    """~/magic-pdf.json 读出来."""
     config = read_config()
 
-    bucket_info = config.get("bucket_info")
+    bucket_info = config.get('bucket_info')
     if bucket_name not in bucket_info:
-        access_key, secret_key, storage_endpoint = bucket_info["[default]"]
+        access_key, secret_key, storage_endpoint = bucket_info['[default]']
     else:
         access_key, secret_key, storage_endpoint = bucket_info[bucket_name]
 
     if access_key is None or secret_key is None or storage_endpoint is None:
-        raise Exception(f"ak, sk or endpoint not found in {CONFIG_FILE_NAME}")
+        raise Exception(f'ak, sk or endpoint not found in {CONFIG_FILE_NAME}')
 
     # logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}")
 
@@ -49,7 +47,7 @@ def get_s3_config(bucket_name: str):
 
 def get_s3_config_dict(path: str):
     access_key, secret_key, storage_endpoint = get_s3_config(get_bucket_name(path))
-    return {"ak": access_key, "sk": secret_key, "endpoint": storage_endpoint}
+    return {'ak': access_key, 'sk': secret_key, 'endpoint': storage_endpoint}
 
 
 def get_bucket_name(path):
@@ -59,33 +57,65 @@ def get_bucket_name(path):
 
 def get_local_models_dir():
     config = read_config()
-    models_dir = config.get("models-dir")
+    models_dir = config.get('models-dir')
     if models_dir is None:
         logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use '/tmp/models' as default")
-        return "/tmp/models"
+        return '/tmp/models'
     else:
         return models_dir
 
 
+def get_local_layoutreader_model_dir():
+    config = read_config()
+    layoutreader_model_dir = config.get('layoutreader-model-dir')
+    if layoutreader_model_dir is None or not os.path.exists(layoutreader_model_dir):
+        home_dir = os.path.expanduser('~')
+        layoutreader_at_modelscope_dir_path = os.path.join(home_dir, '.cache/modelscope/hub/ppaanngggg/layoutreader')
+        logger.warning(f"'layoutreader-model-dir' not exists, use {layoutreader_at_modelscope_dir_path} as default")
+        return layoutreader_at_modelscope_dir_path
+    else:
+        return layoutreader_model_dir
+
+
 def get_device():
     config = read_config()
-    device = config.get("device-mode")
+    device = config.get('device-mode')
     if device is None:
         logger.warning(f"'device-mode' not found in {CONFIG_FILE_NAME}, use 'cpu' as default")
-        return "cpu"
+        return 'cpu'
     else:
         return device
 
 
 def get_table_recog_config():
     config = read_config()
-    table_config = config.get("table-config")
+    table_config = config.get('table-config')
     if table_config is None:
         logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
-        return json.loads('{"is_table_recog_enable": false, "max_time": 400}')
+        return json.loads(f'{{"model": "{MODEL_NAME.TABLE_MASTER}","enable": false, "max_time": 400}}')
     else:
         return table_config
 
 
+def get_layout_config():
+    config = read_config()
+    layout_config = config.get("layout-config")
+    if layout_config is None:
+        logger.warning(f"'layout-config' not found in {CONFIG_FILE_NAME}, use '{MODEL_NAME.LAYOUTLMv3}' as default")
+        return json.loads(f'{{"model": "{MODEL_NAME.LAYOUTLMv3}"}}')
+    else:
+        return layout_config
+
+
+def get_formula_config():
+    config = read_config()
+    formula_config = config.get("formula-config")
+    if formula_config is None:
+        logger.warning(f"'formula-config' not found in {CONFIG_FILE_NAME}, use 'True' as default")
+        return json.loads(f'{{"mfd_model": "{MODEL_NAME.YOLO_V8_MFD}","mfr_model": "{MODEL_NAME.UniMerNet_v2_Small}","enable": true}}')
+    else:
+        return formula_config
+
+
 if __name__ == "__main__":
     ak, sk, endpoint = get_s3_config("llm-raw")

+ 150 - 65
magic_pdf/libs/draw_bbox.py

@@ -1,3 +1,4 @@
+from magic_pdf.data.dataset import PymuDocDataset
 from magic_pdf.libs.commons import fitz  # PyMuPDF
 from magic_pdf.libs.Constants import CROSS_PAGE
 from magic_pdf.libs.ocr_content_type import BlockType, CategoryId, ContentType
@@ -33,7 +34,7 @@ def draw_bbox_without_number(i, bbox_list, page, rgb_config, fill_config):
             )  # Draw the rectangle
 
 
-def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config):
+def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config, draw_bbox=True):
     new_rgb = []
     for item in rgb_config:
         item = float(item) / 255
@@ -42,31 +43,31 @@ def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config):
     for j, bbox in enumerate(page_data):
         x0, y0, x1, y1 = bbox
         rect_coords = fitz.Rect(x0, y0, x1, y1)  # Define the rectangle
-        if fill_config:
-            page.draw_rect(
-                rect_coords,
-                color=None,
-                fill=new_rgb,
-                fill_opacity=0.3,
-                width=0.5,
-                overlay=True,
-            )  # Draw the rectangle
-        else:
-            page.draw_rect(
-                rect_coords,
-                color=new_rgb,
-                fill=None,
-                fill_opacity=1,
-                width=0.5,
-                overlay=True,
-            )  # Draw the rectangle
+        if draw_bbox:
+            if fill_config:
+                page.draw_rect(
+                    rect_coords,
+                    color=None,
+                    fill=new_rgb,
+                    fill_opacity=0.3,
+                    width=0.5,
+                    overlay=True,
+                )  # Draw the rectangle
+            else:
+                page.draw_rect(
+                    rect_coords,
+                    color=new_rgb,
+                    fill=None,
+                    fill_opacity=1,
+                    width=0.5,
+                    overlay=True,
+                )  # Draw the rectangle
         page.insert_text(
-            (x0, y0 + 10), str(j + 1), fontsize=10, color=new_rgb
+            (x1 + 2, y0 + 10), str(j + 1), fontsize=10, color=new_rgb
         )  # Insert the index in the top left corner of the rectangle
 
 
 def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
-    layout_bbox_list = []
     dropped_bbox_list = []
     tables_list, tables_body_list = [], []
     tables_caption_list, tables_footnote_list = [], []
@@ -75,17 +76,19 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
     titles_list = []
     texts_list = []
     interequations_list = []
+    lists_list = []
+    indexs_list = []
     for page in pdf_info:
-        page_layout_list = []
+
         page_dropped_list = []
         tables, tables_body, tables_caption, tables_footnote = [], [], [], []
         imgs, imgs_body, imgs_caption, imgs_footnote = [], [], [], []
         titles = []
         texts = []
         interequations = []
-        for layout in page['layout_bboxes']:
-            page_layout_list.append(layout['layout_bbox'])
-        layout_bbox_list.append(page_layout_list)
+        lists = []
+        indices = []
+
         for dropped_bbox in page['discarded_blocks']:
             page_dropped_list.append(dropped_bbox['bbox'])
         dropped_bbox_list.append(page_dropped_list)
@@ -117,6 +120,11 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
                 texts.append(bbox)
             elif block['type'] == BlockType.InterlineEquation:
                 interequations.append(bbox)
+            elif block['type'] == BlockType.List:
+                lists.append(bbox)
+            elif block['type'] == BlockType.Index:
+                indices.append(bbox)
+
         tables_list.append(tables)
         tables_body_list.append(tables_body)
         tables_caption_list.append(tables_caption)
@@ -128,30 +136,62 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
         titles_list.append(titles)
         texts_list.append(texts)
         interequations_list.append(interequations)
+        lists_list.append(lists)
+        indexs_list.append(indices)
+
+    layout_bbox_list = []
+
+    table_type_order = {
+        'table_caption': 1,
+        'table_body': 2,
+        'table_footnote': 3
+    }
+    for page in pdf_info:
+        page_block_list = []
+        for block in page['para_blocks']:
+            if block['type'] in [
+                BlockType.Text,
+                BlockType.Title,
+                BlockType.InterlineEquation,
+                BlockType.List,
+                BlockType.Index,
+            ]:
+                bbox = block['bbox']
+                page_block_list.append(bbox)
+            elif block['type'] in [BlockType.Image]:
+                for sub_block in block['blocks']:
+                    bbox = sub_block['bbox']
+                    page_block_list.append(bbox)
+            elif block['type'] in [BlockType.Table]:
+                sorted_blocks = sorted(block['blocks'], key=lambda x: table_type_order[x['type']])
+                for sub_block in sorted_blocks:
+                    bbox = sub_block['bbox']
+                    page_block_list.append(bbox)
+
+        layout_bbox_list.append(page_block_list)
 
     pdf_docs = fitz.open('pdf', pdf_bytes)
+
     for i, page in enumerate(pdf_docs):
-        draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False)
-        draw_bbox_without_number(i, dropped_bbox_list, page, [158, 158, 158],
-                                 True)
-        draw_bbox_without_number(i, tables_list, page, [153, 153, 0],
-                                 True)  # color !
-        draw_bbox_without_number(i, tables_body_list, page, [204, 204, 0],
-                                 True)
-        draw_bbox_without_number(i, tables_caption_list, page, [255, 255, 102],
-                                 True)
-        draw_bbox_without_number(i, tables_footnote_list, page,
-                                 [229, 255, 204], True)
-        draw_bbox_without_number(i, imgs_list, page, [51, 102, 0], True)
+
+        draw_bbox_without_number(i, dropped_bbox_list, page, [158, 158, 158], True)
+        # draw_bbox_without_number(i, tables_list, page, [153, 153, 0], True)  # color !
+        draw_bbox_without_number(i, tables_body_list, page, [204, 204, 0], True)
+        draw_bbox_without_number(i, tables_caption_list, page, [255, 255, 102], True)
+        draw_bbox_without_number(i, tables_footnote_list, page, [229, 255, 204], True)
+        # draw_bbox_without_number(i, imgs_list, page, [51, 102, 0], True)
         draw_bbox_without_number(i, imgs_body_list, page, [153, 255, 51], True)
-        draw_bbox_without_number(i, imgs_caption_list, page, [102, 178, 255],
-                                 True)
-        draw_bbox_with_number(i, imgs_footnote_list, page, [255, 178, 102],
-                              True),
+        draw_bbox_without_number(i, imgs_caption_list, page, [102, 178, 255], True)
+        draw_bbox_without_number(i, imgs_footnote_list, page, [255, 178, 102], True),
         draw_bbox_without_number(i, titles_list, page, [102, 102, 255], True)
         draw_bbox_without_number(i, texts_list, page, [153, 0, 76], True)
-        draw_bbox_without_number(i, interequations_list, page, [0, 255, 0],
-                                 True)
+        draw_bbox_without_number(i, interequations_list, page, [0, 255, 0], True)
+        draw_bbox_without_number(i, lists_list, page, [40, 169, 92], True)
+        draw_bbox_without_number(i, indexs_list, page, [40, 169, 92], True)
+
+        draw_bbox_with_number(
+            i, layout_bbox_list, page, [255, 0, 0], False, draw_bbox=False
+        )
 
     # Save the PDF
     pdf_docs.save(f'{out_path}/{filename}_layout.pdf')
@@ -209,11 +249,14 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
                         page_dropped_list.append(span['bbox'])
         dropped_list.append(page_dropped_list)
         # 构造其余useful_list
-        for block in page['para_blocks']:
+        # for block in page['para_blocks']:  # span直接用分段合并前的结果就可以
+        for block in page['preproc_blocks']:
             if block['type'] in [
-                    BlockType.Text,
-                    BlockType.Title,
-                    BlockType.InterlineEquation,
+                BlockType.Text,
+                BlockType.Title,
+                BlockType.InterlineEquation,
+                BlockType.List,
+                BlockType.Index,
             ]:
                 for line in block['lines']:
                     for span in line['spans']:
@@ -232,10 +275,8 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
     for i, page in enumerate(pdf_docs):
         # 获取当前页面的数据
         draw_bbox_without_number(i, text_list, page, [255, 0, 0], False)
-        draw_bbox_without_number(i, inline_equation_list, page, [0, 255, 0],
-                                 False)
-        draw_bbox_without_number(i, interline_equation_list, page, [0, 0, 255],
-                                 False)
+        draw_bbox_without_number(i, inline_equation_list, page, [0, 255, 0], False)
+        draw_bbox_without_number(i, interline_equation_list, page, [0, 0, 255], False)
         draw_bbox_without_number(i, image_list, page, [255, 204, 0], False)
         draw_bbox_without_number(i, table_list, page, [204, 0, 255], False)
         draw_bbox_without_number(i, dropped_list, page, [158, 158, 158], False)
@@ -244,7 +285,7 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
     pdf_docs.save(f'{out_path}/{filename}_spans.pdf')
 
 
-def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
+def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
     dropped_bbox_list = []
     tables_body_list, tables_caption_list, tables_footnote_list = [], [], []
     imgs_body_list, imgs_caption_list, imgs_footnote_list = [], [], []
@@ -252,7 +293,7 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
     texts_list = []
     interequations_list = []
     pdf_docs = fitz.open('pdf', pdf_bytes)
-    magic_model = MagicModel(model_list, pdf_docs)
+    magic_model = MagicModel(model_list, PymuDocDataset(pdf_bytes))
     for i in range(len(model_list)):
         page_dropped_list = []
         tables_body, tables_caption, tables_footnote = [], [], []
@@ -278,8 +319,7 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
                 imgs_body.append(bbox)
             elif layout_det['category_id'] == CategoryId.ImageCaption:
                 imgs_caption.append(bbox)
-            elif layout_det[
-                    'category_id'] == CategoryId.InterlineEquation_YOLO:
+            elif layout_det['category_id'] == CategoryId.InterlineEquation_YOLO:
                 interequations.append(bbox)
             elif layout_det['category_id'] == CategoryId.Abandon:
                 page_dropped_list.append(bbox)
@@ -298,21 +338,66 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
         imgs_footnote_list.append(imgs_footnote)
 
     for i, page in enumerate(pdf_docs):
-        draw_bbox_with_number(i, dropped_bbox_list, page, [158, 158, 158],
-                              True)  # color !
+        draw_bbox_with_number(
+            i, dropped_bbox_list, page, [158, 158, 158], True
+        )  # color !
         draw_bbox_with_number(i, tables_body_list, page, [204, 204, 0], True)
-        draw_bbox_with_number(i, tables_caption_list, page, [255, 255, 102],
-                              True)
-        draw_bbox_with_number(i, tables_footnote_list, page, [229, 255, 204],
-                              True)
+        draw_bbox_with_number(i, tables_caption_list, page, [255, 255, 102], True)
+        draw_bbox_with_number(i, tables_footnote_list, page, [229, 255, 204], True)
         draw_bbox_with_number(i, imgs_body_list, page, [153, 255, 51], True)
-        draw_bbox_with_number(i, imgs_caption_list, page, [102, 178, 255],
-                              True)
-        draw_bbox_with_number(i, imgs_footnote_list, page, [255, 178, 102],
-                              True)
+        draw_bbox_with_number(i, imgs_caption_list, page, [102, 178, 255], True)
+        draw_bbox_with_number(i, imgs_footnote_list, page, [255, 178, 102], True)
         draw_bbox_with_number(i, titles_list, page, [102, 102, 255], True)
         draw_bbox_with_number(i, texts_list, page, [153, 0, 76], True)
         draw_bbox_with_number(i, interequations_list, page, [0, 255, 0], True)
 
     # Save the PDF
     pdf_docs.save(f'{out_path}/{filename}_model.pdf')
+
+
+def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
+    layout_bbox_list = []
+
+    for page in pdf_info:
+        page_line_list = []
+        for block in page['preproc_blocks']:
+            if block['type'] in [BlockType.Text, BlockType.Title, BlockType.InterlineEquation]:
+                for line in block['lines']:
+                    bbox = line['bbox']
+                    index = line['index']
+                    page_line_list.append({'index': index, 'bbox': bbox})
+            if block['type'] in [BlockType.Image, BlockType.Table]:
+                for sub_block in block['blocks']:
+                    if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
+                        for line in sub_block['virtual_lines']:
+                            bbox = line['bbox']
+                            index = line['index']
+                            page_line_list.append({'index': index, 'bbox': bbox})
+                    elif sub_block['type'] in [BlockType.ImageCaption, BlockType.TableCaption, BlockType.ImageFootnote, BlockType.TableFootnote]:
+                        for line in sub_block['lines']:
+                            bbox = line['bbox']
+                            index = line['index']
+                            page_line_list.append({'index': index, 'bbox': bbox})
+        sorted_bboxes = sorted(page_line_list, key=lambda x: x['index'])
+        layout_bbox_list.append(sorted_bbox['bbox'] for sorted_bbox in sorted_bboxes)
+    pdf_docs = fitz.open('pdf', pdf_bytes)
+    for i, page in enumerate(pdf_docs):
+        draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False)
+
+    pdf_docs.save(f'{out_path}/{filename}_line_sort.pdf')
+
+
+def draw_layout_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
+    layout_bbox_list = []
+
+    for page in pdf_info:
+        page_block_list = []
+        for block in page['para_blocks']:
+            bbox = block['bbox']
+            page_block_list.append(bbox)
+        layout_bbox_list.append(page_block_list)
+    pdf_docs = fitz.open('pdf', pdf_bytes)
+    for i, page in enumerate(pdf_docs):
+        draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False)
+
+    pdf_docs.save(f'{out_path}/{filename}_layout_sort.pdf')

+ 2 - 0
magic_pdf/libs/ocr_content_type.py

@@ -20,6 +20,8 @@ class BlockType:
     InterlineEquation = 'interline_equation'
     Footnote = 'footnote'
     Discarded = 'discarded'
+    List = 'list'
+    Index = 'index'
 
 
 class CategoryId:

+ 1 - 1
magic_pdf/libs/version.py

@@ -1 +1 @@
-__version__ = "0.8.1"
+__version__ = "0.9.0"

+ 77 - 32
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -4,7 +4,9 @@ import fitz
 import numpy as np
 from loguru import logger
 
-from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config
+from magic_pdf.libs.clean_memory import clean_memory
+from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config, get_layout_config, \
+    get_formula_config
 from magic_pdf.model.model_list import MODEL
 import magic_pdf.model as model_config
 
@@ -23,7 +25,7 @@ def remove_duplicates_dicts(lst):
     return unique_dicts
 
 
-def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
+def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None) -> list:
     try:
         from PIL import Image
     except ImportError:
@@ -32,18 +34,28 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
 
     images = []
     with fitz.open("pdf", pdf_bytes) as doc:
+        pdf_page_num = doc.page_count
+        end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else pdf_page_num - 1
+        if end_page_id > pdf_page_num - 1:
+            logger.warning("end_page_id is out of range, use images length")
+            end_page_id = pdf_page_num - 1
+
         for index in range(0, doc.page_count):
-            page = doc[index]
-            mat = fitz.Matrix(dpi / 72, dpi / 72)
-            pm = page.get_pixmap(matrix=mat, alpha=False)
+            if start_page_id <= index <= end_page_id:
+                page = doc[index]
+                mat = fitz.Matrix(dpi / 72, dpi / 72)
+                pm = page.get_pixmap(matrix=mat, alpha=False)
+
+                # If the width or height exceeds 9000 after scaling, do not scale further.
+                if pm.width > 9000 or pm.height > 9000:
+                    pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
 
-            # If the width or height exceeds 9000 after scaling, do not scale further.
-            if pm.width > 9000 or pm.height > 9000:
-                pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
+                img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples)
+                img = np.array(img)
+                img_dict = {"img": img, "width": pm.width, "height": pm.height}
+            else:
+                img_dict = {"img": [], "width": 0, "height": 0}
 
-            img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples)
-            img = np.array(img)
-            img_dict = {"img": img, "width": pm.width, "height": pm.height}
             images.append(img_dict)
     return images
 
@@ -57,14 +69,17 @@ class ModelSingleton:
             cls._instance = super().__new__(cls)
         return cls._instance
 
-    def get_model(self, ocr: bool, show_log: bool):
-        key = (ocr, show_log)
+    def get_model(self, ocr: bool, show_log: bool, lang=None, layout_model=None, formula_enable=None, table_enable=None):
+        key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
         if key not in self._models:
-            self._models[key] = custom_model_init(ocr=ocr, show_log=show_log)
+            self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang, layout_model=layout_model,
+                                                  formula_enable=formula_enable, table_enable=table_enable)
         return self._models[key]
 
 
-def custom_model_init(ocr: bool = False, show_log: bool = False):
+def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None,
+                      layout_model=None, formula_enable=None, table_enable=None):
+
     model = None
 
     if model_config.__model_mode__ == "lite":
@@ -78,18 +93,36 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
         model_init_start = time.time()
         if model == MODEL.Paddle:
             from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
-            custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log)
+            custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang)
         elif model == MODEL.PEK:
             from magic_pdf.model.pdf_extract_kit import CustomPEKModel
             # 从配置文件读取model-dir和device
             local_models_dir = get_local_models_dir()
             device = get_device()
+
+            layout_config = get_layout_config()
+            if layout_model is not None:
+                layout_config["model"] = layout_model
+
+            formula_config = get_formula_config()
+            if formula_enable is not None:
+                formula_config["enable"] = formula_enable
+
             table_config = get_table_recog_config()
-            model_input = {"ocr": ocr,
-                           "show_log": show_log,
-                           "models_dir": local_models_dir,
-                           "device": device,
-                           "table_config": table_config}
+            if table_enable is not None:
+                table_config["enable"] = table_enable
+
+            model_input = {
+                            "ocr": ocr,
+                            "show_log": show_log,
+                            "models_dir": local_models_dir,
+                            "device": device,
+                            "table_config": table_config,
+                            "layout_config": layout_config,
+                            "formula_config": formula_config,
+                            "lang": lang,
+            }
+
             custom_model = CustomPEKModel(**model_input)
         else:
             logger.error("Not allow model_name!")
@@ -104,19 +137,23 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
 
 
 def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
-                start_page_id=0, end_page_id=None):
+                start_page_id=0, end_page_id=None, lang=None,
+                layout_model=None, formula_enable=None, table_enable=None):
 
-    model_manager = ModelSingleton()
-    custom_model = model_manager.get_model(ocr, show_log)
+    if lang == "":
+        lang = None
 
-    images = load_images_from_pdf(pdf_bytes)
+    model_manager = ModelSingleton()
+    custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable)
 
-    # end_page_id = end_page_id if end_page_id else len(images) - 1
-    end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(images) - 1
+    with fitz.open("pdf", pdf_bytes) as doc:
+        pdf_page_num = doc.page_count
+        end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else pdf_page_num - 1
+        if end_page_id > pdf_page_num - 1:
+            logger.warning("end_page_id is out of range, use images length")
+            end_page_id = pdf_page_num - 1
 
-    if end_page_id > len(images) - 1:
-        logger.warning("end_page_id is out of range, use images length")
-        end_page_id = len(images) - 1
+    images = load_images_from_pdf(pdf_bytes, start_page_id=start_page_id, end_page_id=end_page_id)
 
     model_json = []
     doc_analyze_start = time.time()
@@ -132,7 +169,15 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
         page_info = {"page_no": index, "height": page_height, "width": page_width}
         page_dict = {"layout_dets": result, "page_info": page_info}
         model_json.append(page_dict)
-    doc_analyze_cost = time.time() - doc_analyze_start
-    logger.info(f"doc analyze cost: {doc_analyze_cost}")
+
+    gc_start = time.time()
+    clean_memory()
+    gc_time = round(time.time() - gc_start, 2)
+    logger.info(f"gc time: {gc_time}")
+
+    doc_analyze_time = round(time.time() - doc_analyze_start, 2)
+    doc_analyze_speed = round( (end_page_id + 1 - start_page_id) / doc_analyze_time, 2)
+    logger.info(f"doc analyze time: {round(time.time() - doc_analyze_start, 2)},"
+                f" speed: {doc_analyze_speed} pages/second")
 
     return model_json

+ 331 - 15
magic_pdf/model/magic_model.py

@@ -1,5 +1,7 @@
+import enum
 import json
 
+from magic_pdf.data.dataset import Dataset
 from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
                                     bbox_relative_pos, box_area, calculate_iou,
                                     calculate_overlap_area_in_bbox1_area_ratio,
@@ -9,6 +11,7 @@ from magic_pdf.libs.coordinate_transform import get_scale_ratio
 from magic_pdf.libs.local_math import float_gt
 from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
 from magic_pdf.libs.ocr_content_type import CategoryId, ContentType
+from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox
 from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
 from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
 
@@ -16,6 +19,14 @@ CAPATION_OVERLAP_AREA_RATIO = 0.6
 MERGE_BOX_OVERLAP_AREA_RATIO = 1.1
 
 
+class PosRelationEnum(enum.Enum):
+    LEFT = 'left'
+    RIGHT = 'right'
+    UP = 'up'
+    BOTTOM = 'bottom'
+    ALL = 'all'
+
+
 class MagicModel:
     """每个函数没有得到元素的时候返回空list."""
 
@@ -24,7 +35,7 @@ class MagicModel:
             need_remove_list = []
             page_no = model_page_info['page_info']['page_no']
             horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio(
-                model_page_info, self.__docs[page_no]
+                model_page_info, self.__docs.get_page(page_no)
             )
             layout_dets = model_page_info['layout_dets']
             for layout_det in layout_dets:
@@ -99,7 +110,7 @@ class MagicModel:
             for need_remove in need_remove_list:
                 layout_dets.remove(need_remove)
 
-    def __init__(self, model_list: list, docs: fitz.Document):
+    def __init__(self, model_list: list, docs: Dataset):
         self.__model_list = model_list
         self.__docs = docs
         """为所有模型数据添加bbox信息(缩放,poly->bbox)"""
@@ -110,6 +121,24 @@ class MagicModel:
         self.__fix_by_remove_high_iou_and_low_confidence()
         self.__fix_footnote()
 
+    def _bbox_distance(self, bbox1, bbox2):
+        left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
+        flags = [left, right, bottom, top]
+        count = sum([1 if v else 0 for v in flags])
+        if count > 1:
+            return float('inf')
+        if left or right:
+            l1 = bbox1[3] - bbox1[1]
+            l2 = bbox2[3] - bbox2[1]
+        else:
+            l1 = bbox1[2] - bbox1[0]
+            l2 = bbox2[2] - bbox2[0]
+
+        if l2 > l1 and (l2 - l1) / l1 > 0.3:
+            return float('inf')
+
+        return bbox_distance(bbox1, bbox2)
+
     def __fix_footnote(self):
         # 3: figure, 5: table, 7: footnote
         for model_page_info in self.__model_list:
@@ -144,7 +173,7 @@ class MagicModel:
                     if pos_flag_count > 1:
                         continue
                     dis_figure_footnote[i] = min(
-                        bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
+                        self._bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
                         dis_figure_footnote.get(i, float('inf')),
                     )
             for i in range(len(footnotes)):
@@ -163,7 +192,7 @@ class MagicModel:
                         continue
 
                     dis_table_footnote[i] = min(
-                        bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
+                        self._bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
                         dis_table_footnote.get(i, float('inf')),
                     )
             for i in range(len(footnotes)):
@@ -195,9 +224,8 @@ class MagicModel:
         筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
         再求出筛选出的 subjects 和 object 的最短距离
         """
-        def search_overlap_between_boxes(
-            subject_idx, object_idx
-        ):
+
+        def search_overlap_between_boxes(subject_idx, object_idx):
             idxes = [subject_idx, object_idx]
             x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes]
             y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes]
@@ -225,9 +253,9 @@ class MagicModel:
             for other_object in other_objects:
                 ratio = max(
                     ratio,
-                    get_overlap_area(
-                        merged_bbox, other_object['bbox']
-                    ) * 1.0 / box_area(all_bboxes[object_idx]['bbox'])
+                    get_overlap_area(merged_bbox, other_object['bbox'])
+                    * 1.0
+                    / box_area(all_bboxes[object_idx]['bbox']),
                 )
                 if ratio >= MERGE_BOX_OVERLAP_AREA_RATIO:
                     break
@@ -345,12 +373,17 @@ class MagicModel:
                 if all_bboxes[j]['category_id'] == subject_category_id:
                     subject_idx, object_idx = j, i
 
-                if search_overlap_between_boxes(subject_idx, object_idx) >= MERGE_BOX_OVERLAP_AREA_RATIO:
+                if (
+                    search_overlap_between_boxes(subject_idx, object_idx)
+                    >= MERGE_BOX_OVERLAP_AREA_RATIO
+                ):
                     dis[i][j] = float('inf')
                     dis[j][i] = dis[i][j]
                     continue
 
-                dis[i][j] = bbox_distance(all_bboxes[i]['bbox'], all_bboxes[j]['bbox'])
+                dis[i][j] = self._bbox_distance(
+                    all_bboxes[subject_idx]['bbox'], all_bboxes[object_idx]['bbox']
+                )
                 dis[j][i] = dis[i][j]
 
         used = set()
@@ -566,6 +599,289 @@ class MagicModel:
                 with_caption_subject.add(j)
         return ret, total_subject_object_dis
 
+    def __tie_up_category_by_distance_v2(
+        self,
+        page_no: int,
+        subject_category_id: int,
+        object_category_id: int,
+        priority_pos: PosRelationEnum,
+    ):
+        """_summary_
+
+        Args:
+            page_no (int): _description_
+            subject_category_id (int): _description_
+            object_category_id (int): _description_
+            priority_pos (PosRelationEnum): _description_
+
+        Returns:
+            _type_: _description_
+        """
+        AXIS_MULPLICITY = 0.5
+        subjects = self.__reduct_overlap(
+            list(
+                map(
+                    lambda x: {'bbox': x['bbox'], 'score': x['score']},
+                    filter(
+                        lambda x: x['category_id'] == subject_category_id,
+                        self.__model_list[page_no]['layout_dets'],
+                    ),
+                )
+            )
+        )
+
+        objects = self.__reduct_overlap(
+            list(
+                map(
+                    lambda x: {'bbox': x['bbox'], 'score': x['score']},
+                    filter(
+                        lambda x: x['category_id'] == object_category_id,
+                        self.__model_list[page_no]['layout_dets'],
+                    ),
+                )
+            )
+        )
+        M = len(objects)
+
+        subjects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
+        objects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
+
+        sub_obj_map_h = {i: [] for i in range(len(subjects))}
+
+        dis_by_directions = {
+            'top': [[-1, float('inf')]] * M,
+            'bottom': [[-1, float('inf')]] * M,
+            'left': [[-1, float('inf')]] * M,
+            'right': [[-1, float('inf')]] * M,
+        }
+
+        for i, obj in enumerate(objects):
+            l_x_axis, l_y_axis = (
+                obj['bbox'][2] - obj['bbox'][0],
+                obj['bbox'][3] - obj['bbox'][1],
+            )
+            axis_unit = min(l_x_axis, l_y_axis)
+            for j, sub in enumerate(subjects):
+
+                bbox1, bbox2, _ = _remove_overlap_between_bbox(
+                    objects[i]['bbox'], subjects[j]['bbox']
+                )
+                left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
+                flags = [left, right, bottom, top]
+                if sum([1 if v else 0 for v in flags]) > 1:
+                    continue
+
+                if left:
+                    if dis_by_directions['left'][i][1] > bbox_distance(
+                        obj['bbox'], sub['bbox']
+                    ):
+                        dis_by_directions['left'][i] = [
+                            j,
+                            bbox_distance(obj['bbox'], sub['bbox']),
+                        ]
+                if right:
+                    if dis_by_directions['right'][i][1] > bbox_distance(
+                        obj['bbox'], sub['bbox']
+                    ):
+                        dis_by_directions['right'][i] = [
+                            j,
+                            bbox_distance(obj['bbox'], sub['bbox']),
+                        ]
+                if bottom:
+                    if dis_by_directions['bottom'][i][1] > bbox_distance(
+                        obj['bbox'], sub['bbox']
+                    ):
+                        dis_by_directions['bottom'][i] = [
+                            j,
+                            bbox_distance(obj['bbox'], sub['bbox']),
+                        ]
+                if top:
+                    if dis_by_directions['top'][i][1] > bbox_distance(
+                        obj['bbox'], sub['bbox']
+                    ):
+                        dis_by_directions['top'][i] = [
+                            j,
+                            bbox_distance(obj['bbox'], sub['bbox']),
+                        ]
+
+            if (
+                dis_by_directions['top'][i][1] != float('inf')
+                and dis_by_directions['bottom'][i][1] != float('inf')
+                and priority_pos in (PosRelationEnum.BOTTOM, PosRelationEnum.UP)
+            ):
+                RATIO = 3
+                if (
+                    abs(
+                        dis_by_directions['top'][i][1]
+                        - dis_by_directions['bottom'][i][1]
+                    )
+                    < RATIO * axis_unit
+                ):
+
+                    if priority_pos == PosRelationEnum.BOTTOM:
+                        sub_obj_map_h[dis_by_directions['bottom'][i][0]].append(i)
+                    else:
+                        sub_obj_map_h[dis_by_directions['top'][i][0]].append(i)
+                    continue
+
+            if dis_by_directions['left'][i][1] != float('inf') or dis_by_directions[
+                'right'
+            ][i][1] != float('inf'):
+                if dis_by_directions['left'][i][1] != float(
+                    'inf'
+                ) and dis_by_directions['right'][i][1] != float('inf'):
+                    if AXIS_MULPLICITY * axis_unit >= abs(
+                        dis_by_directions['left'][i][1]
+                        - dis_by_directions['right'][i][1]
+                    ):
+                        left_sub_bbox = subjects[dis_by_directions['left'][i][0]][
+                            'bbox'
+                        ]
+                        right_sub_bbox = subjects[dis_by_directions['right'][i][0]][
+                            'bbox'
+                        ]
+
+                        left_sub_bbox_y_axis = left_sub_bbox[3] - left_sub_bbox[1]
+                        right_sub_bbox_y_axis = right_sub_bbox[3] - right_sub_bbox[1]
+
+                        if (
+                            abs(left_sub_bbox_y_axis - l_y_axis)
+                            + dis_by_directions['left'][i][0]
+                            > abs(right_sub_bbox_y_axis - l_y_axis)
+                            + dis_by_directions['right'][i][0]
+                        ):
+                            left_or_right = dis_by_directions['right'][i]
+                        else:
+                            left_or_right = dis_by_directions['left'][i]
+                    else:
+                        left_or_right = dis_by_directions['left'][i]
+                        if left_or_right[1] > dis_by_directions['right'][i][1]:
+                            left_or_right = dis_by_directions['right'][i]
+                else:
+                    left_or_right = dis_by_directions['left'][i]
+                    if left_or_right[1] == float('inf'):
+                        left_or_right = dis_by_directions['right'][i]
+            else:
+                left_or_right = [-1, float('inf')]
+
+            if dis_by_directions['top'][i][1] != float('inf') or dis_by_directions[
+                'bottom'
+            ][i][1] != float('inf'):
+                if dis_by_directions['top'][i][1] != float('inf') and dis_by_directions[
+                    'bottom'
+                ][i][1] != float('inf'):
+                    if AXIS_MULPLICITY * axis_unit >= abs(
+                        dis_by_directions['top'][i][1]
+                        - dis_by_directions['bottom'][i][1]
+                    ):
+                        top_bottom = subjects[dis_by_directions['bottom'][i][0]]['bbox']
+                        bottom_top = subjects[dis_by_directions['top'][i][0]]['bbox']
+
+                        top_bottom_x_axis = top_bottom[2] - top_bottom[0]
+                        bottom_top_x_axis = bottom_top[2] - bottom_top[0]
+                        if (
+                            abs(top_bottom_x_axis - l_x_axis)
+                            + dis_by_directions['bottom'][i][1]
+                            > abs(bottom_top_x_axis - l_x_axis)
+                            + dis_by_directions['top'][i][1]
+                        ):
+                            top_or_bottom = dis_by_directions['top'][i]
+                        else:
+                            top_or_bottom = dis_by_directions['bottom'][i]
+                    else:
+                        top_or_bottom = dis_by_directions['top'][i]
+                        if top_or_bottom[1] > dis_by_directions['bottom'][i][1]:
+                            top_or_bottom = dis_by_directions['bottom'][i]
+                else:
+                    top_or_bottom = dis_by_directions['top'][i]
+                    if top_or_bottom[1] == float('inf'):
+                        top_or_bottom = dis_by_directions['bottom'][i]
+            else:
+                top_or_bottom = [-1, float('inf')]
+
+            if left_or_right[1] != float('inf') or top_or_bottom[1] != float('inf'):
+                if left_or_right[1] != float('inf') and top_or_bottom[1] != float(
+                    'inf'
+                ):
+                    if AXIS_MULPLICITY * axis_unit >= abs(
+                        left_or_right[1] - top_or_bottom[1]
+                    ):
+                        y_axis_bbox = subjects[left_or_right[0]]['bbox']
+                        x_axis_bbox = subjects[top_or_bottom[0]]['bbox']
+
+                        if (
+                            abs((x_axis_bbox[2] - x_axis_bbox[0]) - l_x_axis) / l_x_axis
+                            > abs((y_axis_bbox[3] - y_axis_bbox[1]) - l_y_axis)
+                            / l_y_axis
+                        ):
+                            sub_obj_map_h[left_or_right[0]].append(i)
+                        else:
+                            sub_obj_map_h[top_or_bottom[0]].append(i)
+                    else:
+                        if left_or_right[1] > top_or_bottom[1]:
+                            sub_obj_map_h[top_or_bottom[0]].append(i)
+                        else:
+                            sub_obj_map_h[left_or_right[0]].append(i)
+                else:
+                    if left_or_right[1] != float('inf'):
+                        sub_obj_map_h[left_or_right[0]].append(i)
+                    else:
+                        sub_obj_map_h[top_or_bottom[0]].append(i)
+        ret = []
+        for i in sub_obj_map_h.keys():
+            ret.append(
+                {
+                    'sub_bbox': {
+                        'bbox': subjects[i]['bbox'],
+                        'score': subjects[i]['score'],
+                    },
+                    'obj_bboxes': [
+                        {'score': objects[j]['score'], 'bbox': objects[j]['bbox']}
+                        for j in sub_obj_map_h[i]
+                    ],
+                    'sub_idx': i,
+                }
+            )
+        return ret
+
+    def get_imgs_v2(self, page_no: int):
+        with_captions = self.__tie_up_category_by_distance_v2(
+            page_no, 3, 4, PosRelationEnum.BOTTOM
+        )
+        with_footnotes = self.__tie_up_category_by_distance_v2(
+            page_no, 3, CategoryId.ImageFootnote, PosRelationEnum.ALL
+        )
+        ret = []
+        for v in with_captions:
+            record = {
+                'image_body': v['sub_bbox'],
+                'image_caption_list': v['obj_bboxes'],
+            }
+            filter_idx = v['sub_idx']
+            d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
+            record['image_footnote_list'] = d['obj_bboxes']
+            ret.append(record)
+        return ret
+
+    def get_tables_v2(self, page_no: int) -> list:
+        with_captions = self.__tie_up_category_by_distance_v2(
+            page_no, 5, 6, PosRelationEnum.UP
+        )
+        with_footnotes = self.__tie_up_category_by_distance_v2(
+            page_no, 5, 7, PosRelationEnum.ALL
+        )
+        ret = []
+        for v in with_captions:
+            record = {
+                'table_body': v['sub_bbox'],
+                'table_caption_list': v['obj_bboxes'],
+            }
+            filter_idx = v['sub_idx']
+            d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
+            record['table_footnote_list'] = d['obj_bboxes']
+            ret.append(record)
+        return ret
+
     def get_imgs(self, page_no: int):
         with_captions, _ = self.__tie_up_category_by_distance(page_no, 3, 4)
         with_footnotes, _ = self.__tie_up_category_by_distance(
@@ -699,10 +1015,10 @@ class MagicModel:
 
     def get_page_size(self, page_no: int):  # 获取页面宽高
         # 获取当前页的page对象
-        page = self.__docs[page_no]
+        page = self.__docs.get_page(page_no).get_page_info()
         # 获取当前页的宽高
-        page_w = page.rect.width
-        page_h = page.rect.height
+        page_w = page.w
+        page_h = page.h
         return page_w, page_h
 
     def __get_blocks_by_type(

+ 164 - 80
magic_pdf/model/pdf_extract_kit.py

@@ -1,11 +1,14 @@
 from loguru import logger
 import os
 import time
-
+from pathlib import Path
+import shutil
 from magic_pdf.libs.Constants import *
+from magic_pdf.libs.clean_memory import clean_memory
 from magic_pdf.model.model_list import AtomicModel
 
 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
+os.environ['YOLO_VERBOSE'] = 'False'  # disable yolo logger
 try:
     import cv2
     import yaml
@@ -23,6 +26,7 @@ try:
     from unimernet.common.config import Config
     import unimernet.tasks as tasks
     from unimernet.processors import load_processor
+    from doclayout_yolo import YOLOv10
 
 except ImportError as e:
     logger.exception(e)
@@ -32,21 +36,26 @@ except ImportError as e:
     exit(1)
 
 from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
-from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
+from magic_pdf.model.pek_sub_modules.post_process import latex_rm_whitespace
 from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
-from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
+# from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
 from magic_pdf.model.ppTableModel import ppTableModel
 
 
 def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
-    if table_model_type == STRUCT_EQTABLE:
-        table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
-    else:
+    if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
+        # table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
+        logger.error("StructEqTable is under upgrade, the current version does not support it.")
+        exit(1)
+    elif table_model_type == MODEL_NAME.TABLE_MASTER:
         config = {
             "model_dir": model_path,
             "device": _device_
         }
         table_model = ppTableModel(config)
+    else:
+        logger.error("table model type not allow")
+        exit(1)
     return table_model
 
 
@@ -58,12 +67,13 @@ def mfd_model_init(weight):
 def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
     args = argparse.Namespace(cfg_path=cfg_path, options=None)
     cfg = Config(args)
-    cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.bin")
+    cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
     cfg.config.model.model_config.model_name = weight_dir
     cfg.config.model.tokenizer_config.path = weight_dir
     task = tasks.setup_task(cfg)
     model = task.build_model(cfg)
-    model = model.to(_device_)
+    model.to(_device_)
+    model.eval()
     vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
     mfr_transform = transforms.Compose([vis_processor, ])
     return [model, mfr_transform]
@@ -74,8 +84,16 @@ def layout_model_init(weight, config_file, device):
     return model
 
 
-def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3):
-    model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh)
+def doclayout_yolo_model_init(weight):
+    model = YOLOv10(weight)
+    return model
+
+
+def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None, use_dilation=True, det_db_unclip_ratio=1.8):
+    if lang is not None:
+        model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
+    else:
+        model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
     return model
 
 
@@ -108,19 +126,27 @@ class AtomModelSingleton:
         return cls._instance
 
     def get_atom_model(self, atom_model_name: str, **kwargs):
-        if atom_model_name not in self._models:
-            self._models[atom_model_name] = atom_model_init(model_name=atom_model_name, **kwargs)
-        return self._models[atom_model_name]
+        lang = kwargs.get("lang", None)
+        layout_model_name = kwargs.get("layout_model_name", None)
+        key = (atom_model_name, layout_model_name, lang)
+        if key not in self._models:
+            self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
+        return self._models[key]
 
 
 def atom_model_init(model_name: str, **kwargs):
 
     if model_name == AtomicModel.Layout:
-        atom_model = layout_model_init(
-            kwargs.get("layout_weights"),
-            kwargs.get("layout_config_file"),
-            kwargs.get("device")
-        )
+        if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
+            atom_model = layout_model_init(
+                kwargs.get("layout_weights"),
+                kwargs.get("layout_config_file"),
+                kwargs.get("device")
+            )
+        elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
+            atom_model = doclayout_yolo_model_init(
+                kwargs.get("doclayout_yolo_weights"),
+            )
     elif model_name == AtomicModel.MFD:
         atom_model = mfd_model_init(
             kwargs.get("mfd_weights")
@@ -134,11 +160,12 @@ def atom_model_init(model_name: str, **kwargs):
     elif model_name == AtomicModel.OCR:
         atom_model = ocr_model_init(
             kwargs.get("ocr_show_log"),
-            kwargs.get("det_db_box_thresh")
+            kwargs.get("det_db_box_thresh"),
+            kwargs.get("lang")
         )
     elif model_name == AtomicModel.Table:
         atom_model = table_model_init(
-            kwargs.get("table_model_type"),
+            kwargs.get("table_model_name"),
             kwargs.get("table_model_path"),
             kwargs.get("table_max_time"),
             kwargs.get("device")
@@ -150,6 +177,23 @@ def atom_model_init(model_name: str, **kwargs):
     return atom_model
 
 
+#  Unified crop img logic
+def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
+    crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
+    crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
+    # Create a white background with an additional width and height of 50
+    crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
+    crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
+    return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
+
+    # Crop image
+    crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
+    cropped_img = input_pil_img.crop(crop_box)
+    return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
+    return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
+    return return_image, return_list
+
+
 class CustomPEKModel:
 
     def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
@@ -169,22 +213,35 @@ class CustomPEKModel:
         with open(config_path, "r", encoding='utf-8') as f:
             self.configs = yaml.load(f, Loader=yaml.FullLoader)
         # 初始化解析配置
-        self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
-        self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
+
+        # layout config
+        self.layout_config = kwargs.get("layout_config")
+        self.layout_model_name = self.layout_config.get("model", MODEL_NAME.DocLayout_YOLO)
+
+        # formula config
+        self.formula_config = kwargs.get("formula_config")
+        self.mfd_model_name = self.formula_config.get("mfd_model", MODEL_NAME.YOLO_V8_MFD)
+        self.mfr_model_name = self.formula_config.get("mfr_model", MODEL_NAME.UniMerNet_v2_Small)
+        self.apply_formula = self.formula_config.get("enable", True)
+
         # table config
-        self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"])
-        self.apply_table = self.table_config.get("is_table_recog_enable", False)
+        self.table_config = kwargs.get("table_config")
+        self.apply_table = self.table_config.get("enable", False)
         self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
-        self.table_model_type = self.table_config.get("model", TABLE_MASTER)
+        self.table_model_name = self.table_config.get("model", MODEL_NAME.TABLE_MASTER)
+
+        # ocr config
         self.apply_ocr = ocr
+        self.lang = kwargs.get("lang", None)
+
         logger.info(
-            "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}".format(
-                self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table
+            "DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
+            "apply_table: {}, table_model: {}, lang: {}".format(
+                self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name, self.lang
             )
         )
-        assert self.apply_layout, "DocAnalysis must contain layout model."
         # 初始化解析方案
-        self.device = kwargs.get("device", self.configs["config"]["device"])
+        self.device = kwargs.get("device", "cpu")
         logger.info("using device: {}".format(self.device))
         models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
         logger.info("using models_dir: {}".format(models_dir))
@@ -193,17 +250,16 @@ class CustomPEKModel:
 
         # 初始化公式识别
         if self.apply_formula:
+
             # 初始化公式检测模型
-            # self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
             self.mfd_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.MFD,
-                mfd_weights=str(os.path.join(models_dir, self.configs["weights"]["mfd"]))
+                mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name]))
             )
+
             # 初始化公式解析模型
-            mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
+            mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
             mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
-            # self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
-            # self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
             self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.MFR,
                 mfr_weight_dir=mfr_weight_dir,
@@ -212,17 +268,20 @@ class CustomPEKModel:
             )
 
         # 初始化layout模型
-        # self.layout_model = Layoutlmv3_Predictor(
-        #     str(os.path.join(models_dir, self.configs['weights']['layout'])),
-        #     str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
-        #     device=self.device
-        # )
-        self.layout_model = atom_model_manager.get_atom_model(
-            atom_model_name=AtomicModel.Layout,
-            layout_weights=str(os.path.join(models_dir, self.configs['weights']['layout'])),
-            layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
-            device=self.device
-        )
+        if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
+            self.layout_model = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.Layout,
+                layout_model_name=MODEL_NAME.LAYOUTLMv3,
+                layout_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
+                layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
+                device=self.device
+            )
+        elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
+            self.layout_model = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.Layout,
+                layout_model_name=MODEL_NAME.DocLayout_YOLO,
+                doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
+            )
         # 初始化ocr
         if self.apply_ocr:
 
@@ -230,37 +289,67 @@ class CustomPEKModel:
             self.ocr_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.OCR,
                 ocr_show_log=show_log,
-                det_db_box_thresh=0.3
+                det_db_box_thresh=0.3,
+                lang=self.lang
             )
         # init table model
         if self.apply_table:
-            table_model_dir = self.configs["weights"][self.table_model_type]
-            # self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
-            #                                     max_time=self.table_max_time, _device_=self.device)
+            table_model_dir = self.configs["weights"][self.table_model_name]
             self.table_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.Table,
-                table_model_type=self.table_model_type,
+                table_model_name=self.table_model_name,
                 table_model_path=str(os.path.join(models_dir, table_model_dir)),
                 table_max_time=self.table_max_time,
                 device=self.device
             )
 
+            home_directory = Path.home()
+            det_source = os.path.join(models_dir, table_model_dir, DETECT_MODEL_DIR)
+            rec_source = os.path.join(models_dir, table_model_dir, REC_MODEL_DIR)
+            det_dest_dir = os.path.join(home_directory, PP_DET_DIRECTORY)
+            rec_dest_dir = os.path.join(home_directory, PP_REC_DIRECTORY)
+
+            if not os.path.exists(det_dest_dir):
+                shutil.copytree(det_source, det_dest_dir)
+            if not os.path.exists(rec_dest_dir):
+                shutil.copytree(rec_source, rec_dest_dir)
+
         logger.info('DocAnalysis init done!')
 
     def __call__(self, image):
 
+        page_start = time.time()
+
         latex_filling_list = []
         mf_image_list = []
 
         # layout检测
         layout_start = time.time()
-        layout_res = self.layout_model(image, ignore_catids=[])
+        if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
+            # layoutlmv3
+            layout_res = self.layout_model(image, ignore_catids=[])
+        elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
+            # doclayout_yolo
+            layout_res = []
+            doclayout_yolo_res = self.layout_model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
+            for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), doclayout_yolo_res.boxes.cls.cpu()):
+                xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
+                new_item = {
+                    'category_id': int(cla.item()),
+                    'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
+                    'score': round(float(conf.item()), 3),
+                }
+                layout_res.append(new_item)
         layout_cost = round(time.time() - layout_start, 2)
-        logger.info(f"layout detection cost: {layout_cost}")
+        logger.info(f"layout detection time: {layout_cost}")
+
+        pil_img = Image.fromarray(image)
 
         if self.apply_formula:
             # 公式检测
-            mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
+            mfd_start = time.time()
+            mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
+            logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
             for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
                 xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
                 new_item = {
@@ -271,7 +360,7 @@ class CustomPEKModel:
                 }
                 layout_res.append(new_item)
                 latex_filling_list.append(new_item)
-                bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
+                bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
                 mf_image_list.append(bbox_img)
 
             # 公式识别
@@ -281,7 +370,8 @@ class CustomPEKModel:
             mfr_res = []
             for mf_img in dataloader:
                 mf_img = mf_img.to(self.device)
-                output = self.mfr_model.generate({'image': mf_img})
+                with torch.no_grad():
+                    output = self.mfr_model.generate({'image': mf_img})
                 mfr_res.extend(output['pred_str'])
             for res, latex in zip(latex_filling_list, mfr_res):
                 res['latex'] = latex_rm_whitespace(latex)
@@ -303,23 +393,14 @@ class CustomPEKModel:
             elif int(res['category_id']) in [5]:
                 table_res_list.append(res)
 
-        #  Unified crop img logic
-        def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
-            crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
-            crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
-            # Create a white background with an additional width and height of 50
-            crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
-            crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
-            return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
-
-            # Crop image
-            crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
-            cropped_img = input_pil_img.crop(crop_box)
-            return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
-            return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
-            return return_image, return_list
-
-        pil_img = Image.fromarray(image)
+        if torch.cuda.is_available():
+            properties = torch.cuda.get_device_properties(self.device)
+            total_memory = properties.total_memory / (1024 ** 3)  # 将字节转换为 GB
+            if total_memory <= 10:
+                gc_start = time.time()
+                clean_memory()
+                gc_time = round(time.time() - gc_start, 2)
+                logger.info(f"gc time: {gc_time}")
 
         # ocr识别
         if self.apply_ocr:
@@ -369,7 +450,7 @@ class CustomPEKModel:
                         })
 
             ocr_cost = round(time.time() - ocr_start, 2)
-            logger.info(f"ocr cost: {ocr_cost}")
+            logger.info(f"ocr time: {ocr_cost}")
 
         # 表格识别 table recognition
         if self.apply_table:
@@ -377,17 +458,17 @@ class CustomPEKModel:
             for res in table_res_list:
                 new_image, _ = crop_img(res, pil_img)
                 single_table_start_time = time.time()
-                logger.info("------------------table recognition processing begins-----------------")
+                # logger.info("------------------table recognition processing begins-----------------")
                 latex_code = None
                 html_code = None
-                if self.table_model_type == STRUCT_EQTABLE:
+                if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
                     with torch.no_grad():
                         latex_code = self.table_model.image2latex(new_image)[0]
                 else:
                     html_code = self.table_model.img2html(new_image)
 
                 run_time = time.time() - single_table_start_time
-                logger.info(f"------------table recognition processing ends within {run_time}s-----")
+                # logger.info(f"------------table recognition processing ends within {run_time}s-----")
                 if run_time > self.table_max_time:
                     logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
                 # 判断是否返回正常
@@ -398,12 +479,15 @@ class CustomPEKModel:
                     if expected_ending:
                         res["latex"] = latex_code
                     else:
-                        logger.warning(f"------------table recognition processing fails----------")
+                        logger.warning(f"table recognition processing fails, not found expected LaTeX table end")
                 elif html_code:
                     res["html"] = html_code
                 else:
-                    logger.warning(f"------------table recognition processing fails----------")
-            table_cost = round(time.time() - table_start, 2)
-            logger.info(f"table cost: {table_cost}")
+                    logger.warning(f"table recognition processing fails, not get latex or html return")
+            logger.info(f"table time: {round(time.time() - table_start, 2)}")
+
+        logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----")
 
         return layout_res
+
+

+ 8 - 1
magic_pdf/model/pek_sub_modules/structeqtable/StructTableModel.py

@@ -1,5 +1,12 @@
-from struct_eqtable.model import StructTable
+from loguru import logger
+
+try:
+    from struct_eqtable.model import StructTable
+except ImportError:
+    logger.error("StructEqTable is under upgrade, the current version does not support it.")
 from pypandoc import convert_text
+
+
 class StructTableModel:
     def __init__(self, model_path, max_new_tokens=2048, max_time=400, device = 'cpu'):
         # init

+ 2 - 2
magic_pdf/model/ppTableModel.py

@@ -52,11 +52,11 @@ class ppTableModel(object):
         rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR)
         rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT)
         device = kwargs.get("device", "cpu")
-        use_gpu = True if device == "cuda" else False
+        use_gpu = True if device.startswith("cuda") else False
         config = {
             "use_gpu": use_gpu,
             "table_max_len": kwargs.get("table_max_len", TABLE_MAX_LEN),
-            "table_algorithm": TABLE_MASTER,
+            "table_algorithm": "TableMaster",
             "table_model_dir": table_model_dir,
             "table_char_dict_path": table_char_dict_path,
             "det_model_dir": det_model_dir,

+ 5 - 2
magic_pdf/model/pp_structure_v2.py

@@ -18,8 +18,11 @@ def region_to_bbox(region):
 
 
 class CustomPaddleModel:
-    def __init__(self, ocr: bool = False, show_log: bool = False):
-        self.model = PPStructure(table=False, ocr=ocr, show_log=show_log)
+    def __init__(self, ocr: bool = False, show_log: bool = False, lang=None):
+        if lang is not None:
+            self.model = PPStructure(table=False, ocr=ocr, show_log=show_log, lang=lang)
+        else:
+            self.model = PPStructure(table=False, ocr=ocr, show_log=show_log)
 
     def __call__(self, img):
         try:

+ 0 - 0
magic_pdf/model/v3/__init__.py


+ 125 - 0
magic_pdf/model/v3/helpers.py

@@ -0,0 +1,125 @@
+from collections import defaultdict
+from typing import List, Dict
+
+import torch
+from transformers import LayoutLMv3ForTokenClassification
+
+MAX_LEN = 510
+CLS_TOKEN_ID = 0
+UNK_TOKEN_ID = 3
+EOS_TOKEN_ID = 2
+
+
+class DataCollator:
+    def __call__(self, features: List[dict]) -> Dict[str, torch.Tensor]:
+        bbox = []
+        labels = []
+        input_ids = []
+        attention_mask = []
+
+        # clip bbox and labels to max length, build input_ids and attention_mask
+        for feature in features:
+            _bbox = feature["source_boxes"]
+            if len(_bbox) > MAX_LEN:
+                _bbox = _bbox[:MAX_LEN]
+            _labels = feature["target_index"]
+            if len(_labels) > MAX_LEN:
+                _labels = _labels[:MAX_LEN]
+            _input_ids = [UNK_TOKEN_ID] * len(_bbox)
+            _attention_mask = [1] * len(_bbox)
+            assert len(_bbox) == len(_labels) == len(_input_ids) == len(_attention_mask)
+            bbox.append(_bbox)
+            labels.append(_labels)
+            input_ids.append(_input_ids)
+            attention_mask.append(_attention_mask)
+
+        # add CLS and EOS tokens
+        for i in range(len(bbox)):
+            bbox[i] = [[0, 0, 0, 0]] + bbox[i] + [[0, 0, 0, 0]]
+            labels[i] = [-100] + labels[i] + [-100]
+            input_ids[i] = [CLS_TOKEN_ID] + input_ids[i] + [EOS_TOKEN_ID]
+            attention_mask[i] = [1] + attention_mask[i] + [1]
+
+        # padding to max length
+        max_len = max(len(x) for x in bbox)
+        for i in range(len(bbox)):
+            bbox[i] = bbox[i] + [[0, 0, 0, 0]] * (max_len - len(bbox[i]))
+            labels[i] = labels[i] + [-100] * (max_len - len(labels[i]))
+            input_ids[i] = input_ids[i] + [EOS_TOKEN_ID] * (max_len - len(input_ids[i]))
+            attention_mask[i] = attention_mask[i] + [0] * (
+                max_len - len(attention_mask[i])
+            )
+
+        ret = {
+            "bbox": torch.tensor(bbox),
+            "attention_mask": torch.tensor(attention_mask),
+            "labels": torch.tensor(labels),
+            "input_ids": torch.tensor(input_ids),
+        }
+        # set label > MAX_LEN to -100, because original labels may be > MAX_LEN
+        ret["labels"][ret["labels"] > MAX_LEN] = -100
+        # set label > 0 to label-1, because original labels are 1-indexed
+        ret["labels"][ret["labels"] > 0] -= 1
+        return ret
+
+
+def boxes2inputs(boxes: List[List[int]]) -> Dict[str, torch.Tensor]:
+    bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]]
+    input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID]
+    attention_mask = [1] + [1] * len(boxes) + [1]
+    return {
+        "bbox": torch.tensor([bbox]),
+        "attention_mask": torch.tensor([attention_mask]),
+        "input_ids": torch.tensor([input_ids]),
+    }
+
+
+def prepare_inputs(
+    inputs: Dict[str, torch.Tensor], model: LayoutLMv3ForTokenClassification
+) -> Dict[str, torch.Tensor]:
+    ret = {}
+    for k, v in inputs.items():
+        v = v.to(model.device)
+        if torch.is_floating_point(v):
+            v = v.to(model.dtype)
+        ret[k] = v
+    return ret
+
+
+def parse_logits(logits: torch.Tensor, length: int) -> List[int]:
+    """
+    parse logits to orders
+
+    :param logits: logits from model
+    :param length: input length
+    :return: orders
+    """
+    logits = logits[1 : length + 1, :length]
+    orders = logits.argsort(descending=False).tolist()
+    ret = [o.pop() for o in orders]
+    while True:
+        order_to_idxes = defaultdict(list)
+        for idx, order in enumerate(ret):
+            order_to_idxes[order].append(idx)
+        # filter idxes len > 1
+        order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1}
+        if not order_to_idxes:
+            break
+        # filter
+        for order, idxes in order_to_idxes.items():
+            # find original logits of idxes
+            idxes_to_logit = {}
+            for idx in idxes:
+                idxes_to_logit[idx] = logits[idx, order]
+            idxes_to_logit = sorted(
+                idxes_to_logit.items(), key=lambda x: x[1], reverse=True
+            )
+            # keep the highest logit as order, set others to next candidate
+            for idx, _ in idxes_to_logit[1:]:
+                ret[idx] = orders[idx].pop()
+
+    return ret
+
+
+def check_duplicate(a: List[int]) -> bool:
+    return len(a) != len(set(a))

+ 296 - 0
magic_pdf/para/para_split_v3.py

@@ -0,0 +1,296 @@
+import copy
+
+from loguru import logger
+
+from magic_pdf.libs.Constants import LINES_DELETED, CROSS_PAGE
+from magic_pdf.libs.ocr_content_type import BlockType, ContentType
+
+LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';')
+LIST_END_FLAG = ('.', '。', ';', ';')
+
+
+class ListLineTag:
+    IS_LIST_START_LINE = "is_list_start_line"
+    IS_LIST_END_LINE = "is_list_end_line"
+
+
+def __process_blocks(blocks):
+    # 对所有block预处理
+    # 1.通过title和interline_equation将block分组
+    # 2.bbox边界根据line信息重置
+
+    result = []
+    current_group = []
+
+    for i in range(len(blocks)):
+        current_block = blocks[i]
+
+        # 如果当前块是 text 类型
+        if current_block['type'] == 'text':
+            current_block["bbox_fs"] = copy.deepcopy(current_block["bbox"])
+            if 'lines' in current_block and len(current_block["lines"]) > 0:
+                current_block['bbox_fs'] = [min([line['bbox'][0] for line in current_block['lines']]),
+                                            min([line['bbox'][1] for line in current_block['lines']]),
+                                            max([line['bbox'][2] for line in current_block['lines']]),
+                                            max([line['bbox'][3] for line in current_block['lines']])]
+            current_group.append(current_block)
+
+        # 检查下一个块是否存在
+        if i + 1 < len(blocks):
+            next_block = blocks[i + 1]
+            # 如果下一个块不是 text 类型且是 title 或 interline_equation 类型
+            if next_block['type'] in ['title', 'interline_equation']:
+                result.append(current_group)
+                current_group = []
+
+    # 处理最后一个 group
+    if current_group:
+        result.append(current_group)
+
+    return result
+
+
+def __is_list_or_index_block(block):
+    # 一个block如果是list block 应该同时满足以下特征
+    # 1.block内有多个line 2.block 内有多个line左侧顶格写 3.block内有多个line 右侧不顶格(狗牙状)
+    # 1.block内有多个line 2.block 内有多个line左侧顶格写 3.多个line以endflag结尾
+    # 1.block内有多个line 2.block 内有多个line左侧顶格写 3.block内有多个line 左侧不顶格
+
+    # index block 是一种特殊的list block
+    # 一个block如果是index block 应该同时满足以下特征
+    # 1.block内有多个line 2.block 内有多个line两侧均顶格写 3.line的开头或者结尾均为数字
+    if len(block['lines']) >= 2:
+        first_line = block['lines'][0]
+        line_height = first_line['bbox'][3] - first_line['bbox'][1]
+        block_weight = block['bbox_fs'][2] - block['bbox_fs'][0]
+
+        left_close_num = 0
+        left_not_close_num = 0
+        right_not_close_num = 0
+        right_close_num = 0
+        lines_text_list = []
+
+        multiple_para_flag = False
+        last_line = block['lines'][-1]
+        # 如果首行左边不顶格而右边顶格,末行左边顶格而右边不顶格 (第一行可能可以右边不顶格)
+        if (first_line['bbox'][0] - block['bbox_fs'][0] > line_height / 2 and
+                # block['bbox_fs'][2] - first_line['bbox'][2] < line_height and
+                abs(last_line['bbox'][0] - block['bbox_fs'][0]) < line_height / 2 and
+                block['bbox_fs'][2] - last_line['bbox'][2] > line_height
+        ):
+            multiple_para_flag = True
+
+        for line in block['lines']:
+
+            line_text = ""
+
+            for span in line['spans']:
+                span_type = span['type']
+                if span_type == ContentType.Text:
+                    line_text += span['content'].strip()
+
+            lines_text_list.append(line_text)
+
+            # 计算line左侧顶格数量是否大于2,是否顶格用abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height/2 来判断
+            if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
+                left_close_num += 1
+            elif line['bbox'][0] - block['bbox_fs'][0] > line_height:
+                # logger.info(f"{line_text}, {block['bbox_fs']}, {line['bbox']}")
+                left_not_close_num += 1
+
+            # 计算右侧是否顶格
+            if abs(block['bbox_fs'][2] - line['bbox'][2]) < line_height:
+                right_close_num += 1
+            else:
+                # 右侧不顶格情况下是否有一段距离,拍脑袋用0.3block宽度做阈值
+                closed_area = 0.3 * block_weight
+                # closed_area = 5 * line_height
+                if block['bbox_fs'][2] - line['bbox'][2] > closed_area:
+                    right_not_close_num += 1
+
+        # 判断lines_text_list中的元素是否有超过80%都以LIST_END_FLAG结尾
+        line_end_flag = False
+        # 判断lines_text_list中的元素是否有超过80%都以数字开头或都以数字结尾
+        line_num_flag = False
+        num_start_count = 0
+        num_end_count = 0
+        flag_end_count = 0
+        if len(lines_text_list) > 0:
+            for line_text in lines_text_list:
+                if len(line_text) > 0:
+                    if line_text[-1] in LIST_END_FLAG:
+                        flag_end_count += 1
+                    if line_text[0].isdigit():
+                        num_start_count += 1
+                    if line_text[-1].isdigit():
+                        num_end_count += 1
+
+            if flag_end_count / len(lines_text_list) >= 0.8:
+                line_end_flag = True
+
+            if num_start_count / len(lines_text_list) >= 0.8 or num_end_count / len(lines_text_list) >= 0.8:
+                line_num_flag = True
+
+        # 有的目录右侧不贴边, 目前认为左边或者右边有一边全贴边,且符合数字规则极为index
+        if ((left_close_num/len(block['lines']) >= 0.8 or right_close_num/len(block['lines']) >= 0.8)
+                and line_num_flag
+        ):
+            for line in block['lines']:
+                line[ListLineTag.IS_LIST_START_LINE] = True
+            return BlockType.Index
+
+        elif left_close_num >= 2 and (
+                right_not_close_num >= 2 or line_end_flag or left_not_close_num >= 2) and not multiple_para_flag:
+            # 处理一种特殊的没有缩进的list,所有行都贴左边,通过右边的空隙判断是否是item尾
+            if left_close_num / len(block['lines']) > 0.9:
+                # 这种是每个item只有一行,且左边都贴边的短item list
+                if flag_end_count == 0 and right_close_num / len(block['lines']) < 0.5:
+                    for line in block['lines']:
+                        if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
+                            line[ListLineTag.IS_LIST_START_LINE] = True
+                # 这种是大部分line item 都有结束标识符的情况,按结束标识符区分不同item
+                elif line_end_flag:
+                    for i, line in enumerate(block['lines']):
+                        if lines_text_list[i][-1] in LIST_END_FLAG:
+                            line[ListLineTag.IS_LIST_END_LINE] = True
+                            if i + 1 < len(block['lines']):
+                                block['lines'][i+1][ListLineTag.IS_LIST_START_LINE] = True
+                # line item基本没有结束标识符,而且也没有缩进,按右侧空隙判断哪些是item end
+                else:
+                    line_start_flag = False
+                    for i, line in enumerate(block['lines']):
+                        if line_start_flag:
+                            line[ListLineTag.IS_LIST_START_LINE] = True
+                            line_start_flag = False
+                        elif abs(block['bbox_fs'][2] - line['bbox'][2]) > line_height:
+                            line[ListLineTag.IS_LIST_END_LINE] = True
+                            line_start_flag = True
+            # 一种有缩进的特殊有序list,start line 左侧不贴边且以数字开头,end line 以 IS_LIST_END_LINE 结尾且数量和start line 一致
+            elif num_start_count >= 2 and num_start_count == flag_end_count:  # 简单一点先不考虑左侧不贴边的情况
+                for i, line in enumerate(block['lines']):
+                    if lines_text_list[i][0].isdigit():
+                        line[ListLineTag.IS_LIST_START_LINE] = True
+                    if lines_text_list[i][-1] in LIST_END_FLAG:
+                        line[ListLineTag.IS_LIST_END_LINE] = True
+            else:
+                # 正常有缩进的list处理
+                for line in block['lines']:
+                    if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
+                        line[ListLineTag.IS_LIST_START_LINE] = True
+                    if abs(block['bbox_fs'][2] - line['bbox'][2]) > line_height:
+                        line[ListLineTag.IS_LIST_END_LINE] = True
+
+            return BlockType.List
+        else:
+            return BlockType.Text
+    else:
+        return BlockType.Text
+
+
+def __merge_2_text_blocks(block1, block2):
+    if len(block1['lines']) > 0:
+        first_line = block1['lines'][0]
+        line_height = first_line['bbox'][3] - first_line['bbox'][1]
+        block1_weight = block1['bbox'][2] - block1['bbox'][0]
+        block2_weight = block2['bbox'][2] - block2['bbox'][0]
+        min_block_weight = min(block1_weight, block2_weight)
+        if abs(block1['bbox_fs'][0] - first_line['bbox'][0]) < line_height / 2:
+            last_line = block2['lines'][-1]
+            if len(last_line['spans']) > 0:
+                last_span = last_line['spans'][-1]
+                line_height = last_line['bbox'][3] - last_line['bbox'][1]
+                if (abs(block2['bbox_fs'][2] - last_line['bbox'][2]) < line_height and
+                        not last_span['content'].endswith(LINE_STOP_FLAG) and
+                        # 两个block宽度差距超过2倍也不合并
+                        abs(block1_weight - block2_weight) < min_block_weight
+                ):
+                    if block1['page_num'] != block2['page_num']:
+                        for line in block1['lines']:
+                            for span in line['spans']:
+                                span[CROSS_PAGE] = True
+                    block2['lines'].extend(block1['lines'])
+                    block1['lines'] = []
+                    block1[LINES_DELETED] = True
+
+    return block1, block2
+
+
+def __merge_2_list_blocks(block1, block2):
+    if block1['page_num'] != block2['page_num']:
+        for line in block1['lines']:
+            for span in line['spans']:
+                span[CROSS_PAGE] = True
+    block2['lines'].extend(block1['lines'])
+    block1['lines'] = []
+    block1[LINES_DELETED] = True
+
+    return block1, block2
+
+
+def __is_list_group(text_blocks_group):
+    # list group的特征是一个group内的所有block都满足以下条件
+    # 1.每个block都不超过3行 2. 每个block 的左边界都比较接近(逻辑简单点先不加这个规则)
+    for block in text_blocks_group:
+        if len(block['lines']) > 3:
+            return False
+    return True
+
+
+def __para_merge_page(blocks):
+    page_text_blocks_groups = __process_blocks(blocks)
+    for text_blocks_group in page_text_blocks_groups:
+
+        if len(text_blocks_group) > 0:
+            # 需要先在合并前对所有block判断是否为list or index block
+            for block in text_blocks_group:
+                block_type = __is_list_or_index_block(block)
+                block['type'] = block_type
+                # logger.info(f"{block['type']}:{block}")
+
+        if len(text_blocks_group) > 1:
+
+            # 在合并前判断这个group 是否是一个 list group
+            is_list_group = __is_list_group(text_blocks_group)
+
+            # 倒序遍历
+            for i in range(len(text_blocks_group) - 1, -1, -1):
+                current_block = text_blocks_group[i]
+
+                # 检查是否有前一个块
+                if i - 1 >= 0:
+                    prev_block = text_blocks_group[i - 1]
+
+                    if current_block['type'] == 'text' and prev_block['type'] == 'text' and not is_list_group:
+                        __merge_2_text_blocks(current_block, prev_block)
+                    elif (
+                            (current_block['type'] == BlockType.List and prev_block['type'] == BlockType.List) or
+                            (current_block['type'] == BlockType.Index and prev_block['type'] == BlockType.Index)
+                    ):
+                        __merge_2_list_blocks(current_block, prev_block)
+
+        else:
+            continue
+
+
+def para_split(pdf_info_dict, debug_mode=False):
+    all_blocks = []
+    for page_num, page in pdf_info_dict.items():
+        blocks = copy.deepcopy(page['preproc_blocks'])
+        for block in blocks:
+            block['page_num'] = page_num
+        all_blocks.extend(blocks)
+
+    __para_merge_page(all_blocks)
+    for page_num, page in pdf_info_dict.items():
+        page['para_blocks'] = []
+        for block in all_blocks:
+            if block['page_num'] == page_num:
+                page['para_blocks'].append(block)
+
+
+if __name__ == '__main__':
+    input_blocks = []
+    # 调用函数
+    groups = __process_blocks(input_blocks)
+    for group_index, group in enumerate(groups):
+        print(f"Group {group_index}: {group}")

+ 6 - 3
magic_pdf/pdf_parse_by_ocr.py

@@ -1,4 +1,6 @@
-from magic_pdf.pdf_parse_union_core import pdf_parse_union
+from magic_pdf.config.enums import SupportedPdfParseMethod
+from magic_pdf.data.dataset import PymuDocDataset
+from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
 
 
 def parse_pdf_by_ocr(pdf_bytes,
@@ -8,10 +10,11 @@ def parse_pdf_by_ocr(pdf_bytes,
                      end_page_id=None,
                      debug_mode=False,
                      ):
-    return pdf_parse_union(pdf_bytes,
+    dataset = PymuDocDataset(pdf_bytes)
+    return pdf_parse_union(dataset,
                            model_list,
                            imageWriter,
-                           "ocr",
+                           SupportedPdfParseMethod.OCR,
                            start_page_id=start_page_id,
                            end_page_id=end_page_id,
                            debug_mode=debug_mode,

+ 6 - 3
magic_pdf/pdf_parse_by_txt.py

@@ -1,4 +1,6 @@
-from magic_pdf.pdf_parse_union_core import pdf_parse_union
+from magic_pdf.config.enums import SupportedPdfParseMethod
+from magic_pdf.data.dataset import PymuDocDataset
+from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
 
 
 def parse_pdf_by_txt(
@@ -9,10 +11,11 @@ def parse_pdf_by_txt(
     end_page_id=None,
     debug_mode=False,
 ):
-    return pdf_parse_union(pdf_bytes,
+    dataset = PymuDocDataset(pdf_bytes)
+    return pdf_parse_union(dataset,
                            model_list,
                            imageWriter,
-                           "txt",
+                           SupportedPdfParseMethod.TXT,
                            start_page_id=start_page_id,
                            end_page_id=end_page_id,
                            debug_mode=debug_mode,

+ 644 - 0
magic_pdf/pdf_parse_union_core_v2.py

@@ -0,0 +1,644 @@
+import copy
+import os
+import statistics
+import time
+from typing import List
+
+import torch
+from loguru import logger
+
+from magic_pdf.config.enums import SupportedPdfParseMethod
+from magic_pdf.data.dataset import Dataset, PageableData
+from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio
+from magic_pdf.libs.clean_memory import clean_memory
+from magic_pdf.libs.commons import fitz, get_delta_time
+from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
+from magic_pdf.libs.convert_utils import dict_to_list
+from magic_pdf.libs.drop_reason import DropReason
+from magic_pdf.libs.hash_utils import compute_md5
+from magic_pdf.libs.local_math import float_equal
+from magic_pdf.libs.ocr_content_type import ContentType, BlockType
+from magic_pdf.model.magic_model import MagicModel
+from magic_pdf.para.para_split_v3 import para_split
+from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker
+from magic_pdf.pre_proc.construct_page_dict import \
+    ocr_construct_page_component_v2
+from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
+from magic_pdf.pre_proc.equations_replace import (
+    combine_chars_to_pymudict, remove_chars_in_text_blocks,
+    replace_equations_in_textblock)
+from magic_pdf.pre_proc.ocr_detect_all_bboxes import \
+    ocr_prepare_bboxes_for_layout_split_v2
+from magic_pdf.pre_proc.ocr_dict_merge import (fill_spans_in_blocks,
+                                               fix_block_spans,
+                                               fix_discarded_block, fix_block_spans_v2)
+from magic_pdf.pre_proc.ocr_span_list_modify import (
+    get_qa_need_list_v2, remove_overlaps_low_confidence_spans,
+    remove_overlaps_min_spans)
+from magic_pdf.pre_proc.resolve_bbox_conflict import \
+    check_useful_block_horizontal_overlap
+
+
+def remove_horizontal_overlap_block_which_smaller(all_bboxes):
+    useful_blocks = []
+    for bbox in all_bboxes:
+        useful_blocks.append({'bbox': bbox[:4]})
+    is_useful_block_horz_overlap, smaller_bbox, bigger_bbox = (
+        check_useful_block_horizontal_overlap(useful_blocks)
+    )
+    if is_useful_block_horz_overlap:
+        logger.warning(
+            f'skip this page, reason: {DropReason.USEFUL_BLOCK_HOR_OVERLAP}, smaller bbox is {smaller_bbox}, bigger bbox is {bigger_bbox}'
+        )  # noqa: E501
+        for bbox in all_bboxes.copy():
+            if smaller_bbox == bbox[:4]:
+                all_bboxes.remove(bbox)
+
+    return is_useful_block_horz_overlap, all_bboxes
+
+
+def __replace_STX_ETX(text_str: str):
+    """Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks.
+    Drawback: This issue is only observed in English text; it has not been found in Chinese text so far.
+
+        Args:
+            text_str (str): raw text
+
+        Returns:
+            _type_: replaced text
+    """  # noqa: E501
+    if text_str:
+        s = text_str.replace('\u0002', "'")
+        s = s.replace('\u0003', "'")
+        return s
+    return text_str
+
+
+def txt_spans_extract(pdf_page, inline_equations, interline_equations):
+    text_raw_blocks = pdf_page.get_text('dict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
+    char_level_text_blocks = pdf_page.get_text('rawdict', flags=fitz.TEXTFLAGS_TEXT)[
+        'blocks'
+    ]
+    text_blocks = combine_chars_to_pymudict(text_raw_blocks, char_level_text_blocks)
+    text_blocks = replace_equations_in_textblock(
+        text_blocks, inline_equations, interline_equations
+    )
+    text_blocks = remove_citation_marker(text_blocks)
+    text_blocks = remove_chars_in_text_blocks(text_blocks)
+    spans = []
+    for v in text_blocks:
+        for line in v['lines']:
+            for span in line['spans']:
+                bbox = span['bbox']
+                if float_equal(bbox[0], bbox[2]) or float_equal(bbox[1], bbox[3]):
+                    continue
+                if span.get('type') not in (
+                    ContentType.InlineEquation,
+                    ContentType.InterlineEquation,
+                ):
+                    spans.append(
+                        {
+                            'bbox': list(span['bbox']),
+                            'content': __replace_STX_ETX(span['text']),
+                            'type': ContentType.Text,
+                            'score': 1.0,
+                        }
+                    )
+    return spans
+
+
+def replace_text_span(pymu_spans, ocr_spans):
+    return list(filter(lambda x: x['type'] != ContentType.Text, ocr_spans)) + pymu_spans
+
+
+def model_init(model_name: str):
+    from transformers import LayoutLMv3ForTokenClassification
+
+    if torch.cuda.is_available():
+        device = torch.device('cuda')
+        if torch.cuda.is_bf16_supported():
+            supports_bfloat16 = True
+        else:
+            supports_bfloat16 = False
+    else:
+        device = torch.device('cpu')
+        supports_bfloat16 = False
+
+    if model_name == 'layoutreader':
+        # 检测modelscope的缓存目录是否存在
+        layoutreader_model_dir = get_local_layoutreader_model_dir()
+        if os.path.exists(layoutreader_model_dir):
+            model = LayoutLMv3ForTokenClassification.from_pretrained(
+                layoutreader_model_dir
+            )
+        else:
+            logger.warning(
+                'local layoutreader model not exists, use online model from huggingface'
+            )
+            model = LayoutLMv3ForTokenClassification.from_pretrained(
+                'hantian/layoutreader'
+            )
+        # 检查设备是否支持 bfloat16
+        if supports_bfloat16:
+            model.bfloat16()
+        model.to(device).eval()
+    else:
+        logger.error('model name not allow')
+        exit(1)
+    return model
+
+
+class ModelSingleton:
+    _instance = None
+    _models = {}
+
+    def __new__(cls, *args, **kwargs):
+        if cls._instance is None:
+            cls._instance = super().__new__(cls)
+        return cls._instance
+
+    def get_model(self, model_name: str):
+        if model_name not in self._models:
+            self._models[model_name] = model_init(model_name=model_name)
+        return self._models[model_name]
+
+
+def do_predict(boxes: List[List[int]], model) -> List[int]:
+    from magic_pdf.model.v3.helpers import (boxes2inputs, parse_logits,
+                                            prepare_inputs)
+
+    inputs = boxes2inputs(boxes)
+    inputs = prepare_inputs(inputs, model)
+    logits = model(**inputs).logits.cpu().squeeze(0)
+    return parse_logits(logits, len(boxes))
+
+
+def cal_block_index(fix_blocks, sorted_bboxes):
+    for block in fix_blocks:
+
+        line_index_list = []
+        if len(block['lines']) == 0:
+            block['index'] = sorted_bboxes.index(block['bbox'])
+        else:
+            for line in block['lines']:
+                line['index'] = sorted_bboxes.index(line['bbox'])
+                line_index_list.append(line['index'])
+            median_value = statistics.median(line_index_list)
+            block['index'] = median_value
+
+        # 删除图表body block中的虚拟line信息, 并用real_lines信息回填
+        if block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
+            block['virtual_lines'] = copy.deepcopy(block['lines'])
+            block['lines'] = copy.deepcopy(block['real_lines'])
+            del block['real_lines']
+
+    return fix_blocks
+
+
+def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
+    # block_bbox是一个元组(x0, y0, x1, y1),其中(x0, y0)是左下角坐标,(x1, y1)是右上角坐标
+    x0, y0, x1, y1 = block_bbox
+
+    block_height = y1 - y0
+    block_weight = x1 - x0
+
+    # 如果block高度小于n行正文,则直接返回block的bbox
+    if line_height * 3 < block_height:
+        if (
+            block_height > page_h * 0.25 and page_w * 0.5 > block_weight > page_w * 0.25
+        ):  # 可能是双列结构,可以切细点
+            lines = int(block_height / line_height) + 1
+        else:
+            # 如果block的宽度超过0.4页面宽度,则将block分成3行(是一种复杂布局,图不能切的太细)
+            if block_weight > page_w * 0.4:
+                line_height = (y1 - y0) / 3
+                lines = 3
+            elif block_weight > page_w * 0.25:  # (可能是三列结构,也切细点)
+                lines = int(block_height / line_height) + 1
+            else:  # 判断长宽比
+                if block_height / block_weight > 1.2:  # 细长的不分
+                    return [[x0, y0, x1, y1]]
+                else:  # 不细长的还是分成两行
+                    line_height = (y1 - y0) / 2
+                    lines = 2
+
+        # 确定从哪个y位置开始绘制线条
+        current_y = y0
+
+        # 用于存储线条的位置信息[(x0, y), ...]
+        lines_positions = []
+
+        for i in range(lines):
+            lines_positions.append([x0, current_y, x1, current_y + line_height])
+            current_y += line_height
+        return lines_positions
+
+    else:
+        return [[x0, y0, x1, y1]]
+
+
+def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
+    page_line_list = []
+    for block in fix_blocks:
+        if block['type'] in [
+            BlockType.Text, BlockType.Title, BlockType.InterlineEquation,
+            BlockType.ImageCaption, BlockType.ImageFootnote,
+            BlockType.TableCaption, BlockType.TableFootnote
+        ]:
+            if len(block['lines']) == 0:
+                bbox = block['bbox']
+                lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
+                for line in lines:
+                    block['lines'].append({'bbox': line, 'spans': []})
+                page_line_list.extend(lines)
+            else:
+                for line in block['lines']:
+                    bbox = line['bbox']
+                    page_line_list.append(bbox)
+        elif block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
+            bbox = block['bbox']
+            block["real_lines"] = copy.deepcopy(block['lines'])
+            lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
+            block['lines'] = []
+            for line in lines:
+                block['lines'].append({'bbox': line, 'spans': []})
+            page_line_list.extend(lines)
+
+    # 使用layoutreader排序
+    x_scale = 1000.0 / page_w
+    y_scale = 1000.0 / page_h
+    boxes = []
+    # logger.info(f"Scale: {x_scale}, {y_scale}, Boxes len: {len(page_line_list)}")
+    for left, top, right, bottom in page_line_list:
+        if left < 0:
+            logger.warning(
+                f'left < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
+            )  # noqa: E501
+            left = 0
+        if right > page_w:
+            logger.warning(
+                f'right > page_w, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
+            )  # noqa: E501
+            right = page_w
+        if top < 0:
+            logger.warning(
+                f'top < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
+            )  # noqa: E501
+            top = 0
+        if bottom > page_h:
+            logger.warning(
+                f'bottom > page_h, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
+            )  # noqa: E501
+            bottom = page_h
+
+        left = round(left * x_scale)
+        top = round(top * y_scale)
+        right = round(right * x_scale)
+        bottom = round(bottom * y_scale)
+        assert (
+            1000 >= right >= left >= 0 and 1000 >= bottom >= top >= 0
+        ), f'Invalid box. right: {right}, left: {left}, bottom: {bottom}, top: {top}'  # noqa: E126, E121
+        boxes.append([left, top, right, bottom])
+    model_manager = ModelSingleton()
+    model = model_manager.get_model('layoutreader')
+    with torch.no_grad():
+        orders = do_predict(boxes, model)
+    sorted_bboxes = [page_line_list[i] for i in orders]
+
+    return sorted_bboxes
+
+
+def get_line_height(blocks):
+    page_line_height_list = []
+    for block in blocks:
+        if block['type'] in [
+            BlockType.Text, BlockType.Title,
+            BlockType.ImageCaption, BlockType.ImageFootnote,
+            BlockType.TableCaption, BlockType.TableFootnote
+        ]:
+            for line in block['lines']:
+                bbox = line['bbox']
+                page_line_height_list.append(int(bbox[3] - bbox[1]))
+    if len(page_line_height_list) > 0:
+        return statistics.median(page_line_height_list)
+    else:
+        return 10
+
+
+def process_groups(groups, body_key, caption_key, footnote_key):
+    body_blocks = []
+    caption_blocks = []
+    footnote_blocks = []
+    for i, group in enumerate(groups):
+        group[body_key]['group_id'] = i
+        body_blocks.append(group[body_key])
+        for caption_block in group[caption_key]:
+            caption_block['group_id'] = i
+            caption_blocks.append(caption_block)
+        for footnote_block in group[footnote_key]:
+            footnote_block['group_id'] = i
+            footnote_blocks.append(footnote_block)
+    return body_blocks, caption_blocks, footnote_blocks
+
+
+def process_block_list(blocks, body_type, block_type):
+    indices = [block['index'] for block in blocks]
+    median_index = statistics.median(indices)
+
+    body_bbox = next((block['bbox'] for block in blocks if block.get('type') == body_type), [])
+
+    return {
+        'type': block_type,
+        'bbox': body_bbox,
+        'blocks': blocks,
+        'index': median_index,
+    }
+
+
+def revert_group_blocks(blocks):
+    image_groups = {}
+    table_groups = {}
+    new_blocks = []
+    for block in blocks:
+        if block['type'] in [BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote]:
+            group_id = block['group_id']
+            if group_id not in image_groups:
+                image_groups[group_id] = []
+            image_groups[group_id].append(block)
+        elif block['type'] in [BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote]:
+            group_id = block['group_id']
+            if group_id not in table_groups:
+                table_groups[group_id] = []
+            table_groups[group_id].append(block)
+        else:
+            new_blocks.append(block)
+
+    for group_id, blocks in image_groups.items():
+        new_blocks.append(process_block_list(blocks, BlockType.ImageBody, BlockType.Image))
+
+    for group_id, blocks in table_groups.items():
+        new_blocks.append(process_block_list(blocks, BlockType.TableBody, BlockType.Table))
+
+    return new_blocks
+
+
+def remove_outside_spans(spans, all_bboxes, all_discarded_blocks):
+    def get_block_bboxes(blocks, block_type_list):
+        return [block[0:4] for block in blocks if block[7] in block_type_list]
+
+    image_bboxes = get_block_bboxes(all_bboxes, [BlockType.ImageBody])
+    table_bboxes = get_block_bboxes(all_bboxes, [BlockType.TableBody])
+    other_block_type = []
+    for block_type in BlockType.__dict__.values():
+        if not isinstance(block_type, str):
+            continue
+        if block_type not in [BlockType.ImageBody, BlockType.TableBody]:
+            other_block_type.append(block_type)
+    other_block_bboxes = get_block_bboxes(all_bboxes, other_block_type)
+    discarded_block_bboxes = get_block_bboxes(all_discarded_blocks, [BlockType.Discarded])
+
+    new_spans = []
+
+    for span in spans:
+        span_bbox = span['bbox']
+        span_type = span['type']
+
+        if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.4 for block_bbox in
+               discarded_block_bboxes):
+            new_spans.append(span)
+            continue
+
+        if span_type == ContentType.Image:
+            if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
+                   image_bboxes):
+                new_spans.append(span)
+        elif span_type == ContentType.Table:
+            if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
+                   table_bboxes):
+                new_spans.append(span)
+        else:
+            if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
+                   other_block_bboxes):
+                new_spans.append(span)
+
+    return new_spans
+
+
+def parse_page_core(
+    page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
+):
+    need_drop = False
+    drop_reason = []
+
+    """从magic_model对象中获取后面会用到的区块信息"""
+    # img_blocks = magic_model.get_imgs(page_id)
+    # table_blocks = magic_model.get_tables(page_id)
+
+    img_groups = magic_model.get_imgs_v2(page_id)
+    table_groups = magic_model.get_tables_v2(page_id)
+
+    img_body_blocks, img_caption_blocks, img_footnote_blocks = process_groups(
+        img_groups, 'image_body', 'image_caption_list', 'image_footnote_list'
+    )
+
+    table_body_blocks, table_caption_blocks, table_footnote_blocks = process_groups(
+        table_groups, 'table_body', 'table_caption_list', 'table_footnote_list'
+    )
+
+    discarded_blocks = magic_model.get_discarded(page_id)
+    text_blocks = magic_model.get_text_blocks(page_id)
+    title_blocks = magic_model.get_title_blocks(page_id)
+    inline_equations, interline_equations, interline_equation_blocks = (
+        magic_model.get_equations(page_id)
+    )
+
+    page_w, page_h = magic_model.get_page_size(page_id)
+
+    """将所有区块的bbox整理到一起"""
+    # interline_equation_blocks参数不够准,后面切换到interline_equations上
+    interline_equation_blocks = []
+    if len(interline_equation_blocks) > 0:
+        all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2(
+            img_body_blocks, img_caption_blocks, img_footnote_blocks,
+            table_body_blocks, table_caption_blocks, table_footnote_blocks,
+            discarded_blocks,
+            text_blocks,
+            title_blocks,
+            interline_equation_blocks,
+            page_w,
+            page_h,
+        )
+    else:
+        all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2(
+            img_body_blocks, img_caption_blocks, img_footnote_blocks,
+            table_body_blocks, table_caption_blocks, table_footnote_blocks,
+            discarded_blocks,
+            text_blocks,
+            title_blocks,
+            interline_equations,
+            page_w,
+            page_h,
+        )
+
+    spans = magic_model.get_all_spans(page_id)
+
+    """根据parse_mode,构造spans"""
+    if parse_mode == SupportedPdfParseMethod.TXT:
+        """ocr 中文本类的 span 用 pymu spans 替换!"""
+        pymu_spans = txt_spans_extract(page_doc, inline_equations, interline_equations)
+        spans = replace_text_span(pymu_spans, spans)
+    elif parse_mode == SupportedPdfParseMethod.OCR:
+        pass
+    else:
+        raise Exception('parse_mode must be txt or ocr')
+
+    """在删除重复span之前,应该通过image_body和table_body的block过滤一下image和table的span"""
+    """顺便删除大水印并保留abandon的span"""
+    spans = remove_outside_spans(spans, all_bboxes, all_discarded_blocks)
+
+    """删除重叠spans中置信度较低的那些"""
+    spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
+    """删除重叠spans中较小的那些"""
+    spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
+    """对image和table截图"""
+    spans = ocr_cut_image_and_table(
+        spans, page_doc, page_id, pdf_bytes_md5, imageWriter
+    )
+
+    """先处理不需要排版的discarded_blocks"""
+    discarded_block_with_spans, spans = fill_spans_in_blocks(
+        all_discarded_blocks, spans, 0.4
+    )
+    fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)
+
+    """如果当前页面没有bbox则跳过"""
+    if len(all_bboxes) == 0:
+        logger.warning(f'skip this page, not found useful bbox, page_id: {page_id}')
+        return ocr_construct_page_component_v2(
+            [],
+            [],
+            page_id,
+            page_w,
+            page_h,
+            [],
+            [],
+            [],
+            interline_equations,
+            fix_discarded_blocks,
+            need_drop,
+            drop_reason,
+        )
+
+    """将span填入blocks中"""
+    block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.5)
+
+    """对block进行fix操作"""
+    fix_blocks = fix_block_spans_v2(block_with_spans)
+
+    """获取所有line并计算正文line的高度"""
+    line_height = get_line_height(fix_blocks)
+
+    """获取所有line并对line排序"""
+    sorted_bboxes = sort_lines_by_model(fix_blocks, page_w, page_h, line_height)
+
+    """根据line的中位数算block的序列关系"""
+    fix_blocks = cal_block_index(fix_blocks, sorted_bboxes)
+
+    """将image和table的block还原回group形式参与后续流程"""
+    fix_blocks = revert_group_blocks(fix_blocks)
+
+    """重排block"""
+    sorted_blocks = sorted(fix_blocks, key=lambda b: b['index'])
+
+    """获取QA需要外置的list"""
+    images, tables, interline_equations = get_qa_need_list_v2(sorted_blocks)
+
+    """构造pdf_info_dict"""
+    page_info = ocr_construct_page_component_v2(
+        sorted_blocks,
+        [],
+        page_id,
+        page_w,
+        page_h,
+        [],
+        images,
+        tables,
+        interline_equations,
+        fix_discarded_blocks,
+        need_drop,
+        drop_reason,
+    )
+    return page_info
+
+
+def pdf_parse_union(
+    dataset: Dataset,
+    model_list,
+    imageWriter,
+    parse_mode,
+    start_page_id=0,
+    end_page_id=None,
+    debug_mode=False,
+):
+    pdf_bytes_md5 = compute_md5(dataset.data_bits())
+
+    """初始化空的pdf_info_dict"""
+    pdf_info_dict = {}
+
+    """用model_list和docs对象初始化magic_model"""
+    magic_model = MagicModel(model_list, dataset)
+
+    """根据输入的起始范围解析pdf"""
+    # end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
+    end_page_id = (
+        end_page_id
+        if end_page_id is not None and end_page_id >= 0
+        else len(dataset) - 1
+    )
+
+    if end_page_id > len(dataset) - 1:
+        logger.warning('end_page_id is out of range, use pdf_docs length')
+        end_page_id = len(dataset) - 1
+
+    """初始化启动时间"""
+    start_time = time.time()
+
+    for page_id, page in enumerate(dataset):
+        """debug时输出每页解析的耗时."""
+        if debug_mode:
+            time_now = time.time()
+            logger.info(
+                f'page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}'
+            )
+            start_time = time_now
+
+        """解析pdf中的每一页"""
+        if start_page_id <= page_id <= end_page_id:
+            page_info = parse_page_core(
+                page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
+            )
+        else:
+            page_info = page.get_page_info()
+            page_w = page_info.w
+            page_h = page_info.h
+            page_info = ocr_construct_page_component_v2(
+                [], [], page_id, page_w, page_h, [], [], [], [], [], True, 'skip page'
+            )
+        pdf_info_dict[f'page_{page_id}'] = page_info
+
+    """分段"""
+    para_split(pdf_info_dict, debug_mode=debug_mode)
+
+    """dict转list"""
+    pdf_info_list = dict_to_list(pdf_info_dict)
+    new_pdf_info_dict = {
+        'pdf_info': pdf_info_list,
+    }
+
+    clean_memory()
+
+    return new_pdf_info_dict
+
+
+if __name__ == '__main__':
+    pass

+ 5 - 1
magic_pdf/pipe/AbsPipe.py

@@ -17,7 +17,7 @@ class AbsPipe(ABC):
     PIP_TXT = "txt"
 
     def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
-                 start_page_id=0, end_page_id=None):
+                 start_page_id=0, end_page_id=None, lang=None, layout_model=None, formula_enable=None, table_enable=None):
         self.pdf_bytes = pdf_bytes
         self.model_list = model_list
         self.image_writer = image_writer
@@ -25,6 +25,10 @@ class AbsPipe(ABC):
         self.is_debug = is_debug
         self.start_page_id = start_page_id
         self.end_page_id = end_page_id
+        self.lang = lang
+        self.layout_model = layout_model
+        self.formula_enable = formula_enable
+        self.table_enable = table_enable
     
     def get_compress_pdf_mid_data(self):
         return JsonCompressor.compress_json(self.pdf_mid_data)

+ 10 - 4
magic_pdf/pipe/OCRPipe.py

@@ -10,19 +10,25 @@ from magic_pdf.user_api import parse_ocr_pdf
 class OCRPipe(AbsPipe):
 
     def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
-                 start_page_id=0, end_page_id=None):
-        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id)
+                 start_page_id=0, end_page_id=None, lang=None,
+                 layout_model=None, formula_enable=None, table_enable=None):
+        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
+                         layout_model, formula_enable, table_enable)
 
     def pipe_classify(self):
         pass
 
     def pipe_analyze(self):
         self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
-                                      start_page_id=self.start_page_id, end_page_id=self.end_page_id)
+                                      start_page_id=self.start_page_id, end_page_id=self.end_page_id,
+                                      lang=self.lang, layout_model=self.layout_model,
+                                      formula_enable=self.formula_enable, table_enable=self.table_enable)
 
     def pipe_parse(self):
         self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
-                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id)
+                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id,
+                                          lang=self.lang, layout_model=self.layout_model,
+                                          formula_enable=self.formula_enable, table_enable=self.table_enable)
 
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)

+ 10 - 4
magic_pdf/pipe/TXTPipe.py

@@ -11,19 +11,25 @@ from magic_pdf.user_api import parse_txt_pdf
 class TXTPipe(AbsPipe):
 
     def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
-                 start_page_id=0, end_page_id=None):
-        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id)
+                 start_page_id=0, end_page_id=None, lang=None,
+                 layout_model=None, formula_enable=None, table_enable=None):
+        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
+                         layout_model, formula_enable, table_enable)
 
     def pipe_classify(self):
         pass
 
     def pipe_analyze(self):
         self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
-                                      start_page_id=self.start_page_id, end_page_id=self.end_page_id)
+                                      start_page_id=self.start_page_id, end_page_id=self.end_page_id,
+                                      lang=self.lang, layout_model=self.layout_model,
+                                      formula_enable=self.formula_enable, table_enable=self.table_enable)
 
     def pipe_parse(self):
         self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
-                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id)
+                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id,
+                                          lang=self.lang, layout_model=self.layout_model,
+                                          formula_enable=self.formula_enable, table_enable=self.table_enable)
 
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)

+ 16 - 7
magic_pdf/pipe/UNIPipe.py

@@ -14,9 +14,11 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
 class UNIPipe(AbsPipe):
 
     def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False,
-                 start_page_id=0, end_page_id=None):
+                 start_page_id=0, end_page_id=None, lang=None,
+                 layout_model=None, formula_enable=None, table_enable=None):
         self.pdf_type = jso_useful_key["_pdf_type"]
-        super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id)
+        super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id,
+                         lang, layout_model, formula_enable, table_enable)
         if len(self.model_list) == 0:
             self.input_model_is_empty = True
         else:
@@ -28,22 +30,29 @@ class UNIPipe(AbsPipe):
     def pipe_analyze(self):
         if self.pdf_type == self.PIP_TXT:
             self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
-                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id)
+                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id,
+                                          lang=self.lang, layout_model=self.layout_model,
+                                          formula_enable=self.formula_enable, table_enable=self.table_enable)
         elif self.pdf_type == self.PIP_OCR:
             self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
-                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id)
+                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id,
+                                          lang=self.lang, layout_model=self.layout_model,
+                                          formula_enable=self.formula_enable, table_enable=self.table_enable)
 
     def pipe_parse(self):
         if self.pdf_type == self.PIP_TXT:
             self.pdf_mid_data = parse_union_pdf(self.pdf_bytes, self.model_list, self.image_writer,
                                                 is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty,
-                                                start_page_id=self.start_page_id, end_page_id=self.end_page_id)
+                                                start_page_id=self.start_page_id, end_page_id=self.end_page_id,
+                                                lang=self.lang, layout_model=self.layout_model,
+                                                formula_enable=self.formula_enable, table_enable=self.table_enable)
         elif self.pdf_type == self.PIP_OCR:
             self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer,
                                               is_debug=self.is_debug,
-                                              start_page_id=self.start_page_id, end_page_id=self.end_page_id)
+                                              start_page_id=self.start_page_id, end_page_id=self.end_page_id,
+                                              lang=self.lang)
 
-    def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
+    def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.NONE_WITH_REASON):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
         logger.info("uni_pipe mk content list finished")
         return result

+ 83 - 1
magic_pdf/pre_proc/ocr_detect_all_bboxes.py

@@ -1,7 +1,7 @@
 from loguru import logger
 
 from magic_pdf.libs.boxbase import get_minbox_if_overlap_by_ratio, calculate_overlap_area_in_bbox1_area_ratio, \
-    calculate_iou
+    calculate_iou, calculate_vertical_projection_overlap_ratio
 from magic_pdf.libs.drop_tag import DropTag
 from magic_pdf.libs.ocr_content_type import BlockType
 from magic_pdf.pre_proc.remove_bbox_overlap import remove_overlap_between_bbox_for_block
@@ -60,6 +60,88 @@ def ocr_prepare_bboxes_for_layout_split(img_blocks, table_blocks, discarded_bloc
     return all_bboxes, all_discarded_blocks, drop_reasons
 
 
+def add_bboxes(blocks, block_type, bboxes):
+    for block in blocks:
+        x0, y0, x1, y1 = block['bbox']
+        if block_type in [
+            BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote,
+            BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote
+        ]:
+            bboxes.append([x0, y0, x1, y1, None, None, None, block_type, None, None, None, None, block["score"], block["group_id"]])
+        else:
+            bboxes.append([x0, y0, x1, y1, None, None, None, block_type, None, None, None, None, block["score"]])
+
+
+def ocr_prepare_bboxes_for_layout_split_v2(
+        img_body_blocks, img_caption_blocks, img_footnote_blocks,
+        table_body_blocks, table_caption_blocks, table_footnote_blocks,
+        discarded_blocks, text_blocks, title_blocks, interline_equation_blocks, page_w, page_h
+):
+    all_bboxes = []
+
+    add_bboxes(img_body_blocks, BlockType.ImageBody, all_bboxes)
+    add_bboxes(img_caption_blocks, BlockType.ImageCaption, all_bboxes)
+    add_bboxes(img_footnote_blocks, BlockType.ImageFootnote, all_bboxes)
+    add_bboxes(table_body_blocks, BlockType.TableBody, all_bboxes)
+    add_bboxes(table_caption_blocks, BlockType.TableCaption, all_bboxes)
+    add_bboxes(table_footnote_blocks, BlockType.TableFootnote, all_bboxes)
+    add_bboxes(text_blocks, BlockType.Text, all_bboxes)
+    add_bboxes(title_blocks, BlockType.Title, all_bboxes)
+    add_bboxes(interline_equation_blocks, BlockType.InterlineEquation, all_bboxes)
+
+    '''block嵌套问题解决'''
+    '''文本框与标题框重叠,优先信任文本框'''
+    all_bboxes = fix_text_overlap_title_blocks(all_bboxes)
+    '''任何框体与舍弃框重叠,优先信任舍弃框'''
+    all_bboxes = remove_need_drop_blocks(all_bboxes, discarded_blocks)
+
+    # interline_equation 与title或text框冲突的情况,分两种情况处理
+    '''interline_equation框与文本类型框iou比较接近1的时候,信任行间公式框'''
+    all_bboxes = fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes)
+    '''interline_equation框被包含在文本类型框内,且interline_equation比文本区块小很多时信任文本框,这时需要舍弃公式框'''
+    # 通过后续大框套小框逻辑删除
+
+    '''discarded_blocks'''
+    all_discarded_blocks = []
+    add_bboxes(discarded_blocks, BlockType.Discarded, all_discarded_blocks)
+
+    '''footnote识别:宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的'''
+    footnote_blocks = []
+    for discarded in discarded_blocks:
+        x0, y0, x1, y1 = discarded['bbox']
+        if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2):
+            footnote_blocks.append([x0, y0, x1, y1])
+
+    '''移除在footnote下面的任何框'''
+    need_remove_blocks = find_blocks_under_footnote(all_bboxes, footnote_blocks)
+    if len(need_remove_blocks) > 0:
+        for block in need_remove_blocks:
+            all_bboxes.remove(block)
+            all_discarded_blocks.append(block)
+
+    '''经过以上处理后,还存在大框套小框的情况,则删除小框'''
+    all_bboxes = remove_overlaps_min_blocks(all_bboxes)
+    all_discarded_blocks = remove_overlaps_min_blocks(all_discarded_blocks)
+    '''将剩余的bbox做分离处理,防止后面分layout时出错'''
+    all_bboxes, drop_reasons = remove_overlap_between_bbox_for_block(all_bboxes)
+
+    return all_bboxes, all_discarded_blocks
+
+
+def find_blocks_under_footnote(all_bboxes, footnote_blocks):
+    need_remove_blocks = []
+    for block in all_bboxes:
+        block_x0, block_y0, block_x1, block_y1 = block[:4]
+        for footnote_bbox in footnote_blocks:
+            footnote_x0, footnote_y0, footnote_x1, footnote_y1 = footnote_bbox
+            # 如果footnote的纵向投影覆盖了block的纵向投影的80%且block的y0大于等于footnote的y1
+            if block_y0 >= footnote_y1 and calculate_vertical_projection_overlap_ratio((block_x0, block_y0, block_x1, block_y1), footnote_bbox) >= 0.8:
+                if block not in need_remove_blocks:
+                    need_remove_blocks.append(block)
+                    break
+    return need_remove_blocks
+
+
 def fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes):
     # 先提取所有text和interline block
     text_blocks = []

+ 27 - 2
magic_pdf/pre_proc/ocr_dict_merge.py

@@ -49,8 +49,7 @@ def merge_spans_to_line(spans):
                 continue
 
             # 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
-            if __is_overlaps_y_exceeds_threshold(span['bbox'],
-                                                 current_line[-1]['bbox']):
+            if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], 0.5):
                 current_line.append(span)
             else:
                 # 否则,开始新行
@@ -154,6 +153,11 @@ def fill_spans_in_blocks(blocks, spans, radio):
             'type': block_type,
             'bbox': block_bbox,
         }
+        if block_type in [
+            BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote,
+            BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote
+        ]:
+            block_dict["group_id"] = block[-1]
         block_spans = []
         for span in spans:
             span_bbox = span['bbox']
@@ -202,6 +206,27 @@ def fix_block_spans(block_with_spans, img_blocks, table_blocks):
     return fix_blocks
 
 
+def fix_block_spans_v2(block_with_spans):
+    """1、img_block和table_block因为包含caption和footnote的关系,存在block的嵌套关系
+    需要将caption和footnote的text_span放入相应img_block和table_block内的
+    caption_block和footnote_block中 2、同时需要删除block中的spans字段."""
+    fix_blocks = []
+    for block in block_with_spans:
+        block_type = block['type']
+
+        if block_type in [BlockType.Text, BlockType.Title,
+                          BlockType.ImageCaption, BlockType.ImageFootnote,
+                          BlockType.TableCaption, BlockType.TableFootnote
+                          ]:
+            block = fix_text_block(block)
+        elif block_type in [BlockType.InterlineEquation, BlockType.ImageBody, BlockType.TableBody]:
+            block = fix_interline_block(block)
+        else:
+            continue
+        fix_blocks.append(block)
+    return fix_blocks
+
+
 def fix_discarded_block(discarded_block_with_spans):
     fix_discarded_blocks = []
     for block in discarded_block_with_spans:

+ 7 - 7
magic_pdf/resources/model_config/UniMERNet/demo.yaml

@@ -2,13 +2,13 @@ model:
   arch: unimernet
   model_type: unimernet
   model_config:
-    model_name: ./models
-    max_seq_len: 1024
-    length_aware: False
+    model_name: ./models/unimernet_base
+    max_seq_len: 1536
+
   load_pretrained: True
-  pretrained: ./models/pytorch_model.bin
+  pretrained: './models/unimernet_base/pytorch_model.pth'
   tokenizer_config:
-    path: ./models
+    path: ./models/unimernet_base
 
 datasets:
   formula_rec_eval:
@@ -18,7 +18,7 @@ datasets:
         image_size:
           - 192
           - 672
-   
+
 run:
   runner: runner_iter
   task: unimernet_train
@@ -43,4 +43,4 @@ run:
   distributed_type: ddp  # or fsdp when train llm
 
   generate_cfg:
-    temperature: 0.0
+    temperature: 0.0

+ 5 - 13
magic_pdf/resources/model_config/model_configs.yaml

@@ -1,15 +1,7 @@
-config:
-  device: cpu
-  layout: True
-  formula: True
-  table_config:
-    model: TableMaster
-    is_table_recog_enable: False
-    max_time: 400
-
 weights:
-  layout: Layout/model_final.pth
-  mfd: MFD/weights.pt
-  mfr: MFR/UniMERNet
+  layoutlmv3: Layout/LayoutLMv3/model_final.pth
+  doclayout_yolo: Layout/YOLO/doclayout_yolo_ft.pt
+  yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt
+  unimernet_small: MFR/unimernet_small
   struct_eqtable: TabRec/StructEqTable
-  TableMaster: TabRec/TableMaster
+  tablemaster: TabRec/TableMaster

+ 14 - 1
magic_pdf/tools/cli.py

@@ -45,6 +45,18 @@ without method specified, auto will be used by default.""",
     default='auto',
 )
 @click.option(
+    '-l',
+    '--lang',
+    'lang',
+    type=str,
+    help="""
+    Input the languages in the pdf (if known) to improve OCR accuracy.  Optional.
+    You should input "Abbreviation" with language form url:
+    https://paddlepaddle.github.io/PaddleOCR/latest/en/ppocr/blog/multi_languages.html#5-support-languages-and-abbreviations
+    """,
+    default=None,
+)
+@click.option(
     '-d',
     '--debug',
     'debug_able',
@@ -68,7 +80,7 @@ without method specified, auto will be used by default.""",
     help='The ending page for PDF parsing, beginning from 0.',
     default=None,
 )
-def cli(path, output_dir, method, debug_able, start_page_id, end_page_id):
+def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id):
     model_config.__use_inside_model__ = True
     model_config.__model_mode__ = 'full'
     os.makedirs(output_dir, exist_ok=True)
@@ -90,6 +102,7 @@ def cli(path, output_dir, method, debug_able, start_page_id, end_page_id):
                 debug_able,
                 start_page_id=start_page_id,
                 end_page_id=end_page_id,
+                lang=lang
             )
 
         except Exception as e:

+ 18 - 8
magic_pdf/tools/common.py

@@ -6,8 +6,8 @@ import click
 from loguru import logger
 
 import magic_pdf.model as model_config
-from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_span_bbox,
-                                      drow_model_bbox)
+from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_line_sort_bbox,
+                                      draw_model_bbox, draw_span_bbox)
 from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
 from magic_pdf.pipe.OCRPipe import OCRPipe
 from magic_pdf.pipe.TXTPipe import TXTPipe
@@ -39,16 +39,21 @@ def do_parse(
     f_dump_middle_json=True,
     f_dump_model_json=True,
     f_dump_orig_pdf=True,
-    f_dump_content_list=False,
+    f_dump_content_list=True,
     f_make_md_mode=MakeMode.MM_MD,
     f_draw_model_bbox=False,
+    f_draw_line_sort_bbox=False,
     start_page_id=0,
     end_page_id=None,
+    lang=None,
+    layout_model=None,
+    formula_enable=None,
+    table_enable=None,
 ):
     if debug_able:
         logger.warning('debug mode is on')
-        f_dump_content_list = True
         f_draw_model_bbox = True
+        f_draw_line_sort_bbox = True
 
     orig_model_list = copy.deepcopy(model_list)
     local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name,
@@ -61,13 +66,16 @@ def do_parse(
     if parse_method == 'auto':
         jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
         pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True,
-                       start_page_id=start_page_id, end_page_id=end_page_id)
+                       start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
+                       layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
     elif parse_method == 'txt':
         pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True,
-                       start_page_id=start_page_id, end_page_id=end_page_id)
+                       start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
+                       layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
     elif parse_method == 'ocr':
         pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True,
-                       start_page_id=start_page_id, end_page_id=end_page_id)
+                       start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
+                       layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
     else:
         logger.error('unknown parse method')
         exit(1)
@@ -89,7 +97,9 @@ def do_parse(
     if f_draw_span_bbox:
         draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
     if f_draw_model_bbox:
-        drow_model_bbox(copy.deepcopy(orig_model_list), pdf_bytes, local_md_dir, pdf_file_name)
+        draw_model_bbox(copy.deepcopy(orig_model_list), pdf_bytes, local_md_dir, pdf_file_name)
+    if f_draw_line_sort_bbox:
+        draw_line_sort_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
 
     md_content = pipe.pipe_mk_markdown(image_dir,
                                        drop_mode=DropMode.NONE,

+ 25 - 6
magic_pdf/user_api.py

@@ -26,7 +26,7 @@ PARSE_TYPE_OCR = "ocr"
 
 
 def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False,
-                  start_page_id=0, end_page_id=None,
+                  start_page_id=0, end_page_id=None, lang=None,
                   *args, **kwargs):
     """
     解析文本类pdf
@@ -44,11 +44,14 @@ def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit
 
     pdf_info_dict["_version_name"] = __version__
 
+    if lang is not None:
+        pdf_info_dict["_lang"] = lang
+
     return pdf_info_dict
 
 
 def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False,
-                  start_page_id=0, end_page_id=None,
+                  start_page_id=0, end_page_id=None, lang=None,
                   *args, **kwargs):
     """
     解析ocr类pdf
@@ -66,12 +69,15 @@ def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit
 
     pdf_info_dict["_version_name"] = __version__
 
+    if lang is not None:
+        pdf_info_dict["_lang"] = lang
+
     return pdf_info_dict
 
 
 def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False,
                     input_model_is_empty: bool = False,
-                    start_page_id=0, end_page_id=None,
+                    start_page_id=0, end_page_id=None, lang=None,
                     *args, **kwargs):
     """
     ocr和文本混合的pdf,全部解析出来
@@ -95,9 +101,19 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
     if pdf_info_dict is None or pdf_info_dict.get("_need_drop", False):
         logger.warning(f"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr")
         if input_model_is_empty:
-            pdf_models = doc_analyze(pdf_bytes, ocr=True,
-                                     start_page_id=start_page_id,
-                                     end_page_id=end_page_id)
+            layout_model = kwargs.get("layout_model", None)
+            formula_enable = kwargs.get("formula_enable", None)
+            table_enable = kwargs.get("table_enable", None)
+            pdf_models = doc_analyze(
+                pdf_bytes,
+                ocr=True,
+                start_page_id=start_page_id,
+                end_page_id=end_page_id,
+                lang=lang,
+                layout_model=layout_model,
+                formula_enable=formula_enable,
+                table_enable=table_enable,
+            )
         pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
         if pdf_info_dict is None:
             raise Exception("Both parse_pdf_by_txt and parse_pdf_by_ocr failed.")
@@ -108,4 +124,7 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
 
     pdf_info_dict["_version_name"] = __version__
 
+    if lang is not None:
+        pdf_info_dict["_lang"] = lang
+
     return pdf_info_dict

+ 0 - 0
magic_pdf/utils/__init__.py


+ 11 - 0
magic_pdf/utils/annotations.py

@@ -0,0 +1,11 @@
+
+from loguru import logger
+
+
+def ImportPIL(f):
+    try:
+        import PIL  # noqa: F401
+    except ImportError:
+        logger.error('Pillow not installed, please install by pip.')
+        exit(1)
+    return f

+ 16 - 0
next_docs/en/.readthedocs.yaml

@@ -0,0 +1,16 @@
+version: 2
+
+build:
+  os: ubuntu-22.04
+  tools:
+    python: "3.10"
+
+formats:
+  - epub
+
+python:
+  install:
+    - requirements: docs/requirements.txt
+
+sphinx:
+  configuration: docs/en/conf.py

+ 20 - 0
next_docs/en/Makefile

@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS    ?=
+SPHINXBUILD   ?= sphinx-build
+SOURCEDIR     = .
+BUILDDIR      = _build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+	@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+	@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

二進制
next_docs/en/_static/image/logo.png


+ 9 - 0
next_docs/en/api.rst

@@ -0,0 +1,9 @@
+Data Api
+------------------
+
+.. toctree::
+   :maxdepth: 2
+
+   api/dataset.rst
+   api/data_reader_writer.rst
+   api/read_api.rst

+ 44 - 0
next_docs/en/api/data_reader_writer.rst

@@ -0,0 +1,44 @@
+
+Data Reader Writer
+--------------------
+
+.. autoclass:: magic_pdf.data.data_reader_writer.DataReader
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.data_reader_writer.DataWriter
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.data_reader_writer.S3DataReader
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.data_reader_writer.S3DataWriter
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.data_reader_writer.FileBasedDataReader
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.data_reader_writer.FileBasedDataWriter
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.data_reader_writer.S3DataReader
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.data_reader_writer.S3DataWriter
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.data_reader_writer.MultiBucketS3DataReader
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.data_reader_writer.MultiBucketS3DataWriter
+   :members:
+   :inherited-members:
+

+ 22 - 0
next_docs/en/api/dataset.rst

@@ -0,0 +1,22 @@
+Dataset Api
+------------------
+
+.. autoclass:: magic_pdf.data.dataset.PageableData
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.dataset.Dataset
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.dataset.ImageDataset
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.dataset.PymuDocDataset
+   :members:
+   :inherited-members:
+
+.. autoclass:: magic_pdf.data.dataset.Doc
+   :members:
+   :inherited-members:

+ 0 - 0
next_docs/en/api/io.rst


+ 6 - 0
next_docs/en/api/read_api.rst

@@ -0,0 +1,6 @@
+read_api Api
+------------------
+
+.. automodule:: magic_pdf.data.read_api
+   :members:
+   :inherited-members:

+ 0 - 0
next_docs/en/api/schemas.rst


+ 1 - 0
next_docs/en/api/utils.rst

@@ -0,0 +1 @@
+

+ 122 - 0
next_docs/en/conf.py

@@ -0,0 +1,122 @@
+# Configuration file for the Sphinx documentation builder.
+#
+# This file only contains a selection of the most common options. For a full
+# list see the documentation:
+# https://www.sphinx-doc.org/en/master/usage/configuration.html
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+
+import os
+import subprocess
+import sys
+
+from sphinx.ext import autodoc
+
+
+def install(package):
+    subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])
+
+
+requirements_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'requirements.txt'))
+if os.path.exists(requirements_path):
+    with open(requirements_path) as f:
+        packages = f.readlines()
+    for package in packages:
+        install(package.strip())
+
+sys.path.insert(0, os.path.abspath('../..'))
+
+# -- Project information -----------------------------------------------------
+
+project = 'MinerU'
+copyright = '2024, MinerU Contributors'
+author = 'OpenDataLab'
+
+# The full version, including alpha/beta/rc tags
+version_file = '../../magic_pdf/libs/version.py'
+with open(version_file) as f:
+    exec(compile(f.read(), version_file, 'exec'))
+__version__ = locals()['__version__']
+# The short X.Y version
+version = __version__
+# The full version, including alpha/beta/rc tags
+release = __version__
+
+# -- General configuration ---------------------------------------------------
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+    'sphinx.ext.napoleon',
+    'sphinx.ext.viewcode',
+    'sphinx.ext.intersphinx',
+    'sphinx_copybutton',
+    'sphinx.ext.autodoc',
+    'sphinx.ext.autosummary',
+    'myst_parser',
+    'sphinxarg.ext',
+]
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ['_templates']
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
+
+# Exclude the prompt "$" when copying code
+copybutton_prompt_text = r'\$ '
+copybutton_prompt_is_regexp = True
+
+language = 'en'
+
+# -- Options for HTML output -------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages.  See the documentation for
+# a list of builtin themes.
+#
+html_theme = 'sphinx_book_theme'
+html_logo = '_static/image/logo.png'
+html_theme_options = {
+    'path_to_docs': 'docs/en',
+    'repository_url': 'https://github.com/opendatalab/MinerU',
+    'use_repository_button': True,
+}
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+# html_static_path = ['_static']
+
+# Mock out external dependencies here.
+autodoc_mock_imports = [
+    'cpuinfo',
+    'torch',
+    'transformers',
+    'psutil',
+    'prometheus_client',
+    'sentencepiece',
+    'vllm.cuda_utils',
+    'vllm._C',
+    'numpy',
+    'tqdm',
+]
+
+
+class MockedClassDocumenter(autodoc.ClassDocumenter):
+    """Remove note about base class when a class is derived from object."""
+
+    def add_line(self, line: str, source: str, *lineno: int) -> None:
+        if line == '   Bases: :py:class:`object`':
+            return
+        super().add_line(line, source, *lineno)
+
+
+autodoc.ClassDocumenter = MockedClassDocumenter
+
+navigation_with_keys = False

+ 38 - 0
next_docs/en/index.rst

@@ -0,0 +1,38 @@
+.. xtuner documentation master file, created by
+   sphinx-quickstart on Tue Jan  9 16:33:06 2024.
+   You can adapt this file completely to your liking, but it should at least
+   contain the root `toctree` directive.
+
+Welcome to the MinerU Documentation
+==============================================
+
+.. figure:: ./_static/image/logo.png
+  :align: center
+  :alt: mineru
+  :class: no-scaled-link
+
+.. raw:: html
+
+   <p style="text-align:center">
+   <strong>A one-stop, open-source, high-quality data extraction tool
+   </strong>
+   </p>
+
+   <p style="text-align:center">
+   <script async defer src="https://buttons.github.io/buttons.js"></script>
+   <a class="github-button" href="https://github.com/opendatalab/MinerU" data-show-count="true" data-size="large" aria-label="Star">Star</a>
+   <a class="github-button" href="https://github.com/opendatalab/MinerU/subscription" data-icon="octicon-eye" data-size="large" aria-label="Watch">Watch</a>
+   <a class="github-button" href="https://github.com/opendatalab/MinerU/fork" data-icon="octicon-repo-forked" data-size="large" aria-label="Fork">Fork</a>
+   </p>
+
+
+API Reference
+-------------
+
+If you are looking for information on a specific function, class or
+method, this part of the documentation is for you.
+
+.. toctree::
+   :maxdepth: 2
+
+   api

+ 35 - 0
next_docs/en/make.bat

@@ -0,0 +1,35 @@
+@ECHO OFF
+
+pushd %~dp0
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+	set SPHINXBUILD=sphinx-build
+)
+set SOURCEDIR=.
+set BUILDDIR=_build
+
+%SPHINXBUILD% >NUL 2>NUL
+if errorlevel 9009 (
+	echo.
+	echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+	echo.installed, then set the SPHINXBUILD environment variable to point
+	echo.to the full path of the 'sphinx-build' executable. Alternatively you
+	echo.may add the Sphinx directory to PATH.
+	echo.
+	echo.If you don't have Sphinx installed, grab it from
+	echo.https://www.sphinx-doc.org/
+	exit /b 1
+)
+
+if "%1" == "" goto help
+
+%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+goto end
+
+:help
+%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+
+:end
+popd

+ 11 - 0
next_docs/requirements.txt

@@ -0,0 +1,11 @@
+boto3>=1.28.43
+loguru>=0.6.0
+myst-parser
+Pillow==8.4.0
+pydantic>=2.7.2,<2.8.0
+PyMuPDF>=1.24.9
+sphinx
+sphinx-argparse
+sphinx-book-theme
+sphinx-copybutton
+sphinx_rtd_theme

+ 16 - 0
next_docs/zh_cn/.readthedocs.yaml

@@ -0,0 +1,16 @@
+version: 2
+
+build:
+  os: ubuntu-22.04
+  tools:
+    python: "3.10"
+
+formats:
+  - epub
+
+python:
+  install:
+    - requirements: docs/requirements.txt
+
+sphinx:
+  configuration: docs/zh_cn/conf.py

+ 20 - 0
next_docs/zh_cn/Makefile

@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS    ?=
+SPHINXBUILD   ?= sphinx-build
+SOURCEDIR     = .
+BUILDDIR      = _build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+	@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+	@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

二進制
next_docs/zh_cn/_static/image/logo.png


+ 122 - 0
next_docs/zh_cn/conf.py

@@ -0,0 +1,122 @@
+# Configuration file for the Sphinx documentation builder.
+#
+# This file only contains a selection of the most common options. For a full
+# list see the documentation:
+# https://www.sphinx-doc.org/en/master/usage/configuration.html
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+
+import os
+import subprocess
+import sys
+
+from sphinx.ext import autodoc
+
+
+def install(package):
+    subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])
+
+
+requirements_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'requirements.txt'))
+if os.path.exists(requirements_path):
+    with open(requirements_path) as f:
+        packages = f.readlines()
+    for package in packages:
+        install(package.strip())
+
+sys.path.insert(0, os.path.abspath('../..'))
+
+# -- Project information -----------------------------------------------------
+
+project = 'MinerU'
+copyright = '2024, OpenDataLab'
+author = 'MinerU Contributors'
+
+# The full version, including alpha/beta/rc tags
+version_file = '../../magic_pdf/libs/version.py'
+with open(version_file) as f:
+    exec(compile(f.read(), version_file, 'exec'))
+__version__ = locals()['__version__']
+# The short X.Y version
+version = __version__
+# The full version, including alpha/beta/rc tags
+release = __version__
+
+# -- General configuration ---------------------------------------------------
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+    'sphinx.ext.napoleon',
+    'sphinx.ext.viewcode',
+    'sphinx.ext.intersphinx',
+    'sphinx_copybutton',
+    'sphinx.ext.autodoc',
+    'sphinx.ext.autosummary',
+    'myst_parser',
+    'sphinxarg.ext',
+]
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ['_templates']
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
+
+# Exclude the prompt "$" when copying code
+copybutton_prompt_text = r'\$ '
+copybutton_prompt_is_regexp = True
+
+language = 'zh_CN'
+
+# -- Options for HTML output -------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages.  See the documentation for
+# a list of builtin themes.
+#
+html_theme = 'sphinx_book_theme'
+html_logo = '_static/image/logo.png'
+html_theme_options = {
+    'path_to_docs': 'docs/zh_cn',
+    'repository_url': 'https://github.com/opendatalab/MinerU',
+    'use_repository_button': True,
+}
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+# html_static_path = ['_static']
+
+# Mock out external dependencies here.
+autodoc_mock_imports = [
+    'cpuinfo',
+    'torch',
+    'transformers',
+    'psutil',
+    'prometheus_client',
+    'sentencepiece',
+    'vllm.cuda_utils',
+    'vllm._C',
+    'numpy',
+    'tqdm',
+]
+
+
+class MockedClassDocumenter(autodoc.ClassDocumenter):
+    """Remove note about base class when a class is derived from object."""
+
+    def add_line(self, line: str, source: str, *lineno: int) -> None:
+        if line == '   Bases: :py:class:`object`':
+            return
+        super().add_line(line, source, *lineno)
+
+
+autodoc.ClassDocumenter = MockedClassDocumenter
+
+navigation_with_keys = False

部分文件因文件數量過多而無法顯示