|
|
@@ -150,7 +150,10 @@ def doc_analyze(
|
|
|
img_dict = page_data.get_image()
|
|
|
images.append(img_dict['img'])
|
|
|
page_wh_list.append((img_dict['width'], img_dict['height']))
|
|
|
- images_with_extra_info = [(images[index], ocr, dataset._lang) for index in range(len(dataset))]
|
|
|
+ if lang is None or lang == 'auto':
|
|
|
+ images_with_extra_info = [(images[index], ocr, dataset._lang) for index in range(len(dataset))]
|
|
|
+ else:
|
|
|
+ images_with_extra_info = [(images[index], ocr, lang) for index in range(len(dataset))]
|
|
|
|
|
|
if len(images) >= MIN_BATCH_INFERENCE_SIZE:
|
|
|
batch_size = MIN_BATCH_INFERENCE_SIZE
|
|
|
@@ -160,7 +163,7 @@ def doc_analyze(
|
|
|
|
|
|
results = []
|
|
|
for sn, batch_image in enumerate(batch_images):
|
|
|
- _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable)
|
|
|
+ _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log,layout_model, formula_enable, table_enable)
|
|
|
results.extend(result)
|
|
|
|
|
|
model_json = []
|
|
|
@@ -214,7 +217,7 @@ def batch_doc_analyze(
|
|
|
batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
|
|
|
results = []
|
|
|
for sn, batch_image in enumerate(batch_images):
|
|
|
- _, result = may_batch_image_analyze(batch_image, sn, True, show_log, lang, layout_model, formula_enable, table_enable)
|
|
|
+ _, result = may_batch_image_analyze(batch_image, sn, True, show_log, layout_model, formula_enable, table_enable)
|
|
|
results.extend(result)
|
|
|
|
|
|
infer_results = []
|
|
|
@@ -237,7 +240,6 @@ def may_batch_image_analyze(
|
|
|
idx: int,
|
|
|
ocr: bool,
|
|
|
show_log: bool = False,
|
|
|
- lang=None,
|
|
|
layout_model=None,
|
|
|
formula_enable=None,
|
|
|
table_enable=None):
|
|
|
@@ -248,9 +250,6 @@ def may_batch_image_analyze(
|
|
|
from magic_pdf.model.batch_analyze import BatchAnalyze
|
|
|
|
|
|
model_manager = ModelSingleton()
|
|
|
- custom_model = model_manager.get_model(
|
|
|
- ocr, show_log, lang, layout_model, formula_enable, table_enable
|
|
|
- )
|
|
|
|
|
|
images = [image for image, _, _ in images_with_extra_info]
|
|
|
batch_analyze = False
|
|
|
@@ -276,64 +275,15 @@ def may_batch_image_analyze(
|
|
|
else:
|
|
|
batch_ratio = 1
|
|
|
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
|
|
|
- batch_analyze = True
|
|
|
+ # batch_analyze = True
|
|
|
elif str(device).startswith('mps'):
|
|
|
- batch_analyze = True
|
|
|
- doc_analyze_start = time.time()
|
|
|
+ # batch_analyze = True
|
|
|
+ pass
|
|
|
|
|
|
- if batch_analyze:
|
|
|
- """# batch analyze
|
|
|
- images = []
|
|
|
- page_wh_list = []
|
|
|
- for index in range(len(dataset)):
|
|
|
- if start_page_id <= index <= end_page_id:
|
|
|
- page_data = dataset.get_page(index)
|
|
|
- img_dict = page_data.get_image()
|
|
|
- images.append(img_dict['img'])
|
|
|
- page_wh_list.append((img_dict['width'], img_dict['height']))
|
|
|
- """
|
|
|
- batch_model = BatchAnalyze(model_manager, batch_ratio, show_log, layout_model, formula_enable, table_enable)
|
|
|
- results = batch_model(images_with_extra_info)
|
|
|
- """
|
|
|
- for index in range(len(dataset)):
|
|
|
- if start_page_id <= index <= end_page_id:
|
|
|
- result = analyze_result.pop(0)
|
|
|
- page_width, page_height = page_wh_list.pop(0)
|
|
|
- else:
|
|
|
- result = []
|
|
|
- page_height = 0
|
|
|
- page_width = 0
|
|
|
-
|
|
|
- page_info = {'page_no': index, 'width': page_width, 'height': page_height}
|
|
|
- page_dict = {'layout_dets': result, 'page_info': page_info}
|
|
|
- model_json.append(page_dict)
|
|
|
- """
|
|
|
- else:
|
|
|
- # single analyze
|
|
|
- """
|
|
|
- for index in range(len(dataset)):
|
|
|
- page_data = dataset.get_page(index)
|
|
|
- img_dict = page_data.get_image()
|
|
|
- img = img_dict['img']
|
|
|
- page_width = img_dict['width']
|
|
|
- page_height = img_dict['height']
|
|
|
- if start_page_id <= index <= end_page_id:
|
|
|
- page_start = time.time()
|
|
|
- result = custom_model(img)
|
|
|
- logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
|
|
|
- else:
|
|
|
- result = []
|
|
|
+ doc_analyze_start = time.time()
|
|
|
|
|
|
- page_info = {'page_no': index, 'width': page_width, 'height': page_height}
|
|
|
- page_dict = {'layout_dets': result, 'page_info': page_info}
|
|
|
- model_json.append(page_dict)
|
|
|
- """
|
|
|
- results = []
|
|
|
- for img_idx, img in enumerate(images):
|
|
|
- inference_start = time.time()
|
|
|
- result = custom_model(img)
|
|
|
- logger.info(f'-----image index : {img_idx}, image inference total time: {round(time.time() - inference_start, 2)}-----')
|
|
|
- results.append(result)
|
|
|
+ batch_model = BatchAnalyze(model_manager, batch_ratio, show_log, layout_model, formula_enable, table_enable)
|
|
|
+ results = batch_model(images_with_extra_info)
|
|
|
|
|
|
gc_start = time.time()
|
|
|
clean_memory(get_device())
|