Browse Source

feat(model): add onnxruntime support for paddleocr on cpu

- Implement ONNXModelSingleton to manage ONNX models
- Modify ModifiedPaddleOCR to use ONNX models on ARM CPUs without CUDA
- Update RapidTableModel to use RapidOCR with ONNXRuntime on CPU
- Add rapidocr_onnxruntime dependency in setup.py
myhloli 10 tháng trước cách đây
mục cha
commit
512adb6701

+ 1 - 1
magic-pdf.template.json

@@ -7,7 +7,7 @@
     "layoutreader-model-dir":"/tmp/layoutreader",
     "device-mode":"cpu",
     "layout-config": {
-        "model": "layoutlmv3"
+        "model": "doclayout_yolo"
     },
     "formula-config": {
         "mfd_model": "yolo_v8_mfd",

+ 1 - 7
magic_pdf/model/sub_modules/model_init.py

@@ -70,11 +70,6 @@ def ocr_model_init(show_log: bool = False,
                    det_db_unclip_ratio=1.8,
                    ):
 
-    # use_npu = False
-    # device = get_device()
-    # if str(device).startswith("npu"):
-    #     use_npu = True
-
     if lang is not None and lang != '':
         model = ModifiedPaddleOCR(
             show_log=show_log,
@@ -82,7 +77,6 @@ def ocr_model_init(show_log: bool = False,
             lang=lang,
             use_dilation=use_dilation,
             det_db_unclip_ratio=det_db_unclip_ratio,
-            # use_npu=use_npu,
         )
     else:
         model = ModifiedPaddleOCR(
@@ -90,7 +84,6 @@ def ocr_model_init(show_log: bool = False,
             det_db_box_thresh=det_db_box_thresh,
             use_dilation=use_dilation,
             det_db_unclip_ratio=det_db_unclip_ratio,
-            # use_npu=use_npu,
         )
     return model
 
@@ -160,6 +153,7 @@ def atom_model_init(model_name: str, **kwargs):
             kwargs.get('table_model_path'),
             kwargs.get('table_max_time'),
             kwargs.get('device'),
+            kwargs.get('ocr_engine')
         )
     else:
         logger.error('model name not allow')

+ 51 - 1
magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py

@@ -303,4 +303,54 @@ def calculate_is_angle(poly):
         return False
     else:
         # logger.info((p3[1] - p1[1])/height)
-        return True
+        return True
+
+
+class ONNXModelSingleton:
+    _instance = None
+    _models = {}
+
+    def __new__(cls, *args, **kwargs):
+        if cls._instance is None:
+            cls._instance = super().__new__(cls)
+        return cls._instance
+
+    def get_onnx_model(self, **kwargs):
+
+        lang = kwargs.get('lang', None)
+        det_db_box_thresh = kwargs.get('det_db_box_thresh', 0.3)
+        use_dilation = kwargs.get('use_dilation', True)
+        det_db_unclip_ratio = kwargs.get('det_db_unclip_ratio', 1.8)
+        key = (lang, det_db_box_thresh, use_dilation, det_db_unclip_ratio)
+        if key not in self._models:
+            self._models[key] = onnx_model_init(key)
+        return self._models[key]
+
+def onnx_model_init(key):
+
+    import importlib.resources
+
+    resource_path = importlib.resources.path('rapidocr_onnxruntime.models','')
+
+    onnx_model = None
+    additional_ocr_params = {
+        "use_onnx": True,
+        "det_model_dir": f'{resource_path}/ch_PP-OCRv4_det_infer.onnx',
+        "rec_model_dir": f'{resource_path}/ch_PP-OCRv4_rec_infer.onnx',
+        "cls_model_dir": f'{resource_path}/ch_ppocr_mobile_v2.0_cls_infer.onnx',
+        "det_db_box_thresh": key[1],
+        "use_dilation": key[2],
+        "det_db_unclip_ratio": key[3],
+    }
+    logger.info(f"additional_ocr_params: {additional_ocr_params}")
+    if key[0] is not None:
+        additional_ocr_params["lang"] = key[0]
+
+    from paddleocr import PaddleOCR
+    onnx_model = PaddleOCR(**additional_ocr_params)
+
+    if onnx_model is None:
+        logger.error('model init failed')
+        exit(1)
+    else:
+        return onnx_model

+ 30 - 6
magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py

@@ -1,7 +1,9 @@
 import copy
+import platform
 import time
 import cv2
 import numpy as np
+import torch
 
 from paddleocr import PaddleOCR
 from ppocr.utils.logging import get_logger
@@ -9,12 +11,23 @@ from ppocr.utils.utility import alpha_to_color, binarize_img
 from tools.infer.predict_system import sorted_boxes
 from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop
 
-from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes, check_img
+from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes, check_img, \
+    ONNXModelSingleton
 
 logger = get_logger()
 
 
 class ModifiedPaddleOCR(PaddleOCR):
+    def __init__(self, *args, **kwargs):
+
+        super().__init__(*args, **kwargs)
+
+        # 在cpu架构为arm且不支持cuda时调用onnx、
+        if not torch.cuda.is_available() and platform.machine() in ['arm64', 'aarch64']:
+            self.use_onnx = True
+            onnx_model_manager = ONNXModelSingleton()
+            self.additional_ocr = onnx_model_manager.get_onnx_model(**kwargs)
+
     def ocr(self,
             img,
             det=True,
@@ -79,7 +92,10 @@ class ModifiedPaddleOCR(PaddleOCR):
             ocr_res = []
             for img in imgs:
                 img = preprocess_image(img)
-                dt_boxes, elapse = self.text_detector(img)
+                if self.use_onnx:
+                    dt_boxes, elapse = self.additional_ocr.text_detector(img)
+                else:
+                    dt_boxes, elapse = self.text_detector(img)
                 if dt_boxes is None:
                     ocr_res.append(None)
                     continue
@@ -106,7 +122,10 @@ class ModifiedPaddleOCR(PaddleOCR):
                     img, cls_res_tmp, elapse = self.text_classifier(img)
                     if not rec:
                         cls_res.append(cls_res_tmp)
-                rec_res, elapse = self.text_recognizer(img)
+                if self.use_onnx:
+                    rec_res, elapse = self.additional_ocr.text_recognizer(img)
+                else:
+                    rec_res, elapse = self.text_recognizer(img)
                 ocr_res.append(rec_res)
             if not rec:
                 return cls_res
@@ -121,7 +140,10 @@ class ModifiedPaddleOCR(PaddleOCR):
 
         start = time.time()
         ori_im = img.copy()
-        dt_boxes, elapse = self.text_detector(img)
+        if self.use_onnx:
+            dt_boxes, elapse = self.additional_ocr.text_detector(img)
+        else:
+            dt_boxes, elapse = self.text_detector(img)
         time_dict['det'] = elapse
 
         if dt_boxes is None:
@@ -159,8 +181,10 @@ class ModifiedPaddleOCR(PaddleOCR):
             time_dict['cls'] = elapse
             logger.debug("cls num  : {}, elapsed : {}".format(
                 len(img_crop_list), elapse))
-
-        rec_res, elapse = self.text_recognizer(img_crop_list)
+        if self.use_onnx:
+            rec_res, elapse = self.additional_ocr.text_recognizer(img_crop_list)
+        else:
+            rec_res, elapse = self.text_recognizer(img_crop_list)
         time_dict['rec'] = elapse
         logger.debug("rec_res num  : {}, elapsed : {}".format(
             len(rec_res), elapse))

+ 7 - 2
magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py

@@ -1,8 +1,8 @@
 import cv2
 import numpy as np
+import torch
 from loguru import logger
 from rapid_table import RapidTable
-from rapidocr_paddle import RapidOCR
 
 
 class RapidTableModel(object):
@@ -10,7 +10,12 @@ class RapidTableModel(object):
         self.table_model = RapidTable()
         if ocr_engine is None:
             self.ocr_model_name = "RapidOCR"
-            self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
+            if torch.cuda.is_available():
+                from rapidocr_paddle import RapidOCR
+                self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
+            else:
+                from rapidocr_onnxruntime import RapidOCR
+                self.ocr_engine = RapidOCR()
         else:
             self.ocr_model_name = "PaddleOCR"
             self.ocr_engine = ocr_engine

+ 1 - 0
magic_pdf/post_proc/llm_aided.py

@@ -5,6 +5,7 @@ from magic_pdf.dict2md.ocr_mkcontent import merge_para_with_text
 from openai import OpenAI
 
 
+#@todo: 有的公式以"\"结尾,这样会导致尾部拼接的"$"被转义,也需要修复
 formula_optimize_prompt = """请根据以下指南修正LaTeX公式的错误,确保公式能够渲染且符合原始内容:
 
 1. 修正渲染或编译错误:

+ 1 - 0
setup.py

@@ -50,6 +50,7 @@ if __name__ == '__main__':
                      "accelerate",  # struct-eqtable依赖
                      "doclayout_yolo==0.0.2",  # doclayout_yolo
                      "rapidocr-paddle",  # rapidocr-paddle
+                     "rapidocr_onnxruntime",
                      "rapid_table",  # rapid_table
                      "PyYAML",  # yaml
                      "openai",  # openai SDK