Prechádzať zdrojové kódy

Merge pull request #1049 from myhloli/dev

feat(ocr): improve text detection and OCR accuracy
Xiaomeng Zhao 1 rok pred
rodič
commit
190e22312b

+ 29 - 30
magic_pdf/model/pdf_extract_kit.py

@@ -30,6 +30,7 @@ from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
 
 
 class CustomPEKModel:
+
     def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
         """
         ======== model init ========
@@ -149,13 +150,12 @@ class CustomPEKModel:
                 device=self.device,
             )
         # 初始化ocr
-        if self.apply_ocr:
-            self.ocr_model = atom_model_manager.get_atom_model(
-                atom_model_name=AtomicModel.OCR,
-                ocr_show_log=show_log,
-                det_db_box_thresh=0.3,
-                lang=self.lang,
-            )
+        self.ocr_model = atom_model_manager.get_atom_model(
+            atom_model_name=AtomicModel.OCR,
+            ocr_show_log=show_log,
+            det_db_box_thresh=0.3,
+            lang=self.lang
+        )
         # init table model
         if self.apply_table:
             table_model_dir = self.configs['weights'][self.table_model_name]
@@ -208,30 +208,29 @@ class CustomPEKModel:
         )
 
         # ocr识别
+        ocr_start = time.time()
+        # Process each area that requires OCR processing
+        for res in ocr_res_list:
+            new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
+            adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
+
+            # OCR recognition
+            new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
+            if self.apply_ocr:
+                ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
+            else:
+                ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0]
+
+            # Integration results
+            if ocr_res:
+                ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
+                layout_res.extend(ocr_result_list)
+
+        ocr_cost = round(time.time() - ocr_start, 2)
         if self.apply_ocr:
-            ocr_start = time.time()
-            # Process each area that requires OCR processing
-            for res in ocr_res_list:
-                new_image, useful_list = crop_img(
-                    res, pil_img, crop_paste_x=50, crop_paste_y=50
-                )
-                adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
-                    single_page_mfdetrec_res, useful_list
-                )
-
-                # OCR recognition
-                new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
-                ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[
-                    0
-                ]
-
-                # Integration results
-                if ocr_res:
-                    ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
-                    layout_res.extend(ocr_result_list)
-
-            ocr_cost = round(time.time() - ocr_start, 2)
-            logger.info(f'ocr time: {ocr_cost}')
+            logger.info(f"ocr time: {ocr_cost}")
+        else:
+            logger.info(f"det time: {ocr_cost}")
 
         # 表格识别 table recognition
         if self.apply_table:

+ 10 - 5
magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py

@@ -211,16 +211,21 @@ def get_ocr_result_list(ocr_res, useful_list):
     ocr_result_list = []
     for box_ocr_res in ocr_res:
 
-        p1, p2, p3, p4 = box_ocr_res[0]
-        text, score = box_ocr_res[1]
+        if len(box_ocr_res) == 2:
+            p1, p2, p3, p4 = box_ocr_res[0]
+            text, score = box_ocr_res[1]
+        else:
+            p1, p2, p3, p4 = box_ocr_res
+            text, score = "", 1
         # average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
         # if average_angle_degrees > 0.5:
-        if calculate_is_angle(box_ocr_res[0]):
+        poly = [p1, p2, p3, p4]
+        if calculate_is_angle(poly):
             # logger.info(f"average_angle_degrees: {average_angle_degrees}, text: {text}")
             # 与x轴的夹角超过0.5度,对边界做一下矫正
             # 计算几何中心
-            x_center = sum(point[0] for point in box_ocr_res[0]) / 4
-            y_center = sum(point[1] for point in box_ocr_res[0]) / 4
+            x_center = sum(point[0] for point in poly) / 4
+            y_center = sum(point[1] for point in poly) / 4
             new_height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
             new_width = p3[0] - p1[0]
             p1 = [x_center - new_width / 2, y_center - new_height / 2]

+ 10 - 1
magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py

@@ -78,9 +78,18 @@ class ModifiedPaddleOCR(PaddleOCR):
             for idx, img in enumerate(imgs):
                 img = preprocess_image(img)
                 dt_boxes, elapse = self.text_detector(img)
-                if not dt_boxes:
+                if dt_boxes is None:
                     ocr_res.append(None)
                     continue
+                dt_boxes = sorted_boxes(dt_boxes)
+                # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
+                dt_boxes = merge_det_boxes(dt_boxes)
+                if mfd_res:
+                    bef = time.time()
+                    dt_boxes = update_det_boxes(dt_boxes, mfd_res)
+                    aft = time.time()
+                    logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
+                        len(dt_boxes), aft - bef))
                 tmp_res = [box.tolist() for box in dt_boxes]
                 ocr_res.append(tmp_res)
             return ocr_res