Ver código fonte

Merge pull request #964 from myhloli/dev

refactor(model): rename and restructure model modules
Xiaomeng Zhao 1 ano atrás
pai
commit
bed386f759
55 arquivos alterados com 1047 adições e 761 exclusões
  1. 46 320
      magic_pdf/model/pdf_extract_kit.py
  2. 0 36
      magic_pdf/model/pek_sub_modules/post_process.py
  3. 0 388
      magic_pdf/model/pek_sub_modules/self_modify.py
  4. 0 0
      magic_pdf/model/sub_modules/__init__.py
  5. 0 0
      magic_pdf/model/sub_modules/layout/__init__.py
  6. 21 0
      magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py
  7. 0 0
      magic_pdf/model/sub_modules/layout/doclayout_yolo/__init__.py
  8. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/__init__.py
  9. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/backbone.py
  10. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/beit.py
  11. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/deit.py
  12. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/__init__.py
  13. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/__init__.py
  14. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/cord.py
  15. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/data_collator.py
  16. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/funsd.py
  17. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/image_utils.py
  18. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/xfund.py
  19. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/__init__.py
  20. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py
  21. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py
  22. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py
  23. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py
  24. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py
  25. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/model_init.py
  26. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/rcnn_vl.py
  27. 0 0
      magic_pdf/model/sub_modules/layout/layoutlmv3/visualizer.py
  28. 0 0
      magic_pdf/model/sub_modules/mfd/__init__.py
  29. 12 0
      magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py
  30. 0 0
      magic_pdf/model/sub_modules/mfd/yolov8/__init__.py
  31. 0 0
      magic_pdf/model/sub_modules/mfr/__init__.py
  32. 98 0
      magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
  33. 0 0
      magic_pdf/model/sub_modules/mfr/unimernet/__init__.py
  34. 144 0
      magic_pdf/model/sub_modules/model_init.py
  35. 51 0
      magic_pdf/model/sub_modules/model_utils.py
  36. 0 0
      magic_pdf/model/sub_modules/ocr/__init__.py
  37. 0 0
      magic_pdf/model/sub_modules/ocr/paddleocr/__init__.py
  38. 259 0
      magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py
  39. 168 0
      magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py
  40. 213 0
      magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py
  41. 0 0
      magic_pdf/model/sub_modules/reading_oreder/__init__.py
  42. 0 0
      magic_pdf/model/sub_modules/reading_oreder/layoutreader/__init__.py
  43. 0 0
      magic_pdf/model/sub_modules/reading_oreder/layoutreader/helpers.py
  44. 0 0
      magic_pdf/model/sub_modules/reading_oreder/layoutreader/xycut.py
  45. 0 0
      magic_pdf/model/sub_modules/table/__init__.py
  46. 0 0
      magic_pdf/model/sub_modules/table/rapidtable/__init__.py
  47. 14 0
      magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py
  48. 0 0
      magic_pdf/model/sub_modules/table/structeqtable/__init__.py
  49. 3 11
      magic_pdf/model/sub_modules/table/structeqtable/struct_eqtable.py
  50. 11 0
      magic_pdf/model/sub_modules/table/table_utils.py
  51. 0 0
      magic_pdf/model/sub_modules/table/tablemaster/__init__.py
  52. 1 1
      magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py
  53. 3 3
      magic_pdf/pdf_parse_union_core_v2.py
  54. 1 0
      setup.py
  55. 2 2
      tests/test_table/test_tablemaster.py

+ 46 - 320
magic_pdf/model/pdf_extract_kit.py

@@ -1,203 +1,28 @@
+import numpy as np
+import torch
 from loguru import logger
 import os
 import time
-from magic_pdf.libs.Constants import *
-from magic_pdf.libs.clean_memory import clean_memory
-from magic_pdf.model.model_list import AtomicModel
+import cv2
+import yaml
+from PIL import Image
 
 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
 os.environ['YOLO_VERBOSE'] = 'False'  # disable yolo logger
+
 try:
-    import cv2
-    import yaml
-    import argparse
-    import numpy as np
-    import torch
     import torchtext
 
     if torchtext.__version__ >= "0.18.0":
         torchtext.disable_torchtext_deprecation_warning()
-    from PIL import Image
-    from torchvision import transforms
-    from torch.utils.data import Dataset, DataLoader
-    from ultralytics import YOLO
-    from unimernet.common.config import Config
-    import unimernet.tasks as tasks
-    from unimernet.processors import load_processor
-    from doclayout_yolo import YOLOv10
-    from rapid_table import RapidTable
-    from rapidocr_paddle import RapidOCR
-
-except ImportError as e:
-    logger.exception(e)
-    logger.error(
-        'Required dependency not installed, please install by \n'
-        '"pip install magic-pdf[full] --extra-index-url https://myhloli.github.io/wheels/"')
-    exit(1)
-
-from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
-from magic_pdf.model.pek_sub_modules.post_process import latex_rm_whitespace
-from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
-from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
-from magic_pdf.model.ppTableModel import ppTableModel
-
-
-def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
-    ocr_engine = None
-    if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
-        table_model = StructTableModel(model_path, max_time=max_time)
-    elif table_model_type == MODEL_NAME.TABLE_MASTER:
-        config = {
-            "model_dir": model_path,
-            "device": _device_
-        }
-        table_model = ppTableModel(config)
-    elif table_model_type == MODEL_NAME.RAPID_TABLE:
-        table_model = RapidTable()
-        ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
-    else:
-        logger.error("table model type not allow")
-        exit(1)
-
-    if ocr_engine:
-        return [table_model, ocr_engine]
-    else:
-        return table_model
-
-
-def mfd_model_init(weight):
-    mfd_model = YOLO(weight)
-    return mfd_model
-
-
-def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
-    args = argparse.Namespace(cfg_path=cfg_path, options=None)
-    cfg = Config(args)
-    cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
-    cfg.config.model.model_config.model_name = weight_dir
-    cfg.config.model.tokenizer_config.path = weight_dir
-    task = tasks.setup_task(cfg)
-    model = task.build_model(cfg)
-    model.to(_device_)
-    model.eval()
-    vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
-    mfr_transform = transforms.Compose([vis_processor, ])
-    return [model, mfr_transform]
-
-
-def layout_model_init(weight, config_file, device):
-    model = Layoutlmv3_Predictor(weight, config_file, device)
-    return model
-
-
-def doclayout_yolo_model_init(weight):
-    model = YOLOv10(weight)
-    return model
-
-
-def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None, use_dilation=True, det_db_unclip_ratio=1.8):
-    if lang is not None:
-        model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
-    else:
-        model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
-    return model
-
-
-class MathDataset(Dataset):
-    def __init__(self, image_paths, transform=None):
-        self.image_paths = image_paths
-        self.transform = transform
-
-    def __len__(self):
-        return len(self.image_paths)
-
-    def __getitem__(self, idx):
-        # if not pil image, then convert to pil image
-        if isinstance(self.image_paths[idx], str):
-            raw_image = Image.open(self.image_paths[idx])
-        else:
-            raw_image = self.image_paths[idx]
-        if self.transform:
-            image = self.transform(raw_image)
-            return image
-
-
-class AtomModelSingleton:
-    _instance = None
-    _models = {}
-
-    def __new__(cls, *args, **kwargs):
-        if cls._instance is None:
-            cls._instance = super().__new__(cls)
-        return cls._instance
-
-    def get_atom_model(self, atom_model_name: str, **kwargs):
-        lang = kwargs.get("lang", None)
-        layout_model_name = kwargs.get("layout_model_name", None)
-        key = (atom_model_name, layout_model_name, lang)
-        if key not in self._models:
-            self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
-        return self._models[key]
-
-
-def atom_model_init(model_name: str, **kwargs):
-
-    if model_name == AtomicModel.Layout:
-        if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
-            atom_model = layout_model_init(
-                kwargs.get("layout_weights"),
-                kwargs.get("layout_config_file"),
-                kwargs.get("device")
-            )
-        elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
-            atom_model = doclayout_yolo_model_init(
-                kwargs.get("doclayout_yolo_weights"),
-            )
-    elif model_name == AtomicModel.MFD:
-        atom_model = mfd_model_init(
-            kwargs.get("mfd_weights")
-        )
-    elif model_name == AtomicModel.MFR:
-        atom_model = mfr_model_init(
-            kwargs.get("mfr_weight_dir"),
-            kwargs.get("mfr_cfg_path"),
-            kwargs.get("device")
-        )
-    elif model_name == AtomicModel.OCR:
-        atom_model = ocr_model_init(
-            kwargs.get("ocr_show_log"),
-            kwargs.get("det_db_box_thresh"),
-            kwargs.get("lang")
-        )
-    elif model_name == AtomicModel.Table:
-        atom_model = table_model_init(
-            kwargs.get("table_model_name"),
-            kwargs.get("table_model_path"),
-            kwargs.get("table_max_time"),
-            kwargs.get("device")
-        )
-    else:
-        logger.error("model name not allow")
-        exit(1)
-
-    return atom_model
-
-
-#  Unified crop img logic
-def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
-    crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
-    crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
-    # Create a white background with an additional width and height of 50
-    crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
-    crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
-    return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
+except ImportError:
+    pass
 
