Răsfoiți Sursa

Merge branch 'opendatalab:dev' into dev

linfeng 1 an în urmă
părinte
comite
ebfab42461
59 a modificat fișierele cu 1260 adăugiri și 769 ștergeri
  1. 181 8
      README.md
  2. 0 0
      README_zh-CN.md
  3. 46 320
      magic_pdf/model/pdf_extract_kit.py
  4. 0 36
      magic_pdf/model/pek_sub_modules/post_process.py
  5. 0 388
      magic_pdf/model/pek_sub_modules/self_modify.py
  6. 0 0
      magic_pdf/model/sub_modules/__init__.py
  7. 0 0
      magic_pdf/model/sub_modules/layout/__init__.py
  8. 21 0
      magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py
  9. 0 0
      magic_pdf/model/sub_modules/layout/doclayout_yolo/__init__.py
  10. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/__init__.py
  11. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/backbone.py
  12. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/beit.py
  13. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/deit.py
  14. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/__init__.py
  15. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/__init__.py
  16. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/cord.py
  17. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/data_collator.py
  18. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/funsd.py
  19. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/image_utils.py
  20. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/xfund.py
  21. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/__init__.py
  22. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py
  23. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py
  24. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py
  25. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py
  26. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py
  27. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/model_init.py
  28. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/rcnn_vl.py
  29. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/visualizer.py
  30. 0 0
      magic_pdf/model/sub_modules/mfd/__init__.py
  31. 12 0
      magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py
  32. 0 0
      magic_pdf/model/sub_modules/mfd/yolov8/__init__.py
  33. 0 0
      magic_pdf/model/sub_modules/mfr/__init__.py
  34. 98 0
      magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
  35. 0 0
      magic_pdf/model/sub_modules/mfr/unimernet/__init__.py
  36. 144 0
      magic_pdf/model/sub_modules/model_init.py
  37. 51 0
      magic_pdf/model/sub_modules/model_utils.py
  38. 0 0
      magic_pdf/model/sub_modules/ocr/__init__.py
  39. 0 0
      magic_pdf/model/sub_modules/ocr/paddleocr/__init__.py
  40. 259 0
      magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py
  41. 168 0
      magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py
  42. 213 0
      magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py
  43. 0 0
      magic_pdf/model/sub_modules/reading_oreder/__init__.py
  44. 0 0
      magic_pdf/model/sub_modules/reading_oreder/layoutreader/__init__.py
  45. 0 0
      magic_pdf/model/sub_modules/reading_oreder/layoutreader/helpers.py
  46. 0 0
      magic_pdf/model/sub_modules/reading_oreder/layoutreader/xycut.py
  47. 0 0
      magic_pdf/model/sub_modules/table/__init__.py
  48. 0 0
      magic_pdf/model/sub_modules/table/rapidtable/__init__.py
  49. 14 0
      magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py
  50. 0 0
      magic_pdf/model/sub_modules/table/structeqtable/__init__.py
  51. 3 11
      magic_pdf/model/sub_modules/table/structeqtable/struct_eqtable.py
  52. 11 0
      magic_pdf/model/sub_modules/table/table_utils.py
  53. 0 0
      magic_pdf/model/sub_modules/table/tablemaster/__init__.py
  54. 1 1
      magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py
  55. 3 3
      magic_pdf/pdf_parse_union_core_v2.py
  56. 16 0
      next_docs/README.md
  57. 16 0
      next_docs/README_zh-CN.md
  58. 1 0
      setup.py
  59. 2 2
      tests/test_table/test_tablemaster.py

+ 181 - 8
README.md

