浏览代码

refactor(magic-pdf): optimize model initialization and concurrency control

- Remove concurrency limit logic from app.py
- Update model initialization process in various modules
- Remove unused VRAM check for concurrency limit
- Refactor OCR model initialization in pdf_extract_kit.py
- Update txt_spans_extract_v2 function to use lang parameter instead of ocr_model
myhloli 11 月之前
父节点
当前提交
012a46e07d

+ 2 - 4
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -143,10 +143,8 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
     if lang == "":
         lang = None
 
-    # model_manager = ModelSingleton()
-    # custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable)
-
-    custom_model = custom_model_init(ocr, show_log, lang, layout_model, formula_enable, table_enable)
+    model_manager = ModelSingleton()
+    custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable)
 
     with fitz.open("pdf", pdf_bytes) as doc:
         pdf_page_num = doc.page_count

+ 4 - 9
magic_pdf/model/pdf_extract_kit.py

@@ -22,7 +22,7 @@ except ImportError:
 
 from magic_pdf.config.constants import *
 from magic_pdf.model.model_list import AtomicModel
-from magic_pdf.model.sub_modules.model_init import AtomModelSingleton, ocr_model_init
+from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
 from magic_pdf.model.sub_modules.model_utils import (
     clean_vram, crop_img, get_res_list_from_layout_res)
 from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
@@ -150,14 +150,9 @@ class CustomPEKModel:
                 device=self.device,
             )
         # 初始化ocr
-        # self.ocr_model = atom_model_manager.get_atom_model(
-        #     atom_model_name=AtomicModel.OCR,
-        #     ocr_show_log=show_log,
-        #     det_db_box_thresh=0.3,
-        #     lang=self.lang
-        # )
-        self.ocr_model = ocr_model_init(
-            show_log=show_log,
+        self.ocr_model = atom_model_manager.get_atom_model(
+            atom_model_name=AtomicModel.OCR,
+            ocr_show_log=show_log,
             det_db_box_thresh=0.3,
             lang=self.lang
         )

+ 1 - 6
magic_pdf/model/sub_modules/model_init.py

@@ -57,11 +57,6 @@ def doclayout_yolo_model_init(weight, device='cpu'):
     return model
 
 
-import threading
-current_thread = threading.current_thread()
-current_thread_id = current_thread.ident
-
-
 def ocr_model_init(show_log: bool = False,
                    det_db_box_thresh=0.3,
                    lang=None,
@@ -103,7 +98,7 @@ class AtomModelSingleton:
         table_model_name = kwargs.get('table_model_name', None)
 
         if atom_model_name in [AtomicModel.OCR]:
-            key = (atom_model_name, lang, current_thread_id)
+            key = (atom_model_name, lang)
         elif atom_model_name in [AtomicModel.Layout]:
             key = (atom_model_name, layout_model_name)
         elif atom_model_name in [AtomicModel.Table]:

+ 11 - 17
magic_pdf/pdf_parse_union_core_v2.py

@@ -152,7 +152,7 @@ def calculate_char_in_span(char_bbox, span_bbox, char, span_height_radio=0.33):
             return False
 
 
-def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, ocr_model):
+def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang):
 
     text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
 
@@ -231,13 +231,13 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, ocr_
     if len(empty_spans) > 0:
 
         # 初始化ocr模型
-        # atom_model_manager = AtomModelSingleton()
-        # ocr_model = atom_model_manager.get_atom_model(
-        #     atom_model_name='ocr',
-        #     ocr_show_log=False,
-        #     det_db_box_thresh=0.3,
-        #     lang=lang
-        # )
+        atom_model_manager = AtomModelSingleton()
+        ocr_model = atom_model_manager.get_atom_model(
+            atom_model_name='ocr',
+            ocr_show_log=False,
+            det_db_box_thresh=0.3,
+            lang=lang
+        )
 
         for span in empty_spans:
             # 对span的bbox截图再ocr
@@ -613,7 +613,7 @@ def remove_outside_spans(spans, all_bboxes, all_discarded_blocks):
 
 
 def parse_page_core(
-    page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, ocr_model
+    page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, lang
 ):
     need_drop = False
     drop_reason = []
@@ -682,7 +682,7 @@ def parse_page_core(
     if parse_mode == SupportedPdfParseMethod.TXT:
 
         """使用新版本的混合ocr方案"""
-        spans = txt_spans_extract_v2(page_doc, spans, all_bboxes, all_discarded_blocks, ocr_model)
+        spans = txt_spans_extract_v2(page_doc, spans, all_bboxes, all_discarded_blocks, lang)
 
     elif parse_mode == SupportedPdfParseMethod.OCR:
         pass
@@ -772,12 +772,6 @@ def pdf_parse_union(
     lang=None,
 ):
 
-    ocr_model = ocr_model_init(
-        show_log=False,
-        det_db_box_thresh=0.3,
-        lang=lang
-    )
-
     pdf_bytes_md5 = compute_md5(dataset.data_bits())
 
     """初始化空的pdf_info_dict"""
@@ -813,7 +807,7 @@ def pdf_parse_union(
         """解析pdf中的每一页"""
         if start_page_id <= page_id <= end_page_id:
             page_info = parse_page_core(
-                page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, ocr_model
+                page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, lang
             )
         else:
             page_info = page.get_page_info()

+ 1 - 13
projects/gradio_app/app.py

@@ -14,9 +14,7 @@ from gradio_pdf import PDF
 from loguru import logger
 
 from magic_pdf.data.data_reader_writer import FileBasedDataReader
-from magic_pdf.libs.config_reader import get_device
 from magic_pdf.libs.hash_utils import compute_sha256
-from magic_pdf.model.sub_modules.model_utils import get_vram
 from magic_pdf.tools.common import do_parse, prepare_env
 
 
@@ -185,16 +183,6 @@ def to_pdf(file_path):
             return tmp_file_path
 
 
-def get_concurrency_limit(vram_threshold=7.5):
-    vram = get_vram(device = get_device())
-    if vram is not None and isinstance(vram, (int, float)):
-        concurrency_limit = max(1, int(vram // vram_threshold))
-    else:
-        concurrency_limit = 1
-    # logger.info(f'concurrency_limit: {concurrency_limit}')
-    return concurrency_limit
-
-
 if __name__ == '__main__':
     with gr.Blocks() as demo:
         gr.HTML(header)
@@ -231,7 +219,7 @@ if __name__ == '__main__':
                         md_text = gr.TextArea(lines=45, show_copy_button=True)
         file.upload(fn=to_pdf, inputs=file, outputs=pdf_show)
         change_bu.click(fn=to_markdown, inputs=[pdf_show, max_pages, is_ocr, layout_mode, formula_enable, table_enable, language],
-                        outputs=[md, md_text, output_file, pdf_show], concurrency_limit=get_concurrency_limit())
+                        outputs=[md, md_text, output_file, pdf_show])
         clear_bu.add([file, md, pdf_show, md_text, output_file, is_ocr, table_enable, language])
 
     demo.launch(server_name='0.0.0.0')