-    # Crop image
-    crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
-    cropped_img = input_pil_img.crop(crop_box)
-    return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
-    return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
-    return return_image, return_list
+from magic_pdf.libs.Constants import *
+from magic_pdf.model.model_list import AtomicModel
+from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
+from magic_pdf.model.sub_modules.model_utils import get_res_list_from_layout_res, crop_img, clean_vram
+from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list
 
 
 class CustomPEKModel:
@@ -243,7 +68,8 @@ class CustomPEKModel:
         logger.info(
             "DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
             "apply_table: {}, table_model: {}, lang: {}".format(
-                self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name, self.lang
+                self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name,
+                self.lang
             )
         )
         # 初始化解析方案
@@ -256,17 +82,17 @@ class CustomPEKModel:
 
         # 初始化公式识别
         if self.apply_formula:
-
             # 初始化公式检测模型
             self.mfd_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.MFD,
-                mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name]))
+                mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name])),
+                device=self.device
             )
 
             # 初始化公式解析模型
             mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
             mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
-            self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
+            self.mfr_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.MFR,
                 mfr_weight_dir=mfr_weight_dir,
                 mfr_cfg_path=mfr_cfg_path,
@@ -286,7 +112,8 @@ class CustomPEKModel:
             self.layout_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.Layout,
                 layout_model_name=MODEL_NAME.DocLayout_YOLO,
-                doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
+                doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
+                device=self.device
             )
         # 初始化ocr
         if self.apply_ocr:
@@ -299,22 +126,13 @@ class CustomPEKModel:
         # init table model
         if self.apply_table:
             table_model_dir = self.configs["weights"][self.table_model_name]
-            if self.table_model_name in [MODEL_NAME.STRUCT_EQTABLE, MODEL_NAME.TABLE_MASTER]:
-                self.table_model = atom_model_manager.get_atom_model(
-                    atom_model_name=AtomicModel.Table,
-                    table_model_name=self.table_model_name,
-                    table_model_path=str(os.path.join(models_dir, table_model_dir)),
-                    table_max_time=self.table_max_time,
-                    device=self.device
-                )
-            elif self.table_model_name in [MODEL_NAME.RAPID_TABLE]:
-                self.table_model, self.ocr_engine =atom_model_manager.get_atom_model(
-                    atom_model_name=AtomicModel.Table,
-                    table_model_name=self.table_model_name,
-                    table_model_path=str(os.path.join(models_dir, table_model_dir)),
-                    table_max_time=self.table_max_time,
-                    device=self.device
-                )
+            self.table_model = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.Table,
+                table_model_name=self.table_model_name,
+                table_model_path=str(os.path.join(models_dir, table_model_dir)),
+                table_max_time=self.table_max_time,
+                device=self.device
+            )
 
         logger.info('DocAnalysis init done!')
 
@@ -322,26 +140,15 @@ class CustomPEKModel:
 
         page_start = time.time()
 
-        latex_filling_list = []
-        mf_image_list = []
-
         # layout检测
         layout_start = time.time()
+        layout_res = []
         if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
             # layoutlmv3
             layout_res = self.layout_model(image, ignore_catids=[])
         elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
             # doclayout_yolo
-            layout_res = []
-            doclayout_yolo_res = self.layout_model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
-            for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), doclayout_yolo_res.boxes.cls.cpu()):
-                xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
-                new_item = {
-                    'category_id': int(cla.item()),
-                    'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
-                    'score': round(float(conf.item()), 3),
-                }
-                layout_res.append(new_item)
+            layout_res = self.layout_model.predict(image)
         layout_cost = round(time.time() - layout_start, 2)
         logger.info(f"layout detection time: {layout_cost}")
 
@@ -350,58 +157,21 @@ class CustomPEKModel:
         if self.apply_formula:
             # 公式检测
             mfd_start = time.time()
-            mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
+            mfd_res = self.mfd_model.predict(image)
             logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
-            for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
-                xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
-                new_item = {
-                    'category_id': 13 + int(cla.item()),
-                    'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
-                    'score': round(float(conf.item()), 2),
-                    'latex': '',
-                }
-                layout_res.append(new_item)
-                latex_filling_list.append(new_item)
-                bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
-                mf_image_list.append(bbox_img)
 
             # 公式识别
             mfr_start = time.time()
-            dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
-            dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
-            mfr_res = []
-            for mf_img in dataloader:
-                mf_img = mf_img.to(self.device)
-                with torch.no_grad():
-                    output = self.mfr_model.generate({'image': mf_img})
-                mfr_res.extend(output['pred_str'])
-            for res, latex in zip(latex_filling_list, mfr_res):
-                res['latex'] = latex_rm_whitespace(latex)
+            formula_list = self.mfr_model.predict(mfd_res, image)
+            layout_res.extend(formula_list)
             mfr_cost = round(time.time() - mfr_start, 2)
-            logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
-
-        # Select regions for OCR / formula regions / table regions
-        ocr_res_list = []
-        table_res_list = []
-        single_page_mfdetrec_res = []
-        for res in layout_res:
-            if int(res['category_id']) in [13, 14]:
-                single_page_mfdetrec_res.append({
-                    "bbox": [int(res['poly'][0]), int(res['poly'][1]),
-                             int(res['poly'][4]), int(res['poly'][5])],
-                })
-            elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
-                ocr_res_list.append(res)
-            elif int(res['category_id']) in [5]:
-                table_res_list.append(res)
-
-        if torch.cuda.is_available() and self.device != 'cpu':
-            total_memory = torch.cuda.get_device_properties(self.device).total_memory / (1024 ** 3)  # 将字节转换为 GB
-            if total_memory <= 8:
-                gc_start = time.time()
-                clean_memory()
-                gc_time = round(time.time() - gc_start, 2)
-                logger.info(f"gc time: {gc_time}")
+            logger.info(f"formula nums: {len(formula_list)}, mfr time: {mfr_cost}")
+
+        # 清理显存
+        clean_vram(self.device, vram_threshold=8)
+
+        # 从layout_res中获取ocr区域、表格区域、公式区域
+        ocr_res_list, table_res_list, single_page_mfdetrec_res = get_res_list_from_layout_res(layout_res)
 
         # ocr识别
         if self.apply_ocr:
@@ -409,23 +179,7 @@ class CustomPEKModel:
             # Process each area that requires OCR processing
             for res in ocr_res_list:
                 new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
-                paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
-                # Adjust the coordinates of the formula area
-                adjusted_mfdetrec_res = []
-                for mf_res in single_page_mfdetrec_res:
-                    mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
-                    # Adjust the coordinates of the formula area to the coordinates relative to the cropping area
-                    x0 = mf_xmin - xmin + paste_x
-                    y0 = mf_ymin - ymin + paste_y
-                    x1 = mf_xmax - xmin + paste_x
-                    y1 = mf_ymax - ymin + paste_y
-                    # Filter formula blocks outside the graph
-                    if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
-                        continue
-                    else:
-                        adjusted_mfdetrec_res.append({
-                            "bbox": [x0, y0, x1, y1],
-                        })
+                adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
 
                 # OCR recognition
                 new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
@@ -433,22 +187,8 @@ class CustomPEKModel:
 
                 # Integration results
                 if ocr_res:
-                    for box_ocr_res in ocr_res:
-                        p1, p2, p3, p4 = box_ocr_res[0]
-                        text, score = box_ocr_res[1]
-
-                        # Convert the coordinates back to the original coordinate system
-                        p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
-                        p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
-                        p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
-                        p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
-
-                        layout_res.append({
-                            'category_id': 15,
-                            'poly': p1 + p2 + p3 + p4,
-                            'score': round(score, 2),
-                            'text': text,
-                        })
+                    ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
+                    layout_res.extend(ocr_result_list)
 
             ocr_cost = round(time.time() - ocr_start, 2)
             logger.info(f"ocr time: {ocr_cost}")
@@ -459,8 +199,6 @@ class CustomPEKModel:
             for res in table_res_list:
                 new_image, _ = crop_img(res, pil_img)
                 single_table_start_time = time.time()
-                # logger.info("------------------table recognition processing begins-----------------")
-                latex_code = None
                 html_code = None
                 if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
                     with torch.no_grad():
@@ -470,33 +208,21 @@ class CustomPEKModel:
                 elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
                     html_code = self.table_model.img2html(new_image)
                 elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
-                    ocr_result, _ = self.ocr_engine(np.asarray(new_image))
-                    html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(new_image), ocr_result)
-
+                    html_code, table_cell_bboxes, elapse = self.table_model.predict(new_image)
                 run_time = time.time() - single_table_start_time
-                # logger.info(f"------------table recognition processing ends within {run_time}s-----")
                 if run_time > self.table_max_time:
-                    logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
+                    logger.warning(f"table recognition processing exceeds max time {self.table_max_time}s")
                 # 判断是否返回正常
-
-                if latex_code:
-                    expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith('end{table}')
-                    if expected_ending:
-                        res["latex"] = latex_code
-                    else:
-                        logger.warning(f"table recognition processing fails, not found expected LaTeX table end")
-                elif html_code:
+                if html_code:
                     expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
                     if expected_ending:
                         res["html"] = html_code
                     else:
                         logger.warning(f"table recognition processing fails, not found expected HTML table end")
                 else:
-                    logger.warning(f"table recognition processing fails, not get latex or html return")
+                    logger.warning(f"table recognition processing fails, not get html return")
             logger.info(f"table time: {round(time.time() - table_start, 2)}")
 
         logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----")
 
         return layout_res
-
-

+ 0 - 36
magic_pdf/model/pek_sub_modules/post_process.py

@@ -1,36 +0,0 @@
-import re
-
-def layout_rm_equation(layout_res):
-    rm_idxs = []
-    for idx, ele in enumerate(layout_res['layout_dets']):
-        if ele['category_id'] == 10:
-            rm_idxs.append(idx)
-    
-    for idx in rm_idxs[::-1]:
-        del layout_res['layout_dets'][idx]
-    return layout_res
-
-
-def get_croped_image(image_pil, bbox):
-    x_min, y_min, x_max, y_max = bbox
-    croped_img = image_pil.crop((x_min, y_min, x_max, y_max))
-    return croped_img
-
-
-def latex_rm_whitespace(s: str):
-    """Remove unnecessary whitespace from LaTeX code.
-    """
-    text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
-    letter = '[a-zA-Z]'
-    noletter = '[\W_^\d]'
-    names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
-    s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
-    news = s
-    while True:
-        s = news
-        news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s)
-        news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news)
-        news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news)
-        if news == s:
-            break
-    return s

+ 0 - 388
magic_pdf/model/pek_sub_modules/self_modify.py

@@ -1,388 +0,0 @@
-import time
-import copy
-import base64
-import cv2
-import numpy as np
-from io import BytesIO
-from PIL import Image
-
-from paddleocr import PaddleOCR
-from paddleocr.ppocr.utils.logging import get_logger
-from paddleocr.ppocr.utils.utility import check_and_read, alpha_to_color, binarize_img
-from paddleocr.tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image, get_minarea_rect_crop
-
-from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
-from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
-
-logger = get_logger()
-
-
-def img_decode(content: bytes):
-    np_arr = np.frombuffer(content, dtype=np.uint8)
-    return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
-
-
-def check_img(img):
-    if isinstance(img, bytes):
-        img = img_decode(img)
-    if isinstance(img, str):
-        image_file = img
-        img, flag_gif, flag_pdf = check_and_read(image_file)
-        if not flag_gif and not flag_pdf:
-            with open(image_file, 'rb') as f:
-                img_str = f.read()
-                img = img_decode(img_str)
-            if img is None:
-                try:
-                    buf = BytesIO()
-                    image = BytesIO(img_str)
-                    im = Image.open(image)
-                    rgb = im.convert('RGB')
-                    rgb.save(buf, 'jpeg')
-                    buf.seek(0)
-                    image_bytes = buf.read()
-                    data_base64 = str(base64.b64encode(image_bytes),
-                                      encoding="utf-8")
-                    image_decode = base64.b64decode(data_base64)
-                    img_array = np.frombuffer(image_decode, np.uint8)
-                    img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
-                except:
-                    logger.error("error in loading image:{}".format(image_file))
-                    return None
-        if img is None:
-            logger.error("error in loading image:{}".format(image_file))
-            return None
-    if isinstance(img, np.ndarray) and len(img.shape) == 2:
-        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
-
-    return img
-
-
-def sorted_boxes(dt_boxes):
-    """
-    Sort text boxes in order from top to bottom, left to right
-    args:
-        dt_boxes(array):detected text boxes with shape [4, 2]
-    return:
-        sorted boxes(array) with shape [4, 2]
-    """
-    num_boxes = dt_boxes.shape[0]
-    sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
-    _boxes = list(sorted_boxes)
-
-    for i in range(num_boxes - 1):
-        for j in range(i, -1, -1):
-            if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
-                    (_boxes[j + 1][0][0] < _boxes[j][0][0]):
-                tmp = _boxes[j]
-                _boxes[j] = _boxes[j + 1]
-                _boxes[j + 1] = tmp
-            else:
-                break
-    return _boxes
-
-
-def bbox_to_points(bbox):
-    """ 将bbox格式转换为四个顶点的数组 """
-    x0, y0, x1, y1 = bbox
-    return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32')
-
-
-def points_to_bbox(points):
-    """ 将四个顶点的数组转换为bbox格式 """
-    x0, y0 = points[0]
-    x1, _ = points[1]
-    _, y1 = points[2]
-    return [x0, y0, x1, y1]
-
-
-def merge_intervals(intervals):
-    # Sort the intervals based on the start value
-    intervals.sort(key=lambda x: x[0])
-
-    merged = []
-    for interval in intervals:
-        # If the list of merged intervals is empty or if the current
-        # interval does not overlap with the previous, simply append it.
-        if not merged or merged[-1][1] < interval[0]:
-            merged.append(interval)
-        else:
-            # Otherwise, there is overlap, so we merge the current and previous intervals.
-            merged[-1][1] = max(merged[-1][1], interval[1])
-
-    return merged
-
-
-def remove_intervals(original, masks):
-    # Merge all mask intervals
-    merged_masks = merge_intervals(masks)
-
-    result = []
-    original_start, original_end = original
-
-    for mask in merged_masks:
-        mask_start, mask_end = mask
-
-        # If the mask starts after the original range, ignore it
-        if mask_start > original_end:
-            continue
-
-        # If the mask ends before the original range starts, ignore it
-        if mask_end < original_start:
-            continue
-
-        # Remove the masked part from the original range
-        if original_start < mask_start:
-            result.append([original_start, mask_start - 1])
-
-        original_start = max(mask_end + 1, original_start)
-
-    # Add the remaining part of the original range, if any
-    if original_start <= original_end:
-        result.append([original_start, original_end])
-
-    return result
-
-
-def update_det_boxes(dt_boxes, mfd_res):
-    new_dt_boxes = []
-    for text_box in dt_boxes:
-        text_bbox = points_to_bbox(text_box)
-        masks_list = []
-        for mf_box in mfd_res:
-            mf_bbox = mf_box['bbox']
-            if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):
-                masks_list.append([mf_bbox[0], mf_bbox[2]])
-        text_x_range = [text_bbox[0], text_bbox[2]]
-        text_remove_mask_range = remove_intervals(text_x_range, masks_list)
-        temp_dt_box = []
-        for text_remove_mask in text_remove_mask_range:
-            temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]]))
-        if len(temp_dt_box) > 0:
-            new_dt_boxes.extend(temp_dt_box)
-    return new_dt_boxes
-
-
-def merge_overlapping_spans(spans):
-    """
-    Merges overlapping spans on the same line.
-
-    :param spans: A list of span coordinates [(x1, y1, x2, y2), ...]
-    :return: A list of merged spans
-    """
-    # Return an empty list if the input spans list is empty
-    if not spans:
-        return []
-
-    # Sort spans by their starting x-coordinate
-    spans.sort(key=lambda x: x[0])
-
-    # Initialize the list of merged spans
-    merged = []
-    for span in spans:
-        # Unpack span coordinates
-        x1, y1, x2, y2 = span
-        # If the merged list is empty or there's no horizontal overlap, add the span directly
-        if not merged or merged[-1][2] < x1:
-            merged.append(span)
-        else:
-            # If there is horizontal overlap, merge the current span with the previous one
-            last_span = merged.pop()
-            # Update the merged span's top-left corner to the smaller (x1, y1) and bottom-right to the larger (x2, y2)
-            x1 = min(last_span[0], x1)
-            y1 = min(last_span[1], y1)
-            x2 = max(last_span[2], x2)
-            y2 = max(last_span[3], y2)
-            # Add the merged span back to the list
-            merged.append((x1, y1, x2, y2))
-
-    # Return the list of merged spans
-    return merged
-
-
-def merge_det_boxes(dt_boxes):
-    """
-    Merge detection boxes.
-
-    This function takes a list of detected bounding boxes, each represented by four corner points.
-    The goal is to merge these bounding boxes into larger text regions.
-
-    Parameters:
-    dt_boxes (list): A list containing multiple text detection boxes, where each box is defined by four corner points.
-
-    Returns:
-    list: A list containing the merged text regions, where each region is represented by four corner points.
-    """
-    # Convert the detection boxes into a dictionary format with bounding boxes and type
-    dt_boxes_dict_list = []
-    for text_box in dt_boxes:
-        text_bbox = points_to_bbox(text_box)
-        text_box_dict = {
-            'bbox': text_bbox,
-            'type': 'text',
-        }
-        dt_boxes_dict_list.append(text_box_dict)
-
-    # Merge adjacent text regions into lines
-    lines = merge_spans_to_line(dt_boxes_dict_list)
-
-    # Initialize a new list for storing the merged text regions
-    new_dt_boxes = []
-    for line in lines:
-        line_bbox_list = []
-        for span in line:
-            line_bbox_list.append(span['bbox'])
-
-        # Merge overlapping text regions within the same line
-        merged_spans = merge_overlapping_spans(line_bbox_list)
-
-        # Convert the merged text regions back to point format and add them to the new detection box list
-        for span in merged_spans:
-            new_dt_boxes.append(bbox_to_points(span))
-
-    return new_dt_boxes
-
-
-class ModifiedPaddleOCR(PaddleOCR):
-    def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, mfd_res=None, alpha_color=(255, 255, 255)):
-        """
-        OCR with PaddleOCR
-        args:
-            img: img for OCR, support ndarray, img_path and list or ndarray
-            det: use text detection or not. If False, only rec will be exec. Default is True
-            rec: use text recognition or not. If False, only det will be exec. Default is True
-            cls: use angle classifier or not. Default is True. If True, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False.
-            bin: binarize image to black and white. Default is False.
-            inv: invert image colors. Default is False.
-            alpha_color: set RGB color Tuple for transparent parts replacement. Default is pure white.
-        """
-        assert isinstance(img, (np.ndarray, list, str, bytes))
-        if isinstance(img, list) and det == True:
-            logger.error('When input a list of images, det must be false')
-            exit(0)
-        if cls == True and self.use_angle_cls == False:
-            pass
-            # logger.warning(
-            #     'Since the angle classifier is not initialized, it will not be used during the forward process'
-            # )
-
-        img = check_img(img)
-        # for infer pdf file
-        if isinstance(img, list):
-            if self.page_num > len(img) or self.page_num == 0:
-                self.page_num = len(img)
-            imgs = img[:self.page_num]
-        else:
-            imgs = [img]
-
-        def preprocess_image(_image):
-            _image = alpha_to_color(_image, alpha_color)
-            if inv:
-                _image = cv2.bitwise_not(_image)
-            if bin:
-                _image = binarize_img(_image)
-            return _image
-
-        if det and rec:
-            ocr_res = []
-            for idx, img in enumerate(imgs):
-                img = preprocess_image(img)
-                dt_boxes, rec_res, _ = self.__call__(img, cls, mfd_res=mfd_res)
-                if not dt_boxes and not rec_res:
-                    ocr_res.append(None)
-                    continue
-                tmp_res = [[box.tolist(), res]
-                           for box, res in zip(dt_boxes, rec_res)]
-                ocr_res.append(tmp_res)
-            return ocr_res
-        elif det and not rec:
-            ocr_res = []
-            for idx, img in enumerate(imgs):
-                img = preprocess_image(img)
-                dt_boxes, elapse = self.text_detector(img)
-                if not dt_boxes:
-                    ocr_res.append(None)
-                    continue
-                tmp_res = [box.tolist() for box in dt_boxes]
-                ocr_res.append(tmp_res)
-            return ocr_res
-        else:
-            ocr_res = []
-            cls_res = []
-            for idx, img in enumerate(imgs):
-                if not isinstance(img, list):
-                    img = preprocess_image(img)
-                    img = [img]
-                if self.use_angle_cls and cls:
-                    img, cls_res_tmp, elapse = self.text_classifier(img)
-                    if not rec:
-                        cls_res.append(cls_res_tmp)
-                rec_res, elapse = self.text_recognizer(img)
-                ocr_res.append(rec_res)
-            if not rec:
-                return cls_res
-            return ocr_res
-
-    def __call__(self, img, cls=True, mfd_res=None):
-        time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
-
-        if img is None:
-            logger.debug("no valid image provided")
-            return None, None, time_dict
-
-        start = time.time()
-        ori_im = img.copy()
-        dt_boxes, elapse = self.text_detector(img)
-        time_dict['det'] = elapse
-
-        if dt_boxes is None:
-            logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
-            end = time.time()
-            time_dict['all'] = end - start
-            return None, None, time_dict
-        else:
-            logger.debug("dt_boxes num : {}, elapsed : {}".format(
-                len(dt_boxes), elapse))
-        img_crop_list = []
-
-        dt_boxes = sorted_boxes(dt_boxes)
-
-        dt_boxes = merge_det_boxes(dt_boxes)
-
-        if mfd_res:
-            bef = time.time()
-            dt_boxes = update_det_boxes(dt_boxes, mfd_res)
-            aft = time.time()
-            logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
-                len(dt_boxes), aft - bef))
-
-        for bno in range(len(dt_boxes)):
-            tmp_box = copy.deepcopy(dt_boxes[bno])
-            if self.args.det_box_type == "quad":
-                img_crop = get_rotate_crop_image(ori_im, tmp_box)
-            else:
-                img_crop = get_minarea_rect_crop(ori_im, tmp_box)
-            img_crop_list.append(img_crop)
-        if self.use_angle_cls and cls:
-            img_crop_list, angle_list, elapse = self.text_classifier(
-                img_crop_list)
-            time_dict['cls'] = elapse
-            logger.debug("cls num  : {}, elapsed : {}".format(
-                len(img_crop_list), elapse))
-
-        rec_res, elapse = self.text_recognizer(img_crop_list)
-        time_dict['rec'] = elapse
-        logger.debug("rec_res num  : {}, elapsed : {}".format(
-            len(rec_res), elapse))
-        if self.args.save_crop_res:
-            self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,
-                                   rec_res)
-        filter_boxes, filter_rec_res = [], []
-        for box, rec_result in zip(dt_boxes, rec_res):
-            text, score = rec_result
-            if score >= self.drop_score:
-                filter_boxes.append(box)
-                filter_rec_res.append(rec_result)
-        end = time.time()
-        time_dict['all'] = end - start
-        return filter_boxes, filter_rec_res, time_dict

