|
@@ -9,7 +9,7 @@ from .model_list import AtomicModel
|
|
|
from ...utils.config_reader import get_formula_enable, get_table_enable
|
|
from ...utils.config_reader import get_formula_enable, get_table_enable
|
|
|
from ...utils.model_utils import crop_img, get_res_list_from_layout_res
|
|
from ...utils.model_utils import crop_img, get_res_list_from_layout_res
|
|
|
from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence
|
|
from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence
|
|
|
-from ...utils.pdf_image_tools import get_crop_img
|
|
|
|
|
|
|
+from ...utils.pdf_image_tools import get_crop_np_img
|
|
|
|
|
|
|
|
YOLO_LAYOUT_BASE_BATCH_SIZE = 1
|
|
YOLO_LAYOUT_BASE_BATCH_SIZE = 1
|
|
|
MFD_BASE_BATCH_SIZE = 1
|
|
MFD_BASE_BATCH_SIZE = 1
|
|
@@ -38,29 +38,28 @@ class BatchAnalyze:
|
|
|
)
|
|
)
|
|
|
atom_model_manager = AtomModelSingleton()
|
|
atom_model_manager = AtomModelSingleton()
|
|
|
|
|
|
|
|
- images = [image for image, _, _ in images_with_extra_info]
|
|
|
|
|
|
|
+ np_images = [np.asarray(image) for image, _, _ in images_with_extra_info]
|
|
|
|
|
|
|
|
# doclayout_yolo
|
|
# doclayout_yolo
|
|
|
- layout_images = images.copy()
|
|
|
|
|
|
|
|
|
|
images_layout_res += self.model.layout_model.batch_predict(
|
|
images_layout_res += self.model.layout_model.batch_predict(
|
|
|
- layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
|
|
|
|
|
|
|
+ np_images, YOLO_LAYOUT_BASE_BATCH_SIZE
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
if self.formula_enable:
|
|
if self.formula_enable:
|
|
|
# 公式检测
|
|
# 公式检测
|
|
|
images_mfd_res = self.model.mfd_model.batch_predict(
|
|
images_mfd_res = self.model.mfd_model.batch_predict(
|
|
|
- images, MFD_BASE_BATCH_SIZE
|
|
|
|
|
|
|
+ np_images, MFD_BASE_BATCH_SIZE
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
# 公式识别
|
|
# 公式识别
|
|
|
images_formula_list = self.model.mfr_model.batch_predict(
|
|
images_formula_list = self.model.mfr_model.batch_predict(
|
|
|
images_mfd_res,
|
|
images_mfd_res,
|
|
|
- images,
|
|
|
|
|
|
|
+ np_images,
|
|
|
batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
|
|
batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
|
|
|
)
|
|
)
|
|
|
mfr_count = 0
|
|
mfr_count = 0
|
|
|
- for image_index in range(len(images)):
|
|
|
|
|
|
|
+ for image_index in range(len(np_images)):
|
|
|
images_layout_res[image_index] += images_formula_list[image_index]
|
|
images_layout_res[image_index] += images_formula_list[image_index]
|
|
|
mfr_count += len(images_formula_list[image_index])
|
|
mfr_count += len(images_formula_list[image_index])
|
|
|
|
|
|
|
@@ -69,10 +68,10 @@ class BatchAnalyze:
|
|
|
|
|
|
|
|
ocr_res_list_all_page = []
|
|
ocr_res_list_all_page = []
|
|
|
table_res_list_all_page = []
|
|
table_res_list_all_page = []
|
|
|
- for index in range(len(images)):
|
|
|
|
|
|
|
+ for index in range(len(np_images)):
|
|
|
_, ocr_enable, _lang = images_with_extra_info[index]
|
|
_, ocr_enable, _lang = images_with_extra_info[index]
|
|
|
layout_res = images_layout_res[index]
|
|
layout_res = images_layout_res[index]
|
|
|
- pil_img = images[index]
|
|
|
|
|
|
|
+ np_img = np_images[index]
|
|
|
|
|
|
|
|
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
|
|
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
|
|
|
get_res_list_from_layout_res(layout_res)
|
|
get_res_list_from_layout_res(layout_res)
|
|
@@ -81,7 +80,7 @@ class BatchAnalyze:
|
|
|
ocr_res_list_all_page.append({'ocr_res_list':ocr_res_list,
|
|
ocr_res_list_all_page.append({'ocr_res_list':ocr_res_list,
|
|
|
'lang':_lang,
|
|
'lang':_lang,
|
|
|
'ocr_enable':ocr_enable,
|
|
'ocr_enable':ocr_enable,
|
|
|
- 'pil_img':pil_img,
|
|
|
|
|
|
|
+ 'np_img':np_img,
|
|
|
'single_page_mfdetrec_res':single_page_mfdetrec_res,
|
|
'single_page_mfdetrec_res':single_page_mfdetrec_res,
|
|
|
'layout_res':layout_res,
|
|
'layout_res':layout_res,
|
|
|
})
|
|
})
|
|
@@ -93,7 +92,7 @@ class BatchAnalyze:
|
|
|
crop_xmin, crop_ymin = int(table_res['poly'][0]), int(table_res['poly'][1])
|
|
crop_xmin, crop_ymin = int(table_res['poly'][0]), int(table_res['poly'][1])
|
|
|
crop_xmax, crop_ymax = int(table_res['poly'][4]), int(table_res['poly'][5])
|
|
crop_xmax, crop_ymax = int(table_res['poly'][4]), int(table_res['poly'][5])
|
|
|
bbox = (int(crop_xmin/scale), int(crop_ymin/scale), int(crop_xmax/scale), int(crop_ymax/scale))
|
|
bbox = (int(crop_xmin/scale), int(crop_ymin/scale), int(crop_xmax/scale), int(crop_ymax/scale))
|
|
|
- table_img = get_crop_img(bbox, pil_img, scale=scale)
|
|
|
|
|
|
|
+ table_img = get_crop_np_img(bbox, np_img, scale=scale)
|
|
|
|
|
|
|
|
table_res_list_all_page.append({'table_res':table_res,
|
|
table_res_list_all_page.append({'table_res':table_res,
|
|
|
'lang':_lang,
|
|
'lang':_lang,
|
|
@@ -111,17 +110,17 @@ class BatchAnalyze:
|
|
|
|
|
|
|
|
for res in ocr_res_list_dict['ocr_res_list']:
|
|
for res in ocr_res_list_dict['ocr_res_list']:
|
|
|
new_image, useful_list = crop_img(
|
|
new_image, useful_list = crop_img(
|
|
|
- res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
|
|
|
|
|
|
|
+ res, ocr_res_list_dict['np_img'], crop_paste_x=50, crop_paste_y=50
|
|
|
)
|
|
)
|
|
|
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
|
|
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
|
|
|
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
|
|
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
# BGR转换
|
|
# BGR转换
|
|
|
- new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
|
|
|
|
|
|
|
+ bgr_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
|
|
|
|
|
|
|
|
all_cropped_images_info.append((
|
|
all_cropped_images_info.append((
|
|
|
- new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang
|
|
|
|
|
|
|
+ bgr_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang
|
|
|
))
|
|
))
|
|
|
|
|
|
|
|
# 按语言分组
|
|
# 按语言分组
|
|
@@ -186,7 +185,7 @@ class BatchAnalyze:
|
|
|
|
|
|
|
|
# 处理批处理结果
|
|
# 处理批处理结果
|
|
|
for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
|
|
for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
|
|
|
- new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
|
|
|
|
|
|
|
+ bgr_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
|
|
|
|
|
|
|
|
if dt_boxes is not None and len(dt_boxes) > 0:
|
|
if dt_boxes is not None and len(dt_boxes) > 0:
|
|
|
# 直接应用原始OCR流程中的关键处理步骤
|
|
# 直接应用原始OCR流程中的关键处理步骤
|
|
@@ -217,7 +216,7 @@ class BatchAnalyze:
|
|
|
|
|
|
|
|
if ocr_res:
|
|
if ocr_res:
|
|
|
ocr_result_list = get_ocr_result_list(
|
|
ocr_result_list = get_ocr_result_list(
|
|
|
- ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang
|
|
|
|
|
|
|
+ ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], bgr_image, _lang
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
|
|
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
|
|
@@ -235,21 +234,21 @@ class BatchAnalyze:
|
|
|
)
|
|
)
|
|
|
for res in ocr_res_list_dict['ocr_res_list']:
|
|
for res in ocr_res_list_dict['ocr_res_list']:
|
|
|
new_image, useful_list = crop_img(
|
|
new_image, useful_list = crop_img(
|
|
|
- res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
|
|
|
|
|
|
|
+ res, ocr_res_list_dict['np_img'], crop_paste_x=50, crop_paste_y=50
|
|
|
)
|
|
)
|
|
|
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
|
|
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
|
|
|
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
|
|
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
|
|
|
)
|
|
)
|
|
|
# OCR-det
|
|
# OCR-det
|
|
|
- new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
|
|
|
|
|
|
|
+ bgr_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
|
|
|
ocr_res = ocr_model.ocr(
|
|
ocr_res = ocr_model.ocr(
|
|
|
- new_image, mfd_res=adjusted_mfdetrec_res, rec=False
|
|
|
|
|
|
|
+ bgr_image, mfd_res=adjusted_mfdetrec_res, rec=False
|
|
|
)[0]
|
|
)[0]
|
|
|
|
|
|
|
|
# Integration results
|
|
# Integration results
|
|
|
if ocr_res:
|
|
if ocr_res:
|
|
|
ocr_result_list = get_ocr_result_list(
|
|
ocr_result_list = get_ocr_result_list(
|
|
|
- ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],new_image, _lang
|
|
|
|
|
|
|
+ ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],bgr_image, _lang
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
|
|
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
|
|
@@ -273,7 +272,7 @@ class BatchAnalyze:
|
|
|
)
|
|
)
|
|
|
rotate_label = "0"
|
|
rotate_label = "0"
|
|
|
|
|
|
|
|
- np_table_img = np.asarray(table_res_dict["table_img"])
|
|
|
|
|
|
|
+ np_table_img = table_res_dict["table_img"]
|
|
|
if rotate_label == "270":
|
|
if rotate_label == "270":
|
|
|
np_table_img = cv2.rotate(np_table_img, cv2.ROTATE_90_CLOCKWISE)
|
|
np_table_img = cv2.rotate(np_table_img, cv2.ROTATE_90_CLOCKWISE)
|
|
|
elif rotate_label == "90":
|
|
elif rotate_label == "90":
|