|
|
@@ -34,8 +34,6 @@ from magic_pdf.model.model_list import MODEL
|
|
|
|
|
|
# from magic_pdf.operators.models import InferenceResult
|
|
|
|
|
|
-MIN_BATCH_INFERENCE_SIZE = 100
|
|
|
-
|
|
|
class ModelSingleton:
|
|
|
_instance = None
|
|
|
_models = {}
|
|
|
@@ -143,17 +141,14 @@ def doc_analyze(
|
|
|
layout_model=None,
|
|
|
formula_enable=None,
|
|
|
table_enable=None,
|
|
|
- one_shot: bool = True,
|
|
|
):
|
|
|
end_page_id = (
|
|
|
end_page_id
|
|
|
if end_page_id is not None and end_page_id >= 0
|
|
|
else len(dataset) - 1
|
|
|
)
|
|
|
- parallel_count = None
|
|
|
- if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
|
|
|
- parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
|
|
|
|
|
|
+ MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
|
|
|
images = []
|
|
|
page_wh_list = []
|
|
|
for index in range(len(dataset)):
|
|
|
@@ -163,41 +158,16 @@ def doc_analyze(
|
|
|
images.append(img_dict['img'])
|
|
|
page_wh_list.append((img_dict['width'], img_dict['height']))
|
|
|
|
|
|
- if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE:
|
|
|
- if parallel_count is None:
|
|
|
- parallel_count = 2 # should check the gpu memory firstly !
|
|
|
- # split images into parallel_count batches
|
|
|
- if parallel_count > 1:
|
|
|
- batch_size = (len(images) + parallel_count - 1) // parallel_count
|
|
|
- batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
|
|
|
- else:
|
|
|
- batch_images = [images]
|
|
|
- results = []
|
|
|
- parallel_count = len(batch_images) # adjust to real parallel count
|
|
|
- # using concurrent.futures to analyze
|
|
|
- """
|
|
|
- with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
|
|
|
- futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
|
|
|
- for future in fut.as_completed(futures):
|
|
|
- sn, result = future.result()
|
|
|
- result_history[sn] = result
|
|
|
-
|
|
|
- for key in sorted(result_history.keys()):
|
|
|
- results.extend(result_history[key])
|
|
|
- """
|
|
|
- results = []
|
|
|
- pool = mp.Pool(processes=parallel_count)
|
|
|
- mapped_results = pool.starmap(may_batch_image_analyze, [(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)])
|
|
|
- for sn, result in mapped_results:
|
|
|
- results.extend(result)
|
|
|
-
|
|
|
+ if len(images) >= MIN_BATCH_INFERENCE_SIZE:
|
|
|
+ batch_size = MIN_BATCH_INFERENCE_SIZE
|
|
|
+ batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
|
|
|
else:
|
|
|
- _, results = may_batch_image_analyze(
|
|
|
- images,
|
|
|
- 0,
|
|
|
- ocr,
|
|
|
- show_log,
|
|
|
- lang, layout_model, formula_enable, table_enable)
|
|
|
+ batch_images = [images]
|
|
|
+
|
|
|
+ results = []
|
|
|
+ for sn, batch_image in enumerate(batch_images):
|
|
|
+ _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable)
|
|
|
+ results.extend(result)
|
|
|
|
|
|
model_json = []
|
|
|
for index in range(len(dataset)):
|
|
|
@@ -224,11 +194,8 @@ def batch_doc_analyze(
|
|
|
layout_model=None,
|
|
|
formula_enable=None,
|
|
|
table_enable=None,
|
|
|
- one_shot: bool = True,
|
|
|
):
|
|
|
- parallel_count = None
|
|
|
- if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
|
|
|
- parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
|
|
|
+ MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
|
|
|
images = []
|
|
|
page_wh_list = []
|
|
|
for dataset in datasets:
|
|
|
@@ -238,40 +205,17 @@ def batch_doc_analyze(
|
|
|
images.append(img_dict['img'])
|
|
|
page_wh_list.append((img_dict['width'], img_dict['height']))
|
|
|
|
|
|
- if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE:
|
|
|
- if parallel_count is None:
|
|
|
- parallel_count = 2 # should check the gpu memory firstly !
|
|
|
- # split images into parallel_count batches
|
|
|
- if parallel_count > 1:
|
|
|
- batch_size = (len(images) + parallel_count - 1) // parallel_count
|
|
|
- batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
|
|
|
- else:
|
|
|
- batch_images = [images]
|
|
|
- results = []
|
|
|
- parallel_count = len(batch_images) # adjust to real parallel count
|
|
|
- # using concurrent.futures to analyze
|
|
|
- """
|
|
|
- with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
|
|
|
- futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
|
|
|
- for future in fut.as_completed(futures):
|
|
|
- sn, result = future.result()
|
|
|
- result_history[sn] = result
|
|
|
-
|
|
|
- for key in sorted(result_history.keys()):
|
|
|
- results.extend(result_history[key])
|
|
|
- """
|
|
|
- results = []
|
|
|
- pool = mp.Pool(processes=parallel_count)
|
|
|
- mapped_results = pool.starmap(may_batch_image_analyze, [(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)])
|
|
|
- for sn, result in mapped_results:
|
|
|
- results.extend(result)
|
|
|
+ if len(images) >= MIN_BATCH_INFERENCE_SIZE:
|
|
|
+ batch_size = MIN_BATCH_INFERENCE_SIZE
|
|
|
+ batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
|
|
|
else:
|
|
|
- _, results = may_batch_image_analyze(
|
|
|
- images,
|
|
|
- 0,
|
|
|
- ocr,
|
|
|
- show_log,
|
|
|
- lang, layout_model, formula_enable, table_enable)
|
|
|
+ batch_images = [images]
|
|
|
+
|
|
|
+ results = []
|
|
|
+
|
|
|
+ for sn, batch_image in enumerate(batch_images):
|
|
|
+ _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable)
|
|
|
+ results.extend(result)
|
|
|
infer_results = []
|
|
|
|
|
|
from magic_pdf.operators.models import InferenceResult
|