+ 0 - 0
magic_pdf/model/pek_sub_modules/__init__.py → magic_pdf/model/sub_modules/__init__.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/__init__.py → magic_pdf/model/sub_modules/layout/__init__.py


+ 21 - 0
magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py

@@ -0,0 +1,21 @@
+from doclayout_yolo import YOLOv10
+
+
+class DocLayoutYOLOModel(object):
+    def __init__(self, weight, device):
+        self.model = YOLOv10(weight)
+        self.device = device
+
+    def predict(self, image):
+        layout_res = []
+        doclayout_yolo_res = self.model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
+        for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(),
+                                   doclayout_yolo_res.boxes.cls.cpu()):
+            xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
+            new_item = {
+                'category_id': int(cla.item()),
+                'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
+                'score': round(float(conf.item()), 3),
+            }
+            layout_res.append(new_item)
+        return layout_res

+ 0 - 0
magic_pdf/model/pek_sub_modules/structeqtable/__init__.py → magic_pdf/model/sub_modules/layout/doclayout_yolo/__init__.py


+ 0 - 0
magic_pdf/model/v3/__init__.py → magic_pdf/model/sub_modules/layout/layoutlmv3/__init__.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/backbone.py → magic_pdf/model/sub_modules/layout/layoutlmv3/backbone.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/beit.py → magic_pdf/model/sub_modules/layout/layoutlmv3/beit.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/deit.py → magic_pdf/model/sub_modules/layout/layoutlmv3/deit.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/__init__.py → magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/__init__.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/__init__.py → magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/__init__.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/cord.py → magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/cord.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/data_collator.py → magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/data_collator.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/funsd.py → magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/funsd.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/image_utils.py → magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/image_utils.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/xfund.py → magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/xfund.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/__init__.py → magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/__init__.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py → magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py → magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py → magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py → magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py → magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/model_init.py → magic_pdf/model/sub_modules/layout/layoutlmv3/model_init.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/rcnn_vl.py → magic_pdf/model/sub_modules/layout/layoutlmv3/rcnn_vl.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/visualizer.py → magic_pdf/model/sub_modules/layout/layoutlmv3/visualizer.py


+ 0 - 0
magic_pdf/model/sub_modules/mfd/__init__.py


+ 12 - 0
magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py

@@ -0,0 +1,12 @@
+from ultralytics import YOLO
+
+
+class YOLOv8MFDModel(object):
+    def __init__(self, weight, device='cpu'):
+        self.mfd_model = YOLO(weight)
+        self.device = device
+
+    def predict(self, image):
+        mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
+        return mfd_res
+

+ 0 - 0
magic_pdf/model/sub_modules/mfd/yolov8/__init__.py


+ 0 - 0
magic_pdf/model/sub_modules/mfr/__init__.py


+ 98 - 0
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py

