Forráskód Böngészése

feat(model): add xycut algorithm for block sorting

- Implement xycut algorithm to sort blocks when layoutreader fails
- Add recursive_xy_cut function to perform the xycut algorithm- Update pdf_parse_union_core_v2.py to use xycut when layoutreader fails
- Modify draw_bbox.py to handle cases where layoutreader fails to sort blocks
myhloli 1 éve
szülő
commit
7d5850e3ce

+ 10 - 4
magic_pdf/libs/draw_bbox.py

@@ -369,10 +369,16 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
             if block['type'] in [BlockType.Image, BlockType.Table]:
                 for sub_block in block['blocks']:
                     if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
-                        for line in sub_block['virtual_lines']:
-                            bbox = line['bbox']
-                            index = line['index']
-                            page_line_list.append({'index': index, 'bbox': bbox})
+                        if len(sub_block['virtual_lines']) > 0 and sub_block['virtual_lines'][0].get('index', None) is not None:
+                            for line in sub_block['virtual_lines']:
+                                bbox = line['bbox']
+                                index = line['index']
+                                page_line_list.append({'index': index, 'bbox': bbox})
+                        else:
+                            for line in sub_block['lines']:
+                                bbox = line['bbox']
+                                index = line['index']
+                                page_line_list.append({'index': index, 'bbox': bbox})
                     elif sub_block['type'] in [BlockType.ImageCaption, BlockType.TableCaption, BlockType.ImageFootnote, BlockType.TableFootnote]:
                         for line in sub_block['lines']:
                             bbox = line['bbox']

+ 242 - 0
magic_pdf/model/v3/xycut.py

