Prechádzať zdrojové kódy

fix(magic_pdf): optimize formula area selection for OCR

myhloli 1 rok pred
rodič
commit
e7ce3051a4

+ 63 - 22
magic_pdf/model/pdf_extract_kit.py

@@ -168,33 +168,74 @@ class CustomPEKModel:
         if self.apply_ocr:
             ocr_start = time.time()
             pil_img = Image.fromarray(image)
+
+            # 筛选出需要OCR的区域和公式区域
+            ocr_res_list = []
             single_page_mfdetrec_res = []
             for res in layout_res:
                 if int(res['category_id']) in [13, 14]:
-                    xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
-                    xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
                     single_page_mfdetrec_res.append({
-                        "bbox": [xmin, ymin, xmax, ymax],
+                        "bbox": [int(res['poly'][0]), int(res['poly'][1]),
+                                 int(res['poly'][4]), int(res['poly'][5])],
                     })
-            for res in layout_res:
-                if int(res['category_id']) in [0, 1, 2, 4, 6, 7]:  # 需要进行ocr的类别
-                    xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
-                    xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
-                    crop_box = (xmin, ymin, xmax, ymax)
-                    cropped_img = Image.new('RGB', pil_img.size, 'white')
-                    cropped_img.paste(pil_img.crop(crop_box), crop_box)
-                    cropped_img = cv2.cvtColor(np.asarray(cropped_img), cv2.COLOR_RGB2BGR)
-                    ocr_res = self.ocr_model.ocr(cropped_img, mfd_res=single_page_mfdetrec_res)[0]
-                    if ocr_res:
-                        for box_ocr_res in ocr_res:
-                            p1, p2, p3, p4 = box_ocr_res[0]
-                            text, score = box_ocr_res[1]
-                            layout_res.append({
-                                'category_id': 15,
-                                'poly': p1 + p2 + p3 + p4,
-                                'score': round(score, 2),
-                                'text': text,
-                            })
+                elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
+                    ocr_res_list.append(res)
+
+            # 对每一个需OCR处理的区域进行处理
+            for res in ocr_res_list:
+                xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
+                xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
+
+                paste_x = 50
+                paste_y = 50
+                # 创建一个宽高各多50的白色背景
+                new_width = xmax - xmin + paste_x*2
+                new_height = ymax - ymin + paste_y*2
+                new_image = Image.new('RGB', (new_width, new_height), 'white')
+
+                # 裁剪图像
+                crop_box = (xmin, ymin, xmax, ymax)
+                cropped_img = pil_img.crop(crop_box)
+                new_image.paste(cropped_img, (paste_x, paste_y))
+
+                # 调整公式区域坐标
+                adjusted_mfdetrec_res = []
+                for mf_res in single_page_mfdetrec_res:
+                    mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
+                    # 将公式区域坐标调整为相对于裁剪区域的坐标
+                    x0 = mf_xmin - xmin + paste_x
+                    y0 = mf_ymin - ymin + paste_y
+                    x1 = mf_xmax - xmin + paste_x
+                    y1 = mf_ymax - ymin + paste_y
+                    if any([x0 < 0, y0 < 0, x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height, x1 > new_width, y1 > new_height]):
+                        continue
+                    else:
+                        adjusted_mfdetrec_res.append({
+                            "bbox": [x0, y0, x1, y1],
+                        })
+
+                # OCR识别
+                ocr_res = self.ocr_model.ocr(np.array(new_image), mfd_res=adjusted_mfdetrec_res)[0]
+
+                # 整合结果
+                if ocr_res:
+                    for box_ocr_res in ocr_res:
+                        p1, p2, p3, p4 = box_ocr_res[0]
+                        text, score = box_ocr_res[1]
+
+                        # 将坐标转换回原图坐标系
+                        p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
+                        p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
+                        p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
+                        p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
+
+                        layout_res.append({
+                            'category_id': 15,
+                            'poly': p1 + p2 + p3 + p4,
+                            'score': round(score, 2),
+                            'text': text,
+                        })
+
             ocr_cost = round(time.time() - ocr_start, 2)
             logger.info(f"ocr cost: {ocr_cost}")
 