@@ -0,0 +1,98 @@
+import os
+import argparse
+import re
+
+from PIL import Image
+import torch
+from torch.utils.data import Dataset, DataLoader
+from torchvision import transforms
+from unimernet.common.config import Config
+import unimernet.tasks as tasks
+from unimernet.processors import load_processor
+
+
+class MathDataset(Dataset):
+    def __init__(self, image_paths, transform=None):
+        self.image_paths = image_paths
+        self.transform = transform
+
+    def __len__(self):
+        return len(self.image_paths)
+
+    def __getitem__(self, idx):
+        # if not pil image, then convert to pil image
+        if isinstance(self.image_paths[idx], str):
+            raw_image = Image.open(self.image_paths[idx])
+        else:
+            raw_image = self.image_paths[idx]
+        if self.transform:
+            image = self.transform(raw_image)
+            return image
+
+
+def latex_rm_whitespace(s: str):
+    """Remove unnecessary whitespace from LaTeX code.
+    """
+    text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
+    letter = '[a-zA-Z]'
+    noletter = '[\W_^\d]'
+    names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
+    s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
+    news = s
+    while True:
+        s = news
+        news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s)
+        news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news)
+        news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news)
+        if news == s:
+            break
+    return s
+
+
+class UnimernetModel(object):
+    def __init__(self, weight_dir, cfg_path, _device_='cpu'):
+
+        args = argparse.Namespace(cfg_path=cfg_path, options=None)
+        cfg = Config(args)
+        cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
+        cfg.config.model.model_config.model_name = weight_dir
+        cfg.config.model.tokenizer_config.path = weight_dir
+        task = tasks.setup_task(cfg)
+        self.model = task.build_model(cfg)
+        self.device = _device_
+        self.model.to(_device_)
+        self.model.eval()
+        vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
+        self.mfr_transform = transforms.Compose([vis_processor, ])
+
+    def predict(self, mfd_res, image):
+
+        formula_list = []
+        mf_image_list = []
+        for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
+            xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
+            new_item = {
+                'category_id': 13 + int(cla.item()),
+                'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
+                'score': round(float(conf.item()), 2),
+                'latex': '',
+            }
+            formula_list.append(new_item)
+            pil_img = Image.fromarray(image)
+            bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
+            mf_image_list.append(bbox_img)
+
+        dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
+        dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
+        mfr_res = []
+        for mf_img in dataloader:
+            mf_img = mf_img.to(self.device)
+            with torch.no_grad():
+                output = self.model.generate({'image': mf_img})
+            mfr_res.extend(output['pred_str'])
+        for res, latex in zip(formula_list, mfr_res):
+            res['latex'] = latex_rm_whitespace(latex)
+        return formula_list
+
+
+

+ 0 - 0
magic_pdf/model/sub_modules/mfr/unimernet/__init__.py


+ 144 - 0
magic_pdf/model/sub_modules/model_init.py

@@ -0,0 +1,144 @@
+from loguru import logger
+
+from magic_pdf.libs.Constants import MODEL_NAME
+from magic_pdf.model.model_list import AtomicModel
+from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
+from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
+from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
+
+from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
+from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
+# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
+from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
+from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
+from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
+
+
+def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
+    if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
+        table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
+    elif table_model_type == MODEL_NAME.TABLE_MASTER:
+        config = {
+            "model_dir": model_path,
+            "device": _device_
+        }
+        table_model = TableMasterPaddleModel(config)
+    elif table_model_type == MODEL_NAME.RAPID_TABLE:
+        table_model = RapidTableModel()
+    else:
+        logger.error("table model type not allow")
+        exit(1)
+
+    return table_model
+
+
+def mfd_model_init(weight, device='cpu'):
+    mfd_model = YOLOv8MFDModel(weight, device)
+    return mfd_model
+
+
+def mfr_model_init(weight_dir, cfg_path, device='cpu'):
+    mfr_model = UnimernetModel(weight_dir, cfg_path, device)
+    return mfr_model
+
+
+def layout_model_init(weight, config_file, device):
+    model = Layoutlmv3_Predictor(weight, config_file, device)
+    return model
+
+
+def doclayout_yolo_model_init(weight, device='cpu'):
+    model = DocLayoutYOLOModel(weight, device)
+    return model
+
+
+def ocr_model_init(show_log: bool = False,
+                   det_db_box_thresh=0.3,
+                   lang=None,
+                   use_dilation=True,
+                   det_db_unclip_ratio=1.8,
+                   ):
+    if lang is not None:
+        model = ModifiedPaddleOCR(
+            show_log=show_log,
+            det_db_box_thresh=det_db_box_thresh,
+            lang=lang,
+            use_dilation=use_dilation,
+            det_db_unclip_ratio=det_db_unclip_ratio,
+        )
+    else:
+        model = ModifiedPaddleOCR(
+            show_log=show_log,
+            det_db_box_thresh=det_db_box_thresh,
+            use_dilation=use_dilation,
+            det_db_unclip_ratio=det_db_unclip_ratio,
+            # use_angle_cls=True,
+        )
+    return model
+
+
+class AtomModelSingleton:
+    _instance = None
+    _models = {}
+
+    def __new__(cls, *args, **kwargs):
+        if cls._instance is None:
+            cls._instance = super().__new__(cls)
+        return cls._instance
+
+    def get_atom_model(self, atom_model_name: str, **kwargs):
+        lang = kwargs.get("lang", None)
+        layout_model_name = kwargs.get("layout_model_name", None)
+        key = (atom_model_name, layout_model_name, lang)
+        if key not in self._models:
+            self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
+        return self._models[key]
+
+
+def atom_model_init(model_name: str, **kwargs):
+    atom_model = None
+    if model_name == AtomicModel.Layout:
+        if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
+            atom_model = layout_model_init(
+                kwargs.get("layout_weights"),
+                kwargs.get("layout_config_file"),
+                kwargs.get("device")
+            )
+        elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
+            atom_model = doclayout_yolo_model_init(
+                kwargs.get("doclayout_yolo_weights"),
+                kwargs.get("device")
+            )
+    elif model_name == AtomicModel.MFD:
+        atom_model = mfd_model_init(
+            kwargs.get("mfd_weights"),
+            kwargs.get("device")
+        )
+    elif model_name == AtomicModel.MFR:
+        atom_model = mfr_model_init(
+            kwargs.get("mfr_weight_dir"),
+            kwargs.get("mfr_cfg_path"),
+            kwargs.get("device")
+        )
+    elif model_name == AtomicModel.OCR:
+        atom_model = ocr_model_init(
+            kwargs.get("ocr_show_log"),
+            kwargs.get("det_db_box_thresh"),
+            kwargs.get("lang")
+        )
+    elif model_name == AtomicModel.Table:
+        atom_model = table_model_init(
+            kwargs.get("table_model_name"),
+            kwargs.get("table_model_path"),
+            kwargs.get("table_max_time"),
+            kwargs.get("device")
+        )
+    else:
+        logger.error("model name not allow")
+        exit(1)
+
+    if atom_model is None:
+        logger.error("model init failed")
+        exit(1)
+    else:
+        return atom_model

+ 51 - 0
magic_pdf/model/sub_modules/model_utils.py

@@ -0,0 +1,51 @@
+import time
+
+import torch
+from PIL import Image
+from loguru import logger
+
+from magic_pdf.libs.clean_memory import clean_memory
+
+
+def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
+    crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
+    crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
+    # Create a white background with an additional width and height of 50
+    crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
+    crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
+    return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
+
+    # Crop image
+    crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
+    cropped_img = input_pil_img.crop(crop_box)
+    return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
+    return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
+    return return_image, return_list
+
+
+# Select regions for OCR / formula regions / table regions
+def get_res_list_from_layout_res(layout_res):
+    ocr_res_list = []
+    table_res_list = []
+    single_page_mfdetrec_res = []
+    for res in layout_res:
+        if int(res['category_id']) in [13, 14]:
+            single_page_mfdetrec_res.append({
+                "bbox": [int(res['poly'][0]), int(res['poly'][1]),
+                         int(res['poly'][4]), int(res['poly'][5])],
+            })
+        elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
+            ocr_res_list.append(res)
+        elif int(res['category_id']) in [5]:
+            table_res_list.append(res)
+    return ocr_res_list, table_res_list, single_page_mfdetrec_res
+
+
+def clean_vram(device, vram_threshold=8):
+    if torch.cuda.is_available() and device != 'cpu':
+        total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)  # 将字节转换为 GB
+        if total_memory <= vram_threshold:
+            gc_start = time.time()
+            clean_memory()
+            gc_time = round(time.time() - gc_start, 2)
+            logger.info(f"gc time: {gc_time}")

+ 0 - 0
magic_pdf/model/sub_modules/ocr/__init__.py


+ 0 - 0
magic_pdf/model/sub_modules/ocr/paddleocr/__init__.py


+ 259 - 0
magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py

@@ -0,0 +1,259 @@
+import math
+
+import numpy as np
+from loguru import logger
+
+from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
+from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
+
+
+def bbox_to_points(bbox):
+    """ 将bbox格式转换为四个顶点的数组 """
+    x0, y0, x1, y1 = bbox
+    return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32')
+
+
+def points_to_bbox(points):
+    """ 将四个顶点的数组转换为bbox格式 """
+    x0, y0 = points[0]
+    x1, _ = points[1]
+    _, y1 = points[2]
+    return [x0, y0, x1, y1]
+
+
+def merge_intervals(intervals):
+    # Sort the intervals based on the start value
+    intervals.sort(key=lambda x: x[0])
+
+    merged = []
+    for interval in intervals:
+        # If the list of merged intervals is empty or if the current
+        # interval does not overlap with the previous, simply append it.
+        if not merged or merged[-1][1] < interval[0]:
+            merged.append(interval)
+        else:
+            # Otherwise, there is overlap, so we merge the current and previous intervals.
+            merged[-1][1] = max(merged[-1][1], interval[1])
+
+    return merged
+
+
+def remove_intervals(original, masks):
+    # Merge all mask intervals
+    merged_masks = merge_intervals(masks)
+
+    result = []
+    original_start, original_end = original
+
+    for mask in merged_masks:
+        mask_start, mask_end = mask
+
+        # If the mask starts after the original range, ignore it
+        if mask_start > original_end:
+            continue
+
+        # If the mask ends before the original range starts, ignore it
+        if mask_end < original_start:
+            continue
+
+        # Remove the masked part from the original range
+        if original_start < mask_start:
+            result.append([original_start, mask_start - 1])
+
+        original_start = max(mask_end + 1, original_start)
+
+    # Add the remaining part of the original range, if any
+    if original_start <= original_end:
+        result.append([original_start, original_end])
+
+    return result
+
+
+def update_det_boxes(dt_boxes, mfd_res):
+    new_dt_boxes = []
+    for text_box in dt_boxes:
+        text_bbox = points_to_bbox(text_box)
+        masks_list = []
+        for mf_box in mfd_res:
+            mf_bbox = mf_box['bbox']
+            if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):
+                masks_list.append([mf_bbox[0], mf_bbox[2]])
+        text_x_range = [text_bbox[0], text_bbox[2]]
+        text_remove_mask_range = remove_intervals(text_x_range, masks_list)
+        temp_dt_box = []
+        for text_remove_mask in text_remove_mask_range:
+            temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]]))
+        if len(temp_dt_box) > 0:
+            new_dt_boxes.extend(temp_dt_box)
+    return new_dt_boxes
+
+
+def merge_overlapping_spans(spans):
+    """
+    Merges overlapping spans on the same line.
+
+    :param spans: A list of span coordinates [(x1, y1, x2, y2), ...]
+    :return: A list of merged spans
+    """
+    # Return an empty list if the input spans list is empty
+    if not spans:
+        return []
+
+    # Sort spans by their starting x-coordinate
+    spans.sort(key=lambda x: x[0])
+
+    # Initialize the list of merged spans
+    merged = []
+    for span in spans:
+        # Unpack span coordinates
+        x1, y1, x2, y2 = span
+        # If the merged list is empty or there's no horizontal overlap, add the span directly
+        if not merged or merged[-1][2] < x1:
+            merged.append(span)
+        else:
+            # If there is horizontal overlap, merge the current span with the previous one
+            last_span = merged.pop()
+            # Update the merged span's top-left corner to the smaller (x1, y1) and bottom-right to the larger (x2, y2)
+            x1 = min(last_span[0], x1)
+            y1 = min(last_span[1], y1)
+            x2 = max(last_span[2], x2)
+            y2 = max(last_span[3], y2)
+            # Add the merged span back to the list
+            merged.append((x1, y1, x2, y2))
+
+    # Return the list of merged spans
+    return merged
+
+
+def merge_det_boxes(dt_boxes):
+    """
+    Merge detection boxes.
+
+    This function takes a list of detected bounding boxes, each represented by four corner points.
+    The goal is to merge these bounding boxes into larger text regions.
+
+    Parameters:
+    dt_boxes (list): A list containing multiple text detection boxes, where each box is defined by four corner points.
+
+    Returns:
+    list: A list containing the merged text regions, where each region is represented by four corner points.
+    """
+    # Convert the detection boxes into a dictionary format with bounding boxes and type
+    dt_boxes_dict_list = []
+    angle_boxes_list = []
+    for text_box in dt_boxes:
+        text_bbox = points_to_bbox(text_box)
+        if text_bbox[2] <= text_bbox[0] or text_bbox[3] <= text_bbox[1]:
+            angle_boxes_list.append(text_box)
+            continue
+        text_box_dict = {
+            'bbox': text_bbox,
+            'type': 'text',
+        }
+        dt_boxes_dict_list.append(text_box_dict)
+
+    # Merge adjacent text regions into lines
+    lines = merge_spans_to_line(dt_boxes_dict_list)
+
+    # Initialize a new list for storing the merged text regions
+    new_dt_boxes = []
+    for line in lines:
+        line_bbox_list = []
+        for span in line:
+            line_bbox_list.append(span['bbox'])
+
+        # Merge overlapping text regions within the same line
+        merged_spans = merge_overlapping_spans(line_bbox_list)
+
+        # Convert the merged text regions back to point format and add them to the new detection box list
+        for span in merged_spans:
+            new_dt_boxes.append(bbox_to_points(span))
+
+    new_dt_boxes.extend(angle_boxes_list)
+
+    return new_dt_boxes
+
+
+def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list):
+    paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
+    # Adjust the coordinates of the formula area
+    adjusted_mfdetrec_res = []
+    for mf_res in single_page_mfdetrec_res:
+        mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
+        # Adjust the coordinates of the formula area to the coordinates relative to the cropping area
+        x0 = mf_xmin - xmin + paste_x
+        y0 = mf_ymin - ymin + paste_y
+        x1 = mf_xmax - xmin + paste_x
+        y1 = mf_ymax - ymin + paste_y
+        # Filter formula blocks outside the graph
+        if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
+            continue
+        else:
+            adjusted_mfdetrec_res.append({
+                "bbox": [x0, y0, x1, y1],
+            })
+    return adjusted_mfdetrec_res
+
+
+def get_ocr_result_list(ocr_res, useful_list):
+    paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
+    ocr_result_list = []
+    for box_ocr_res in ocr_res:
+
+        p1, p2, p3, p4 = box_ocr_res[0]
+        text, score = box_ocr_res[1]
+        average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
+        if average_angle_degrees > 0.5:
+            # logger.info(f"average_angle_degrees: {average_angle_degrees}, text: {text}")
+            # 与x轴的夹角超过0.5度,对边界做一下矫正
+            # 计算几何中心
+            x_center = sum(point[0] for point in box_ocr_res[0]) / 4
+            y_center = sum(point[1] for point in box_ocr_res[0]) / 4
+            new_height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
+            new_width = p3[0] - p1[0]
+            p1 = [x_center - new_width / 2, y_center - new_height / 2]
+            p2 = [x_center + new_width / 2, y_center - new_height / 2]
+            p3 = [x_center + new_width / 2, y_center + new_height / 2]
+            p4 = [x_center - new_width / 2, y_center + new_height / 2]
+
+        # Convert the coordinates back to the original coordinate system
+        p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
+        p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
+        p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
+        p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
+
+        ocr_result_list.append({
+            'category_id': 15,
+            'poly': p1 + p2 + p3 + p4,
+            'score': float(round(score, 2)),
+            'text': text,
+        })
+
+    return ocr_result_list
+
+
+def calculate_angle_degrees(poly):
+    # 定义对角线的顶点
+    diagonal1 = (poly[0], poly[2])
+    diagonal2 = (poly[1], poly[3])
+
+    # 计算对角线的斜率
+    def slope(p1, p2):
+        return (p2[1] - p1[1]) / (p2[0] - p1[0]) if p2[0] != p1[0] else float('inf')
+
+    slope1 = slope(diagonal1[0], diagonal1[1])
+    slope2 = slope(diagonal2[0], diagonal2[1])
+
+    # 计算对角线与x轴的夹角(以弧度为单位)
+    angle1_radians = math.atan(slope1)
+    angle2_radians = math.atan(slope2)
+
+    # 将弧度转换为角度
+    angle1_degrees = math.degrees(angle1_radians)
+    angle2_degrees = math.degrees(angle2_radians)
+
+    # 取两条对角线与x轴夹角的平均值
+    average_angle_degrees = abs((angle1_degrees + angle2_degrees) / 2)
+    # logger.info(f"average_angle_degrees: {average_angle_degrees}")
+    return average_angle_degrees
+

+ 168 - 0
magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py

