Просмотр исходного кода

feat: batch inference with ocr and lang flag

icecraft 7 месяцев назад
Родитель
Сommit
bbba2a120c

+ 20 - 8
magic_pdf/model/batch_analyze.py

@@ -17,13 +17,25 @@ MFR_BASE_BATCH_SIZE = 16
 
 
 class BatchAnalyze:
-    def __init__(self, model: CustomPEKModel, batch_ratio: int):
-        self.model = model
+    def __init__(self, model_manager, batch_ratio: int, show_log, layout_model, formula_enable, table_enable):
+        self.model_manager = model_manager
         self.batch_ratio = batch_ratio
-
-    def __call__(self, images: list) -> list:
+        self.show_log = show_log
+        self.layout_model = layout_model
+        self.formula_enable = formula_enable
+        self.table_enable = table_enable
+
+    def __call__(self, images_with_extra_info: list) -> list:
+        if len(images_with_extra_info) == 0:
+            return []
+    
         images_layout_res = []
         layout_start_time = time.time()
+        _, fst_ocr, fst_lang = images_with_extra_info[0]
+        self.model = self.model_manager.get_model(fst_ocr, self.show_log, fst_lang, self.layout_model, self.formula_enable, self.table_enable)
+
+        images = [image for image, _, _ in images_with_extra_info]
+
         if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3:
             # layoutlmv3
             for image in images:
@@ -79,6 +91,8 @@ class BatchAnalyze:
         table_count = 0
         # reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
         for index in range(len(images)):
+            _, ocr_enable, _lang = images_with_extra_info[index]
+            self.model = self.model_manager.get_model(ocr_enable, self.show_log, _lang, self.layout_model, self.formula_enable, self.table_enable)
             layout_res = images_layout_res[index]
             np_array_img = images[index]
 
@@ -99,7 +113,7 @@ class BatchAnalyze:
                 # OCR recognition
                 new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
 
-                if self.model.apply_ocr:
+                if ocr_enable:
                     ocr_res = self.model.ocr_model.ocr(
                         new_image, mfd_res=adjusted_mfdetrec_res
                     )[0]
@@ -159,9 +173,7 @@ class BatchAnalyze:
                 table_count += len(table_res_list)
 
         if self.model.apply_ocr:
-            logger.info(f'ocr time: {round(ocr_time, 2)}, image num: {ocr_count}')
-        else:
-            logger.info(f'det time: {round(ocr_time, 2)}, image num: {ocr_count}')
+            logger.info(f'det or det time costs: {round(ocr_time, 2)}, image num: {ocr_count}')
         if self.model.apply_table:
             logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
 

+ 24 - 38
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -15,7 +15,7 @@ os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
 from loguru import logger
 
 from magic_pdf.model.sub_modules.model_utils import get_vram
-
+from magic_pdf.config.enums import SupportedPdfParseMethod
 import magic_pdf.model as model_config
 from magic_pdf.data.dataset import Dataset
 from magic_pdf.libs.clean_memory import clean_memory
@@ -150,12 +150,13 @@ def doc_analyze(
             img_dict = page_data.get_image()
             images.append(img_dict['img'])
             page_wh_list.append((img_dict['width'], img_dict['height']))
+    images_with_extra_info = [(images[index], ocr, dataset._lang) for index in range(len(dataset))]
 
     if len(images) >= MIN_BATCH_INFERENCE_SIZE:
         batch_size = MIN_BATCH_INFERENCE_SIZE
-        batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
+        batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
     else:
-        batch_images = [images]
+        batch_images = [images_with_extra_info]
 
     results = []
     for sn, batch_image in enumerate(batch_images):
@@ -181,7 +182,7 @@ def doc_analyze(
 
 def batch_doc_analyze(
     datasets: list[Dataset],
-    ocr: bool = False,
+    parse_method: str,
     show_log: bool = False,
     lang=None,
     layout_model=None,
@@ -192,47 +193,31 @@ def batch_doc_analyze(
     batch_size = MIN_BATCH_INFERENCE_SIZE
     images = []
     page_wh_list = []
-    lang_list = []
-    lang_s = set()
+
+    images_with_extra_info = []
     for dataset in datasets:
         for index in range(len(dataset)):
             if lang is None or lang == 'auto':
-                lang_list.append(dataset._lang)
+                _lang = dataset._lang
             else:
-                lang_list.append(lang)
-            lang_s.add(lang_list[-1])
+                _lang = lang
+
             page_data = dataset.get_page(index)
             img_dict = page_data.get_image()
             images.append(img_dict['img'])
             page_wh_list.append((img_dict['width'], img_dict['height']))
+            if parse_method == 'auto':
+                images_with_extra_info.append((images[-1], dataset.classify() == SupportedPdfParseMethod.OCR, _lang))
+            else:
+                images_with_extra_info.append((images[-1], parse_method == 'ocr', _lang))
 
-    batch_images = []
-    img_idx_list = []
-    for t_lang in lang_s:
-        tmp_img_idx_list = []
-        for i, _lang in enumerate(lang_list):
-            if _lang == t_lang:
-                tmp_img_idx_list.append(i)
-        img_idx_list.extend(tmp_img_idx_list)
-
-        if batch_size >= len(tmp_img_idx_list):
-            batch_images.append((t_lang, [images[j] for j in tmp_img_idx_list]))
-        else:
-            slices = [tmp_img_idx_list[k:k+batch_size] for k in range(0, len(tmp_img_idx_list), batch_size)]
-            for arr in slices:
-                batch_images.append((t_lang, [images[j] for j in arr]))
-
-    unorder_results = []
-
-    for sn, (_lang, batch_image) in enumerate(batch_images):
-        _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, _lang, layout_model, formula_enable, table_enable)
-        unorder_results.extend(result)
-    results = [None] * len(img_idx_list)
-    for i, idx in enumerate(img_idx_list):
-        results[idx] = unorder_results[i]
+    batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
+    results = []
+    for sn, batch_image in enumerate(batch_images):
+        _, result = may_batch_image_analyze(batch_image, sn, True, show_log, lang, layout_model, formula_enable, table_enable)
+        results.extend(result)
 
     infer_results = []
-
     from magic_pdf.operators.models import InferenceResult
     for index in range(len(datasets)):
         dataset = datasets[index]
@@ -248,9 +233,9 @@ def batch_doc_analyze(
 
 
 def may_batch_image_analyze(
-        images: list[np.ndarray],
+        images_with_extra_info: list[(np.ndarray, bool, str)],
         idx: int,
-        ocr: bool = False,
+        ocr: bool,
         show_log: bool = False,
         lang=None,
         layout_model=None,
@@ -267,6 +252,7 @@ def may_batch_image_analyze(
         ocr, show_log, lang, layout_model, formula_enable, table_enable
     )
 
+    images = [image for image, _, _ in images_with_extra_info]
     batch_analyze = False
     batch_ratio = 1
     device = get_device()
@@ -306,8 +292,8 @@ def may_batch_image_analyze(
                 images.append(img_dict['img'])
                 page_wh_list.append((img_dict['width'], img_dict['height']))
         """
-        batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
-        results = batch_model(images)
+        batch_model = BatchAnalyze(model_manager, batch_ratio, show_log, layout_model, formula_enable, table_enable)
+        results = batch_model(images_with_extra_info)
         """
         for index in range(len(dataset)):
             if start_page_id <= index <= end_page_id:

+ 3 - 14
magic_pdf/tools/common.py

@@ -314,21 +314,10 @@ def batch_do_parse(
             dss.append(PymuDocDataset(v, lang=lang))
         else:
             dss.append(v)
-    dss_with_fn = list(zip(dss, pdf_file_names))
-    if parse_method == 'auto':
-        dss_typed_txt = [(i, x) for i, x in enumerate(dss_with_fn) if x[0].classify() == SupportedPdfParseMethod.TXT]
-        dss_typed_ocr = [(i, x) for i, x in enumerate(dss_with_fn) if x[0].classify() == SupportedPdfParseMethod.OCR]
-        infer_results = [None] * len(dss_with_fn)
-        infer_results_txt = batch_doc_analyze([x[1][0] for x in dss_typed_txt], lang=lang, ocr=False,  layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
-        infer_results_ocr = batch_doc_analyze([x[1][0] for x in dss_typed_ocr], lang=lang, ocr=True,  layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
-        for i, infer_res in enumerate(infer_results_txt):
-            infer_results[dss_typed_txt[i][0]] = infer_res
-        for i, infer_res in enumerate(infer_results_ocr):
-            infer_results[dss_typed_ocr[i][0]] = infer_res
-    else:
-        infer_results = batch_doc_analyze(dss, lang=lang, ocr=parse_method == 'ocr',  layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
+
+    infer_results = batch_doc_analyze(dss, parse_method, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
     for idx, infer_result in enumerate(infer_results):
-        _do_parse(output_dir, dss_with_fn[idx][1], dss_with_fn[idx][0], infer_result.get_infer_res(), parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox, lang=lang)
+        _do_parse(output_dir, pdf_file_names[idx], dss[idx], infer_result.get_infer_res(), parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox, lang=lang)
 
 
 parse_pdf_methods = click.Choice(['ocr', 'txt', 'auto'])