+ 144 - 44
magic_pdf/model/pek_sub_modules/self_modify.py

@@ -10,12 +10,17 @@ from paddleocr import PaddleOCR
 from paddleocr.ppocr.utils.logging import get_logger
 from paddleocr.ppocr.utils.utility import check_and_read, alpha_to_color, binarize_img
 from paddleocr.tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image, get_minarea_rect_crop
+
+from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
+
 logger = get_logger()
 
+
 def img_decode(content: bytes):
     np_arr = np.frombuffer(content, dtype=np.uint8)
     return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
 
+
 def check_img(img):
     if isinstance(img, bytes):
         img = img_decode(img)
@@ -51,6 +56,7 @@ def check_img(img):
 
     return img
 
+
 def sorted_boxes(dt_boxes):
     """
     Sort text boxes in order from top to bottom, left to right
@@ -75,49 +81,87 @@ def sorted_boxes(dt_boxes):
     return _boxes
 
 
-def formula_in_text(mf_bbox, text_bbox):
-    x1, y1, x2, y2 = mf_bbox
-    x3, y3 = text_bbox[0]
-    x4, y4 = text_bbox[2]
-    left_box, right_box = None, None
-    same_line = abs((y1+y2)/2 - (y3+y4)/2) / abs(y4-y3) < 0.2
-    if not same_line:
-        return False, left_box, right_box
-    else:
-        drop_origin = False
-        left_x = x1 - 1
-        right_x = x2 + 1
-        if x3 < x1 and x2 < x4:
-            drop_origin = True
-            left_box = np.array([text_bbox[0], [left_x, text_bbox[1][1]], [left_x, text_bbox[2][1]], text_bbox[3]]).astype('float32')
-            right_box = np.array([[right_x, text_bbox[0][1]], text_bbox[1], text_bbox[2], [right_x, text_bbox[3][1]]]).astype('float32')
-        if x3 < x1 and x1 <= x4 <= x2:
-            drop_origin = True
-            left_box = np.array([text_bbox[0], [left_x, text_bbox[1][1]], [left_x, text_bbox[2][1]], text_bbox[3]]).astype('float32')
-        if x1 <= x3 <= x2 and x2 < x4:
-            drop_origin = True
-            right_box = np.array([[right_x, text_bbox[0][1]], text_bbox[1], text_bbox[2], [right_x, text_bbox[3][1]]]).astype('float32')
-        if x1 <= x3 < x4 <= x2:
-            drop_origin = True
-        return drop_origin, left_box, right_box
-
-    
-def update_det_boxes(dt_boxes, mfdetrec_res):
-    new_dt_boxes = dt_boxes
-    for mf_box in mfdetrec_res:
-        flag, left_box, right_box = False, None, None
-        for idx, text_box in enumerate(new_dt_boxes):
-            ret, left_box, right_box = formula_in_text(mf_box['bbox'], text_box)
-            if ret:
-                new_dt_boxes.pop(idx)
-                if left_box is not None:
-                    new_dt_boxes.append(left_box)
-                if right_box is not None:
-                    new_dt_boxes.append(right_box)
-                break
-            
+def bbox_to_points(bbox):
+    """ 将bbox格式转换为四个顶点的数组 """
+    x0, y0, x1, y1 = bbox
+    return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32')
+
+
+def points_to_bbox(points):
+    """ 将四个顶点的数组转换为bbox格式 """
+    x0, y0 = points[0]
+    x1, _ = points[1]
+    _, y1 = points[2]
+    return [x0, y0, x1, y1]
+
+
+def merge_intervals(intervals):
+    # Sort the intervals based on the start value
+    intervals.sort(key=lambda x: x[0])
+
+    merged = []
+    for interval in intervals:
+        # If the list of merged intervals is empty or if the current
+        # interval does not overlap with the previous, simply append it.
+        if not merged or merged[-1][1] < interval[0]:
+            merged.append(interval)
+        else:
+            # Otherwise, there is overlap, so we merge the current and previous intervals.
+            merged[-1][1] = max(merged[-1][1], interval[1])
+
+    return merged
+
+
+def remove_intervals(original, masks):
+    # Merge all mask intervals
+    merged_masks = merge_intervals(masks)
+
+    result = []
+    original_start, original_end = original
+
+    for mask in merged_masks:
+        mask_start, mask_end = mask
+
+        # If the mask starts after the original range, ignore it
+        if mask_start > original_end:
+            continue
+
+        # If the mask ends before the original range starts, ignore it
+        if mask_end < original_start:
+            continue
+
+        # Remove the masked part from the original range
+        if original_start < mask_start:
+            result.append([original_start, mask_start - 1])
+
+        original_start = max(mask_end + 1, original_start)
+
+    # Add the remaining part of the original range, if any
+    if original_start <= original_end:
+        result.append([original_start, original_end])
+
+    return result
+
+
+def update_det_boxes(dt_boxes, mfd_res):
+    new_dt_boxes = []
+    for text_box in dt_boxes:
+        text_bbox = points_to_bbox(text_box)
+        masks_list = []
+        for mf_box in mfd_res:
+            mf_bbox = mf_box['bbox']
+            if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):
+                masks_list.append([mf_bbox[0], mf_bbox[2]])
+        text_x_range = [text_bbox[0], text_bbox[2]]
+        text_remove_mask_range = remove_intervals(text_x_range, masks_list)
+        temp_dt_box = []
+        for text_remove_mask in text_remove_mask_range:
+            temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]]))
+        if len(temp_dt_box) > 0:
+            new_dt_boxes.extend(temp_dt_box)
     return new_dt_boxes
 
+
 class ModifiedPaddleOCR(PaddleOCR):
     def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, mfd_res=None, alpha_color=(255, 255, 255)):
         """