@@ -0,0 +1,242 @@
+from typing import List
+import cv2
+import numpy as np
+
+
+def projection_by_bboxes(boxes: np.array, axis: int) -> np.ndarray:
+    """
+     通过一组 bbox 获得投影直方图,最后以 per-pixel 形式输出
+
+    Args:
+        boxes: [N, 4]
+        axis: 0-x坐标向水平方向投影, 1-y坐标向垂直方向投影
+
+    Returns:
+        1D 投影直方图,长度为投影方向坐标的最大值(我们不需要图片的实际边长,因为只是要找文本框的间隔)
+
+    """
+    assert axis in [0, 1]
+    length = np.max(boxes[:, axis::2])
+    res = np.zeros(length, dtype=int)
+    # TODO: how to remove for loop?
+    for start, end in boxes[:, axis::2]:
+        res[start:end] += 1
+    return res
+
+
+# from: https://dothinking.github.io/2021-06-19-%E9%80%92%E5%BD%92%E6%8A%95%E5%BD%B1%E5%88%86%E5%89%B2%E7%AE%97%E6%B3%95/#:~:text=%E9%80%92%E5%BD%92%E6%8A%95%E5%BD%B1%E5%88%86%E5%89%B2%EF%BC%88Recursive%20XY,%EF%BC%8C%E5%8F%AF%E4%BB%A5%E5%88%92%E5%88%86%E6%AE%B5%E8%90%BD%E3%80%81%E8%A1%8C%E3%80%82
+def split_projection_profile(arr_values: np.array, min_value: float, min_gap: float):
+    """Split projection profile:
+
+    ```
+                              ┌──┐
+         arr_values           │  │       ┌─┐───
+             ┌──┐             │  │       │ │ |
+             │  │             │  │ ┌───┐ │ │min_value
+             │  │<- min_gap ->│  │ │   │ │ │ |
+         ────┴──┴─────────────┴──┴─┴───┴─┴─┴─┴───
+         0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
+    ```
+
+    Args:
+        arr_values (np.array): 1-d array representing the projection profile.
+        min_value (float): Ignore the profile if `arr_value` is less than `min_value`.
+        min_gap (float): Ignore the gap if less than this value.
+
+    Returns:
+        tuple: Start indexes and end indexes of split groups.
+    """
+    # all indexes with projection height exceeding the threshold
+    arr_index = np.where(arr_values > min_value)[0]
+    if not len(arr_index):
+        return
+
+    # find zero intervals between adjacent projections
+    # |  |                    ||
+    # ||||<- zero-interval -> |||||
+    arr_diff = arr_index[1:] - arr_index[0:-1]
+    arr_diff_index = np.where(arr_diff > min_gap)[0]
+    arr_zero_intvl_start = arr_index[arr_diff_index]
+    arr_zero_intvl_end = arr_index[arr_diff_index + 1]
+
+    # convert to index of projection range:
+    # the start index of zero interval is the end index of projection
+    arr_start = np.insert(arr_zero_intvl_end, 0, arr_index[0])
+    arr_end = np.append(arr_zero_intvl_start, arr_index[-1])
+    arr_end += 1  # end index will be excluded as index slice
+
+    return arr_start, arr_end
+
+
+def recursive_xy_cut(boxes: np.ndarray, indices: List[int], res: List[int]):
+    """
+
+    Args:
+        boxes: (N, 4)
+        indices: 递归过程中始终表示 box 在原始数据中的索引
+        res: 保存输出结果
+
+    """
+    # 向 y 轴投影
+    assert len(boxes) == len(indices)
+
+    _indices = boxes[:, 1].argsort()
+    y_sorted_boxes = boxes[_indices]
+    y_sorted_indices = indices[_indices]
+
+    # debug_vis(y_sorted_boxes, y_sorted_indices)
+
+    y_projection = projection_by_bboxes(boxes=y_sorted_boxes, axis=1)
+    pos_y = split_projection_profile(y_projection, 0, 1)
+    if not pos_y:
+        return
+
+    arr_y0, arr_y1 = pos_y
+    for r0, r1 in zip(arr_y0, arr_y1):
+        # [r0, r1] 表示按照水平切分,有 bbox 的区域,对这些区域会再进行垂直切分
+        _indices = (r0 <= y_sorted_boxes[:, 1]) & (y_sorted_boxes[:, 1] < r1)
+
+        y_sorted_boxes_chunk = y_sorted_boxes[_indices]
+        y_sorted_indices_chunk = y_sorted_indices[_indices]
+
+        _indices = y_sorted_boxes_chunk[:, 0].argsort()
+        x_sorted_boxes_chunk = y_sorted_boxes_chunk[_indices]
+        x_sorted_indices_chunk = y_sorted_indices_chunk[_indices]
+
+        # 往 x 方向投影
+        x_projection = projection_by_bboxes(boxes=x_sorted_boxes_chunk, axis=0)
+        pos_x = split_projection_profile(x_projection, 0, 1)
+        if not pos_x:
+            continue
+
+        arr_x0, arr_x1 = pos_x
+        if len(arr_x0) == 1:
+            # x 方向无法切分
+            res.extend(x_sorted_indices_chunk)
+            continue
+
+        # x 方向上能分开,继续递归调用
+        for c0, c1 in zip(arr_x0, arr_x1):
+            _indices = (c0 <= x_sorted_boxes_chunk[:, 0]) & (
+                x_sorted_boxes_chunk[:, 0] < c1
+            )
+            recursive_xy_cut(
+                x_sorted_boxes_chunk[_indices], x_sorted_indices_chunk[_indices], res
+            )
+
+
+def points_to_bbox(points):
+    assert len(points) == 8
+
+    # [x1,y1,x2,y2,x3,y3,x4,y4]
+    left = min(points[::2])
+    right = max(points[::2])
+    top = min(points[1::2])
+    bottom = max(points[1::2])
+
+    left = max(left, 0)
+    top = max(top, 0)
+    right = max(right, 0)
+    bottom = max(bottom, 0)
+    return [left, top, right, bottom]
+
+
+def bbox2points(bbox):
+    left, top, right, bottom = bbox
+    return [left, top, right, top, right, bottom, left, bottom]
+
+
+def vis_polygon(img, points, thickness=2, color=None):
+    br2bl_color = color
+    tl2tr_color = color
+    tr2br_color = color
+    bl2tl_color = color
+    cv2.line(
+        img,
+        (points[0][0], points[0][1]),
+        (points[1][0], points[1][1]),
+        color=tl2tr_color,
+        thickness=thickness,
+    )
+
+    cv2.line(
+        img,
+        (points[1][0], points[1][1]),
+        (points[2][0], points[2][1]),
+        color=tr2br_color,
+        thickness=thickness,
+    )
+
+    cv2.line(
+        img,
+        (points[2][0], points[2][1]),
+        (points[3][0], points[3][1]),
+        color=br2bl_color,
+        thickness=thickness,
+    )
+
+    cv2.line(
+        img,
+        (points[3][0], points[3][1]),
+        (points[0][0], points[0][1]),
+        color=bl2tl_color,
+        thickness=thickness,
+    )
+    return img
+
+
+def vis_points(
+    img: np.ndarray, points, texts: List[str] = None, color=(0, 200, 0)
+) -> np.ndarray:
+    """
+
+    Args:
+        img:
+        points: [N, 8]  8: x1,y1,x2,y2,x3,y3,x3,y4
+        texts:
+        color:
+
+    Returns:
+
+    """
+    points = np.array(points)
+    if texts is not None:
+        assert len(texts) == points.shape[0]
+
+    for i, _points in enumerate(points):
+        vis_polygon(img, _points.reshape(-1, 2), thickness=2, color=color)
+        bbox = points_to_bbox(_points)
+        left, top, right, bottom = bbox
+        cx = (left + right) // 2
+        cy = (top + bottom) // 2
+
+        txt = texts[i]
+        font = cv2.FONT_HERSHEY_SIMPLEX
+        cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0]
+
+        img = cv2.rectangle(
+            img,
+            (cx - 5 * len(txt), cy - cat_size[1] - 5),
+            (cx - 5 * len(txt) + cat_size[0], cy - 5),
+            color,
+            -1,
+        )
+
+        img = cv2.putText(
+            img,
+            txt,
+            (cx - 5 * len(txt), cy - 5),
+            font,
+            0.5,
+            (255, 255, 255),
+            thickness=1,
+            lineType=cv2.LINE_AA,
+        )
+
+    return img
+
+
+def vis_polygons_with_index(image, points):
+    texts = [str(i) for i in range(len(points))]
+    res_img = vis_points(image.copy(), points, texts)
+    return res_img

