|
|
@@ -15,7 +15,7 @@ os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
|
|
from loguru import logger
|
|
|
|
|
|
from magic_pdf.model.sub_modules.model_utils import get_vram
|
|
|
-
|
|
|
+from magic_pdf.config.enums import SupportedPdfParseMethod
|
|
|
import magic_pdf.model as model_config
|
|
|
from magic_pdf.data.dataset import Dataset
|
|
|
from magic_pdf.libs.clean_memory import clean_memory
|
|
|
@@ -150,12 +150,13 @@ def doc_analyze(
|
|
|
img_dict = page_data.get_image()
|
|
|
images.append(img_dict['img'])
|
|
|
page_wh_list.append((img_dict['width'], img_dict['height']))
|
|
|
+ images_with_extra_info = [(images[index], ocr, dataset._lang) for index in range(len(dataset))]
|
|
|
|
|
|
if len(images) >= MIN_BATCH_INFERENCE_SIZE:
|
|
|
batch_size = MIN_BATCH_INFERENCE_SIZE
|
|
|
- batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
|
|
|
+ batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
|
|
|
else:
|
|
|
- batch_images = [images]
|
|
|
+ batch_images = [images_with_extra_info]
|
|
|
|
|
|
results = []
|
|
|
for sn, batch_image in enumerate(batch_images):
|
|
|
@@ -181,7 +182,7 @@ def doc_analyze(
|
|
|
|
|
|
def batch_doc_analyze(
|
|
|
datasets: list[Dataset],
|
|
|
- ocr: bool = False,
|
|
|
+ parse_method: str,
|
|
|
show_log: bool = False,
|
|
|
lang=None,
|
|
|
layout_model=None,
|
|
|
@@ -192,47 +193,31 @@ def batch_doc_analyze(
|
|
|
batch_size = MIN_BATCH_INFERENCE_SIZE
|
|
|
images = []
|
|
|
page_wh_list = []
|
|
|
- lang_list = []
|
|
|
- lang_s = set()
|
|
|
+
|
|
|
+ images_with_extra_info = []
|
|
|
for dataset in datasets:
|
|
|
for index in range(len(dataset)):
|
|
|
if lang is None or lang == 'auto':
|
|
|
- lang_list.append(dataset._lang)
|
|
|
+ _lang = dataset._lang
|
|
|
else:
|
|
|
- lang_list.append(lang)
|
|
|
- lang_s.add(lang_list[-1])
|
|
|
+ _lang = lang
|
|
|
+
|
|
|
page_data = dataset.get_page(index)
|
|
|
img_dict = page_data.get_image()
|
|
|
images.append(img_dict['img'])
|
|
|
page_wh_list.append((img_dict['width'], img_dict['height']))
|
|
|
+ if parse_method == 'auto':
|
|
|
+ images_with_extra_info.append((images[-1], dataset.classify() == SupportedPdfParseMethod.OCR, _lang))
|
|
|
+ else:
|
|
|
+ images_with_extra_info.append((images[-1], parse_method == 'ocr', _lang))
|
|
|
|
|
|
- batch_images = []
|
|
|
- img_idx_list = []
|
|
|
- for t_lang in lang_s:
|
|
|
- tmp_img_idx_list = []
|
|
|
- for i, _lang in enumerate(lang_list):
|
|
|
- if _lang == t_lang:
|
|
|
- tmp_img_idx_list.append(i)
|
|
|
- img_idx_list.extend(tmp_img_idx_list)
|
|
|
-
|
|
|
- if batch_size >= len(tmp_img_idx_list):
|
|
|
- batch_images.append((t_lang, [images[j] for j in tmp_img_idx_list]))
|
|
|
- else:
|
|
|
- slices = [tmp_img_idx_list[k:k+batch_size] for k in range(0, len(tmp_img_idx_list), batch_size)]
|
|
|
- for arr in slices:
|
|
|
- batch_images.append((t_lang, [images[j] for j in arr]))
|
|
|
-
|
|
|
- unorder_results = []
|
|
|
-
|
|
|
- for sn, (_lang, batch_image) in enumerate(batch_images):
|
|
|
- _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, _lang, layout_model, formula_enable, table_enable)
|
|
|
- unorder_results.extend(result)
|
|
|
- results = [None] * len(img_idx_list)
|
|
|
- for i, idx in enumerate(img_idx_list):
|
|
|
- results[idx] = unorder_results[i]
|
|
|
+ batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
|
|
|
+ results = []
|
|
|
+ for sn, batch_image in enumerate(batch_images):
|
|
|
+ _, result = may_batch_image_analyze(batch_image, sn, True, show_log, lang, layout_model, formula_enable, table_enable)
|
|
|
+ results.extend(result)
|
|
|
|
|
|
infer_results = []
|
|
|
-
|
|
|
from magic_pdf.operators.models import InferenceResult
|
|
|
for index in range(len(datasets)):
|
|
|
dataset = datasets[index]
|
|
|
@@ -248,9 +233,9 @@ def batch_doc_analyze(
|
|
|
|
|
|
|
|
|
def may_batch_image_analyze(
|
|
|
- images: list[np.ndarray],
|
|
|
+ images_with_extra_info: list[(np.ndarray, bool, str)],
|
|
|
idx: int,
|
|
|
- ocr: bool = False,
|
|
|
+ ocr: bool,
|
|
|
show_log: bool = False,
|
|
|
lang=None,
|
|
|
layout_model=None,
|
|
|
@@ -267,6 +252,7 @@ def may_batch_image_analyze(
|
|
|
ocr, show_log, lang, layout_model, formula_enable, table_enable
|
|
|
)
|
|
|
|
|
|
+ images = [image for image, _, _ in images_with_extra_info]
|
|
|
batch_analyze = False
|
|
|
batch_ratio = 1
|
|
|
device = get_device()
|
|
|
@@ -306,8 +292,8 @@ def may_batch_image_analyze(
|
|
|
images.append(img_dict['img'])
|
|
|
page_wh_list.append((img_dict['width'], img_dict['height']))
|
|
|
"""
|
|
|
- batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
|
|
|
- results = batch_model(images)
|
|
|
+ batch_model = BatchAnalyze(model_manager, batch_ratio, show_log, layout_model, formula_enable, table_enable)
|
|
|
+ results = batch_model(images_with_extra_info)
|
|
|
"""
|
|
|
for index in range(len(dataset)):
|
|
|
if start_page_id <= index <= end_page_id:
|