瀏覽代碼

refactor(ocr): replace AtomModelSingleton with ocr_model_init for OCR model instantiation

- Remove usage of AtomModelSingleton for OCR model creation
- Add ocr_model_init function to initialize OCR model
- Update OCR model initialization in pdf_extract_kit.py and pdf_parse_union_core_v2.py
- Modify txt_spans_extract_v2 function to accept ocr_model as a parameter
- Update parse_page_core function to use ocr_model instead of lang for OCR processing
myhloli 11 月之前
父節點
當前提交
47a83d28f5
共有 2 個文件被更改,包括 28 次插入16 次删除
  1. 9 4
      magic_pdf/model/pdf_extract_kit.py
  2. 19 12
      magic_pdf/pdf_parse_union_core_v2.py

+ 9 - 4
magic_pdf/model/pdf_extract_kit.py

@@ -22,7 +22,7 @@ except ImportError:
 
 from magic_pdf.config.constants import *
 from magic_pdf.model.model_list import AtomicModel
-from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
+from magic_pdf.model.sub_modules.model_init import AtomModelSingleton, ocr_model_init
 from magic_pdf.model.sub_modules.model_utils import (
     clean_vram, crop_img, get_res_list_from_layout_res)
 from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
@@ -150,9 +150,14 @@ class CustomPEKModel:
                 device=self.device,
             )
         # 初始化ocr
-        self.ocr_model = atom_model_manager.get_atom_model(
-            atom_model_name=AtomicModel.OCR,
-            ocr_show_log=show_log,
+        # self.ocr_model = atom_model_manager.get_atom_model(
+        #     atom_model_name=AtomicModel.OCR,
+        #     ocr_show_log=show_log,
+        #     det_db_box_thresh=0.3,
+        #     lang=self.lang
+        # )
+        self.ocr_model = ocr_model_init(
+            show_log=show_log,
             det_db_box_thresh=0.3,
             lang=self.lang
         )

+ 19 - 12
magic_pdf/pdf_parse_union_core_v2.py

@@ -31,7 +31,7 @@ try:
 except ImportError:
     pass
 
-from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
+from magic_pdf.model.sub_modules.model_init import AtomModelSingleton, ocr_model_init
 from magic_pdf.para.para_split_v3 import para_split
 from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2
 from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
@@ -152,7 +152,7 @@ def calculate_char_in_span(char_bbox, span_bbox, char, span_height_radio=0.33):
             return False
 
 
-def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang):
+def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, ocr_model):
 
     text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
 
@@ -231,13 +231,13 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
     if len(empty_spans) > 0:
 
         # 初始化ocr模型
-        atom_model_manager = AtomModelSingleton()
-        ocr_model = atom_model_manager.get_atom_model(
-            atom_model_name='ocr',
-            ocr_show_log=False,
-            det_db_box_thresh=0.3,
-            lang=lang
-        )
+        # atom_model_manager = AtomModelSingleton()
+        # ocr_model = atom_model_manager.get_atom_model(
+        #     atom_model_name='ocr',
+        #     ocr_show_log=False,
+        #     det_db_box_thresh=0.3,
+        #     lang=lang
+        # )
 
         for span in empty_spans:
             # 对span的bbox截图再ocr
@@ -613,7 +613,7 @@ def remove_outside_spans(spans, all_bboxes, all_discarded_blocks):
 
 
 def parse_page_core(
-    page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, lang
+    page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, ocr_model
 ):
     need_drop = False
     drop_reason = []
@@ -682,7 +682,7 @@ def parse_page_core(
     if parse_mode == SupportedPdfParseMethod.TXT:
 
         """使用新版本的混合ocr方案"""
-        spans = txt_spans_extract_v2(page_doc, spans, all_bboxes, all_discarded_blocks, lang)
+        spans = txt_spans_extract_v2(page_doc, spans, all_bboxes, all_discarded_blocks, ocr_model)
 
     elif parse_mode == SupportedPdfParseMethod.OCR:
         pass
@@ -771,6 +771,13 @@ def pdf_parse_union(
     debug_mode=False,
     lang=None,
 ):
+
+    ocr_model = ocr_model_init(
+        show_log=False,
+        det_db_box_thresh=0.3,
+        lang=lang
+    )
+
     pdf_bytes_md5 = compute_md5(dataset.data_bits())
 
     """初始化空的pdf_info_dict"""
@@ -806,7 +813,7 @@ def pdf_parse_union(
         """解析pdf中的每一页"""
         if start_page_id <= page_id <= end_page_id:
             page_info = parse_page_core(
-                page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, lang
+                page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, ocr_model
             )
         else:
             page_info = page.get_page_info()