Quellcode durchsuchen

style: remove unused code

icecraft vor 8 Monaten
Ursprung
Commit
e9c2473913
2 geänderte Dateien mit 22 neuen und 78 gelöschten Zeilen
  1. 21 77
      magic_pdf/model/doc_analyze_by_custom_model.py
  2. 1 1
      magic_pdf/tools/common.py

+ 21 - 77
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -34,8 +34,6 @@ from magic_pdf.model.model_list import MODEL
 
 # from magic_pdf.operators.models import InferenceResult
 
-MIN_BATCH_INFERENCE_SIZE = 100
-
 class ModelSingleton:
     _instance = None
     _models = {}
@@ -143,17 +141,14 @@ def doc_analyze(
     layout_model=None,
     formula_enable=None,
     table_enable=None,
-    one_shot: bool = True,
 ):
     end_page_id = (
         end_page_id
         if end_page_id is not None and end_page_id >= 0
         else len(dataset) - 1
     )
-    parallel_count = None
-    if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
-        parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
 
+    MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
     images = []
     page_wh_list = []
     for index in range(len(dataset)):
@@ -163,41 +158,16 @@ def doc_analyze(
             images.append(img_dict['img'])
             page_wh_list.append((img_dict['width'], img_dict['height']))
 
-    if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE:
-        if parallel_count is None:
-            parallel_count = 2 # should check the gpu memory firstly !
-        # split images into parallel_count batches
-        if parallel_count > 1:
-            batch_size = (len(images) + parallel_count - 1) // parallel_count
-            batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
-        else:
-            batch_images = [images]
-        results = []
-        parallel_count = len(batch_images) # adjust to real parallel count
-        # using concurrent.futures to analyze
-        """
-        with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
-            futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
-            for future in fut.as_completed(futures):
-                sn, result = future.result()
-                result_history[sn] = result
-
-        for key in sorted(result_history.keys()):
-            results.extend(result_history[key])
-        """
-        results = []
-        pool = mp.Pool(processes=parallel_count)
-        mapped_results = pool.starmap(may_batch_image_analyze, [(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)])
-        for sn, result in mapped_results:
-            results.extend(result)
-
+    if len(images) >= MIN_BATCH_INFERENCE_SIZE:
+        batch_size = MIN_BATCH_INFERENCE_SIZE
+        batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
     else:
-        _, results = may_batch_image_analyze(
-            images,
-            0,
-            ocr,
-            show_log,
-            lang, layout_model, formula_enable, table_enable)
+        batch_images = [images]
+
+    results = []
+    for sn, batch_image in enumerate(batch_images):
+        _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable)
+        results.extend(result)
 
     model_json = []
     for index in range(len(dataset)):
@@ -224,11 +194,8 @@ def batch_doc_analyze(
     layout_model=None,
     formula_enable=None,
     table_enable=None,
-    one_shot: bool = True,
 ):
-    parallel_count = None
-    if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
-        parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
+    MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
     images = []
     page_wh_list = []
     for dataset in datasets:
@@ -238,40 +205,17 @@ def batch_doc_analyze(
             images.append(img_dict['img'])
             page_wh_list.append((img_dict['width'], img_dict['height']))
 
-    if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE:
-        if parallel_count is None:
-            parallel_count = 2 # should check the gpu memory firstly !
-        # split images into parallel_count batches
-        if parallel_count > 1:
-            batch_size = (len(images) + parallel_count - 1) // parallel_count
-            batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
-        else:
-            batch_images = [images]
-        results = []
-        parallel_count = len(batch_images) # adjust to real parallel count
-        # using concurrent.futures to analyze
-        """
-        with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
-            futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
-            for future in fut.as_completed(futures):
-                sn, result = future.result()
-                result_history[sn] = result
-
-        for key in sorted(result_history.keys()):
-            results.extend(result_history[key])
-        """
-        results = []
-        pool = mp.Pool(processes=parallel_count)
-        mapped_results = pool.starmap(may_batch_image_analyze, [(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)])
-        for sn, result in mapped_results:
-            results.extend(result)
+    if len(images) >= MIN_BATCH_INFERENCE_SIZE:
+        batch_size = MIN_BATCH_INFERENCE_SIZE
+        batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
     else:
-        _, results = may_batch_image_analyze(
-            images,
-            0,
-            ocr,
-            show_log,
-            lang, layout_model, formula_enable, table_enable)
+        batch_images = [images]
+
+    results = []
+
+    for sn, batch_image in enumerate(batch_images):
+        _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable)
+        results.extend(result)
     infer_results = []
 
     from magic_pdf.operators.models import InferenceResult

+ 1 - 1
magic_pdf/tools/common.py

@@ -314,7 +314,7 @@ def batch_do_parse(
             dss.append(PymuDocDataset(v, lang=lang))
         else:
             dss.append(v)
-    infer_results = batch_doc_analyze(dss, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable, one_shot=True)
+    infer_results = batch_doc_analyze(dss, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
     for idx, infer_result in enumerate(infer_results):
         _do_parse(output_dir, pdf_file_names[idx], dss[idx], infer_result.get_infer_res(), parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox)