Jelajahi Sumber

Merge pull request #1208 from myhloli/dev

fix(multi-threading ):Enable multi-threading support for PaddleOCR.
Xiaomeng Zhao 11 bulan lalu
induk
melakukan
92c10d1ec1

+ 4 - 2
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -143,8 +143,10 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
     if lang == "":
         lang = None
 
-    model_manager = ModelSingleton()
-    custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable)
+    # model_manager = ModelSingleton()
+    # custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable)
+
+    custom_model = custom_model_init(ocr, show_log, lang, layout_model, formula_enable, table_enable)
 
     with fitz.open("pdf", pdf_bytes) as doc:
         pdf_page_num = doc.page_count

+ 15 - 10
magic_pdf/model/pdf_extract_kit.py

@@ -22,7 +22,7 @@ except ImportError:
 
 from magic_pdf.config.constants import *
 from magic_pdf.model.model_list import AtomicModel
-from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
+from magic_pdf.model.sub_modules.model_init import AtomModelSingleton, ocr_model_init
 from magic_pdf.model.sub_modules.model_utils import (
     clean_vram, crop_img, get_res_list_from_layout_res)
 from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
@@ -37,6 +37,7 @@ class CustomPEKModel:
         """
         ======== model init ========
         """
+        self._lock = Lock()
         # 获取当前文件(即 pdf_extract_kit.py)的绝对路径
         current_file_path = os.path.abspath(__file__)
         # 获取当前文件所在的目录(model)
@@ -152,9 +153,14 @@ class CustomPEKModel:
                 device=self.device,
             )
         # 初始化ocr
-        self.ocr_model = atom_model_manager.get_atom_model(
-            atom_model_name=AtomicModel.OCR,
-            ocr_show_log=show_log,
+        # self.ocr_model = atom_model_manager.get_atom_model(
+        #     atom_model_name=AtomicModel.OCR,
+        #     ocr_show_log=show_log,
+        #     det_db_box_thresh=0.3,
+        #     lang=self.lang
+        # )
+        self.ocr_model = ocr_model_init(
+            show_log=show_log,
             det_db_box_thresh=0.3,
             lang=self.lang
         )
@@ -211,18 +217,17 @@ class CustomPEKModel:
         # ocr识别
         ocr_start = time.time()
         # Process each area that requires OCR processing
-        lock = Lock()
         for res in ocr_res_list:
             new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
             adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
 
             # OCR recognition
             new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
-            with lock:
-                if self.apply_ocr:
-                    ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
-                else:
-                    ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0]
+            # with self._lock:
+            if self.apply_ocr:
+                ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
+            else:
+                ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0]
 
             # Integration results
             if ocr_res:

+ 10 - 5
magic_pdf/pdf_parse_union_core_v2.py

@@ -31,7 +31,7 @@ try:
 except ImportError:
     pass
 
-from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
+from magic_pdf.model.sub_modules.model_init import AtomModelSingleton, ocr_model_init
 from magic_pdf.para.para_split_v3 import para_split
 from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2
 from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
@@ -231,10 +231,15 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
     if len(empty_spans) > 0:
 
         # 初始化ocr模型
-        atom_model_manager = AtomModelSingleton()
-        ocr_model = atom_model_manager.get_atom_model(
-            atom_model_name="ocr",
-            ocr_show_log=False,
+        # atom_model_manager = AtomModelSingleton()
+        # ocr_model = atom_model_manager.get_atom_model(
+        #     atom_model_name="ocr",
+        #     ocr_show_log=False,
+        #     det_db_box_thresh=0.3,
+        #     lang=lang
+        # )
+        ocr_model = ocr_model_init(
+            show_log=False,
             det_db_box_thresh=0.3,
             lang=lang
         )