Browse Source

refactor(model): replace ModelSingleton with direct model initialization and improve threading

- Remove usage of ModelSingleton class
- Initialize model directly using custom_model_init function
- Add self._lock attribute to PDFExtractKit class for thread safety- Replace local lock with self._lock for OCR processing
myhloli 11 months ago
parent
commit
6f636b6e7e

+ 4 - 2
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -143,8 +143,10 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
     if lang == "":
         lang = None
 
-    model_manager = ModelSingleton()
-    custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable)
+    # model_manager = ModelSingleton()
+    # custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable)
+
+    custom_model = custom_model_init(ocr, show_log, lang, layout_model, formula_enable, table_enable)
 
     with fitz.open("pdf", pdf_bytes) as doc:
         pdf_page_num = doc.page_count

+ 2 - 2
magic_pdf/model/pdf_extract_kit.py

@@ -37,6 +37,7 @@ class CustomPEKModel:
         """
         ======== model init ========
         """
+        self._lock = Lock()
         # 获取当前文件(即 pdf_extract_kit.py)的绝对路径
         current_file_path = os.path.abspath(__file__)
         # 获取当前文件所在的目录(model)
@@ -211,14 +212,13 @@ class CustomPEKModel:
         # ocr识别
         ocr_start = time.time()
         # Process each area that requires OCR processing
-        lock = Lock()
         for res in ocr_res_list:
             new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
             adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
 
             # OCR recognition
             new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
-            with lock:
+            with self._lock:
                 if self.apply_ocr:
                     ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
                 else: