|
|
@@ -10,14 +10,11 @@ from .model_init import AtomModelSingleton
|
|
|
from .model_list import AtomicModel
|
|
|
from ...utils.config_reader import get_formula_enable, get_table_enable
|
|
|
from ...utils.model_utils import crop_img, get_res_list_from_layout_res
|
|
|
-from ...utils.ocr_utils import (
|
|
|
- get_adjusted_mfdetrec_res,
|
|
|
- get_ocr_result_list,
|
|
|
- OcrConfidence,
|
|
|
- get_rotate_crop_image,
|
|
|
-)
|
|
|
+from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence, get_rotate_crop_image
|
|
|
from ...utils.pdf_image_tools import get_crop_img
|
|
|
from ...utils.ocr_utils import merge_det_boxes, update_det_boxes, sorted_boxes
|
|
|
+from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence
|
|
|
+from ...utils.pdf_image_tools import get_crop_np_img
|
|
|
|
|
|
YOLO_LAYOUT_BASE_BATCH_SIZE = 1
|
|
|
MFD_BASE_BATCH_SIZE = 1
|
|
|
@@ -47,29 +44,28 @@ class BatchAnalyze:
|
|
|
)
|
|
|
atom_model_manager = AtomModelSingleton()
|
|
|
|
|
|
- images = [image for image, _, _ in images_with_extra_info]
|
|
|
+ np_images = [np.asarray(image) for image, _, _ in images_with_extra_info]
|
|
|
|
|
|
# doclayout_yolo
|
|
|
- layout_images = images.copy()
|
|
|
|
|
|
images_layout_res += self.model.layout_model.batch_predict(
|
|
|
- layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
|
|
|
+ np_images, YOLO_LAYOUT_BASE_BATCH_SIZE
|
|
|
)
|
|
|
|
|
|
if self.formula_enable:
|
|
|
# 公式检测
|
|
|
images_mfd_res = self.model.mfd_model.batch_predict(
|
|
|
- images, MFD_BASE_BATCH_SIZE
|
|
|
+ np_images, MFD_BASE_BATCH_SIZE
|
|
|
)
|
|
|
|
|
|
# 公式识别
|
|
|
images_formula_list = self.model.mfr_model.batch_predict(
|
|
|
images_mfd_res,
|
|
|
- images,
|
|
|
+ np_images,
|
|
|
batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
|
|
|
)
|
|
|
mfr_count = 0
|
|
|
- for image_index in range(len(images)):
|
|
|
+ for image_index in range(len(np_images)):
|
|
|
images_layout_res[image_index] += images_formula_list[image_index]
|
|
|
mfr_count += len(images_formula_list[image_index])
|
|
|
|
|
|
@@ -78,10 +74,10 @@ class BatchAnalyze:
|
|
|
|
|
|
ocr_res_list_all_page = []
|
|
|
table_res_list_all_page = []
|
|
|
- for index in range(len(images)):
|
|
|
+ for index in range(len(np_images)):
|
|
|
_, ocr_enable, _lang = images_with_extra_info[index]
|
|
|
layout_res = images_layout_res[index]
|
|
|
- pil_img = images[index]
|
|
|
+ np_img = np_images[index]
|
|
|
|
|
|
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
|
|
|
get_res_list_from_layout_res(layout_res)
|
|
|
@@ -90,7 +86,7 @@ class BatchAnalyze:
|
|
|
ocr_res_list_all_page.append({'ocr_res_list':ocr_res_list,
|
|
|
'lang':_lang,
|
|
|
'ocr_enable':ocr_enable,
|
|
|
- 'pil_img':pil_img,
|
|
|
+ 'np_img':np_img,
|
|
|
'single_page_mfdetrec_res':single_page_mfdetrec_res,
|
|
|
'layout_res':layout_res,
|
|
|
})
|
|
|
@@ -102,7 +98,7 @@ class BatchAnalyze:
|
|
|
crop_xmin, crop_ymin = int(table_res['poly'][0]), int(table_res['poly'][1])
|
|
|
crop_xmax, crop_ymax = int(table_res['poly'][4]), int(table_res['poly'][5])
|
|
|
bbox = (int(crop_xmin/scale), int(crop_ymin/scale), int(crop_xmax/scale), int(crop_ymax/scale))
|
|
|
- table_img = get_crop_img(bbox, pil_img, scale=scale)
|
|
|
+ table_img = get_crop_np_img(bbox, np_img, scale=scale)
|
|
|
|
|
|
table_res_list_all_page.append({'table_res':table_res,
|
|
|
'lang':_lang,
|
|
|
@@ -120,17 +116,17 @@ class BatchAnalyze:
|
|
|
|
|
|
for res in ocr_res_list_dict['ocr_res_list']:
|
|
|
new_image, useful_list = crop_img(
|
|
|
- res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
|
|
|
+ res, ocr_res_list_dict['np_img'], crop_paste_x=50, crop_paste_y=50
|
|
|
)
|
|
|
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
|
|
|
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
|
|
|
)
|
|
|
|
|
|
# BGR转换
|
|
|
- new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
|
|
|
+ bgr_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
|
|
|
|
|
|
all_cropped_images_info.append((
|
|
|
- new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang
|
|
|
+ bgr_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang
|
|
|
))
|
|
|
|
|
|
# 按语言分组
|
|
|
@@ -195,10 +191,13 @@ class BatchAnalyze:
|
|
|
|
|
|
# 处理批处理结果
|
|
|
for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
|
|
|
- new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
|
|
|
+ bgr_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
|
|
|
|
|
|
if dt_boxes is not None and len(dt_boxes) > 0:
|
|
|
# 直接应用原始OCR流程中的关键处理步骤
|
|
|
+ from mineru.utils.ocr_utils import (
|
|
|
+ merge_det_boxes, update_det_boxes, sorted_boxes
|
|
|
+ )
|
|
|
|
|
|
# 1. 排序检测框
|
|
|
if len(dt_boxes) > 0:
|
|
|
@@ -223,7 +222,7 @@ class BatchAnalyze:
|
|
|
|
|
|
if ocr_res:
|
|
|
ocr_result_list = get_ocr_result_list(
|
|
|
- ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang
|
|
|
+ ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], bgr_image, _lang
|
|
|
)
|
|
|
|
|
|
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
|
|
|
@@ -241,21 +240,21 @@ class BatchAnalyze:
|
|
|
)
|
|
|
for res in ocr_res_list_dict['ocr_res_list']:
|
|
|
new_image, useful_list = crop_img(
|
|
|
- res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
|
|
|
+ res, ocr_res_list_dict['np_img'], crop_paste_x=50, crop_paste_y=50
|
|
|
)
|
|
|
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
|
|
|
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
|
|
|
)
|
|
|
# OCR-det
|
|
|
- new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
|
|
|
+ bgr_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
|
|
|
ocr_res = ocr_model.ocr(
|
|
|
- new_image, mfd_res=adjusted_mfdetrec_res, rec=False
|
|
|
+ bgr_image, mfd_res=adjusted_mfdetrec_res, rec=False
|
|
|
)[0]
|
|
|
|
|
|
# Integration results
|
|
|
if ocr_res:
|
|
|
ocr_result_list = get_ocr_result_list(
|
|
|
- ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],new_image, _lang
|
|
|
+ ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],bgr_image, _lang
|
|
|
)
|
|
|
|
|
|
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
|