@@ -197,7 +241,7 @@ class ModifiedPaddleOCR(PaddleOCR):
             if not rec:
                 return cls_res
             return ocr_res
-        
+
     def __call__(self, img, cls=True, mfd_res=None):
         time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
 
@@ -226,7 +270,7 @@ class ModifiedPaddleOCR(PaddleOCR):
             dt_boxes = update_det_boxes(dt_boxes, mfd_res)
             aft = time.time()
             logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
-                len(dt_boxes), aft-bef))
+                len(dt_boxes), aft - bef))
 
         for bno in range(len(dt_boxes)):
             tmp_box = copy.deepcopy(dt_boxes[bno])
@@ -257,4 +301,60 @@ class ModifiedPaddleOCR(PaddleOCR):
                 filter_rec_res.append(rec_result)
         end = time.time()
         time_dict['all'] = end - start
-        return filter_boxes, filter_rec_res, time_dict
+        return filter_boxes, filter_rec_res, time_dict
+
+
+if __name__ == '__main__':
+    def merge_intervals(intervals):
+        # Sort the intervals based on the start value
+        intervals.sort(key=lambda x: x[0])
+
+        merged = []
+        for interval in intervals:
+            # If the list of merged intervals is empty or if the current
+            # interval does not overlap with the previous, simply append it.
+            if not merged or merged[-1][1] < interval[0]:
+                merged.append(interval)
+            else:
+                # Otherwise, there is overlap, so we merge the current and previous intervals.
+                merged[-1][1] = max(merged[-1][1], interval[1])
+
+        return merged
+
+
+    def remove_intervals(original, masks):
+        # Merge all mask intervals
+        merged_masks = merge_intervals(masks)
+
+        result = []
+        original_start, original_end = original
+
+        for mask in merged_masks:
+            mask_start, mask_end = mask
+
+            # If the mask starts after the original range, ignore it
+            if mask_start > original_end:
+                continue
+
+            # If the mask ends before the original range starts, ignore it
+            if mask_end < original_start:
+                continue
+
+            # Remove the masked part from the original range
+            if original_start < mask_start:
+                result.append([original_start, mask_start - 1])
+
+            original_start = max(mask_end + 1, original_start)
+
+        # Add the remaining part of the original range, if any
+        if original_start <= original_end:
+            result.append([original_start, original_end])
+
+        return result
+
+
+    # Test the function
+    original_range = [1, 100]
+    masks = [[0, 15], [25, 40], [55, 80]]
+    result = remove_intervals(original_range, masks)
+    print(result)  # Expected output: [[1, 4], [21, 59], [81, 100]]