@@ -0,0 +1,168 @@
+import copy
+import time
+
+import cv2
+import numpy as np
+from paddleocr import PaddleOCR
+from paddleocr.paddleocr import check_img, logger
+from paddleocr.ppocr.utils.utility import alpha_to_color, binarize_img
+from paddleocr.tools.infer.predict_system import sorted_boxes
+from paddleocr.tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop
+
+from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes
+
+
+class ModifiedPaddleOCR(PaddleOCR):
+    def ocr(self,
+            img,
+            det=True,
+            rec=True,
+            cls=True,
+            bin=False,
+            inv=False,
+            alpha_color=(255, 255, 255),
+            mfd_res=None,
+            ):
+        """
+        OCR with PaddleOCR
+        args:
+            img: img for OCR, support ndarray, img_path and list or ndarray
+            det: use text detection or not. If False, only rec will be exec. Default is True
+            rec: use text recognition or not. If False, only det will be exec. Default is True
+            cls: use angle classifier or not. Default is True. If True, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False.
+            bin: binarize image to black and white. Default is False.
+            inv: invert image colors. Default is False.
+            alpha_color: set RGB color Tuple for transparent parts replacement. Default is pure white.
+        """
+        assert isinstance(img, (np.ndarray, list, str, bytes))
+        if isinstance(img, list) and det == True:
+            logger.error('When input a list of images, det must be false')
+            exit(0)
+        if cls == True and self.use_angle_cls == False:
+            pass
+            # logger.warning(
+            #     'Since the angle classifier is not initialized, it will not be used during the forward process'
+            # )
+
+        img = check_img(img)
+        # for infer pdf file
+        if isinstance(img, list):
+            if self.page_num > len(img) or self.page_num == 0:
+                self.page_num = len(img)
+            imgs = img[:self.page_num]
+        else:
+            imgs = [img]
+
+        def preprocess_image(_image):
+            _image = alpha_to_color(_image, alpha_color)
+            if inv:
+                _image = cv2.bitwise_not(_image)
+            if bin:
+                _image = binarize_img(_image)
+            return _image
+
+        if det and rec:
+            ocr_res = []
+            for idx, img in enumerate(imgs):
+                img = preprocess_image(img)
+                dt_boxes, rec_res, _ = self.__call__(img, cls, mfd_res=mfd_res)
+                if not dt_boxes and not rec_res:
+                    ocr_res.append(None)
+                    continue
+                tmp_res = [[box.tolist(), res]
+                           for box, res in zip(dt_boxes, rec_res)]
+                ocr_res.append(tmp_res)
+            return ocr_res
+        elif det and not rec:
+            ocr_res = []
+            for idx, img in enumerate(imgs):
+                img = preprocess_image(img)
+                dt_boxes, elapse = self.text_detector(img)
+                if not dt_boxes:
+                    ocr_res.append(None)
+                    continue
+                tmp_res = [box.tolist() for box in dt_boxes]
+                ocr_res.append(tmp_res)
+            return ocr_res
+        else:
+            ocr_res = []
+            cls_res = []
+            for idx, img in enumerate(imgs):
+                if not isinstance(img, list):
+                    img = preprocess_image(img)
+                    img = [img]
+                if self.use_angle_cls and cls:
+                    img, cls_res_tmp, elapse = self.text_classifier(img)
+                    if not rec:
+                        cls_res.append(cls_res_tmp)
+                rec_res, elapse = self.text_recognizer(img)
+                ocr_res.append(rec_res)
+            if not rec:
+                return cls_res
+            return ocr_res
+
+    def __call__(self, img, cls=True, mfd_res=None):
+        time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
+
+        if img is None:
+            logger.debug("no valid image provided")
+            return None, None, time_dict
+
+        start = time.time()
+        ori_im = img.copy()
+        dt_boxes, elapse = self.text_detector(img)
+        time_dict['det'] = elapse
+
+        if dt_boxes is None:
+            logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
+            end = time.time()
+            time_dict['all'] = end - start
+            return None, None, time_dict
+        else:
+            logger.debug("dt_boxes num : {}, elapsed : {}".format(
+                len(dt_boxes), elapse))
+        img_crop_list = []
+
+        dt_boxes = sorted_boxes(dt_boxes)
+
+        # @todo 目前是在bbox层merge,对倾斜文本行的兼容性不佳,需要修改成支持poly的merge
+        # dt_boxes = merge_det_boxes(dt_boxes)
+
+
+        if mfd_res:
+            bef = time.time()
+            dt_boxes = update_det_boxes(dt_boxes, mfd_res)
+            aft = time.time()
+            logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
+                len(dt_boxes), aft - bef))
+
+        for bno in range(len(dt_boxes)):
+            tmp_box = copy.deepcopy(dt_boxes[bno])
+            if self.args.det_box_type == "quad":
+                img_crop = get_rotate_crop_image(ori_im, tmp_box)
+            else:
+                img_crop = get_minarea_rect_crop(ori_im, tmp_box)
+            img_crop_list.append(img_crop)
+        if self.use_angle_cls and cls:
+            img_crop_list, angle_list, elapse = self.text_classifier(
+                img_crop_list)
+            time_dict['cls'] = elapse
+            logger.debug("cls num  : {}, elapsed : {}".format(
+                len(img_crop_list), elapse))
+
+        rec_res, elapse = self.text_recognizer(img_crop_list)
+        time_dict['rec'] = elapse
+        logger.debug("rec_res num  : {}, elapsed : {}".format(
+            len(rec_res), elapse))
+        if self.args.save_crop_res:
+            self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,
+                                   rec_res)
+        filter_boxes, filter_rec_res = [], []
+        for box, rec_result in zip(dt_boxes, rec_res):
+            text, score = rec_result
+            if score >= self.drop_score:
+                filter_boxes.append(box)
+                filter_rec_res.append(rec_result)
+        end = time.time()
+        time_dict['all'] = end - start
+        return filter_boxes, filter_rec_res, time_dict

+ 213 - 0
magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py

@@ -0,0 +1,213 @@
+import copy
+import time
+
+
+import cv2
+import numpy as np
+from paddleocr import PaddleOCR
+from paddleocr.paddleocr import check_img, logger
+from paddleocr.ppocr.utils.utility import alpha_to_color, binarize_img
+from paddleocr.tools.infer.predict_system import sorted_boxes
+from paddleocr.tools.infer.utility import slice_generator, merge_fragmented, get_rotate_crop_image, \
+    get_minarea_rect_crop
+
+from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes
+
+
+class ModifiedPaddleOCR(PaddleOCR):
+
+    def ocr(
+        self,
+        img,
+        det=True,
+        rec=True,
+        cls=True,
+        bin=False,
+        inv=False,
+        alpha_color=(255, 255, 255),
+        slice={},
+        mfd_res=None,
+    ):
+        """
+        OCR with PaddleOCR
+
+        Args:
+            img: Image for OCR. It can be an ndarray, img_path, or a list of ndarrays.
+            det: Use text detection or not. If False, only text recognition will be executed. Default is True.
+            rec: Use text recognition or not. If False, only text detection will be executed. Default is True.
+            cls: Use angle classifier or not. Default is True. If True, the text with a rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance.
+            bin: Binarize image to black and white. Default is False.
+            inv: Invert image colors. Default is False.
+            alpha_color: Set RGB color Tuple for transparent parts replacement. Default is pure white.
+            slice: Use sliding window inference for large images. Both det and rec must be True. Requires int values for slice["horizontal_stride"], slice["vertical_stride"], slice["merge_x_thres"], slice["merge_y_thres"] (See doc/doc_en/slice_en.md). Default is {}.
+
+        Returns:
+            If both det and rec are True, returns a list of OCR results for each image. Each OCR result is a list of bounding boxes and recognized text for each detected text region.
+            If det is True and rec is False, returns a list of detected bounding boxes for each image.
+            If det is False and rec is True, returns a list of recognized text for each image.
+            If both det and rec are False, returns a list of angle classification results for each image.
+
+        Raises:
+            AssertionError: If the input image is not of type ndarray, list, str, or bytes.
+            SystemExit: If det is True and the input is a list of images.
+
+        Note:
+            - If the angle classifier is not initialized (use_angle_cls=False), it will not be used during the forward process.
+            - For PDF files, if the input is a list of images and the page_num is specified, only the first page_num images will be processed.
+            - The preprocess_image function is used to preprocess the input image by applying alpha color replacement, inversion, and binarization if specified.
+        """
+        assert isinstance(img, (np.ndarray, list, str, bytes))
+        if isinstance(img, list) and det == True:
+            logger.error("When input a list of images, det must be false")
+            exit(0)
+        if cls == True and self.use_angle_cls == False:
+            logger.warning(
+                "Since the angle classifier is not initialized, it will not be used during the forward process"
+            )
+
+        img, flag_gif, flag_pdf = check_img(img, alpha_color)
+        # for infer pdf file
+        if isinstance(img, list) and flag_pdf:
+            if self.page_num > len(img) or self.page_num == 0:
+                imgs = img
+            else:
+                imgs = img[: self.page_num]
+        else:
+            imgs = [img]
+
+        def preprocess_image(_image):
+            _image = alpha_to_color(_image, alpha_color)
+            if inv:
+                _image = cv2.bitwise_not(_image)
+            if bin:
+                _image = binarize_img(_image)
+            return _image
+
+        if det and rec:
+            ocr_res = []
+            for img in imgs:
+                img = preprocess_image(img)
+                dt_boxes, rec_res, _ = self.__call__(img, cls, slice, mfd_res=mfd_res)
+                if not dt_boxes and not rec_res:
+                    ocr_res.append(None)
+                    continue
+                tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
+                ocr_res.append(tmp_res)
+            return ocr_res
+        elif det and not rec:
+            ocr_res = []
+            for img in imgs:
+                img = preprocess_image(img)
+                dt_boxes, elapse = self.text_detector(img)
+                if dt_boxes.size == 0:
+                    ocr_res.append(None)
+                    continue
+                tmp_res = [box.tolist() for box in dt_boxes]
+                ocr_res.append(tmp_res)
+            return ocr_res
+        else:
+            ocr_res = []
+            cls_res = []
+            for img in imgs:
+                if not isinstance(img, list):
+                    img = preprocess_image(img)
+                    img = [img]
+                if self.use_angle_cls and cls:
+                    img, cls_res_tmp, elapse = self.text_classifier(img)
+                    if not rec:
+                        cls_res.append(cls_res_tmp)
+                rec_res, elapse = self.text_recognizer(img)
+                ocr_res.append(rec_res)
+            if not rec:
+                return cls_res
+            return ocr_res
+
+    def __call__(self, img, cls=True, slice={}, mfd_res=None):
+        time_dict = {"det": 0, "rec": 0, "cls": 0, "all": 0}
+
+        if img is None:
+            logger.debug("no valid image provided")
+            return None, None, time_dict
+
+        start = time.time()
+        ori_im = img.copy()
+        if slice:
+            slice_gen = slice_generator(
+                img,
+                horizontal_stride=slice["horizontal_stride"],
+                vertical_stride=slice["vertical_stride"],
+            )
+            elapsed = []
+            dt_slice_boxes = []
+            for slice_crop, v_start, h_start in slice_gen:
+                dt_boxes, elapse = self.text_detector(slice_crop, use_slice=True)
+                if dt_boxes.size:
+                    dt_boxes[:, :, 0] += h_start
+                    dt_boxes[:, :, 1] += v_start
+                    dt_slice_boxes.append(dt_boxes)
+                    elapsed.append(elapse)
+            dt_boxes = np.concatenate(dt_slice_boxes)
+
+            dt_boxes = merge_fragmented(
+                boxes=dt_boxes,
+                x_threshold=slice["merge_x_thres"],
+                y_threshold=slice["merge_y_thres"],
+            )
+            elapse = sum(elapsed)
+        else:
+            dt_boxes, elapse = self.text_detector(img)
+
+        time_dict["det"] = elapse
+
+        if dt_boxes is None:
+            logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
+            end = time.time()
+            time_dict["all"] = end - start
+            return None, None, time_dict
+        else:
+            logger.debug(
+                "dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse)
+            )
+        img_crop_list = []
+
+        dt_boxes = sorted_boxes(dt_boxes)
+
+        if mfd_res:
+            bef = time.time()
+            dt_boxes = update_det_boxes(dt_boxes, mfd_res)
+            aft = time.time()
+            logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
+                len(dt_boxes), aft - bef))
+
+        for bno in range(len(dt_boxes)):
+            tmp_box = copy.deepcopy(dt_boxes[bno])
+            if self.args.det_box_type == "quad":
+                img_crop = get_rotate_crop_image(ori_im, tmp_box)
+            else:
+                img_crop = get_minarea_rect_crop(ori_im, tmp_box)
+            img_crop_list.append(img_crop)
+        if self.use_angle_cls and cls:
+            img_crop_list, angle_list, elapse = self.text_classifier(img_crop_list)
+            time_dict["cls"] = elapse
+            logger.debug(
+                "cls num  : {}, elapsed : {}".format(len(img_crop_list), elapse)
+            )
+        if len(img_crop_list) > 1000:
+            logger.debug(
+                f"rec crops num: {len(img_crop_list)}, time and memory cost may be large."
+            )
+
+        rec_res, elapse = self.text_recognizer(img_crop_list)
+        time_dict["rec"] = elapse
+        logger.debug("rec_res num  : {}, elapsed : {}".format(len(rec_res), elapse))
+        if self.args.save_crop_res:
+            self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list, rec_res)
+        filter_boxes, filter_rec_res = [], []
+        for box, rec_result in zip(dt_boxes, rec_res):
+            text, score = rec_result[0], rec_result[1]
+            if score >= self.drop_score:
+                filter_boxes.append(box)
+                filter_rec_res.append(rec_result)
+        end = time.time()
+        time_dict["all"] = end - start
+        return filter_boxes, filter_rec_res, time_dict

