|
@@ -30,6 +30,7 @@ from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
|
|
|
|
|
|
|
|
|
|
|
|
|
class CustomPEKModel:
|
|
class CustomPEKModel:
|
|
|
|
|
+
|
|
|
def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
|
|
def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
|
|
|
"""
|
|
"""
|
|
|
======== model init ========
|
|
======== model init ========
|
|
@@ -149,13 +150,12 @@ class CustomPEKModel:
|
|
|
device=self.device,
|
|
device=self.device,
|
|
|
)
|
|
)
|
|
|
# 初始化ocr
|
|
# 初始化ocr
|
|
|
- if self.apply_ocr:
|
|
|
|
|
- self.ocr_model = atom_model_manager.get_atom_model(
|
|
|
|
|
- atom_model_name=AtomicModel.OCR,
|
|
|
|
|
- ocr_show_log=show_log,
|
|
|
|
|
- det_db_box_thresh=0.3,
|
|
|
|
|
- lang=self.lang,
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ self.ocr_model = atom_model_manager.get_atom_model(
|
|
|
|
|
+ atom_model_name=AtomicModel.OCR,
|
|
|
|
|
+ ocr_show_log=show_log,
|
|
|
|
|
+ det_db_box_thresh=0.3,
|
|
|
|
|
+ lang=self.lang
|
|
|
|
|
+ )
|
|
|
# init table model
|
|
# init table model
|
|
|
if self.apply_table:
|
|
if self.apply_table:
|
|
|
table_model_dir = self.configs['weights'][self.table_model_name]
|
|
table_model_dir = self.configs['weights'][self.table_model_name]
|
|
@@ -208,30 +208,29 @@ class CustomPEKModel:
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
# ocr识别
|
|
# ocr识别
|
|
|
|
|
+ ocr_start = time.time()
|
|
|
|
|
+ # Process each area that requires OCR processing
|
|
|
|
|
+ for res in ocr_res_list:
|
|
|
|
|
+ new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
|
|
|
|
|
+ adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
|
|
|
|
|
+
|
|
|
|
|
+ # OCR recognition
|
|
|
|
|
+ new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
|
|
|
|
|
+ if self.apply_ocr:
|
|
|
|
|
+ ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
|
|
|
|
|
+ else:
|
|
|
|
|
+ ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0]
|
|
|
|
|
+
|
|
|
|
|
+ # Integration results
|
|
|
|
|
+ if ocr_res:
|
|
|
|
|
+ ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
|
|
|
|
|
+ layout_res.extend(ocr_result_list)
|
|
|
|
|
+
|
|
|
|
|
+ ocr_cost = round(time.time() - ocr_start, 2)
|
|
|
if self.apply_ocr:
|
|
if self.apply_ocr:
|
|
|
- ocr_start = time.time()
|
|
|
|
|
- # Process each area that requires OCR processing
|
|
|
|
|
- for res in ocr_res_list:
|
|
|
|
|
- new_image, useful_list = crop_img(
|
|
|
|
|
- res, pil_img, crop_paste_x=50, crop_paste_y=50
|
|
|
|
|
- )
|
|
|
|
|
- adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
|
|
|
|
|
- single_page_mfdetrec_res, useful_list
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- # OCR recognition
|
|
|
|
|
- new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
|
|
|
|
|
- ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[
|
|
|
|
|
- 0
|
|
|
|
|
- ]
|
|
|
|
|
-
|
|
|
|
|
- # Integration results
|
|
|
|
|
- if ocr_res:
|
|
|
|
|
- ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
|
|
|
|
|
- layout_res.extend(ocr_result_list)
|
|
|
|
|
-
|
|
|
|
|
- ocr_cost = round(time.time() - ocr_start, 2)
|
|
|
|
|
- logger.info(f'ocr time: {ocr_cost}')
|
|
|
|
|
|
|
+ logger.info(f"ocr time: {ocr_cost}")
|
|
|
|
|
+ else:
|
|
|
|
|
+ logger.info(f"det time: {ocr_cost}")
|
|
|
|
|
|
|
|
# 表格识别 table recognition
|
|
# 表格识别 table recognition
|
|
|
if self.apply_table:
|
|
if self.apply_table:
|