|
@@ -7,17 +7,17 @@ from loguru import logger
|
|
|
from PIL import Image
|
|
from PIL import Image
|
|
|
|
|
|
|
|
from magic_pdf.config.constants import MODEL_NAME
|
|
from magic_pdf.config.constants import MODEL_NAME
|
|
|
-from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
|
|
|
|
|
-from magic_pdf.data.dataset import Dataset
|
|
|
|
|
-from magic_pdf.libs.clean_memory import clean_memory
|
|
|
|
|
-from magic_pdf.libs.config_reader import get_device
|
|
|
|
|
-from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
|
|
|
|
|
|
|
+# from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
|
|
|
|
|
+# from magic_pdf.data.dataset import Dataset
|
|
|
|
|
+# from magic_pdf.libs.clean_memory import clean_memory
|
|
|
|
|
+# from magic_pdf.libs.config_reader import get_device
|
|
|
|
|
+# from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
|
|
|
from magic_pdf.model.pdf_extract_kit import CustomPEKModel
|
|
from magic_pdf.model.pdf_extract_kit import CustomPEKModel
|
|
|
from magic_pdf.model.sub_modules.model_utils import (
|
|
from magic_pdf.model.sub_modules.model_utils import (
|
|
|
clean_vram, crop_img, get_res_list_from_layout_res)
|
|
clean_vram, crop_img, get_res_list_from_layout_res)
|
|
|
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
|
|
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
|
|
|
get_adjusted_mfdetrec_res, get_ocr_result_list)
|
|
get_adjusted_mfdetrec_res, get_ocr_result_list)
|
|
|
-from magic_pdf.operators.models import InferenceResult
|
|
|
|
|
|
|
+# from magic_pdf.operators.models import InferenceResult
|
|
|
|
|
|
|
|
YOLO_LAYOUT_BASE_BATCH_SIZE = 4
|
|
YOLO_LAYOUT_BASE_BATCH_SIZE = 4
|
|
|
MFD_BASE_BATCH_SIZE = 1
|
|
MFD_BASE_BATCH_SIZE = 1
|
|
@@ -91,10 +91,12 @@ class BatchAnalyze:
|
|
|
images,
|
|
images,
|
|
|
batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
|
|
batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
|
|
|
)
|
|
)
|
|
|
|
|
+ mfr_count = 0
|
|
|
for image_index in range(len(images)):
|
|
for image_index in range(len(images)):
|
|
|
images_layout_res[image_index] += images_formula_list[image_index]
|
|
images_layout_res[image_index] += images_formula_list[image_index]
|
|
|
|
|
+ mfr_count += len(images_formula_list[image_index])
|
|
|
logger.info(
|
|
logger.info(
|
|
|
- f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {len(images)}'
|
|
|
|
|
|
|
+ f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {mfr_count}'
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
# 清理显存
|
|
# 清理显存
|
|
@@ -195,81 +197,81 @@ class BatchAnalyze:
|
|
|
return images_layout_res
|
|
return images_layout_res
|
|
|
|
|
|
|
|
|
|
|
|
|
-def doc_batch_analyze(
|
|
|
|
|
- dataset: Dataset,
|
|
|
|
|
- ocr: bool = False,
|
|
|
|
|
- show_log: bool = False,
|
|
|
|
|
- start_page_id=0,
|
|
|
|
|
- end_page_id=None,
|
|
|
|
|
- lang=None,
|
|
|
|
|
- layout_model=None,
|
|
|
|
|
- formula_enable=None,
|
|
|
|
|
- table_enable=None,
|
|
|
|
|
- batch_ratio: int | None = None,
|
|
|
|
|
-) -> InferenceResult:
|
|
|
|
|
- """Perform batch analysis on a document dataset.
|
|
|
|
|
-
|
|
|
|
|
- Args:
|
|
|
|
|
- dataset (Dataset): The dataset containing document pages to be analyzed.
|
|
|
|
|
- ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
|
|
|
|
|
- show_log (bool, optional): Flag to enable logging. Defaults to False.
|
|
|
|
|
- start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
|
|
|
|
|
- end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
|
|
|
|
|
- lang (str, optional): Language for OCR. Defaults to None.
|
|
|
|
|
- layout_model (optional): Layout model to be used for analysis. Defaults to None.
|
|
|
|
|
- formula_enable (optional): Flag to enable formula detection. Defaults to None.
|
|
|
|
|
- table_enable (optional): Flag to enable table detection. Defaults to None.
|
|
|
|
|
- batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
|
|
|
|
|
-
|
|
|
|
|
- Raises:
|
|
|
|
|
- CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
|
|
|
|
|
-
|
|
|
|
|
- Returns:
|
|
|
|
|
- InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
|
|
|
|
|
- """
|
|
|
|
|
-
|
|
|
|
|
- if not torch.cuda.is_available():
|
|
|
|
|
- raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
|
|
|
|
|
-
|
|
|
|
|
- lang = None if lang == '' else lang
|
|
|
|
|
- # TODO: auto detect batch size
|
|
|
|
|
- batch_ratio = 1 if batch_ratio is None else batch_ratio
|
|
|
|
|
- end_page_id = end_page_id if end_page_id else len(dataset)
|
|
|
|
|
-
|
|
|
|
|
- model_manager = ModelSingleton()
|
|
|
|
|
- custom_model: CustomPEKModel = model_manager.get_model(
|
|
|
|
|
- ocr, show_log, lang, layout_model, formula_enable, table_enable
|
|
|
|
|
- )
|
|
|
|
|
- batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
|
|
|
|
|
-
|
|
|
|
|
- model_json = []
|
|
|
|
|
-
|
|
|
|
|
- # batch analyze
|
|
|
|
|
- images = []
|
|
|
|
|
- for index in range(len(dataset)):
|
|
|
|
|
- if start_page_id <= index <= end_page_id:
|
|
|
|
|
- page_data = dataset.get_page(index)
|
|
|
|
|
- img_dict = page_data.get_image()
|
|
|
|
|
- images.append(img_dict['img'])
|
|
|
|
|
- analyze_result = batch_model(images)
|
|
|
|
|
-
|
|
|
|
|
- for index in range(len(dataset)):
|
|
|
|
|
- page_data = dataset.get_page(index)
|
|
|
|
|
- img_dict = page_data.get_image()
|
|
|
|
|
- page_width = img_dict['width']
|
|
|
|
|
- page_height = img_dict['height']
|
|
|
|
|
- if start_page_id <= index <= end_page_id:
|
|
|
|
|
- result = analyze_result.pop(0)
|
|
|
|
|
- else:
|
|
|
|
|
- result = []
|
|
|
|
|
-
|
|
|
|
|
- page_info = {'page_no': index, 'height': page_height, 'width': page_width}
|
|
|
|
|
- page_dict = {'layout_dets': result, 'page_info': page_info}
|
|
|
|
|
- model_json.append(page_dict)
|
|
|
|
|
-
|
|
|
|
|
- # TODO: clean memory when gpu memory is not enough
|
|
|
|
|
- clean_memory_start_time = time.time()
|
|
|
|
|
- clean_memory(get_device())
|
|
|
|
|
- logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
|
|
|
|
|
-
|
|
|
|
|
- return InferenceResult(model_json, dataset)
|
|
|
|
|
|
|
+# def doc_batch_analyze(
|
|
|
|
|
+# dataset: Dataset,
|
|
|
|
|
+# ocr: bool = False,
|
|
|
|
|
+# show_log: bool = False,
|
|
|
|
|
+# start_page_id=0,
|
|
|
|
|
+# end_page_id=None,
|
|
|
|
|
+# lang=None,
|
|
|
|
|
+# layout_model=None,
|
|
|
|
|
+# formula_enable=None,
|
|
|
|
|
+# table_enable=None,
|
|
|
|
|
+# batch_ratio: int | None = None,
|
|
|
|
|
+# ) -> InferenceResult:
|
|
|
|
|
+# """Perform batch analysis on a document dataset.
|
|
|
|
|
+#
|
|
|
|
|
+# Args:
|
|
|
|
|
+# dataset (Dataset): The dataset containing document pages to be analyzed.
|
|
|
|
|
+# ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
|
|
|
|
|
+# show_log (bool, optional): Flag to enable logging. Defaults to False.
|
|
|
|
|
+# start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
|
|
|
|
|
+# end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
|
|
|
|
|
+# lang (str, optional): Language for OCR. Defaults to None.
|
|
|
|
|
+# layout_model (optional): Layout model to be used for analysis. Defaults to None.
|
|
|
|
|
+# formula_enable (optional): Flag to enable formula detection. Defaults to None.
|
|
|
|
|
+# table_enable (optional): Flag to enable table detection. Defaults to None.
|
|
|
|
|
+# batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
|
|
|
|
|
+#
|
|
|
|
|
+# Raises:
|
|
|
|
|
+# CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
|
|
|
|
|
+#
|
|
|
|
|
+# Returns:
|
|
|
|
|
+# InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
|
|
|
|
|
+# """
|
|
|
|
|
+#
|
|
|
|
|
+# if not torch.cuda.is_available():
|
|
|
|
|
+# raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
|
|
|
|
|
+#
|
|
|
|
|
+# lang = None if lang == '' else lang
|
|
|
|
|
+# # TODO: auto detect batch size
|
|
|
|
|
+# batch_ratio = 1 if batch_ratio is None else batch_ratio
|
|
|
|
|
+# end_page_id = end_page_id if end_page_id else len(dataset)
|
|
|
|
|
+#
|
|
|
|
|
+# model_manager = ModelSingleton()
|
|
|
|
|
+# custom_model: CustomPEKModel = model_manager.get_model(
|
|
|
|
|
+# ocr, show_log, lang, layout_model, formula_enable, table_enable
|
|
|
|
|
+# )
|
|
|
|
|
+# batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
|
|
|
|
|
+#
|
|
|
|
|
+# model_json = []
|
|
|
|
|
+#
|
|
|
|
|
+# # batch analyze
|
|
|
|
|
+# images = []
|
|
|
|
|
+# for index in range(len(dataset)):
|
|
|
|
|
+# if start_page_id <= index <= end_page_id:
|
|
|
|
|
+# page_data = dataset.get_page(index)
|
|
|
|
|
+# img_dict = page_data.get_image()
|
|
|
|
|
+# images.append(img_dict['img'])
|
|
|
|
|
+# analyze_result = batch_model(images)
|
|
|
|
|
+#
|
|
|
|
|
+# for index in range(len(dataset)):
|
|
|
|
|
+# page_data = dataset.get_page(index)
|
|
|
|
|
+# img_dict = page_data.get_image()
|
|
|
|
|
+# page_width = img_dict['width']
|
|
|
|
|
+# page_height = img_dict['height']
|
|
|
|
|
+# if start_page_id <= index <= end_page_id:
|
|
|
|
|
+# result = analyze_result.pop(0)
|
|
|
|
|
+# else:
|
|
|
|
|
+# result = []
|
|
|
|
|
+#
|
|
|
|
|
+# page_info = {'page_no': index, 'height': page_height, 'width': page_width}
|
|
|
|
|
+# page_dict = {'layout_dets': result, 'page_info': page_info}
|
|
|
|
|
+# model_json.append(page_dict)
|
|
|
|
|
+#
|
|
|
|
|
+# # TODO: clean memory when gpu memory is not enough
|
|
|
|
|
+# clean_memory_start_time = time.time()
|
|
|
|
|
+# clean_memory(get_device())
|
|
|
|
|
+# logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
|
|
|
|
|
+#
|
|
|
|
|
+# return InferenceResult(model_json, dataset)
|