@@ -20,6 +20,7 @@
 [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/myhloli/3b3a00a4a0a61577b6c30f989092d20d/mineru_demo.ipynb)
 [![Paper](https://img.shields.io/badge/Paper-arXiv-green)](https://arxiv.org/abs/2409.18839)
 
+
 <a href="https://trendshift.io/repositories/11174" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11174" alt="opendatalab%2FMinerU | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
 
 <!-- language -->
@@ -38,15 +39,10 @@
     👋 join us on <a href="https://discord.gg/Tdedn9GTXq" target="_blank">Discord</a> and <a href="https://cdn.vansin.top/internlm/mineru.jpg" target="_blank">WeChat</a>
 </p>
 
-
-<!-- read the docs -->
-<p align="center">
-    read more docs on  <a href="https://mineru.readthedocs.io/en/latest/"> Read The Docs </a>
-</p>
-
 </div>
 
 # Changelog
+- 2024/11/15 0.9.3 released. Integrated [RapidTable](https://github.com/RapidAI/RapidTable) for table recognition, improving single-table parsing speed by more than 10 times, with higher accuracy and lower GPU memory usage.
 - 2024/11/06 0.9.2 released. Integrated the [StructTable-InternVL2-1B](https://huggingface.co/U4R/StructTable-InternVL2-1B) model for table recognition functionality.
 - 2024/10/31 0.9.0 released. This is a major new version with extensive code refactoring, addressing numerous issues, improving performance, reducing hardware requirements, and enhancing usability:
   - Refactored the sorting module code to use [layoutreader](https://github.com/ppaanngggg/layoutreader) for reading order sorting, ensuring high accuracy in various layouts.
@@ -81,10 +77,12 @@
             <ul>
             <li><a href="#online-demo">Online Demo</a></li>
             <li><a href="#quick-cpu-demo">Quick CPU Demo</a></li>
+            <li><a href="#using-gpu">Using GPU</a></li>
             </ul>
         </li>
         <li><a href="#usage">Usage</a>
             <ul>
+            <li><a href="#command-line">Command Line</a></li>
             <li><a href="#api">API</a></li>
             <li><a href="#deploy-derived-projects">Deploy Derived Projects</a></li>
             <li><a href="#development-guide">Development Guide</a></li>
@@ -93,6 +91,8 @@
       </ul>
     </li>
     <li><a href="#todo">TODO</a></li>
+    <li><a href="#known-issues">Known Issues</a></li>
+    <li><a href="#faq">FAQ</a></li>
     <li><a href="#all-thanks-to-our-contributors">All Thanks To Our Contributors</a></li>
     <li><a href="#license-information">License Information</a></li>
     <li><a href="#acknowledgments">Acknowledgments</a></li>
@@ -114,12 +114,89 @@ Compared to well-known commercial products, MinerU is still young. If you encoun
 
 https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c
 
+## Key Features
+
+- Remove headers, footers, footnotes, page numbers, etc., to ensure semantic coherence.
+- Output text in human-readable order, suitable for single-column, multi-column, and complex layouts.
+- Preserve the structure of the original document, including headings, paragraphs, lists, etc.
+- Extract images, image descriptions, tables, table titles, and footnotes.
+- Automatically recognize and convert formulas in the document to LaTeX format.
+- Automatically recognize and convert tables in the document to LaTeX or HTML format.
+- Automatically detect scanned PDFs and garbled PDFs and enable OCR functionality.
+- OCR supports detection and recognition of 84 languages.
+- Supports multiple output formats, such as multimodal and NLP Markdown, JSON sorted by reading order, and rich intermediate formats.
+- Supports various visualization results, including layout visualization and span visualization, for efficient confirmation of output quality.
+- Supports both CPU and GPU environments.
+- Compatible with Windows, Linux, and Mac platforms.
+
 ## Quick Start
 
-There are multiple different ways to experience MinerU:
+If you encounter any installation issues, please first consult the <a href="#faq">FAQ</a>. </br>
+If the parsing results are not as expected, refer to the <a href="#known-issues">Known Issues</a>. </br>
+There are three different ways to experience MinerU:
 
 - [Online Demo (No Installation Required)](#online-demo)
 - [Quick CPU Demo (Windows, Linux, Mac)](#quick-cpu-demo)
+- [Linux/Windows + CUDA](#Using-GPU)
+
+> [!WARNING]
+> **Pre-installation Notice—Hardware and Software Environment Support**
+>
+> To ensure the stability and reliability of the project, we only optimize and test for specific hardware and software environments during development. This ensures that users deploying and running the project on recommended system configurations will get the best performance with the fewest compatibility issues.
+>
+> By focusing resources on the mainline environment, our team can more efficiently resolve potential bugs and develop new features.
+>
+> In non-mainline environments, due to the diversity of hardware and software configurations, as well as third-party dependency compatibility issues, we cannot guarantee 100% project availability. Therefore, for users who wish to use this project in non-recommended environments, we suggest carefully reading the documentation and FAQ first. Most issues already have corresponding solutions in the FAQ. We also encourage community feedback to help us gradually expand support.
+
+<table>
+    <tr>
+        <td colspan="3" rowspan="2">Operating System</td>
+    </tr>
+    <tr>
+        <td>Ubuntu 22.04 LTS</td>
+        <td>Windows 10 / 11</td>
+        <td>macOS 11+</td>
+    </tr>
+    <tr>
+        <td colspan="3">CPU</td>
+        <td>x86_64(unsupported ARM Linux)</td>
+        <td>x86_64(unsupported ARM Windows)</td>
+        <td>x86_64 / arm64</td>
+    </tr>
+    <tr>
+        <td colspan="3">Memory</td>
+        <td colspan="3">16GB or more, recommended 32GB+</td>
+    </tr>
+    <tr>
+        <td colspan="3">Python Version</td>
+        <td colspan="3">3.10(Please make sure to create a Python 3.10 virtual environment using conda)</td>
+    </tr>
+    <tr>
+        <td colspan="3">Nvidia Driver Version</td>
+        <td>latest (Proprietary Driver)</td>
+        <td>latest</td>
+        <td>None</td>
+    </tr>
+    <tr>
+        <td colspan="3">CUDA Environment</td>
+        <td>Automatic installation [12.1 (pytorch) + 11.8 (paddle)]</td>
+        <td>11.8 (manual installation) + cuDNN v8.7.0 (manual installation)</td>
+        <td>None</td>
+    </tr>
+    <tr>
+        <td rowspan="2">GPU Hardware Support List</td>
+        <td colspan="2">Minimum Requirement 8G+ VRAM</td>
+        <td colspan="2">3060ti/3070/4060<br>
+        8G VRAM enables layout, formula recognition acceleration and OCR acceleration</td>
+        <td rowspan="2">None</td>
+    </tr>
+    <tr>
+        <td colspan="2">Recommended Configuration 10G+ VRAM</td>
+        <td colspan="2">3080/3080ti/3090/3090ti/4070/4070ti/4070tisuper/4080/4090<br>
+        10G VRAM or more can enable layout, formula recognition, OCR acceleration and table recognition acceleration simultaneously
+        </td>
+    </tr>
+</table>
 
 ### Online Demo
 
@@ -154,6 +231,7 @@ You can find the `magic-pdf.json` file in your 【user directory】.
 
 You can modify certain configurations in this file to enable or disable features, such as table recognition:
 
+
 > [!NOTE]
 > If the following items are not present in the JSON, please manually add the required items and remove the comment content (standard JSON does not support comments).
 
@@ -169,15 +247,92 @@ You can modify certain configurations in this file to enable or disable features
         "enable": true  // The formula recognition feature is enabled by default. If you need to disable it, please change the value here to "false".
     },
     "table-config": {
-        "model": "rapid_table",  // Default to using "rapid_table", can be switched to "tablemaster" or "struct_eqtable".
+        "model": "rapid_table",  // When using structEqTable, please change to "struct_eqtable".
         "enable": false, // The table recognition feature is disabled by default. If you need to enable it, please change the value here to "true".
         "max_time": 400
     }
 }
 ```
 
+### Using GPU
+
+If your device supports CUDA and meets the GPU requirements of the mainline environment, you can use GPU acceleration. Please select the appropriate guide based on your system:
+
+- [Ubuntu 22.04 LTS + GPU](docs/README_Ubuntu_CUDA_Acceleration_en_US.md)
+- [Windows 10/11 + GPU](docs/README_Windows_CUDA_Acceleration_en_US.md)
+- Quick Deployment with Docker
+> [!IMPORTANT]
+> Docker requires a GPU with at least 8GB of VRAM, and all acceleration features are enabled by default.
+>
+> Before running this Docker, you can use the following command to check if your device supports CUDA acceleration on Docker.
+> 
+> ```bash
+> docker run --rm --gpus=all nvidia/cuda:12.1.0-base-ubuntu22.04 nvidia-smi
+> ```
+  ```bash
+  wget https://github.com/opendatalab/MinerU/raw/master/Dockerfile
+  docker build -t mineru:latest .
+  docker run --rm -it --gpus=all mineru:latest /bin/bash
+  magic-pdf --help
+  ```
+
 ## Usage
 
+### Command Line
+
+```bash
+magic-pdf --help
+Usage: magic-pdf [OPTIONS]
+
+Options:
+  -v, --version                display the version and exit
+  -p, --path PATH              local pdf filepath or directory  [required]
+  -o, --output-dir PATH        output local directory  [required]
+  -m, --method [ocr|txt|auto]  the method for parsing pdf. ocr: using ocr
+                               technique to extract information from pdf. txt:
+                               suitable for the text-based pdf only and
+                               outperform ocr. auto: automatically choose the
+                               best method for parsing pdf from ocr and txt.
+                               without method specified, auto will be used by
+                               default.
+  -l, --lang TEXT              Input the languages in the pdf (if known) to
+                               improve OCR accuracy.  Optional. You should
+                               input "Abbreviation" with language form url: ht
+                               tps://paddlepaddle.github.io/PaddleOCR/latest/en
+                               /ppocr/blog/multi_languages.html#5-support-languages-
+                               and-abbreviations
+  -d, --debug BOOLEAN          Enables detailed debugging information during
+                               the execution of the CLI commands.
+  -s, --start INTEGER          The starting page for PDF parsing, beginning
+                               from 0.
+  -e, --end INTEGER            The ending page for PDF parsing, beginning from
+                               0.
+  --help                       Show this message and exit.
+
+
+## show version
+magic-pdf -v
+
+## command line example
+magic-pdf -p {some_pdf} -o {some_output_dir} -m auto
+```
+
+`{some_pdf}` can be a single PDF file or a directory containing multiple PDFs.
+The results will be saved in the `{some_output_dir}` directory. The output file list is as follows:
+
+```text
+├── some_pdf.md                          # markdown file
+├── images                               # directory for storing images
+├── some_pdf_layout.pdf                  # layout diagram (Include layout reading order)
+├── some_pdf_middle.json                 # MinerU intermediate processing result
+├── some_pdf_model.json                  # model inference result
+├── some_pdf_origin.pdf                  # original PDF file
+├── some_pdf_spans.pdf                   # smallest granularity bbox position information diagram
+└── some_pdf_content_list.json           # Rich text JSON arranged in reading order
+```
+> [!TIP]
+> For more information about the output files, please refer to the [Output File Description](docs/output_file_en_us.md).
+
 ### API
 
 Processing files from local disk
@@ -234,6 +389,24 @@ TODO
 - [ ] [Chemical formula recognition](docs/chemical_knowledge_introduction/introduction.pdf)
 - [ ] Geometric shape recognition
 
+# Known Issues
+
+- Reading order is determined by the model based on the spatial distribution of readable content, and may be out of order in some areas under extremely complex layouts.
+- Vertical text is not supported.
+- Tables of contents and lists are recognized through rules, and some uncommon list formats may not be recognized.
+- Only one level of headings is supported; hierarchical headings are not currently supported.
+- Code blocks are not yet supported in the layout model.
+- Comic books, art albums, primary school textbooks, and exercises cannot be parsed well.
+- Table recognition may result in row/column recognition errors in complex tables.
+- OCR recognition may produce inaccurate characters in PDFs of lesser-known languages (e.g., diacritical marks in Latin script, easily confused characters in Arabic script).
+- Some formulas may not render correctly in Markdown.
+
+# FAQ
+
+[FAQ in Chinese](docs/FAQ_zh_cn.md)
+
+[FAQ in English](docs/FAQ_en_us.md)
+
 # All Thanks To Our Contributors
 
 <a href="https://github.com/opendatalab/MinerU/graphs/contributors">

Fișier diff suprimat deoarece este prea mare
+ 0 - 0
README_zh-CN.md


+ 46 - 320
magic_pdf/model/pdf_extract_kit.py

@@ -1,203 +1,28 @@
+import numpy as np
+import torch
 from loguru import logger
 import os
 import time
-from magic_pdf.libs.Constants import *
-from magic_pdf.libs.clean_memory import clean_memory
-from magic_pdf.model.model_list import AtomicModel
+import cv2
+import yaml
+from PIL import Image
 
 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
 os.environ['YOLO_VERBOSE'] = 'False'  # disable yolo logger
+
 try:
-    import cv2
-    import yaml
-    import argparse
-    import numpy as np
-    import torch
     import torchtext
 
     if torchtext.__version__ >= "0.18.0":
         torchtext.disable_torchtext_deprecation_warning()
-    from PIL import Image
-    from torchvision import transforms
-    from torch.utils.data import Dataset, DataLoader
-    from ultralytics import YOLO
-    from unimernet.common.config import Config
-    import unimernet.tasks as tasks
-    from unimernet.processors import load_processor
-    from doclayout_yolo import YOLOv10
-    from rapid_table import RapidTable
-    from rapidocr_paddle import RapidOCR
-
-except ImportError as e:
-    logger.exception(e)
-    logger.error(
-        'Required dependency not installed, please install by \n'
-        '"pip install magic-pdf[full] --extra-index-url https://myhloli.github.io/wheels/"')
-    exit(1)
-
-from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
-from magic_pdf.model.pek_sub_modules.post_process import latex_rm_whitespace
-from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
-from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
-from magic_pdf.model.ppTableModel import ppTableModel
-
-
-def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
-    ocr_engine = None
-    if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
-        table_model = StructTableModel(model_path, max_time=max_time)
-    elif table_model_type == MODEL_NAME.TABLE_MASTER:
-        config = {
-            "model_dir": model_path,
-            "device": _device_
-        }
-        table_model = ppTableModel(config)
-    elif table_model_type == MODEL_NAME.RAPID_TABLE:
-        table_model = RapidTable()
-        ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
-    else:
-        logger.error("table model type not allow")
-        exit(1)
-
-    if ocr_engine:
-        return [table_model, ocr_engine]
-    else:
-        return table_model
-
-
-def mfd_model_init(weight):
-    mfd_model = YOLO(weight)
-    return mfd_model
-
-
-def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
-    args = argparse.Namespace(cfg_path=cfg_path, options=None)
-    cfg = Config(args)
-    cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
-    cfg.config.model.model_config.model_name = weight_dir
-    cfg.config.model.tokenizer_config.path = weight_dir
-    task = tasks.setup_task(cfg)
-    model = task.build_model(cfg)
-    model.to(_device_)
-    model.eval()
-    vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
-    mfr_transform = transforms.Compose([vis_processor, ])
-    return [model, mfr_transform]
-
-
-def layout_model_init(weight, config_file, device):
-    model = Layoutlmv3_Predictor(weight, config_file, device)
-    return model
-
-
-def doclayout_yolo_model_init(weight):
-    model = YOLOv10(weight)
-    return model
-
-
-def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None, use_dilation=True, det_db_unclip_ratio=1.8):
-    if lang is not None:
-        model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
-    else:
-        model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
-    return model
-
-
-class MathDataset(Dataset):
-    def __init__(self, image_paths, transform=None):
-        self.image_paths = image_paths
-        self.transform = transform
-
-    def __len__(self):
-        return len(self.image_paths)
-
-    def __getitem__(self, idx):
-        # if not pil image, then convert to pil image
-        if isinstance(self.image_paths[idx], str):
-            raw_image = Image.open(self.image_paths[idx])
-        else:
-            raw_image = self.image_paths[idx]
-        if self.transform:
-            image = self.transform(raw_image)
-            return image
-
-
-class AtomModelSingleton:
-    _instance = None
-    _models = {}
-
-    def __new__(cls, *args, **kwargs):
-        if cls._instance is None:
-            cls._instance = super().__new__(cls)
-        return cls._instance
-
-    def get_atom_model(self, atom_model_name: str, **kwargs):
-        lang = kwargs.get("lang", None)
-        layout_model_name = kwargs.get("layout_model_name", None)
-        key = (atom_model_name, layout_model_name, lang)
-        if key not in self._models:
-            self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
-        return self._models[key]
-
-
-def atom_model_init(model_name: str, **kwargs):
-
-    if model_name == AtomicModel.Layout:
-        if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
-            atom_model = layout_model_init(
-                kwargs.get("layout_weights"),
-                kwargs.get("layout_config_file"),
-                kwargs.get("device")
-            )
-        elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
-            atom_model = doclayout_yolo_model_init(
-                kwargs.get("doclayout_yolo_weights"),
-            )
-    elif model_name == AtomicModel.MFD:
-        atom_model = mfd_model_init(
-            kwargs.get("mfd_weights")
-        )
-    elif model_name == AtomicModel.MFR:
-        atom_model = mfr_model_init(
-            kwargs.get("mfr_weight_dir"),
-            kwargs.get("mfr_cfg_path"),
-            kwargs.get("device")
-        )
-    elif model_name == AtomicModel.OCR:
-        atom_model = ocr_model_init(
-            kwargs.get("ocr_show_log"),
-            kwargs.get("det_db_box_thresh"),
-            kwargs.get("lang")
-        )
-    elif model_name == AtomicModel.Table:
-        atom_model = table_model_init(
-            kwargs.get("table_model_name"),
-            kwargs.get("table_model_path"),
-            kwargs.get("table_max_time"),
-            kwargs.get("device")
-        )
-    else:
-        logger.error("model name not allow")
-        exit(1)
-
-    return atom_model
-
-
-#  Unified crop img logic
-def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
-    crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
-    crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
-    # Create a white background with an additional width and height of 50
-    crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
-    crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
-    return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
+except ImportError:
+    pass
 
-    # Crop image
-    crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
-    cropped_img = input_pil_img.crop(crop_box)
-    return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
-    return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
-    return return_image, return_list
+from magic_pdf.libs.Constants import *
+from magic_pdf.model.model_list import AtomicModel
+from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
+from magic_pdf.model.sub_modules.model_utils import get_res_list_from_layout_res, crop_img, clean_vram
+from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list
 
 
 class CustomPEKModel:
@@ -243,7 +68,8 @@ class CustomPEKModel:
         logger.info(
             "DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
             "apply_table: {}, table_model: {}, lang: {}".format(
-                self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name, self.lang
+                self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name,
+                self.lang
             )
         )
         # 初始化解析方案
@@ -256,17 +82,17 @@ class CustomPEKModel:
 
         # 初始化公式识别
         if self.apply_formula:
-
             # 初始化公式检测模型
             self.mfd_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.MFD,
-                mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name]))
+                mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name])),
+                device=self.device
             )
 
             # 初始化公式解析模型
             mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
             mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
-            self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
+            self.mfr_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.MFR,
                 mfr_weight_dir=mfr_weight_dir,
                 mfr_cfg_path=mfr_cfg_path,
@@ -286,7 +112,8 @@ class CustomPEKModel:
             self.layout_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.Layout,
                 layout_model_name=MODEL_NAME.DocLayout_YOLO,
-                doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
+                doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
+                device=self.device
             )
         # 初始化ocr
         if self.apply_ocr:
@@ -299,22 +126,13 @@ class CustomPEKModel:
         # init table model
         if self.apply_table:
             table_model_dir = self.configs["weights"][self.table_model_name]
-            if self.table_model_name in [MODEL_NAME.STRUCT_EQTABLE, MODEL_NAME.TABLE_MASTER]:
-                self.table_model = atom_model_manager.get_atom_model(
-                    atom_model_name=AtomicModel.Table,
-                    table_model_name=self.table_model_name,
-                    table_model_path=str(os.path.join(models_dir, table_model_dir)),
-                    table_max_time=self.table_max_time,
-                    device=self.device
-                )
-            elif self.table_model_name in [MODEL_NAME.RAPID_TABLE]:
-                self.table_model, self.ocr_engine =atom_model_manager.get_atom_model(
-                    atom_model_name=AtomicModel.Table,
-                    table_model_name=self.table_model_name,
-                    table_model_path=str(os.path.join(models_dir, table_model_dir)),
-                    table_max_time=self.table_max_time,
-                    device=self.device
-                )
+            self.table_model = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.Table,
+                table_model_name=self.table_model_name,
+                table_model_path=str(os.path.join(models_dir, table_model_dir)),
+                table_max_time=self.table_max_time,
+                device=self.device
+            )
 
         logger.info('DocAnalysis init done!')
 
@@ -322,26 +140,15 @@ class CustomPEKModel:
 
         page_start = time.time()
 
-        latex_filling_list = []
-        mf_image_list = []
-
         # layout检测
         layout_start = time.time()
+        layout_res = []
         if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
             # layoutlmv3
             layout_res = self.layout_model(image, ignore_catids=[])
         elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
             # doclayout_yolo
-            layout_res = []
-            doclayout_yolo_res = self.layout_model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
-            for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), doclayout_yolo_res.boxes.cls.cpu()):
-                xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
-                new_item = {
-                    'category_id': int(cla.item()),
-                    'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
-                    'score': round(float(conf.item()), 3),
-                }
-                layout_res.append(new_item)
+            layout_res = self.layout_model.predict(image)
         layout_cost = round(time.time() - layout_start, 2)
         logger.info(f"layout detection time: {layout_cost}")
 
@@ -350,58 +157,21 @@ class CustomPEKModel:
         if self.apply_formula:
             # 公式检测
             mfd_start = time.time()
-            mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
+            mfd_res = self.mfd_model.predict(image)
             logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
-            for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
-                xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
-                new_item = {
-                    'category_id': 13 + int(cla.item()),
-                    'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
-                    'score': round(float(conf.item()), 2),
-                    'latex': '',
-                }
-                layout_res.append(new_item)
-                latex_filling_list.append(new_item)
-                bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
-                mf_image_list.append(bbox_img)
 
             # 公式识别
             mfr_start = time.time()
-            dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
-            dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
-            mfr_res = []
-            for mf_img in dataloader:
-                mf_img = mf_img.to(self.device)
-                with torch.no_grad():
-                    output = self.mfr_model.generate({'image': mf_img})
-                mfr_res.extend(output['pred_str'])
-            for res, latex in zip(latex_filling_list, mfr_res):
-                res['latex'] = latex_rm_whitespace(latex)
+            formula_list = self.mfr_model.predict(mfd_res, image)
+            layout_res.extend(formula_list)
             mfr_cost = round(time.time() - mfr_start, 2)
-            logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
-
-        # Select regions for OCR / formula regions / table regions
-        ocr_res_list = []
-        table_res_list = []
-        single_page_mfdetrec_res = []
-        for res in layout_res:
-            if int(res['category_id']) in [13, 14]:
-                single_page_mfdetrec_res.append({
-                    "bbox": [int(res['poly'][0]), int(res['poly'][1]),
-                             int(res['poly'][4]), int(res['poly'][5])],
-                })
-            elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
-                ocr_res_list.append(res)
-            elif int(res['category_id']) in [5]:
-                table_res_list.append(res)
-
-        if torch.cuda.is_available() and self.device != 'cpu':
-            total_memory = torch.cuda.get_device_properties(self.device).total_memory / (1024 ** 3)  # 将字节转换为 GB
-            if total_memory <= 8:
-                gc_start = time.time()
-                clean_memory()
-                gc_time = round(time.time() - gc_start, 2)
-                logger.info(f"gc time: {gc_time}")
+            logger.info(f"formula nums: {len(formula_list)}, mfr time: {mfr_cost}")
+
+        # 清理显存
+        clean_vram(self.device, vram_threshold=8)
+
+        # 从layout_res中获取ocr区域、表格区域、公式区域
+        ocr_res_list, table_res_list, single_page_mfdetrec_res = get_res_list_from_layout_res(layout_res)
 
         # ocr识别
         if self.apply_ocr:
@@ -409,23 +179,7 @@ class CustomPEKModel:
             # Process each area that requires OCR processing
             for res in ocr_res_list:
                 new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
-                paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
-                # Adjust the coordinates of the formula area
-                adjusted_mfdetrec_res = []
-                for mf_res in single_page_mfdetrec_res:
-                    mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
-                    # Adjust the coordinates of the formula area to the coordinates relative to the cropping area
-                    x0 = mf_xmin - xmin + paste_x
-                    y0 = mf_ymin - ymin + paste_y
-                    x1 = mf_xmax - xmin + paste_x
-                    y1 = mf_ymax - ymin + paste_y
-                    # Filter formula blocks outside the graph
-                    if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
-                        continue
-                    else:
-                        adjusted_mfdetrec_res.append({
-                            "bbox": [x0, y0, x1, y1],
-                        })
+                adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
 
                 # OCR recognition
                 new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
@@ -433,22 +187,8 @@ class CustomPEKModel:
 
                 # Integration results
                 if ocr_res:
-                    for box_ocr_res in ocr_res:
-                        p1, p2, p3, p4 = box_ocr_res[0]
-                        text, score = box_ocr_res[1]
-
-                        # Convert the coordinates back to the original coordinate system
-                        p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
-                        p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
-                        p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
-                        p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
-
-                        layout_res.append({
-                            'category_id': 15,
-                            'poly': p1 + p2 + p3 + p4,
-                            'score': round(score, 2),
-                            'text': text,
-                        })
+                    ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
+                    layout_res.extend(ocr_result_list)
 
             ocr_cost = round(time.time() - ocr_start, 2)
             logger.info(f"ocr time: {ocr_cost}")
@@ -459,8 +199,6 @@ class CustomPEKModel:
             for res in table_res_list:
                 new_image, _ = crop_img(res, pil_img)
                 single_table_start_time = time.time()
-                # logger.info("------------------table recognition processing begins-----------------")
-                latex_code = None
                 html_code = None
                 if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
                     with torch.no_grad():
@@ -470,33 +208,21 @@ class CustomPEKModel:
                 elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
                     html_code = self.table_model.img2html(new_image)
                 elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
-                    ocr_result, _ = self.ocr_engine(np.asarray(new_image))
-                    html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(new_image), ocr_result)
-
+                    html_code, table_cell_bboxes, elapse = self.table_model.predict(new_image)
                 run_time = time.time() - single_table_start_time
-                # logger.info(f"------------table recognition processing ends within {run_time}s-----")
                 if run_time > self.table_max_time:
-                    logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
+                    logger.warning(f"table recognition processing exceeds max time {self.table_max_time}s")
                 # 判断是否返回正常
-
-                if latex_code:
-                    expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith('end{table}')
-                    if expected_ending:
-                        res["latex"] = latex_code
-                    else:
-                        logger.warning(f"table recognition processing fails, not found expected LaTeX table end")
-                elif html_code:
+                if html_code:
                     expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
                     if expected_ending:
                         res["html"] = html_code
                     else:
                         logger.warning(f"table recognition processing fails, not found expected HTML table end")
                 else:
-                    logger.warning(f"table recognition processing fails, not get latex or html return")
+                    logger.warning(f"table recognition processing fails, not get html return")
             logger.info(f"table time: {round(time.time() - table_start, 2)}")
 
         logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----")
 
         return layout_res
-
-

+ 0 - 36
magic_pdf/model/pek_sub_modules/post_process.py

@@ -1,36 +0,0 @@
-import re
-
-def layout_rm_equation(layout_res):
-    rm_idxs = []
-    for idx, ele in enumerate(layout_res['layout_dets']):
-        if ele['category_id'] == 10:
-            rm_idxs.append(idx)
-    
-    for idx in rm_idxs[::-1]:
-        del layout_res['layout_dets'][idx]
-    return layout_res
-
-
-def get_croped_image(image_pil, bbox):
-    x_min, y_min, x_max, y_max = bbox
-    croped_img = image_pil.crop((x_min, y_min, x_max, y_max))
-    return croped_img
-
-
-def latex_rm_whitespace(s: str):
-    """Remove unnecessary whitespace from LaTeX code.
-    """
-    text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
-    letter = '[a-zA-Z]'
-    noletter = '[\W_^\d]'
-    names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
-    s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
-    news = s
-    while True:
-        s = news
-        news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s)
-        news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news)
-        news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news)
-        if news == s:
-            break
-    return s

+ 0 - 388
magic_pdf/model/pek_sub_modules/self_modify.py

@@ -1,388 +0,0 @@
-import time
-import copy
-import base64
-import cv2
-import numpy as np
-from io import BytesIO
-from PIL import Image
-
-from paddleocr import PaddleOCR
-from paddleocr.ppocr.utils.logging import get_logger
-from paddleocr.ppocr.utils.utility import check_and_read, alpha_to_color, binarize_img
-from paddleocr.tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image, get_minarea_rect_crop
-
-from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
-from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
-
-logger = get_logger()
-
-
-def img_decode(content: bytes):
-    np_arr = np.frombuffer(content, dtype=np.uint8)
-    return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
-
-
-def check_img(img):
-    if isinstance(img, bytes):
-        img = img_decode(img)
-    if isinstance(img, str):
-        image_file = img
-        img, flag_gif, flag_pdf = check_and_read(image_file)
-        if not flag_gif and not flag_pdf:
-            with open(image_file, 'rb') as f:
-                img_str = f.read()
-                img = img_decode(img_str)
-            if img is None:
-                try:
-                    buf = BytesIO()
-                    image = BytesIO(img_str)
-                    im = Image.open(image)
-                    rgb = im.convert('RGB')
-                    rgb.save(buf, 'jpeg')
-                    buf.seek(0)
-                    image_bytes = buf.read()
-                    data_base64 = str(base64.b64encode(image_bytes),
-                                      encoding="utf-8")
-                    image_decode = base64.b64decode(data_base64)
-                    img_array = np.frombuffer(image_decode, np.uint8)
-                    img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
-                except:
-                    logger.error("error in loading image:{}".format(image_file))
-                    return None
-        if img is None:
-            logger.error("error in loading image:{}".format(image_file))
-            return None
-    if isinstance(img, np.ndarray) and len(img.shape) == 2:
-        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
-
-    return img
-
-
-def sorted_boxes(dt_boxes):
-    """
-    Sort text boxes in order from top to bottom, left to right
-    args:
-        dt_boxes(array):detected text boxes with shape [4, 2]
-    return:
-        sorted boxes(array) with shape [4, 2]
-    """
-    num_boxes = dt_boxes.shape[0]
-    sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
-    _boxes = list(sorted_boxes)
-
-    for i in range(num_boxes - 1):
-        for j in range(i, -1, -1):
-            if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
-                    (_boxes[j + 1][0][0] < _boxes[j][0][0]):
-                tmp = _boxes[j]
-                _boxes[j] = _boxes[j + 1]
-                _boxes[j + 1] = tmp
-            else:
-                break
-    return _boxes
-
-
-def bbox_to_points(bbox):
-    """ 将bbox格式转换为四个顶点的数组 """
-    x0, y0, x1, y1 = bbox
-    return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32')
-
-
-def points_to_bbox(points):
-    """ 将四个顶点的数组转换为bbox格式 """
-    x0, y0 = points[0]
-    x1, _ = points[1]
-    _, y1 = points[2]
-    return [x0, y0, x1, y1]
-
-
-def merge_intervals(intervals):
-    # Sort the intervals based on the start value
-    intervals.sort(key=lambda x: x[0])
-
-    merged = []
-    for interval in intervals:
-        # If the list of merged intervals is empty or if the current
-        # interval does not overlap with the previous, simply append it.
-        if not merged or merged[-1][1] < interval[0]:
-            merged.append(interval)
-        else:
-            # Otherwise, there is overlap, so we merge the current and previous intervals.
-            merged[-1][1] = max(merged[-1][1], interval[1])
-
-    return merged
-
-
-def remove_intervals(original, masks):
-    # Merge all mask intervals
-    merged_masks = merge_intervals(masks)
-
-    result = []
-    original_start, original_end = original
-
-    for mask in merged_masks:
-        mask_start, mask_end = mask
-
-        # If the mask starts after the original range, ignore it
-        if mask_start > original_end:
-            continue
-
-        # If the mask ends before the original range starts, ignore it
-        if mask_end < original_start:
-            continue
-
-        # Remove the masked part from the original range
-        if original_start < mask_start:
-            result.append([original_start, mask_start - 1])
-
-        original_start = max(mask_end + 1, original_start)
-
-    # Add the remaining part of the original range, if any
-    if original_start <= original_end:
-        result.append([original_start, original_end])
-
-    return result
-
-
-def update_det_boxes(dt_boxes, mfd_res):
-    new_dt_boxes = []
-    for text_box in dt_boxes:
-        text_bbox = points_to_bbox(text_box)
-        masks_list = []
-        for mf_box in mfd_res:
-            mf_bbox = mf_box['bbox']
-            if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):
-                masks_list.append([mf_bbox[0], mf_bbox[2]])
-        text_x_range = [text_bbox[0], text_bbox[2]]
-        text_remove_mask_range = remove_intervals(text_x_range, masks_list)
-        temp_dt_box = []
-        for text_remove_mask in text_remove_mask_range:
-            temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]]))
-        if len(temp_dt_box) > 0:
-            new_dt_boxes.extend(temp_dt_box)
-    return new_dt_boxes
-
-
-def merge_overlapping_spans(spans):
-    """
-    Merges overlapping spans on the same line.
-
-    :param spans: A list of span coordinates [(x1, y1, x2, y2), ...]
-    :return: A list of merged spans
-    """
-    # Return an empty list if the input spans list is empty
-    if not spans:
-        return []
-
-    # Sort spans by their starting x-coordinate
-    spans.sort(key=lambda x: x[0])
-
-    # Initialize the list of merged spans
-    merged = []
-    for span in spans:
-        # Unpack span coordinates
-        x1, y1, x2, y2 = span
-        # If the merged list is empty or there's no horizontal overlap, add the span directly
-        if not merged or merged[-1][2] < x1:
-            merged.append(span)
-        else:
-            # If there is horizontal overlap, merge the current span with the previous one
-            last_span = merged.pop()
-            # Update the merged span's top-left corner to the smaller (x1, y1) and bottom-right to the larger (x2, y2)
-            x1 = min(last_span[0], x1)
-            y1 = min(last_span[1], y1)
-            x2 = max(last_span[2], x2)
-            y2 = max(last_span[3], y2)
-            # Add the merged span back to the list
-            merged.append((x1, y1, x2, y2))
-
-    # Return the list of merged spans
-    return merged
-
-
-def merge_det_boxes(dt_boxes):
-    """
-    Merge detection boxes.
-
-    This function takes a list of detected bounding boxes, each represented by four corner points.
-    The goal is to merge these bounding boxes into larger text regions.
-
-    Parameters:
-    dt_boxes (list): A list containing multiple text detection boxes, where each box is defined by four corner points.
-
-    Returns:
-    list: A list containing the merged text regions, where each region is represented by four corner points.
-    """
-    # Convert the detection boxes into a dictionary format with bounding boxes and type
-    dt_boxes_dict_list = []
-    for text_box in dt_boxes:
-        text_bbox = points_to_bbox(text_box)
-        text_box_dict = {
-            'bbox': text_bbox,
-            'type': 'text',
-        }
-        dt_boxes_dict_list.append(text_box_dict)
-
-    # Merge adjacent text regions into lines
-    lines = merge_spans_to_line(dt_boxes_dict_list)
-
-    # Initialize a new list for storing the merged text regions
-    new_dt_boxes = []
-    for line in lines:
-        line_bbox_list = []
-        for span in line:
-            line_bbox_list.append(span['bbox'])
-
-        # Merge overlapping text regions within the same line
-        merged_spans = merge_overlapping_spans(line_bbox_list)
-
-        # Convert the merged text regions back to point format and add them to the new detection box list
-        for span in merged_spans:
-            new_dt_boxes.append(bbox_to_points(span))
-
-    return new_dt_boxes
-
-
-class ModifiedPaddleOCR(PaddleOCR):
-    def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, mfd_res=None, alpha_color=(255, 255, 255)):
-        """
-        OCR with PaddleOCR
-        args:
-            img: img for OCR, support ndarray, img_path and list or ndarray
-            det: use text detection or not. If False, only rec will be exec. Default is True
-            rec: use text recognition or not. If False, only det will be exec. Default is True
-            cls: use angle classifier or not. Default is True. If True, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False.
-            bin: binarize image to black and white. Default is False.
-            inv: invert image colors. Default is False.
-            alpha_color: set RGB color Tuple for transparent parts replacement. Default is pure white.
-        """
-        assert isinstance(img, (np.ndarray, list, str, bytes))
-        if isinstance(img, list) and det == True:
-            logger.error('When input a list of images, det must be false')
-            exit(0)
-        if cls == True and self.use_angle_cls == False:
-            pass
-            # logger.warning(
-            #     'Since the angle classifier is not initialized, it will not be used during the forward process'
-            # )
-
-        img = check_img(img)
-        # for infer pdf file
-        if isinstance(img, list):
-            if self.page_num > len(img) or self.page_num == 0:
-                self.page_num = len(img)
-            imgs = img[:self.page_num]
-        else:
-            imgs = [img]
-
-        def preprocess_image(_image):
-            _image = alpha_to_color(_image, alpha_color)
-            if inv:
-                _image = cv2.bitwise_not(_image)
-            if bin:
-                _image = binarize_img(_image)
-            return _image
-
-        if det and rec:
-            ocr_res = []
-            for idx, img in enumerate(imgs):
-                img = preprocess_image(img)
-                dt_boxes, rec_res, _ = self.__call__(img, cls, mfd_res=mfd_res)
-                if not dt_boxes and not rec_res:
-                    ocr_res.append(None)
-                    continue
-                tmp_res = [[box.tolist(), res]
-                           for box, res in zip(dt_boxes, rec_res)]
-                ocr_res.append(tmp_res)
-            return ocr_res
-        elif det and not rec:
-            ocr_res = []
-            for idx, img in enumerate(imgs):
-                img = preprocess_image(img)
-                dt_boxes, elapse = self.text_detector(img)
-                if not dt_boxes:
-                    ocr_res.append(None)
-                    continue
-                tmp_res = [box.tolist() for box in dt_boxes]
-                ocr_res.append(tmp_res)
-            return ocr_res
-        else:
-            ocr_res = []
-            cls_res = []
-            for idx, img in enumerate(imgs):
-                if not isinstance(img, list):
-                    img = preprocess_image(img)
-                    img = [img]
-                if self.use_angle_cls and cls:
-                    img, cls_res_tmp, elapse = self.text_classifier(img)
-                    if not rec:
-                        cls_res.append(cls_res_tmp)
-                rec_res, elapse = self.text_recognizer(img)
-                ocr_res.append(rec_res)
-            if not rec:
-                return cls_res
-            return ocr_res
-
-    def __call__(self, img, cls=True, mfd_res=None):
-        time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
-
-        if img is None:
-            logger.debug("no valid image provided")
-            return None, None, time_dict
-
-        start = time.time()
-        ori_im = img.copy()
-        dt_boxes, elapse = self.text_detector(img)
-        time_dict['det'] = elapse
-
-        if dt_boxes is None:
-            logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
-            end = time.time()
-            time_dict['all'] = end - start
-            return None, None, time_dict
-        else:
-            logger.debug("dt_boxes num : {}, elapsed : {}".format(
-                len(dt_boxes), elapse))
-        img_crop_list = []
-
-        dt_boxes = sorted_boxes(dt_boxes)
-
-        dt_boxes = merge_det_boxes(dt_boxes)
-
-        if mfd_res:
-            bef = time.time()
-            dt_boxes = update_det_boxes(dt_boxes, mfd_res)
-            aft = time.time()
-            logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
-                len(dt_boxes), aft - bef))
-
-        for bno in range(len(dt_boxes)):
-            tmp_box = copy.deepcopy(dt_boxes[bno])
-            if self.args.det_box_type == "quad":
-                img_crop = get_rotate_crop_image(ori_im, tmp_box)
-            else:
-                img_crop = get_minarea_rect_crop(ori_im, tmp_box)
-            img_crop_list.append(img_crop)
-        if self.use_angle_cls and cls:
-            img_crop_list, angle_list, elapse = self.text_classifier(
-                img_crop_list)
-            time_dict['cls'] = elapse
-            logger.debug("cls num  : {}, elapsed : {}".format(
-                len(img_crop_list), elapse))
-
-        rec_res, elapse = self.text_recognizer(img_crop_list)
-        time_dict['rec'] = elapse
-        logger.debug("rec_res num  : {}, elapsed : {}".format(
-            len(rec_res), elapse))
-        if self.args.save_crop_res:
-            self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,
-                                   rec_res)
-        filter_boxes, filter_rec_res = [], []
-        for box, rec_result in zip(dt_boxes, rec_res):
-            text, score = rec_result
-            if score >= self.drop_score:
-                filter_boxes.append(box)
-                filter_rec_res.append(rec_result)
-        end = time.time()
-        time_dict['all'] = end - start
-        return filter_boxes, filter_rec_res, time_dict

+ 0 - 0
magic_pdf/model/pek_sub_modules/__init__.py → magic_pdf/model/sub_modules/__init__.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/__init__.py → magic_pdf/model/sub_modules/layout/__init__.py


+ 21 - 0
magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py

@@ -0,0 +1,21 @@
+from doclayout_yolo import YOLOv10
+
+
+class DocLayoutYOLOModel(object):
+    def __init__(self, weight, device):
+        self.model = YOLOv10(weight)
+        self.device = device
+
+    def predict(self, image):
+        layout_res = []
+        doclayout_yolo_res = self.model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
+        for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(),
+                                   doclayout_yolo_res.boxes.cls.cpu()):
+            xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
+            new_item = {
+                'category_id': int(cla.item()),
+                'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
+                'score': round(float(conf.item()), 3),
+            }
+            layout_res.append(new_item)
+        return layout_res

+ 0 - 0
magic_pdf/model/pek_sub_modules/structeqtable/__init__.py → magic_pdf/model/sub_modules/layout/doclayout_yolo/__init__.py


+ 0 - 0
magic_pdf/model/v3/__init__.py → magic_pdf/model/sub_modules/layout/layoutlmv3/__init__.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/backbone.py → magic_pdf/model/sub_modules/layout/layoutlmv3/backbone.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/beit.py → magic_pdf/model/sub_modules/layout/layoutlmv3/beit.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/deit.py → magic_pdf/model/sub_modules/layout/layoutlmv3/deit.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/__init__.py → magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/__init__.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/__init__.py → magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/__init__.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/cord.py → magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/cord.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/data_collator.py → magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/data_collator.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/funsd.py → magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/funsd.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/image_utils.py → magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/image_utils.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/xfund.py → magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/xfund.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/__init__.py → magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/__init__.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py → magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py → magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py → magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py → magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py → magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/model_init.py → magic_pdf/model/sub_modules/layout/layoutlmv3/model_init.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/rcnn_vl.py → magic_pdf/model/sub_modules/layout/layoutlmv3/rcnn_vl.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/visualizer.py → magic_pdf/model/sub_modules/layout/layoutlmv3/visualizer.py


+ 0 - 0
magic_pdf/model/sub_modules/mfd/__init__.py


+ 12 - 0
magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py

@@ -0,0 +1,12 @@
+from ultralytics import YOLO
+
+
+class YOLOv8MFDModel(object):
+    def __init__(self, weight, device='cpu'):
+        self.mfd_model = YOLO(weight)
+        self.device = device
+
+    def predict(self, image):
+        mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
+        return mfd_res
+

+ 0 - 0
magic_pdf/model/sub_modules/mfd/yolov8/__init__.py


+ 0 - 0
magic_pdf/model/sub_modules/mfr/__init__.py


+ 98 - 0
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py

@@ -0,0 +1,98 @@
+import os
+import argparse
+import re
+
+from PIL import Image
+import torch
+from torch.utils.data import Dataset, DataLoader
+from torchvision import transforms
+from unimernet.common.config import Config
+import unimernet.tasks as tasks
+from unimernet.processors import load_processor
+
+
+class MathDataset(Dataset):
+    def __init__(self, image_paths, transform=None):
+        self.image_paths = image_paths
+        self.transform = transform
+
+    def __len__(self):
+        return len(self.image_paths)
+
+    def __getitem__(self, idx):
+        # if not pil image, then convert to pil image
+        if isinstance(self.image_paths[idx], str):
+            raw_image = Image.open(self.image_paths[idx])
+        else:
+            raw_image = self.image_paths[idx]
+        if self.transform:
+            image = self.transform(raw_image)
+            return image
+
+
+def latex_rm_whitespace(s: str):
+    """Remove unnecessary whitespace from LaTeX code.
+    """
+    text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
+    letter = '[a-zA-Z]'
+    noletter = '[\W_^\d]'
+    names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
+    s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
+    news = s
+    while True:
+        s = news
+        news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s)
+        news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news)
+        news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news)
+        if news == s:
+            break
+    return s
+
+
+class UnimernetModel(object):
+    def __init__(self, weight_dir, cfg_path, _device_='cpu'):
+
+        args = argparse.Namespace(cfg_path=cfg_path, options=None)
+        cfg = Config(args)
+        cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
+        cfg.config.model.model_config.model_name = weight_dir
+        cfg.config.model.tokenizer_config.path = weight_dir
+        task = tasks.setup_task(cfg)
+        self.model = task.build_model(cfg)
+        self.device = _device_
+        self.model.to(_device_)
+        self.model.eval()
+        vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
+        self.mfr_transform = transforms.Compose([vis_processor, ])
+
+    def predict(self, mfd_res, image):
+
+        formula_list = []
+        mf_image_list = []
+        for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
+            xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
+            new_item = {
+                'category_id': 13 + int(cla.item()),
+                'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
+                'score': round(float(conf.item()), 2),
+                'latex': '',
+            }
+            formula_list.append(new_item)
+            pil_img = Image.fromarray(image)
+            bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
+            mf_image_list.append(bbox_img)
+
+        dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
+        dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
+        mfr_res = []
+        for mf_img in dataloader:
+            mf_img = mf_img.to(self.device)
+            with torch.no_grad():
+                output = self.model.generate({'image': mf_img})
+            mfr_res.extend(output['pred_str'])
+        for res, latex in zip(formula_list, mfr_res):
+            res['latex'] = latex_rm_whitespace(latex)
+        return formula_list
+
+
+

+ 0 - 0
magic_pdf/model/sub_modules/mfr/unimernet/__init__.py


+ 144 - 0
magic_pdf/model/sub_modules/model_init.py

@@ -0,0 +1,144 @@
+from loguru import logger
+
+from magic_pdf.libs.Constants import MODEL_NAME
+from magic_pdf.model.model_list import AtomicModel
+from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
+from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
+from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
+
+from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
+from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
+# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
+from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
+from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
+from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
+
+
+def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
+    if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
+        table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
+    elif table_model_type == MODEL_NAME.TABLE_MASTER:
+        config = {
+            "model_dir": model_path,
+            "device": _device_
+        }
+        table_model = TableMasterPaddleModel(config)
+    elif table_model_type == MODEL_NAME.RAPID_TABLE:
+        table_model = RapidTableModel()
+    else:
+        logger.error("table model type not allow")
+        exit(1)
+
+    return table_model
+
+
+def mfd_model_init(weight, device='cpu'):
+    mfd_model = YOLOv8MFDModel(weight, device)
+    return mfd_model
+
+
+def mfr_model_init(weight_dir, cfg_path, device='cpu'):
+    mfr_model = UnimernetModel(weight_dir, cfg_path, device)
+    return mfr_model
+
+
+def layout_model_init(weight, config_file, device):
+    model = Layoutlmv3_Predictor(weight, config_file, device)
+    return model
+
+
+def doclayout_yolo_model_init(weight, device='cpu'):
+    model = DocLayoutYOLOModel(weight, device)
+    return model
+
+
+def ocr_model_init(show_log: bool = False,
+                   det_db_box_thresh=0.3,
+                   lang=None,
+                   use_dilation=True,
+                   det_db_unclip_ratio=1.8,
+                   ):
+    if lang is not None:
+        model = ModifiedPaddleOCR(
+            show_log=show_log,
+            det_db_box_thresh=det_db_box_thresh,
+            lang=lang,
+            use_dilation=use_dilation,
+            det_db_unclip_ratio=det_db_unclip_ratio,
+        )
+    else:
+        model = ModifiedPaddleOCR(
+            show_log=show_log,
+            det_db_box_thresh=det_db_box_thresh,
+            use_dilation=use_dilation,
+            det_db_unclip_ratio=det_db_unclip_ratio,
+            # use_angle_cls=True,
+        )
+    return model
+
+
+class AtomModelSingleton:
+    _instance = None
+    _models = {}
+
+    def __new__(cls, *args, **kwargs):
+        if cls._instance is None:
+            cls._instance = super().__new__(cls)
+        return cls._instance
+
+    def get_atom_model(self, atom_model_name: str, **kwargs):
+        lang = kwargs.get("lang", None)
+        layout_model_name = kwargs.get("layout_model_name", None)
+        key = (atom_model_name, layout_model_name, lang)
+        if key not in self._models:
+            self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
+        return self._models[key]
+
+
+def atom_model_init(model_name: str, **kwargs):
+    atom_model = None
+    if model_name == AtomicModel.Layout:
+        if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
+            atom_model = layout_model_init(
+                kwargs.get("layout_weights"),
+                kwargs.get("layout_config_file"),
+                kwargs.get("device")
+            )
+        elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
+            atom_model = doclayout_yolo_model_init(
+                kwargs.get("doclayout_yolo_weights"),
+                kwargs.get("device")
+            )
+    elif model_name == AtomicModel.MFD:
+        atom_model = mfd_model_init(
+            kwargs.get("mfd_weights"),
+            kwargs.get("device")
+        )
+    elif model_name == AtomicModel.MFR:
+        atom_model = mfr_model_init(
+            kwargs.get("mfr_weight_dir"),
+            kwargs.get("mfr_cfg_path"),
+            kwargs.get("device")
+        )
+    elif model_name == AtomicModel.OCR:
+        atom_model = ocr_model_init(
+            kwargs.get("ocr_show_log"),
+            kwargs.get("det_db_box_thresh"),
+            kwargs.get("lang")
+        )
+    elif model_name == AtomicModel.Table:
+        atom_model = table_model_init(
+            kwargs.get("table_model_name"),
+            kwargs.get("table_model_path"),
+            kwargs.get("table_max_time"),
+            kwargs.get("device")
+        )
+    else:
+        logger.error("model name not allow")
+        exit(1)
+
+    if atom_model is None:
+        logger.error("model init failed")
+        exit(1)
+    else:
+        return atom_model

+ 51 - 0
magic_pdf/model/sub_modules/model_utils.py

@@ -0,0 +1,51 @@
+import time
+
+import torch
+from PIL import Image
+from loguru import logger
+
+from magic_pdf.libs.clean_memory import clean_memory
+
+
+def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
+    crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
+    crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
+    # Create a white background with an additional width and height of 50
+    crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
+    crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
+    return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
+
+    # Crop image
+    crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
+    cropped_img = input_pil_img.crop(crop_box)
+    return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
+    return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
+    return return_image, return_list
+
+
+# Select regions for OCR / formula regions / table regions
+def get_res_list_from_layout_res(layout_res):
+    ocr_res_list = []
+    table_res_list = []
+    single_page_mfdetrec_res = []
+    for res in layout_res:
+        if int(res['category_id']) in [13, 14]:
+            single_page_mfdetrec_res.append({
+                "bbox": [int(res['poly'][0]), int(res['poly'][1]),
+                         int(res['poly'][4]), int(res['poly'][5])],
+            })
+        elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
+            ocr_res_list.append(res)
+        elif int(res['category_id']) in [5]:
+            table_res_list.append(res)
+    return ocr_res_list, table_res_list, single_page_mfdetrec_res
+
+
+def clean_vram(device, vram_threshold=8):
+    if torch.cuda.is_available() and device != 'cpu':
+        total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)  # 将字节转换为 GB
+        if total_memory <= vram_threshold:
+            gc_start = time.time()
+            clean_memory()
+            gc_time = round(time.time() - gc_start, 2)
+            logger.info(f"gc time: {gc_time}")

+ 0 - 0
magic_pdf/model/sub_modules/ocr/__init__.py


+ 0 - 0
magic_pdf/model/sub_modules/ocr/paddleocr/__init__.py


+ 259 - 0
magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py

@@ -0,0 +1,259 @@
+import math
+
+import numpy as np
+from loguru import logger
+
+from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
+from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
+
+
+def bbox_to_points(bbox):
+    """ 将bbox格式转换为四个顶点的数组 """
+    x0, y0, x1, y1 = bbox
+    return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32')
+
+
+def points_to_bbox(points):
+    """ 将四个顶点的数组转换为bbox格式 """
+    x0, y0 = points[0]
+    x1, _ = points[1]
+    _, y1 = points[2]
+    return [x0, y0, x1, y1]
+
+
+def merge_intervals(intervals):
+    # Sort the intervals based on the start value
+    intervals.sort(key=lambda x: x[0])
+
+    merged = []
+    for interval in intervals:
+        # If the list of merged intervals is empty or if the current
+        # interval does not overlap with the previous, simply append it.
+        if not merged or merged[-1][1] < interval[0]:
+            merged.append(interval)
+        else:
+            # Otherwise, there is overlap, so we merge the current and previous intervals.
+            merged[-1][1] = max(merged[-1][1], interval[1])
+
+    return merged
+
+
+def remove_intervals(original, masks):
+    # Merge all mask intervals
+    merged_masks = merge_intervals(masks)
+
+    result = []
+    original_start, original_end = original
+
+    for mask in merged_masks:
+        mask_start, mask_end = mask
+
+        # If the mask starts after the original range, ignore it
+        if mask_start > original_end:
+            continue
+
+        # If the mask ends before the original range starts, ignore it
+        if mask_end < original_start:
+            continue
+
+        # Remove the masked part from the original range
+        if original_start < mask_start:
+            result.append([original_start, mask_start - 1])
+
+        original_start = max(mask_end + 1, original_start)
+
+    # Add the remaining part of the original range, if any
+    if original_start <= original_end:
+        result.append([original_start, original_end])
+
+    return result
+
+
+def update_det_boxes(dt_boxes, mfd_res):
+    new_dt_boxes = []
+    for text_box in dt_boxes:
+        text_bbox = points_to_bbox(text_box)
+        masks_list = []
+        for mf_box in mfd_res:
+            mf_bbox = mf_box['bbox']
+            if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):
+                masks_list.append([mf_bbox[0], mf_bbox[2]])
+        text_x_range = [text_bbox[0], text_bbox[2]]
+        text_remove_mask_range = remove_intervals(text_x_range, masks_list)
+        temp_dt_box = []
+        for text_remove_mask in text_remove_mask_range:
+            temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]]))
+        if len(temp_dt_box) > 0:
+            new_dt_boxes.extend(temp_dt_box)
+    return new_dt_boxes
+
+
+def merge_overlapping_spans(spans):
+    """
+    Merges overlapping spans on the same line.
+
+    :param spans: A list of span coordinates [(x1, y1, x2, y2), ...]
+    :return: A list of merged spans
+    """
+    # Return an empty list if the input spans list is empty
+    if not spans:
+        return []
+
+    # Sort spans by their starting x-coordinate
+    spans.sort(key=lambda x: x[0])
+
+    # Initialize the list of merged spans
+    merged = []
+    for span in spans:
+        # Unpack span coordinates
+        x1, y1, x2, y2 = span
+        # If the merged list is empty or there's no horizontal overlap, add the span directly
+        if not merged or merged[-1][2] < x1:
+            merged.append(span)
+        else:
+            # If there is horizontal overlap, merge the current span with the previous one
+            last_span = merged.pop()
+            # Update the merged span's top-left corner to the smaller (x1, y1) and bottom-right to the larger (x2, y2)
+            x1 = min(last_span[0], x1)
+            y1 = min(last_span[1], y1)
+            x2 = max(last_span[2], x2)
+            y2 = max(last_span[3], y2)
+            # Add the merged span back to the list
+            merged.append((x1, y1, x2, y2))
+
+    # Return the list of merged spans
+    return merged
+
+
+def merge_det_boxes(dt_boxes):
+    """
+    Merge detection boxes.
+
+    This function takes a list of detected bounding boxes, each represented by four corner points.
+    The goal is to merge these bounding boxes into larger text regions.
+
+    Parameters:
+    dt_boxes (list): A list containing multiple text detection boxes, where each box is defined by four corner points.
+
+    Returns:
+    list: A list containing the merged text regions, where each region is represented by four corner points.
+    """
+    # Convert the detection boxes into a dictionary format with bounding boxes and type
+    dt_boxes_dict_list = []
+    angle_boxes_list = []
+    for text_box in dt_boxes:
+        text_bbox = points_to_bbox(text_box)
+        if text_bbox[2] <= text_bbox[0] or text_bbox[3] <= text_bbox[1]:
+            angle_boxes_list.append(text_box)
+            continue
+        text_box_dict = {
+            'bbox': text_bbox,
+            'type': 'text',
+        }
+        dt_boxes_dict_list.append(text_box_dict)
+
+    # Merge adjacent text regions into lines
+    lines = merge_spans_to_line(dt_boxes_dict_list)
+
+    # Initialize a new list for storing the merged text regions
+    new_dt_boxes = []
+    for line in lines:
+        line_bbox_list = []
+        for span in line:
+            line_bbox_list.append(span['bbox'])
+
+        # Merge overlapping text regions within the same line
+        merged_spans = merge_overlapping_spans(line_bbox_list)
+
+        # Convert the merged text regions back to point format and add them to the new detection box list
+        for span in merged_spans:
+            new_dt_boxes.append(bbox_to_points(span))
+
+    new_dt_boxes.extend(angle_boxes_list)
+
+    return new_dt_boxes
+
+
+def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list):
+    paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
+    # Adjust the coordinates of the formula area
+    adjusted_mfdetrec_res = []
+    for mf_res in single_page_mfdetrec_res:
+        mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
+        # Adjust the coordinates of the formula area to the coordinates relative to the cropping area
+        x0 = mf_xmin - xmin + paste_x
+        y0 = mf_ymin - ymin + paste_y
+        x1 = mf_xmax - xmin + paste_x
+        y1 = mf_ymax - ymin + paste_y
+        # Filter formula blocks outside the graph
+        if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
+            continue
+        else:
+            adjusted_mfdetrec_res.append({
+                "bbox": [x0, y0, x1, y1],
+            })
+    return adjusted_mfdetrec_res
+
+
+def get_ocr_result_list(ocr_res, useful_list):
+    paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
+    ocr_result_list = []
+    for box_ocr_res in ocr_res:
+
+        p1, p2, p3, p4 = box_ocr_res[0]
+        text, score = box_ocr_res[1]
+        average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
+        if average_angle_degrees > 0.5:
+            # logger.info(f"average_angle_degrees: {average_angle_degrees}, text: {text}")
+            # 与x轴的夹角超过0.5度,对边界做一下矫正
+            # 计算几何中心
+            x_center = sum(point[0] for point in box_ocr_res[0]) / 4
+            y_center = sum(point[1] for point in box_ocr_res[0]) / 4
+            new_height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
+            new_width = p3[0] - p1[0]
+            p1 = [x_center - new_width / 2, y_center - new_height / 2]
+            p2 = [x_center + new_width / 2, y_center - new_height / 2]
+            p3 = [x_center + new_width / 2, y_center + new_height / 2]
+            p4 = [x_center - new_width / 2, y_center + new_height / 2]
+
+        # Convert the coordinates back to the original coordinate system
+        p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
+        p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
+        p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
+        p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
+
+        ocr_result_list.append({
+            'category_id': 15,
+            'poly': p1 + p2 + p3 + p4,
+            'score': float(round(score, 2)),
+            'text': text,
+        })
+
+    return ocr_result_list
+
+
+def calculate_angle_degrees(poly):
+    # 定义对角线的顶点
+    diagonal1 = (poly[0], poly[2])
+    diagonal2 = (poly[1], poly[3])
+
+    # 计算对角线的斜率
+    def slope(p1, p2):
+        return (p2[1] - p1[1]) / (p2[0] - p1[0]) if p2[0] != p1[0] else float('inf')
+
+    slope1 = slope(diagonal1[0], diagonal1[1])
+    slope2 = slope(diagonal2[0], diagonal2[1])
+
+    # 计算对角线与x轴的夹角(以弧度为单位)
+    angle1_radians = math.atan(slope1)
+    angle2_radians = math.atan(slope2)
+
+    # 将弧度转换为角度
+    angle1_degrees = math.degrees(angle1_radians)
+    angle2_degrees = math.degrees(angle2_radians)
+
+    # 取两条对角线与x轴夹角的平均值
+    average_angle_degrees = abs((angle1_degrees + angle2_degrees) / 2)
+    # logger.info(f"average_angle_degrees: {average_angle_degrees}")
+    return average_angle_degrees
+

+ 168 - 0
magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py

@@ -0,0 +1,168 @@
+import copy
+import time
+
+import cv2
+import numpy as np
+from paddleocr import PaddleOCR
+from paddleocr.paddleocr import check_img, logger
+from paddleocr.ppocr.utils.utility import alpha_to_color, binarize_img
+from paddleocr.tools.infer.predict_system import sorted_boxes
+from paddleocr.tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop
+
+from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes
+
+
+class ModifiedPaddleOCR(PaddleOCR):
+    def ocr(self,
+            img,
+            det=True,
+            rec=True,
+            cls=True,
+            bin=False,
+            inv=False,
+            alpha_color=(255, 255, 255),
+            mfd_res=None,
+            ):
+        """
+        OCR with PaddleOCR
+        args:
+            img: img for OCR, support ndarray, img_path and list or ndarray
+            det: use text detection or not. If False, only rec will be exec. Default is True
+            rec: use text recognition or not. If False, only det will be exec. Default is True
+            cls: use angle classifier or not. Default is True. If True, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False.
+            bin: binarize image to black and white. Default is False.
+            inv: invert image colors. Default is False.
+            alpha_color: set RGB color Tuple for transparent parts replacement. Default is pure white.
+        """
+        assert isinstance(img, (np.ndarray, list, str, bytes))
+        if isinstance(img, list) and det == True:
+            logger.error('When input a list of images, det must be false')
+            exit(0)
+        if cls == True and self.use_angle_cls == False:
+            pass
+            # logger.warning(
+            #     'Since the angle classifier is not initialized, it will not be used during the forward process'
+            # )
+
+        img = check_img(img)
+        # for infer pdf file
+        if isinstance(img, list):
+            if self.page_num > len(img) or self.page_num == 0:
+                self.page_num = len(img)
+            imgs = img[:self.page_num]
+        else:
+            imgs = [img]
+
+        def preprocess_image(_image):
+            _image = alpha_to_color(_image, alpha_color)
+            if inv:
+                _image = cv2.bitwise_not(_image)
+            if bin:
+                _image = binarize_img(_image)
+            return _image
+
+        if det and rec:
+            ocr_res = []
+            for idx, img in enumerate(imgs):
+                img = preprocess_image(img)
+                dt_boxes, rec_res, _ = self.__call__(img, cls, mfd_res=mfd_res)
+                if not dt_boxes and not rec_res:
+                    ocr_res.append(None)
+                    continue
+                tmp_res = [[box.tolist(), res]
+                           for box, res in zip(dt_boxes, rec_res)]
+                ocr_res.append(tmp_res)
+            return ocr_res
+        elif det and not rec:
+            ocr_res = []
+            for idx, img in enumerate(imgs):
+                img = preprocess_image(img)
+                dt_boxes, elapse = self.text_detector(img)
+                if not dt_boxes:
+                    ocr_res.append(None)
+                    continue
+                tmp_res = [box.tolist() for box in dt_boxes]
+                ocr_res.append(tmp_res)
+            return ocr_res
+        else:
+            ocr_res = []
+            cls_res = []
+            for idx, img in enumerate(imgs):
+                if not isinstance(img, list):
+                    img = preprocess_image(img)
+                    img = [img]
+                if self.use_angle_cls and cls:
+                    img, cls_res_tmp, elapse = self.text_classifier(img)
+                    if not rec:
+                        cls_res.append(cls_res_tmp)
+                rec_res, elapse = self.text_recognizer(img)
+                ocr_res.append(rec_res)
+            if not rec:
+                return cls_res
+            return ocr_res
+
+    def __call__(self, img, cls=True, mfd_res=None):
+        time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
+
+        if img is None:
+            logger.debug("no valid image provided")
+            return None, None, time_dict
+
+        start = time.time()
+        ori_im = img.copy()
+        dt_boxes, elapse = self.text_detector(img)
+        time_dict['det'] = elapse
+
+        if dt_boxes is None:
+            logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
+            end = time.time()
+            time_dict['all'] = end - start
+            return None, None, time_dict
+        else:
+            logger.debug("dt_boxes num : {}, elapsed : {}".format(
+                len(dt_boxes), elapse))
+        img_crop_list = []
+
+        dt_boxes = sorted_boxes(dt_boxes)
+
+        # @todo 目前是在bbox层merge,对倾斜文本行的兼容性不佳,需要修改成支持poly的merge
+        # dt_boxes = merge_det_boxes(dt_boxes)
+
+
+        if mfd_res:
+            bef = time.time()
+            dt_boxes = update_det_boxes(dt_boxes, mfd_res)
+            aft = time.time()
+            logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
+                len(dt_boxes), aft - bef))
+
+        for bno in range(len(dt_boxes)):
+            tmp_box = copy.deepcopy(dt_boxes[bno])
+            if self.args.det_box_type == "quad":
+                img_crop = get_rotate_crop_image(ori_im, tmp_box)
+            else:
+                img_crop = get_minarea_rect_crop(ori_im, tmp_box)
+            img_crop_list.append(img_crop)
+        if self.use_angle_cls and cls:
+            img_crop_list, angle_list, elapse = self.text_classifier(
+                img_crop_list)
+            time_dict['cls'] = elapse
+            logger.debug("cls num  : {}, elapsed : {}".format(
+                len(img_crop_list), elapse))
+
+        rec_res, elapse = self.text_recognizer(img_crop_list)
+        time_dict['rec'] = elapse
+        logger.debug("rec_res num  : {}, elapsed : {}".format(
+            len(rec_res), elapse))
+        if self.args.save_crop_res:
+            self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,
+                                   rec_res)
+        filter_boxes, filter_rec_res = [], []
+        for box, rec_result in zip(dt_boxes, rec_res):
+            text, score = rec_result
+            if score >= self.drop_score:
+                filter_boxes.append(box)
+                filter_rec_res.append(rec_result)
+        end = time.time()
+        time_dict['all'] = end - start
+        return filter_boxes, filter_rec_res, time_dict

+ 213 - 0
magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py

@@ -0,0 +1,213 @@
+import copy
+import time
+
+
+import cv2
+import numpy as np
+from paddleocr import PaddleOCR
+from paddleocr.paddleocr import check_img, logger
+from paddleocr.ppocr.utils.utility import alpha_to_color, binarize_img
+from paddleocr.tools.infer.predict_system import sorted_boxes
+from paddleocr.tools.infer.utility import slice_generator, merge_fragmented, get_rotate_crop_image, \
+    get_minarea_rect_crop
+
+from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes
+
+
+class ModifiedPaddleOCR(PaddleOCR):
+
+    def ocr(
+        self,
+        img,
+        det=True,
+        rec=True,
+        cls=True,
+        bin=False,
+        inv=False,
+        alpha_color=(255, 255, 255),
+        slice={},
+        mfd_res=None,
+    ):
+        """
+        OCR with PaddleOCR
+
+        Args:
+            img: Image for OCR. It can be an ndarray, img_path, or a list of ndarrays.
+            det: Use text detection or not. If False, only text recognition will be executed. Default is True.
+            rec: Use text recognition or not. If False, only text detection will be executed. Default is True.
+            cls: Use angle classifier or not. Default is True. If True, the text with a rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance.
+            bin: Binarize image to black and white. Default is False.
+            inv: Invert image colors. Default is False.
+            alpha_color: Set RGB color Tuple for transparent parts replacement. Default is pure white.
+            slice: Use sliding window inference for large images. Both det and rec must be True. Requires int values for slice["horizontal_stride"], slice["vertical_stride"], slice["merge_x_thres"], slice["merge_y_thres"] (See doc/doc_en/slice_en.md). Default is {}.
+
+        Returns:
+            If both det and rec are True, returns a list of OCR results for each image. Each OCR result is a list of bounding boxes and recognized text for each detected text region.
+            If det is True and rec is False, returns a list of detected bounding boxes for each image.
+            If det is False and rec is True, returns a list of recognized text for each image.
+            If both det and rec are False, returns a list of angle classification results for each image.
+
+        Raises:
+            AssertionError: If the input image is not of type ndarray, list, str, or bytes.
+            SystemExit: If det is True and the input is a list of images.
+
+        Note:
+            - If the angle classifier is not initialized (use_angle_cls=False), it will not be used during the forward process.
+            - For PDF files, if the input is a list of images and the page_num is specified, only the first page_num images will be processed.
+            - The preprocess_image function is used to preprocess the input image by applying alpha color replacement, inversion, and binarization if specified.
+        """
+        assert isinstance(img, (np.ndarray, list, str, bytes))
+        if isinstance(img, list) and det == True:
+            logger.error("When input a list of images, det must be false")
+            exit(0)
+        if cls == True and self.use_angle_cls == False:
+            logger.warning(
+                "Since the angle classifier is not initialized, it will not be used during the forward process"
+            )
+
+        img, flag_gif, flag_pdf = check_img(img, alpha_color)
+        # for infer pdf file
+        if isinstance(img, list) and flag_pdf:
+            if self.page_num > len(img) or self.page_num == 0:
+                imgs = img
+            else:
+                imgs = img[: self.page_num]
+        else:
+            imgs = [img]
+
+        def preprocess_image(_image):
+            _image = alpha_to_color(_image, alpha_color)
+            if inv:
+                _image = cv2.bitwise_not(_image)
+            if bin:
+                _image = binarize_img(_image)
+            return _image
+
+        if det and rec:
+            ocr_res = []
+            for img in imgs:
+                img = preprocess_image(img)
+                dt_boxes, rec_res, _ = self.__call__(img, cls, slice, mfd_res=mfd_res)
+                if not dt_boxes and not rec_res:
+                    ocr_res.append(None)
+                    continue
+                tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
+                ocr_res.append(tmp_res)
+            return ocr_res
+        elif det and not rec:
+            ocr_res = []
+            for img in imgs:
+                img = preprocess_image(img)
+                dt_boxes, elapse = self.text_detector(img)
+                if dt_boxes.size == 0:
+                    ocr_res.append(None)
+                    continue
+                tmp_res = [box.tolist() for box in dt_boxes]
+                ocr_res.append(tmp_res)
+            return ocr_res
+        else:
+            ocr_res = []
+            cls_res = []
+            for img in imgs:
+                if not isinstance(img, list):
+                    img = preprocess_image(img)
+                    img = [img]
+                if self.use_angle_cls and cls:
+                    img, cls_res_tmp, elapse = self.text_classifier(img)
+                    if not rec:
+                        cls_res.append(cls_res_tmp)
+                rec_res, elapse = self.text_recognizer(img)
+                ocr_res.append(rec_res)
+            if not rec:
+                return cls_res
+            return ocr_res
+
+    def __call__(self, img, cls=True, slice={}, mfd_res=None):
+        time_dict = {"det": 0, "rec": 0, "cls": 0, "all": 0}
+
+        if img is None:
+            logger.debug("no valid image provided")
+            return None, None, time_dict
+
+        start = time.time()
+        ori_im = img.copy()
+        if slice:
+            slice_gen = slice_generator(
+                img,
+                horizontal_stride=slice["horizontal_stride"],
+                vertical_stride=slice["vertical_stride"],
+            )
+            elapsed = []
+            dt_slice_boxes = []
+            for slice_crop, v_start, h_start in slice_gen:
+                dt_boxes, elapse = self.text_detector(slice_crop, use_slice=True)
+                if dt_boxes.size:
+                    dt_boxes[:, :, 0] += h_start
+                    dt_boxes[:, :, 1] += v_start
+                    dt_slice_boxes.append(dt_boxes)
+                    elapsed.append(elapse)
+            dt_boxes = np.concatenate(dt_slice_boxes)
+
+            dt_boxes = merge_fragmented(
+                boxes=dt_boxes,
+                x_threshold=slice["merge_x_thres"],
+                y_threshold=slice["merge_y_thres"],
+            )
+            elapse = sum(elapsed)
+        else:
+            dt_boxes, elapse = self.text_detector(img)
+
+        time_dict["det"] = elapse
+
+        if dt_boxes is None:
+            logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
+            end = time.time()
+            time_dict["all"] = end - start
+            return None, None, time_dict
+        else:
+            logger.debug(
+                "dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse)
+            )
+        img_crop_list = []
+
+        dt_boxes = sorted_boxes(dt_boxes)
+
+        if mfd_res:
+            bef = time.time()
+            dt_boxes = update_det_boxes(dt_boxes, mfd_res)
+            aft = time.time()
+            logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
+                len(dt_boxes), aft - bef))
+
+        for bno in range(len(dt_boxes)):
+            tmp_box = copy.deepcopy(dt_boxes[bno])
+            if self.args.det_box_type == "quad":
+                img_crop = get_rotate_crop_image(ori_im, tmp_box)
+            else:
+                img_crop = get_minarea_rect_crop(ori_im, tmp_box)
+            img_crop_list.append(img_crop)
+        if self.use_angle_cls and cls:
+            img_crop_list, angle_list, elapse = self.text_classifier(img_crop_list)
+            time_dict["cls"] = elapse
+            logger.debug(
+                "cls num  : {}, elapsed : {}".format(len(img_crop_list), elapse)
+            )
+        if len(img_crop_list) > 1000:
+            logger.debug(
+                f"rec crops num: {len(img_crop_list)}, time and memory cost may be large."
+            )
+
+        rec_res, elapse = self.text_recognizer(img_crop_list)
+        time_dict["rec"] = elapse
+        logger.debug("rec_res num  : {}, elapsed : {}".format(len(rec_res), elapse))
+        if self.args.save_crop_res:
+            self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list, rec_res)
+        filter_boxes, filter_rec_res = [], []
+        for box, rec_result in zip(dt_boxes, rec_res):
+            text, score = rec_result[0], rec_result[1]
+            if score >= self.drop_score:
+                filter_boxes.append(box)
+                filter_rec_res.append(rec_result)
+        end = time.time()
+        time_dict["all"] = end - start
+        return filter_boxes, filter_rec_res, time_dict

+ 0 - 0
magic_pdf/model/sub_modules/reading_oreder/__init__.py


+ 0 - 0
magic_pdf/model/sub_modules/reading_oreder/layoutreader/__init__.py


+ 0 - 0
magic_pdf/model/v3/helpers.py → magic_pdf/model/sub_modules/reading_oreder/layoutreader/helpers.py


+ 0 - 0
magic_pdf/model/v3/xycut.py → magic_pdf/model/sub_modules/reading_oreder/layoutreader/xycut.py


+ 0 - 0
magic_pdf/model/sub_modules/table/__init__.py


+ 0 - 0
magic_pdf/model/sub_modules/table/rapidtable/__init__.py


+ 14 - 0
magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py

@@ -0,0 +1,14 @@
+import numpy as np
+from rapid_table import RapidTable
+from rapidocr_paddle import RapidOCR
+
+
+class RapidTableModel(object):
+    def __init__(self):
+        self.table_model = RapidTable()
+        self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
+
+    def predict(self, image):
+        ocr_result, _ = self.ocr_engine(np.asarray(image))
+        html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
+        return html_code, table_cell_bboxes, elapse

+ 0 - 0
magic_pdf/model/sub_modules/table/structeqtable/__init__.py


+ 3 - 11
magic_pdf/model/pek_sub_modules/structeqtable/StructTableModel.py → magic_pdf/model/sub_modules/table/structeqtable/struct_eqtable.py

@@ -1,8 +1,8 @@
-import re
-
 import torch
 from struct_eqtable import build_model
 
+from magic_pdf.model.sub_modules.table.table_utils import minify_html
+
 
 class StructTableModel:
     def __init__(self, model_path, max_new_tokens=1024, max_time=60):
@@ -31,15 +31,7 @@ class StructTableModel:
         )
 
         if output_format == "html":
-            results = [self.minify_html(html) for html in results]
+            results = [minify_html(html) for html in results]
 
         return results
 
-    def minify_html(self, html):
-        # 移除多余的空白字符
-        html = re.sub(r'\s+', ' ', html)
-        # 移除行尾的空白字符
-        html = re.sub(r'\s*>\s*', '>', html)
-        # 移除标签前的空白字符
-        html = re.sub(r'\s*<\s*', '<', html)
-        return html.strip()

+ 11 - 0
magic_pdf/model/sub_modules/table/table_utils.py

@@ -0,0 +1,11 @@
+import re
+
+
+def minify_html(html):
+    # 移除多余的空白字符
+    html = re.sub(r'\s+', ' ', html)
+    # 移除行尾的空白字符
+    html = re.sub(r'\s*>\s*', '>', html)
+    # 移除标签前的空白字符
+    html = re.sub(r'\s*<\s*', '<', html)
+    return html.strip()

+ 0 - 0
magic_pdf/model/sub_modules/table/tablemaster/__init__.py


+ 1 - 1
magic_pdf/model/ppTableModel.py → magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py

@@ -7,7 +7,7 @@ from PIL import Image
 import numpy as np
 
 
-class ppTableModel(object):
+class TableMasterPaddleModel(object):
     """
         This class is responsible for converting image of table into HTML format using a pre-trained model.
 

+ 3 - 3
magic_pdf/pdf_parse_union_core_v2.py

@@ -164,8 +164,8 @@ class ModelSingleton:
 
 
 def do_predict(boxes: List[List[int]], model) -> List[int]:
-    from magic_pdf.model.v3.helpers import (boxes2inputs, parse_logits,
-                                            prepare_inputs)
+    from magic_pdf.model.sub_modules.reading_oreder.layoutreader.helpers import (boxes2inputs, parse_logits,
+                                                                                 prepare_inputs)
 
     inputs = boxes2inputs(boxes)
     inputs = prepare_inputs(inputs, model)
@@ -206,7 +206,7 @@ def cal_block_index(fix_blocks, sorted_bboxes):
                 del block['real_lines']
 
         import numpy as np
-        from magic_pdf.model.v3.xycut import recursive_xy_cut
+        from magic_pdf.model.sub_modules.reading_oreder.layoutreader.xycut import recursive_xy_cut
 
         random_boxes = np.array(block_bboxes)
         np.random.shuffle(random_boxes)

Fișier diff suprimat deoarece este prea mare
+ 16 - 0
next_docs/README.md


Fișier diff suprimat deoarece este prea mare
+ 16 - 0
next_docs/README_zh-CN.md


+ 1 - 0
setup.py

@@ -49,6 +49,7 @@ if __name__ == '__main__':
                      "doclayout_yolo==0.0.2",  # doclayout_yolo
                      "rapidocr-paddle",  # rapidocr-paddle
                      "rapid_table",  # rapid_table
+                     "PyYAML",  # yaml
                      "detectron2"
                      ],
         },

+ 2 - 2
tests/test_table/test_tablemaster.py

@@ -2,7 +2,7 @@ import unittest
 from PIL import Image
 from lxml import etree
 
-from magic_pdf.model.ppTableModel import ppTableModel
+from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
 
 
 class TestppTableModel(unittest.TestCase):
@@ -11,7 +11,7 @@ class TestppTableModel(unittest.TestCase):
         # 修改table模型路径
         config = {"device": "cuda",
                   "model_dir": "/home/quyuan/.cache/modelscope/hub/opendatalab/PDF-Extract-Kit/models/TabRec/TableMaster"}
-        table_model = ppTableModel(config)
+        table_model = TableMasterPaddleModel(config)
         res = table_model.img2html(img)
         # 验证生成的 HTML 是否符合预期
         parser = etree.HTMLParser()

Unele fișiere nu au fost afișate deoarece prea multe fișiere au fost modificate în acest diff