+ 54 - 17
magic_pdf/pdf_parse_union_core_v2.py

@@ -30,8 +30,8 @@ from magic_pdf.pre_proc.equations_replace import (
 from magic_pdf.pre_proc.ocr_detect_all_bboxes import \
     ocr_prepare_bboxes_for_layout_split_v2
 from magic_pdf.pre_proc.ocr_dict_merge import (fill_spans_in_blocks,
-                                               fix_block_spans,
-                                               fix_discarded_block, fix_block_spans_v2)
+                                               fix_discarded_block,
+                                               fix_block_spans_v2)
 from magic_pdf.pre_proc.ocr_span_list_modify import (
     get_qa_need_list_v2, remove_overlaps_low_confidence_spans,
     remove_overlaps_min_spans)
@@ -174,23 +174,57 @@ def do_predict(boxes: List[List[int]], model) -> List[int]:
 
 
 def cal_block_index(fix_blocks, sorted_bboxes):
-    for block in fix_blocks:
 
-        line_index_list = []
-        if len(block['lines']) == 0:
-            block['index'] = sorted_bboxes.index(block['bbox'])
-        else:
+    if sorted_bboxes is not None:
+        # 使用layoutreader排序
+        for block in fix_blocks:
+            line_index_list = []
+            if len(block['lines']) == 0:
+                block['index'] = sorted_bboxes.index(block['bbox'])
+            else:
+                for line in block['lines']:
+                    line['index'] = sorted_bboxes.index(line['bbox'])
+                    line_index_list.append(line['index'])
+                median_value = statistics.median(line_index_list)
+                block['index'] = median_value
+
+            # 删除图表body block中的虚拟line信息, 并用real_lines信息回填
+            if block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
+                block['virtual_lines'] = copy.deepcopy(block['lines'])
+                block['lines'] = copy.deepcopy(block['real_lines'])
+                del block['real_lines']
+    else:
+        # 使用xycut排序
+        block_bboxes = []
+        for block in fix_blocks:
+            block_bboxes.append(block['bbox'])
+
+            # 删除图表body block中的虚拟line信息, 并用real_lines信息回填
+            if block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
+                block['virtual_lines'] = copy.deepcopy(block['lines'])
+                block['lines'] = copy.deepcopy(block['real_lines'])
+                del block['real_lines']
+
+        import numpy as np
+        from magic_pdf.model.v3.xycut import recursive_xy_cut
+
+        random_boxes = np.array(block_bboxes)
+        np.random.shuffle(random_boxes)
+        res = []
+        recursive_xy_cut(np.asarray(random_boxes).astype(int), np.arange(len(block_bboxes)), res)
+        assert len(res) == len(block_bboxes)
+        sorted_boxes = random_boxes[np.array(res)].tolist()
+
+        for i, block in enumerate(fix_blocks):
+            block['index'] = sorted_boxes.index(block['bbox'])
+
+        # 生成line index
+        sorted_blocks = sorted(fix_blocks, key=lambda b: b['index'])
+        line_inedx = 1
+        for block in sorted_blocks:
             for line in block['lines']:
-                line['index'] = sorted_bboxes.index(line['bbox'])
-                line_index_list.append(line['index'])
-            median_value = statistics.median(line_index_list)
-            block['index'] = median_value
-
-        # 删除图表body block中的虚拟line信息, 并用real_lines信息回填
-        if block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
-            block['virtual_lines'] = copy.deepcopy(block['lines'])
-            block['lines'] = copy.deepcopy(block['real_lines'])
-            del block['real_lines']
+                line['index'] = line_inedx
+                line_inedx += 1
 
     return fix_blocks
 
@@ -264,6 +298,9 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
                 block['lines'].append({'bbox': line, 'spans': []})
             page_line_list.extend(lines)
 
+    if len(page_line_list) > 512:  # layoutreader最高支持512line
+        return None
+
     # 使用layoutreader排序
     x_scale = 1000.0 / page_w
     y_scale = 1000.0 / page_h