Browse Source

refactor(model): implement thread-safe OCR model initialization

- Add threading support for OCR model initialization
- Modify AtomModelSingleton to handle thread-specific instances
- Update PDFExtractKit and PDFParseUnionCoreV2 to use new thread-safe OCR initialization
myhloli 11 months ago
parent
commit
f2a92d5782

+ 4 - 3
magic_pdf/model/pdf_extract_kit.py

@@ -22,7 +22,7 @@ except ImportError:
 
 from magic_pdf.config.constants import *
 from magic_pdf.model.model_list import AtomicModel
-from magic_pdf.model.sub_modules.model_init import AtomModelSingleton, ocr_model_init
+from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
 from magic_pdf.model.sub_modules.model_utils import (
     clean_vram, crop_img, get_res_list_from_layout_res)
 from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
@@ -150,8 +150,9 @@ class CustomPEKModel:
                 device=self.device,
             )
         # 初始化ocr
-        self.ocr_model = ocr_model_init(
-            show_log=show_log,
+        self.ocr_model = atom_model_manager.get_atom_model(
+            atom_model_name=AtomicModel.OCR,
+            ocr_show_log=show_log,
             det_db_box_thresh=0.3,
             lang=self.lang
         )

+ 18 - 3
magic_pdf/model/sub_modules/model_init.py

@@ -57,6 +57,11 @@ def doclayout_yolo_model_init(weight, device='cpu'):
     return model
 
 
+import threading
+current_thread = threading.current_thread()
+current_thread_id = current_thread.ident
+
+
 def ocr_model_init(show_log: bool = False,
                    det_db_box_thresh=0.3,
                    lang=None,
@@ -92,14 +97,24 @@ class AtomModelSingleton:
         return cls._instance
 
     def get_atom_model(self, atom_model_name: str, **kwargs):
+
         lang = kwargs.get('lang', None)
         layout_model_name = kwargs.get('layout_model_name', None)
-        key = (atom_model_name, layout_model_name, lang)
+        table_model_name = kwargs.get('table_model_name', None)
+
+        if atom_model_name in [AtomicModel.OCR]:
+            key = (atom_model_name, lang, current_thread_id)
+        elif atom_model_name in [AtomicModel.Layout]:
+            key = (atom_model_name, layout_model_name)
+        elif atom_model_name in [AtomicModel.Table]:
+            key = (atom_model_name, table_model_name)
+        else:
+            key = atom_model_name
+
         if key not in self._models:
             self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
         return self._models[key]
 
-
 def atom_model_init(model_name: str, **kwargs):
     atom_model = None
     if model_name == AtomicModel.Layout:
@@ -129,7 +144,7 @@ def atom_model_init(model_name: str, **kwargs):
         atom_model = ocr_model_init(
             kwargs.get('ocr_show_log'),
             kwargs.get('det_db_box_thresh'),
-            kwargs.get('lang')
+            kwargs.get('lang'),
         )
     elif model_name == AtomicModel.Table:
         atom_model = table_model_init(

+ 5 - 4
magic_pdf/pdf_parse_union_core_v2.py

@@ -31,7 +31,7 @@ try:
 except ImportError:
     pass
 
-from magic_pdf.model.sub_modules.model_init import ocr_model_init
+from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
 from magic_pdf.para.para_split_v3 import para_split
 from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2
 from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
@@ -231,9 +231,10 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
     if len(empty_spans) > 0:
 
         # 初始化ocr模型
-
-        ocr_model = ocr_model_init(
-            show_log=False,
+        atom_model_manager = AtomModelSingleton()
+        ocr_model = atom_model_manager.get_atom_model(
+            atom_model_name='ocr',
+            ocr_show_log=False,
             det_db_box_thresh=0.3,
             lang=lang
         )