+ 0 - 0
magic_pdf/model/sub_modules/reading_oreder/__init__.py


+ 0 - 0
magic_pdf/model/sub_modules/reading_oreder/layoutreader/__init__.py


+ 0 - 0
magic_pdf/model/v3/helpers.py → magic_pdf/model/sub_modules/reading_oreder/layoutreader/helpers.py


+ 0 - 0
magic_pdf/model/v3/xycut.py → magic_pdf/model/sub_modules/reading_oreder/layoutreader/xycut.py


+ 0 - 0
magic_pdf/model/sub_modules/table/__init__.py


+ 0 - 0
magic_pdf/model/sub_modules/table/rapidtable/__init__.py


+ 14 - 0
magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py

@@ -0,0 +1,14 @@
+import numpy as np
+from rapid_table import RapidTable
+from rapidocr_paddle import RapidOCR
+
+
+class RapidTableModel(object):
+    def __init__(self):
+        self.table_model = RapidTable()
+        self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
+
+    def predict(self, image):
+        ocr_result, _ = self.ocr_engine(np.asarray(image))
+        html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
+        return html_code, table_cell_bboxes, elapse

+ 0 - 0
magic_pdf/model/sub_modules/table/structeqtable/__init__.py


+ 3 - 11
magic_pdf/model/pek_sub_modules/structeqtable/StructTableModel.py → magic_pdf/model/sub_modules/table/structeqtable/struct_eqtable.py

@@ -1,8 +1,8 @@
-import re
-
 import torch
 from struct_eqtable import build_model
 
+from magic_pdf.model.sub_modules.table.table_utils import minify_html
+
 
 class StructTableModel:
     def __init__(self, model_path, max_new_tokens=1024, max_time=60):
@@ -31,15 +31,7 @@ class StructTableModel:
         )
 
         if output_format == "html":
-            results = [self.minify_html(html) for html in results]
+            results = [minify_html(html) for html in results]
 
         return results
 
-    def minify_html(self, html):
-        # 移除多余的空白字符
-        html = re.sub(r'\s+', ' ', html)
-        # 移除行尾的空白字符
-        html = re.sub(r'\s*>\s*', '>', html)
-        # 移除标签前的空白字符
-        html = re.sub(r'\s*<\s*', '<', html)
-        return html.strip()

+ 11 - 0
magic_pdf/model/sub_modules/table/table_utils.py

@@ -0,0 +1,11 @@
+import re
+
+
+def minify_html(html):
+    # 移除多余的空白字符
+    html = re.sub(r'\s+', ' ', html)
+    # 移除行尾的空白字符
+    html = re.sub(r'\s*>\s*', '>', html)
+    # 移除标签前的空白字符
+    html = re.sub(r'\s*<\s*', '<', html)
+    return html.strip()

+ 0 - 0
magic_pdf/model/sub_modules/table/tablemaster/__init__.py


+ 1 - 1
magic_pdf/model/ppTableModel.py → magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py

@@ -7,7 +7,7 @@ from PIL import Image
 import numpy as np
 
 
-class ppTableModel(object):
+class TableMasterPaddleModel(object):
     """
         This class is responsible for converting image of table into HTML format using a pre-trained model.
 

+ 3 - 3
magic_pdf/pdf_parse_union_core_v2.py

@@ -164,8 +164,8 @@ class ModelSingleton:
 
 
 def do_predict(boxes: List[List[int]], model) -> List[int]:
-    from magic_pdf.model.v3.helpers import (boxes2inputs, parse_logits,
-                                            prepare_inputs)
+    from magic_pdf.model.sub_modules.reading_oreder.layoutreader.helpers import (boxes2inputs, parse_logits,
+                                                                                 prepare_inputs)
 
     inputs = boxes2inputs(boxes)
     inputs = prepare_inputs(inputs, model)
@@ -206,7 +206,7 @@ def cal_block_index(fix_blocks, sorted_bboxes):
                 del block['real_lines']
 
         import numpy as np
-        from magic_pdf.model.v3.xycut import recursive_xy_cut
+        from magic_pdf.model.sub_modules.reading_oreder.layoutreader.xycut import recursive_xy_cut
 
         random_boxes = np.array(block_bboxes)
         np.random.shuffle(random_boxes)

+ 1 - 0
setup.py

@@ -49,6 +49,7 @@ if __name__ == '__main__':
                      "doclayout_yolo==0.0.2",  # doclayout_yolo
                      "rapidocr-paddle",  # rapidocr-paddle
                      "rapid_table",  # rapid_table
+                     "PyYAML",  # yaml
                      "detectron2"
                      ],
         },

+ 2 - 2
tests/test_table/test_tablemaster.py

@@ -2,7 +2,7 @@ import unittest
 from PIL import Image
 from lxml import etree
 
-from magic_pdf.model.ppTableModel import ppTableModel
+from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
 
 
 class TestppTableModel(unittest.TestCase):
@@ -11,7 +11,7 @@ class TestppTableModel(unittest.TestCase):
         # 修改table模型路径
         config = {"device": "cuda",
                   "model_dir": "/home/quyuan/.cache/modelscope/hub/opendatalab/PDF-Extract-Kit/models/TabRec/TableMaster"}
-        table_model = ppTableModel(config)
+        table_model = TableMasterPaddleModel(config)
         res = table_model.img2html(img)
         # 验证生成的 HTML 是否符合预期
         parser = etree.HTMLParser()