浏览代码

refactor(magic_pdf): optimize model initialization and threading

- Remove unnecessary threading.Lock in AtomModelSingleton
- Add threading.Lock to CustomPEKModel for OCR processing
- Simplify model initialization logic in AtomModelSingleton
赵小蒙 11 月之前
父节点
当前提交
878f3de004
共有 2 个文件被更改,包括 11 次插入17 次删除
  1. 8 4
      magic_pdf/model/pdf_extract_kit.py
  2. 3 13
      magic_pdf/model/sub_modules/model_init.py

+ 8 - 4
magic_pdf/model/pdf_extract_kit.py

@@ -28,6 +28,8 @@ from magic_pdf.model.sub_modules.model_utils import (
 from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
     get_adjusted_mfdetrec_res, get_ocr_result_list)
 
+from threading import Lock
+
 
 class CustomPEKModel:
 
@@ -209,16 +211,18 @@ class CustomPEKModel:
         # ocr识别
         ocr_start = time.time()
         # Process each area that requires OCR processing
+        lock = Lock()
         for res in ocr_res_list:
             new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
             adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
 
             # OCR recognition
             new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
-            if self.apply_ocr:
-                ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
-            else:
-                ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0]
+            with lock:
+                if self.apply_ocr:
+                    ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
+                else:
+                    ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0]
 
             # Integration results
             if ocr_res:

+ 3 - 13
magic_pdf/model/sub_modules/model_init.py

@@ -82,12 +82,9 @@ def ocr_model_init(show_log: bool = False,
     return model
 
 
-from threading import Lock
-
 class AtomModelSingleton:
     _instance = None
     _models = {}
-    _lock = Lock()
 
     def __new__(cls, *args, **kwargs):
         if cls._instance is None:
@@ -98,17 +95,10 @@ class AtomModelSingleton:
         lang = kwargs.get('lang', None)
         layout_model_name = kwargs.get('layout_model_name', None)
         key = (atom_model_name, layout_model_name, lang)
-        if atom_model_name == AtomicModel.OCR:
-            with self._lock:
-                if key not in self._models:
-                    self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
-                else:
-                    return self._models[key]
+        if key not in self._models:
+            self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
         else:
-            if key not in self._models:
-                self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
-            else:
-                return self._models[key]
+            return self._models[key]
 
 
 def atom_model_init(model_name: str, **kwargs):