|
|
@@ -1,18 +1,17 @@
|
|
|
import os
|
|
|
import time
|
|
|
import torch
|
|
|
-
|
|
|
+import numpy as np
|
|
|
+import multiprocessing as mp
|
|
|
+import concurrent.futures as fut
|
|
|
os.environ['FLAGS_npu_jit_compile'] = '0' # 关闭paddle的jit编译
|
|
|
os.environ['FLAGS_use_stride_kernel'] = '0'
|
|
|
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
|
|
|
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
|
|
-# 关闭paddle的信号处理
|
|
|
-import paddle
|
|
|
-paddle.disable_signal_handler()
|
|
|
+
|
|
|
|
|
|
from loguru import logger
|
|
|
|
|
|
-from magic_pdf.model.batch_analyze import BatchAnalyze
|
|
|
from magic_pdf.model.sub_modules.model_utils import get_vram
|
|
|
|
|
|
try:
|
|
|
@@ -30,8 +29,9 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
|
|
|
get_local_models_dir,
|
|
|
get_table_recog_config)
|
|
|
from magic_pdf.model.model_list import MODEL
|
|
|
-from magic_pdf.operators.models import InferenceResult
|
|
|
+# from magic_pdf.operators.models import InferenceResult
|
|
|
|
|
|
+MIN_BATCH_INFERENCE_SIZE = 100
|
|
|
|
|
|
class ModelSingleton:
|
|
|
_instance = None
|
|
|
@@ -72,9 +72,7 @@ def custom_model_init(
|
|
|
formula_enable=None,
|
|
|
table_enable=None,
|
|
|
):
|
|
|
-
|
|
|
model = None
|
|
|
-
|
|
|
if model_config.__model_mode__ == 'lite':
|
|
|
logger.warning(
|
|
|
'The Lite mode is provided for developers to conduct testing only, and the output quality is '
|
|
|
@@ -132,7 +130,6 @@ def custom_model_init(
|
|
|
|
|
|
return custom_model
|
|
|
|
|
|
-
|
|
|
def doc_analyze(
|
|
|
dataset: Dataset,
|
|
|
ocr: bool = False,
|
|
|
@@ -143,14 +140,166 @@ def doc_analyze(
|
|
|
layout_model=None,
|
|
|
formula_enable=None,
|
|
|
table_enable=None,
|
|
|
-) -> InferenceResult:
|
|
|
-
|
|
|
+ one_shot: bool = True,
|
|
|
+):
|
|
|
end_page_id = (
|
|
|
end_page_id
|
|
|
if end_page_id is not None and end_page_id >= 0
|
|
|
else len(dataset) - 1
|
|
|
)
|
|
|
+ parallel_count = None
|
|
|
+ if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
|
|
|
+ parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
|
|
|
+
|
|
|
+ images = []
|
|
|
+ page_wh_list = []
|
|
|
+ for index in range(len(dataset)):
|
|
|
+ if start_page_id <= index <= end_page_id:
|
|
|
+ page_data = dataset.get_page(index)
|
|
|
+ img_dict = page_data.get_image()
|
|
|
+ images.append(img_dict['img'])
|
|
|
+ page_wh_list.append((img_dict['width'], img_dict['height']))
|
|
|
+
|
|
|
+ if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE:
|
|
|
+ if parallel_count is None:
|
|
|
+ parallel_count = 2 # should check the gpu memory firstly !
|
|
|
+ # split images into parallel_count batches
|
|
|
+ if parallel_count > 1:
|
|
|
+ batch_size = (len(images) + parallel_count - 1) // parallel_count
|
|
|
+ batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
|
|
|
+ else:
|
|
|
+ batch_images = [images]
|
|
|
+ results = []
|
|
|
+ parallel_count = len(batch_images) # adjust to real parallel count
|
|
|
+ # using concurrent.futures to analyze
|
|
|
+ """
|
|
|
+ with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
|
|
|
+ futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
|
|
|
+ for future in fut.as_completed(futures):
|
|
|
+ sn, result = future.result()
|
|
|
+ result_history[sn] = result
|
|
|
+
|
|
|
+ for key in sorted(result_history.keys()):
|
|
|
+ results.extend(result_history[key])
|
|
|
+ """
|
|
|
+ results = []
|
|
|
+ pool = mp.Pool(processes=parallel_count)
|
|
|
+ mapped_results = pool.starmap(may_batch_image_analyze, [(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)])
|
|
|
+ for sn, result in mapped_results:
|
|
|
+ results.extend(result)
|
|
|
|
|
|
+ else:
|
|
|
+ _, results = may_batch_image_analyze(
|
|
|
+ images,
|
|
|
+ 0,
|
|
|
+ ocr,
|
|
|
+ show_log,
|
|
|
+ lang, layout_model, formula_enable, table_enable)
|
|
|
+
|
|
|
+ model_json = []
|
|
|
+ for index in range(len(dataset)):
|
|
|
+ if start_page_id <= index <= end_page_id:
|
|
|
+ result = results.pop(0)
|
|
|
+ page_width, page_height = page_wh_list.pop(0)
|
|
|
+ else:
|
|
|
+ result = []
|
|
|
+ page_height = 0
|
|
|
+ page_width = 0
|
|
|
+
|
|
|
+ page_info = {'page_no': index, 'width': page_width, 'height': page_height}
|
|
|
+ page_dict = {'layout_dets': result, 'page_info': page_info}
|
|
|
+ model_json.append(page_dict)
|
|
|
+
|
|
|
+ from magic_pdf.operators.models import InferenceResult
|
|
|
+ return InferenceResult(model_json, dataset)
|
|
|
+
|
|
|
+def batch_doc_analyze(
|
|
|
+ datasets: list[Dataset],
|
|
|
+ ocr: bool = False,
|
|
|
+ show_log: bool = False,
|
|
|
+ lang=None,
|
|
|
+ layout_model=None,
|
|
|
+ formula_enable=None,
|
|
|
+ table_enable=None,
|
|
|
+ one_shot: bool = True,
|
|
|
+):
|
|
|
+ parallel_count = None
|
|
|
+ if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
|
|
|
+ parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
|
|
|
+ images = []
|
|
|
+ page_wh_list = []
|
|
|
+ for dataset in datasets:
|
|
|
+ for index in range(len(dataset)):
|
|
|
+ page_data = dataset.get_page(index)
|
|
|
+ img_dict = page_data.get_image()
|
|
|
+ images.append(img_dict['img'])
|
|
|
+ page_wh_list.append((img_dict['width'], img_dict['height']))
|
|
|
+
|
|
|
+ if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE:
|
|
|
+ if parallel_count is None:
|
|
|
+ parallel_count = 2 # should check the gpu memory firstly !
|
|
|
+ # split images into parallel_count batches
|
|
|
+ if parallel_count > 1:
|
|
|
+ batch_size = (len(images) + parallel_count - 1) // parallel_count
|
|
|
+ batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
|
|
|
+ else:
|
|
|
+ batch_images = [images]
|
|
|
+ results = []
|
|
|
+ parallel_count = len(batch_images) # adjust to real parallel count
|
|
|
+ # using concurrent.futures to analyze
|
|
|
+ """
|
|
|
+ with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
|
|
|
+ futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
|
|
|
+ for future in fut.as_completed(futures):
|
|
|
+ sn, result = future.result()
|
|
|
+ result_history[sn] = result
|
|
|
+
|
|
|
+ for key in sorted(result_history.keys()):
|
|
|
+ results.extend(result_history[key])
|
|
|
+ """
|
|
|
+ results = []
|
|
|
+ pool = mp.Pool(processes=parallel_count)
|
|
|
+ mapped_results = pool.starmap(may_batch_image_analyze, [(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)])
|
|
|
+ for sn, result in mapped_results:
|
|
|
+ results.extend(result)
|
|
|
+ else:
|
|
|
+ _, results = may_batch_image_analyze(
|
|
|
+ images,
|
|
|
+ 0,
|
|
|
+ ocr,
|
|
|
+ show_log,
|
|
|
+ lang, layout_model, formula_enable, table_enable)
|
|
|
+ infer_results = []
|
|
|
+
|
|
|
+ from magic_pdf.operators.models import InferenceResult
|
|
|
+ for index in range(len(datasets)):
|
|
|
+ dataset = datasets[index]
|
|
|
+ model_json = []
|
|
|
+ for i in range(len(dataset)):
|
|
|
+ result = results.pop(0)
|
|
|
+ page_width, page_height = page_wh_list.pop(0)
|
|
|
+ page_info = {'page_no': i, 'width': page_width, 'height': page_height}
|
|
|
+ page_dict = {'layout_dets': result, 'page_info': page_info}
|
|
|
+ model_json.append(page_dict)
|
|
|
+ infer_results.append(InferenceResult(model_json, dataset))
|
|
|
+ return infer_results
|
|
|
+
|
|
|
+
|
|
|
+def may_batch_image_analyze(
|
|
|
+ images: list[np.ndarray],
|
|
|
+ idx: int,
|
|
|
+ ocr: bool = False,
|
|
|
+ show_log: bool = False,
|
|
|
+ lang=None,
|
|
|
+ layout_model=None,
|
|
|
+ formula_enable=None,
|
|
|
+ table_enable=None):
|
|
|
+ # os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
|
|
|
+ # 关闭paddle的信号处理
|
|
|
+ import paddle
|
|
|
+ paddle.disable_signal_handler()
|
|
|
+ from magic_pdf.model.batch_analyze import BatchAnalyze
|
|
|
+
|
|
|
model_manager = ModelSingleton()
|
|
|
custom_model = model_manager.get_model(
|
|
|
ocr, show_log, lang, layout_model, formula_enable, table_enable
|
|
|
@@ -181,12 +330,10 @@ def doc_analyze(
|
|
|
|
|
|
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
|
|
|
batch_analyze = True
|
|
|
-
|
|
|
- model_json = []
|
|
|
doc_analyze_start = time.time()
|
|
|
|
|
|
if batch_analyze:
|
|
|
- # batch analyze
|
|
|
+ """# batch analyze
|
|
|
images = []
|
|
|
page_wh_list = []
|
|
|
for index in range(len(dataset)):
|
|
|
@@ -195,9 +342,10 @@ def doc_analyze(
|
|
|
img_dict = page_data.get_image()
|
|
|
images.append(img_dict['img'])
|
|
|
page_wh_list.append((img_dict['width'], img_dict['height']))
|
|
|
+ """
|
|
|
batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
|
|
|
- analyze_result = batch_model(images)
|
|
|
-
|
|
|
+ results = batch_model(images)
|
|
|
+ """
|
|
|
for index in range(len(dataset)):
|
|
|
if start_page_id <= index <= end_page_id:
|
|
|
result = analyze_result.pop(0)
|
|
|
@@ -210,10 +358,10 @@ def doc_analyze(
|
|
|
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
|
|
|
page_dict = {'layout_dets': result, 'page_info': page_info}
|
|
|
model_json.append(page_dict)
|
|
|
-
|
|
|
+ """
|
|
|
else:
|
|
|
# single analyze
|
|
|
-
|
|
|
+ """
|
|
|
for index in range(len(dataset)):
|
|
|
page_data = dataset.get_page(index)
|
|
|
img_dict = page_data.get_image()
|
|
|
@@ -230,6 +378,13 @@ def doc_analyze(
|
|
|
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
|
|
|
page_dict = {'layout_dets': result, 'page_info': page_info}
|
|
|
model_json.append(page_dict)
|
|
|
+ """
|
|
|
+ results = []
|
|
|
+ for img_idx, img in enumerate(images):
|
|
|
+ inference_start = time.time()
|
|
|
+ result = custom_model(img)
|
|
|
+ logger.info(f'-----image index : {img_idx}, image inference total time: {round(time.time() - inference_start, 2)}-----')
|
|
|
+ results.append(result)
|
|
|
|
|
|
gc_start = time.time()
|
|
|
clean_memory(get_device())
|
|
|
@@ -237,10 +392,10 @@ def doc_analyze(
|
|
|
logger.info(f'gc time: {gc_time}')
|
|
|
|
|
|
doc_analyze_time = round(time.time() - doc_analyze_start, 2)
|
|
|
- doc_analyze_speed = round((end_page_id + 1 - start_page_id) / doc_analyze_time, 2)
|
|
|
+ doc_analyze_speed = round(len(images) / doc_analyze_time, 2)
|
|
|
logger.info(
|
|
|
f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
|
|
|
f' speed: {doc_analyze_speed} pages/second'
|
|
|
)
|
|
|
+ return (idx, results)
|
|
|
|
|
|
- return InferenceResult(model_json, dataset)
|