Jelajahi Sumber

fix: support auto method and auto lang

icecraft 7 bulan lalu
induk
melakukan
adbf492111

+ 4 - 1
magic_pdf/data/dataset.py

@@ -143,6 +143,7 @@ class PymuDocDataset(Dataset):
         self._records = [Doc(v) for v in self._raw_fitz]
         self._data_bits = bits
         self._raw_data = bits
+        self._classify_result = None
 
         if lang == '':
             self._lang = None
@@ -218,7 +219,9 @@ class PymuDocDataset(Dataset):
         Returns:
             SupportedPdfParseMethod: _description_
         """
-        return classify(self._data_bits)
+        if self._classify_result is None:
+            self._classify_result = classify(self._data_bits)
+        return self._classify_result
 
     def clone(self):
         """clone this dataset."""

+ 31 - 9
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -189,26 +189,48 @@ def batch_doc_analyze(
     table_enable=None,
 ):
     MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
+    batch_size = MIN_BATCH_INFERENCE_SIZE
     images = []
     page_wh_list = []
+    lang_list = []
+    lang_s = set()
     for dataset in datasets:
         for index in range(len(dataset)):
+            if lang is None or lang == 'auto':
+                lang_list.append(dataset._lang)
+            else:
+                lang_list.append(lang)
+            lang_s.add(lang_list[-1])
             page_data = dataset.get_page(index)
             img_dict = page_data.get_image()
             images.append(img_dict['img'])
             page_wh_list.append((img_dict['width'], img_dict['height']))
 
-    if len(images) >= MIN_BATCH_INFERENCE_SIZE:
-        batch_size = MIN_BATCH_INFERENCE_SIZE
-        batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
-    else:
-        batch_images = [images]
+    batch_images = []
+    img_idx_list = []
+    for t_lang in lang_s:
+        tmp_img_idx_list = []
+        for i, _lang in enumerate(lang_list):
+            if _lang == t_lang:
+                tmp_img_idx_list.append(i)
+        img_idx_list.extend(tmp_img_idx_list)
+
+        if batch_size >= len(tmp_img_idx_list):
+            batch_images.append((t_lang, [images[j] for j in tmp_img_idx_list]))
+        else:
+            slices = [tmp_img_idx_list[k:k+batch_size] for k in range(0, len(tmp_img_idx_list), batch_size)]
+            for arr in slices:
+                batch_images.append((t_lang, [images[j] for j in arr]))
 
-    results = []
+    unorder_results = []
+
+    for sn, (_lang, batch_image) in enumerate(batch_images):
+        _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, _lang, layout_model, formula_enable, table_enable)
+        unorder_results.extend(result)
+    results = [None] * len(img_idx_list)
+    for i, idx in enumerate(img_idx_list):
+        results[idx] = unorder_results[i]
 
-    for sn, batch_image in enumerate(batch_images):
-        _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable)
-        results.extend(result)
     infer_results = []
 
     from magic_pdf.operators.models import InferenceResult

+ 15 - 3
magic_pdf/tools/common.py

@@ -281,7 +281,7 @@ def do_parse(
             ds = PymuDocDataset(pdf_bytes, lang=lang)
         else:
             ds = pdf_bytes_or_dataset
-        batch_do_parse(output_dir, [pdf_file_name], [ds], parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox)
+        batch_do_parse(output_dir, [pdf_file_name], [ds], parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox, lang=lang)
     else:
         _do_parse(output_dir, pdf_file_name, pdf_bytes_or_dataset, model_list, parse_method, debug_able, start_page_id=start_page_id, end_page_id=end_page_id, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable,  f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox)
 
@@ -314,9 +314,21 @@ def batch_do_parse(
             dss.append(PymuDocDataset(v, lang=lang))
         else:
             dss.append(v)
-    infer_results = batch_doc_analyze(dss, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
+    dss_with_fn = list(zip(dss, pdf_file_names))
+    if parse_method == 'auto':
+        dss_typed_txt = [(i, x) for i, x in enumerate(dss_with_fn) if x[0].classify() == SupportedPdfParseMethod.TXT]
+        dss_typed_ocr = [(i, x) for i, x in enumerate(dss_with_fn) if x[0].classify() == SupportedPdfParseMethod.OCR]
+        infer_results = [None] * len(dss_with_fn)
+        infer_results_txt = batch_doc_analyze([x[1][0] for x in dss_typed_txt], lang=lang, ocr=False,  layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
+        infer_results_ocr = batch_doc_analyze([x[1][0] for x in dss_typed_ocr], lang=lang, ocr=True,  layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
+        for i, infer_res in enumerate(infer_results_txt):
+            infer_results[dss_typed_txt[i][0]] = infer_res
+        for i, infer_res in enumerate(infer_results_ocr):
+            infer_results[dss_typed_ocr[i][0]] = infer_res
+    else:
+        infer_results = batch_doc_analyze(dss, lang=lang, ocr=parse_method == 'ocr',  layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
     for idx, infer_result in enumerate(infer_results):
-        _do_parse(output_dir, pdf_file_names[idx], dss[idx], infer_result.get_infer_res(), parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox)
+        _do_parse(output_dir, dss_with_fn[idx][1], dss_with_fn[idx][0], infer_result.get_infer_res(), parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox, lang=lang)
 
 
 parse_pdf_methods = click.Choice(['ocr', 'txt', 'auto'])