|
|
@@ -1,23 +1,15 @@
|
|
|
import time
|
|
|
|
|
|
import cv2
|
|
|
-import numpy as np
|
|
|
import torch
|
|
|
from loguru import logger
|
|
|
-from PIL import Image
|
|
|
|
|
|
from magic_pdf.config.constants import MODEL_NAME
|
|
|
-# from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
|
|
|
-# from magic_pdf.data.dataset import Dataset
|
|
|
-# from magic_pdf.libs.clean_memory import clean_memory
|
|
|
-# from magic_pdf.libs.config_reader import get_device
|
|
|
-# from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
|
|
|
from magic_pdf.model.pdf_extract_kit import CustomPEKModel
|
|
|
from magic_pdf.model.sub_modules.model_utils import (
|
|
|
clean_vram, crop_img, get_res_list_from_layout_res)
|
|
|
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
|
|
|
get_adjusted_mfdetrec_res, get_ocr_result_list)
|
|
|
-# from magic_pdf.operators.models import InferenceResult
|
|
|
|
|
|
YOLO_LAYOUT_BASE_BATCH_SIZE = 1
|
|
|
MFD_BASE_BATCH_SIZE = 1
|
|
|
@@ -31,7 +23,6 @@ class BatchAnalyze:
|
|
|
|
|
|
def __call__(self, images: list) -> list:
|
|
|
images_layout_res = []
|
|
|
-
|
|
|
layout_start_time = time.time()
|
|
|
if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3:
|
|
|
# layoutlmv3
|
|
|
@@ -41,36 +32,14 @@ class BatchAnalyze:
|
|
|
elif self.model.layout_model_name == MODEL_NAME.DocLayout_YOLO:
|
|
|
# doclayout_yolo
|
|
|
layout_images = []
|
|
|
- modified_images = []
|
|
|
for image_index, image in enumerate(images):
|
|
|
- pil_img = Image.fromarray(image)
|
|
|
- # width, height = pil_img.size
|
|
|
- # if height > width:
|
|
|
- # input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
|
|
|
- # new_image, useful_list = crop_img(
|
|
|
- # input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
|
|
|
- # )
|
|
|
- # layout_images.append(new_image)
|
|
|
- # modified_images.append([image_index, useful_list])
|
|
|
- # else:
|
|
|
- layout_images.append(pil_img)
|
|
|
+ layout_images.append(image)
|
|
|
|
|
|
images_layout_res += self.model.layout_model.batch_predict(
|
|
|
# layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
|
|
|
layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
|
|
|
)
|
|
|
|
|
|
- for image_index, useful_list in modified_images:
|
|
|
- for res in images_layout_res[image_index]:
|
|
|
- for i in range(len(res['poly'])):
|
|
|
- if i % 2 == 0:
|
|
|
- res['poly'][i] = (
|
|
|
- res['poly'][i] - useful_list[0] + useful_list[2]
|
|
|
- )
|
|
|
- else:
|
|
|
- res['poly'][i] = (
|
|
|
- res['poly'][i] - useful_list[1] + useful_list[3]
|
|
|
- )
|
|
|
logger.info(
|
|
|
f'layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}'
|
|
|
)
|
|
|
@@ -111,7 +80,7 @@ class BatchAnalyze:
|
|
|
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
|
|
|
for index in range(len(images)):
|
|
|
layout_res = images_layout_res[index]
|
|
|
- pil_img = Image.fromarray(images[index])
|
|
|
+ np_array_img = images[index]
|
|
|
|
|
|
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
|
|
|
get_res_list_from_layout_res(layout_res)
|
|
|
@@ -121,14 +90,14 @@ class BatchAnalyze:
|
|
|
# Process each area that requires OCR processing
|
|
|
for res in ocr_res_list:
|
|
|
new_image, useful_list = crop_img(
|
|
|
- res, pil_img, crop_paste_x=50, crop_paste_y=50
|
|
|
+ res, np_array_img, crop_paste_x=50, crop_paste_y=50
|
|
|
)
|
|
|
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
|
|
|
single_page_mfdetrec_res, useful_list
|
|
|
)
|
|
|
|
|
|
# OCR recognition
|
|
|
- new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
|
|
|
+ new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
|
|
|
|
|
|
if self.model.apply_ocr:
|
|
|
ocr_res = self.model.ocr_model.ocr(
|
|
|
@@ -150,7 +119,7 @@ class BatchAnalyze:
|
|
|
if self.model.apply_table:
|
|
|
table_start = time.time()
|
|
|
for res in table_res_list:
|
|
|
- new_image, _ = crop_img(res, pil_img)
|
|
|
+ new_image, _ = crop_img(res, np_array_img)
|
|
|
single_table_start_time = time.time()
|
|
|
html_code = None
|
|
|
if self.model.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
|
|
|
@@ -197,83 +166,3 @@ class BatchAnalyze:
|
|
|
logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
|
|
|
|
|
|
return images_layout_res
|
|
|
-
|
|
|
-
|
|
|
-# def doc_batch_analyze(
|
|
|
-# dataset: Dataset,
|
|
|
-# ocr: bool = False,
|
|
|
-# show_log: bool = False,
|
|
|
-# start_page_id=0,
|
|
|
-# end_page_id=None,
|
|
|
-# lang=None,
|
|
|
-# layout_model=None,
|
|
|
-# formula_enable=None,
|
|
|
-# table_enable=None,
|
|
|
-# batch_ratio: int | None = None,
|
|
|
-# ) -> InferenceResult:
|
|
|
-# """Perform batch analysis on a document dataset.
|
|
|
-#
|
|
|
-# Args:
|
|
|
-# dataset (Dataset): The dataset containing document pages to be analyzed.
|
|
|
-# ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
|
|
|
-# show_log (bool, optional): Flag to enable logging. Defaults to False.
|
|
|
-# start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
|
|
|
-# end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
|
|
|
-# lang (str, optional): Language for OCR. Defaults to None.
|
|
|
-# layout_model (optional): Layout model to be used for analysis. Defaults to None.
|
|
|
-# formula_enable (optional): Flag to enable formula detection. Defaults to None.
|
|
|
-# table_enable (optional): Flag to enable table detection. Defaults to None.
|
|
|
-# batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
|
|
|
-#
|
|
|
-# Raises:
|
|
|
-# CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
|
|
|
-#
|
|
|
-# Returns:
|
|
|
-# InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
|
|
|
-# """
|
|
|
-#
|
|
|
-# if not torch.cuda.is_available():
|
|
|
-# raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
|
|
|
-#
|
|
|
-# lang = None if lang == '' else lang
|
|
|
-# # TODO: auto detect batch size
|
|
|
-# batch_ratio = 1 if batch_ratio is None else batch_ratio
|
|
|
-# end_page_id = end_page_id if end_page_id else len(dataset)
|
|
|
-#
|
|
|
-# model_manager = ModelSingleton()
|
|
|
-# custom_model: CustomPEKModel = model_manager.get_model(
|
|
|
-# ocr, show_log, lang, layout_model, formula_enable, table_enable
|
|
|
-# )
|
|
|
-# batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
|
|
|
-#
|
|
|
-# model_json = []
|
|
|
-#
|
|
|
-# # batch analyze
|
|
|
-# images = []
|
|
|
-# for index in range(len(dataset)):
|
|
|
-# if start_page_id <= index <= end_page_id:
|
|
|
-# page_data = dataset.get_page(index)
|
|
|
-# img_dict = page_data.get_image()
|
|
|
-# images.append(img_dict['img'])
|
|
|
-# analyze_result = batch_model(images)
|
|
|
-#
|
|
|
-# for index in range(len(dataset)):
|
|
|
-# page_data = dataset.get_page(index)
|
|
|
-# img_dict = page_data.get_image()
|
|
|
-# page_width = img_dict['width']
|
|
|
-# page_height = img_dict['height']
|
|
|
-# if start_page_id <= index <= end_page_id:
|
|
|
-# result = analyze_result.pop(0)
|
|
|
-# else:
|
|
|
-# result = []
|
|
|
-#
|
|
|
-# page_info = {'page_no': index, 'height': page_height, 'width': page_width}
|
|
|
-# page_dict = {'layout_dets': result, 'page_info': page_info}
|
|
|
-# model_json.append(page_dict)
|
|
|
-#
|
|
|
-# # TODO: clean memory when gpu memory is not enough
|
|
|
-# clean_memory_start_time = time.time()
|
|
|
-# clean_memory(get_device())
|
|
|
-# logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
|
|
|
-#
|
|
|
-# return InferenceResult(model_json, dataset)
|