Эх сурвалжийг харах

fix(self_modify): merge detection boxes for optimized text region detection (#448)

Merge adjacent and overlapping detection boxes to optimize text region detection in
the document. Post processing of text boxes is enhanced by consolidating them into
larger text lines, taking into account their vertical and horizontal alignment. This
improvement reduces fragmentation and improves the readability of detected text blocks.
Xiaomeng Zhao 1 жил өмнө
parent
commit
3da5c41115

+ 84 - 0
magic_pdf/model/pek_sub_modules/self_modify.py

@@ -12,6 +12,7 @@ from paddleocr.ppocr.utils.utility import check_and_read, alpha_to_color, binari
 from paddleocr.tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image, get_minarea_rect_crop
 
 from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
+from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
 
 logger = get_logger()
 
@@ -162,6 +163,86 @@ def update_det_boxes(dt_boxes, mfd_res):
     return new_dt_boxes
 
 
+def merge_overlapping_spans(spans):
+    """
+    Merges overlapping spans on the same line.
+
+    :param spans: A list of span coordinates [(x1, y1, x2, y2), ...]
+    :return: A list of merged spans
+    """
+    # Return an empty list if the input spans list is empty
+    if not spans:
+        return []
+
+    # Sort spans by their starting x-coordinate
+    spans.sort(key=lambda x: x[0])
+
+    # Initialize the list of merged spans
+    merged = []
+    for span in spans:
+        # Unpack span coordinates
+        x1, y1, x2, y2 = span
+        # If the merged list is empty or there's no horizontal overlap, add the span directly
+        if not merged or merged[-1][2] < x1:
+            merged.append(span)
+        else:
+            # If there is horizontal overlap, merge the current span with the previous one
+            last_span = merged.pop()
+            # Update the merged span's top-left corner to the smaller (x1, y1) and bottom-right to the larger (x2, y2)
+            x1 = min(last_span[0], x1)
+            y1 = min(last_span[1], y1)
+            x2 = max(last_span[2], x2)
+            y2 = max(last_span[3], y2)
+            # Add the merged span back to the list
+            merged.append((x1, y1, x2, y2))
+
+    # Return the list of merged spans
+    return merged
+
+
+def merge_det_boxes(dt_boxes):
+    """
+    Merge detection boxes.
+
+    This function takes a list of detected bounding boxes, each represented by four corner points.
+    The goal is to merge these bounding boxes into larger text regions.
+
+    Parameters:
+    dt_boxes (list): A list containing multiple text detection boxes, where each box is defined by four corner points.
+
+    Returns:
+    list: A list containing the merged text regions, where each region is represented by four corner points.
+    """
+    # Convert the detection boxes into a dictionary format with bounding boxes and type
+    dt_boxes_dict_list = []
+    for text_box in dt_boxes:
+        text_bbox = points_to_bbox(text_box)
+        text_box_dict = {
+            'bbox': text_bbox,
+            'type': 'text',
+        }
+        dt_boxes_dict_list.append(text_box_dict)
+
+    # Merge adjacent text regions into lines
+    lines = merge_spans_to_line(dt_boxes_dict_list)
+
+    # Initialize a new list for storing the merged text regions
+    new_dt_boxes = []
+    for line in lines:
+        line_bbox_list = []
+        for span in line:
+            line_bbox_list.append(span['bbox'])
+
+        # Merge overlapping text regions within the same line
+        merged_spans = merge_overlapping_spans(line_bbox_list)
+
+        # Convert the merged text regions back to point format and add them to the new detection box list
+        for span in merged_spans:
+            new_dt_boxes.append(bbox_to_points(span))
+
+    return new_dt_boxes
+
+
 class ModifiedPaddleOCR(PaddleOCR):
     def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, mfd_res=None, alpha_color=(255, 255, 255)):
         """
@@ -265,6 +346,9 @@ class ModifiedPaddleOCR(PaddleOCR):
         img_crop_list = []
 
         dt_boxes = sorted_boxes(dt_boxes)
+
+        dt_boxes = merge_det_boxes(dt_boxes)
+
         if mfd_res:
             bef = time.time()
             dt_boxes = update_det_boxes(dt_boxes, mfd_res)