فهرست منبع

Fix markdown doc_title&layout order&cross block and page space;update v2 doc (#3392)

* delete layout_parsing_v2 color function

* update document for layout_parsing_v2 & layout_parsing_v2 order

* fix layout order

* update doc
shuai.liu 9 ماه پیش
والد
کامیت
b08a00e972

+ 16 - 5
docs/pipeline_usage/tutorials/ocr_pipelines/layout_parsing_v2.md

@@ -93,7 +93,6 @@ comments: true
 </tr>
 </tbody>
 </table>
-
 <p><b>表格结构识别模块(可选):</b></p>
 <table>
 <tr>
@@ -598,15 +597,18 @@ output = pipeline.predict(
     use_doc_unwarping=False,
     use_textline_orientation=False)
 
-markdown_texts = ""
+markdown_list = []
 markdown_images = []
 
 for res in output:
-    markdown_texts += res.markdown["markdown_texts"]
-    markdown_images.append(res.markdown["markdown_images"])
+    markdown_list.append(res.markdown)
+    markdown_images.append(res.get("markdown_images", {}))
+
+markdown_texts = pipeline.concatenate_markdown_pages(markdown_list)
 
 mkd_file_path = output_path / f"{Path(input_file).stem}.md"
 mkd_file_path.parent.mkdir(parents=True, exist_ok=True)
+
 with open(mkd_file_path, "w", encoding="utf-8") as f:
     f.write(markdown_texts)
 
@@ -1052,6 +1054,14 @@ for item in markdown_images:
 <td>保存的文件路径,支持目录或文件路径</td>
 <td>无</td>
 </tr>
+<tr>
+<td><code>concatenate_markdown_pages()</code></td>
+<td>将多页Markdown内容拼接为单一文档</td>
+<td><code>markdown_list</code></td>
+<td><code>list</code></td>
+<td>包含每一页Markdown数据的列表</td>
+<td>返回处理后的Markdown文本和图像列表</td>
+</tr>
 </table>
 
 - 调用`print()` 方法会将结果打印到终端,打印到终端的内容解释如下:
@@ -1140,6 +1150,7 @@ for item in markdown_images:
 - 调用`save_to_json()` 方法会将上述内容保存到指定的 `save_path` 中,如果指定为目录,则保存的路径为`save_path/{your_img_basename}_res.json`,如果指定为文件,则直接保存到该文件中。由于 json 文件不支持保存numpy数组,因此会将其中的 `numpy.array` 类型转换为列表形式。
 - 调用`save_to_img()` 方法会将可视化结果保存到指定的 `save_path` 中,如果指定为目录,则会将版面区域检测可视化图像、全局OCR可视化图像、版面阅读顺序可视化图像等内容保存,如果指定为文件,则直接保存到该文件中。(产线通常包含较多结果图片,不建议直接指定为具体的文件路径,否则多张图会被覆盖,仅保留最后一张图)
 - 调用`save_to_markdown()` 方法会将转化后的 Markdown 文件保存到指定的 `save_path` 中,保存的文件路径为`save_path/{your_img_basename}.md`,如果输入是 PDF 文件,建议直接指定目录,否责多个 markdown 文件会被覆盖。
+- 调用 `concatenate_markdown_pages()` 方法将 `layout parsing pipeline` 输出的多页Markdown内容`markdown_list`合并为单个完整文档,并返回合并后的Markdown内容。
 
 此外,也支持通过属性获取带结果的可视化图像和预测结果,具体如下:
 <table>
@@ -1173,7 +1184,7 @@ for item in markdown_images:
 
 - `json` 属性获取的预测结果为字典类型的数据,相关内容与调用 `save_to_json()` 方法保存的内容一致。
 - `img` 属性返回的预测结果是一个字典类型的数据。其中,键分别为 `layout_det_res`、`overall_ocr_res`、`text_paragraphs_ocr_res`、`formula_res_region1`、`table_cell_img` 和 `seal_res_region1`,对应的值是 `Image.Image` 对象:分别用于显示版面区域检测、OCR、OCR文本段落、公式、表格和印章结果的可视化图像。如果没有使用可选模块,则字典中只包含 `layout_det_res`。
-- `markdown` 属性返回的预测结果是一个字典类型的数据。其中,键分别为 `markdown_texts` 和 `markdown_images`,对应的值分别是 markdown 文本和用于在 Markdown 中显示的图像(`Image.Image` 对象)。
+- `markdown` 属性返回的预测结果是一个字典类型的数据。其中,键分别为 `markdown_texts` 、 `markdown_images`和`page_continuation_flags`,对应的值分别是 markdown 文本,在 Markdown 中显示的图像(`Image.Image` 对象)和用于标识当前页面第一个元素是否为段开始以及最后一个元素是否为段结束的bool元组
 
 此外,您可以获取版面解析产线配置文件,并加载配置文件进行预测。可执行如下命令将结果保存在 `my_path` 中:
 ```

+ 57 - 1
paddlex/inference/pipelines/layout_parsing/pipeline_v2.py

@@ -13,8 +13,9 @@
 # limitations under the License.
 from __future__ import annotations
 
-from typing import Optional, Union, Tuple
+from typing import Optional, Union, Tuple, Iterator
 import numpy as np
+import re
 
 from ....utils import logging
 from ...common.batch_sampler import ImageBatchSampler
@@ -605,3 +606,58 @@ class LayoutParsingPipelineV2(BasePipeline):
                 "model_settings": model_settings,
             }
             yield LayoutParsingResultV2(single_img_res)
+
+    def concatenate_markdown_pages(self, markdown_list: list) -> tuple:
+        """
+        Concatenate Markdown content from multiple pages into a single document.
+
+        Args:
+            markdown_list (list): A list containing Markdown data for each page.
+
+        Returns:
+            tuple: A tuple containing the processed Markdown text.
+        """
+        markdown_texts = ""
+        previous_page_last_element_paragraph_end_flag = True
+
+        for res in markdown_list:
+            # Get the paragraph flags for the current page
+            page_first_element_paragraph_start_flag: bool = res[
+                "page_continuation_flags"
+            ][0]
+            page_last_element_paragraph_end_flag: bool = res["page_continuation_flags"][
+                1
+            ]
+
+            # Determine whether to add a space or a newline
+            if (
+                not page_first_element_paragraph_start_flag
+                and not previous_page_last_element_paragraph_end_flag
+            ):
+                last_char_of_markdown = markdown_texts[-1] if markdown_texts else ""
+                first_char_of_handler = (
+                    res["markdown_texts"][0] if res["markdown_texts"] else ""
+                )
+
+                # Check if the last character and the first character are Chinese characters
+                last_is_chinese_char = (
+                    re.match(r"[\u4e00-\u9fff]", last_char_of_markdown)
+                    if last_char_of_markdown
+                    else False
+                )
+                first_is_chinese_char = (
+                    re.match(r"[\u4e00-\u9fff]", first_char_of_handler)
+                    if first_char_of_handler
+                    else False
+                )
+                if not (last_is_chinese_char or first_is_chinese_char):
+                    markdown_texts += " " + res["markdown_texts"]
+                else:
+                    markdown_texts += res["markdown_texts"]
+            else:
+                markdown_texts += "\n\n" + res["markdown_texts"]
+            previous_page_last_element_paragraph_end_flag = (
+                page_last_element_paragraph_end_flag
+            )
+
+        return markdown_texts

+ 46 - 5
paddlex/inference/pipelines/layout_parsing/result_v2.py

@@ -252,7 +252,7 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
                     if "." in content_value
                     else 1
                 )
-                return f"{'#' * level} {content_value}".replace("-\n", "").replace(
+                return f"#{'#' * level} {content_value}".replace("-\n", "").replace(
                     "\n",
                     " ",
                 )
@@ -319,28 +319,69 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
             last_label = None
             seg_start_flag = None
             seg_end_flag = None
+            page_first_element_seg_start_flag = None
+            page_last_element_seg_end_flag = None
+            parsing_res_list = sorted(
+                parsing_res_list,
+                key=lambda x: x.get("sub_index", 999),
+            )
             for block in sorted(
                 parsing_res_list,
                 key=lambda x: x.get("sub_index", 999),
             ):
                 label = block.get("block_label")
                 seg_start_flag = block.get("seg_start_flag")
+                page_first_element_seg_start_flag = (
+                    seg_start_flag
+                    if (page_first_element_seg_start_flag is None)
+                    else page_first_element_seg_start_flag
+                )
                 handler = handlers.get(label)
                 if handler:
                     if (
                         label == last_label == "text"
                         and seg_start_flag == seg_end_flag == False
                     ):
-                        markdown_content += " " + handler()
+                        last_char_of_markdown = (
+                            markdown_content[-1] if markdown_content else ""
+                        )
+                        first_char_of_handler = handler()[0] if handler() else ""
+                        last_is_chinese_char = (
+                            re.match(r"[\u4e00-\u9fff]", last_char_of_markdown)
+                            if last_char_of_markdown
+                            else False
+                        )
+                        first_is_chinese_char = (
+                            re.match(r"[\u4e00-\u9fff]", first_char_of_handler)
+                            if first_char_of_handler
+                            else False
+                        )
+                        if not (last_is_chinese_char or first_is_chinese_char):
+                            markdown_content += " " + handler()
+                        else:
+                            markdown_content += handler()
                     else:
-                        markdown_content += "\n\n" + handler()
+                        markdown_content += (
+                            "\n\n" + handler() if markdown_content else handler()
+                        )
                     last_label = label
                     seg_end_flag = block.get("seg_end_flag")
+            page_last_element_seg_end_flag = seg_end_flag
 
-            return markdown_content
+            return markdown_content, (
+                page_first_element_seg_start_flag,
+                page_last_element_seg_end_flag,
+            )
 
         markdown_info = dict()
-        markdown_info["markdown_texts"] = _format_data(self)
+        markdown_info["markdown_texts"], (
+            page_first_element_seg_start_flag,
+            page_last_element_seg_end_flag,
+        ) = _format_data(self)
+        markdown_info["page_continuation_flags"] = (
+            page_first_element_seg_start_flag,
+            page_last_element_seg_end_flag,
+        )
         markdown_info["markdown_images"] = dict()
         for block in self["parsing_res_list"]:
             if block["block_label"] in ["image", "chart"]:

+ 127 - 29
paddlex/inference/pipelines/layout_parsing/utils.py

@@ -760,6 +760,7 @@ def sort_by_xycut(
     block_bboxes: Union[np.ndarray, List[List[int]]],
     direction: int = 0,
     min_gap: int = 1,
+    pre_cuts: Optional[Dict[str, List[int]]] = None,
 ) -> List[int]:
     """
     Sort bounding boxes using recursive XY cut method based on the specified direction.
@@ -771,27 +772,55 @@ def sort_by_xycut(
         direction (int): Direction for the initial cut. Use 1 for Y-axis first and 0 for X-axis first.
                          Defaults to 0.
         min_gap (int): Minimum gap width to consider a separation between segments. Defaults to 1.
+        pre_cuts (Optional[Dict[str, List[int]]]): A dictionary specifying pre-cut points along the axes.
+                                                  The keys are 'x' or 'y', representing the axis to pre-cut,
+                                                  and the values are lists of integers specifying the cut points.
+                                                  For example, {'y': [100, 200]} will pre-cut the y-axis at
+                                                  positions 100 and 200 before applying the main XY cut algorithm.
+                                                  Defaults to None.
 
     Returns:
         List[int]: A list of indices representing the order of sorted bounding boxes.
     """
     block_bboxes = np.asarray(block_bboxes).astype(int)
     res = []
-
-    if direction == 1:
-        _recursive_yx_cut(
-            block_bboxes,
-            np.arange(len(block_bboxes)),
-            res,
-            min_gap,
-        )
+    axis = "x" if direction == 1 else "y"
+    if len(pre_cuts[axis]) > 0:
+        cuts = sorted(pre_cuts[axis])
+        axis_index = 1 if axis == "y" else 0
+        max_val = block_bboxes[:, 3].max() if axis == "y" else block_bboxes[:, 2].max()
+        intervals = []
+        prev = 0
+        for cut in cuts:
+            intervals.append((prev, cut))
+            prev = cut
+        intervals.append((prev, max_val))
+        for start, end in intervals:
+            mask = (block_bboxes[:, axis_index] >= start) & (
+                block_bboxes[:, axis_index] < end
+            )
+            sub_boxes = block_bboxes[mask]
+            sub_indices = np.arange(len(block_bboxes))[mask].tolist()
+            if len(sub_boxes) > 0:
+                if direction == 1:
+                    _recursive_yx_cut(sub_boxes, sub_indices, res, min_gap)
+                else:
+                    _recursive_xy_cut(sub_boxes, sub_indices, res, min_gap)
     else:
-        _recursive_xy_cut(
-            block_bboxes,
-            np.arange(len(block_bboxes)),
-            res,
-            min_gap,
-        )
+        if direction == 1:
+            _recursive_yx_cut(
+                block_bboxes,
+                np.arange(len(block_bboxes)).tolist(),
+                res,
+                min_gap,
+            )
+        else:
+            _recursive_xy_cut(
+                block_bboxes,
+                np.arange(len(block_bboxes)).tolist(),
+                res,
+                min_gap,
+            )
 
     return res
 
@@ -1090,9 +1119,9 @@ def _get_projection_iou(
 
 def _get_sub_category(
     blocks: List[Dict[str, Any]], title_labels: List[str]
-) -> List[Dict[str, Any]]:
+) -> Tuple[List[Dict[str, Any]], List[float]]:
     """
-    Determine the layout of title and text blocks.
+    Determine the layout of title and text blocks and collect pre_cuts.
 
     Args:
         blocks (List[Dict[str, Any]]): List of block dictionaries.
@@ -1100,10 +1129,29 @@ def _get_sub_category(
 
     Returns:
         List[Dict[str, Any]]: Updated list of blocks with title-text layout information.
+        List[float]: List of pre_cuts coordinates.
     """
 
     sub_title_labels = ["paragraph_title"]
     vision_labels = ["image", "table", "chart", "figure"]
+    vision_title_labels = ["figure_title", "chart_title", "table_title"]
+    all_labels = title_labels + sub_title_labels + vision_labels + vision_title_labels
+
+    relevant_blocks = [block for block in blocks if block["block_label"] in all_labels]
+
+    region_bbox = None
+    if relevant_blocks:
+        min_x = min(block["block_bbox"][0] for block in relevant_blocks)
+        min_y = min(block["block_bbox"][1] for block in relevant_blocks)
+        max_x = max(block["block_bbox"][2] for block in relevant_blocks)
+        max_y = max(block["block_bbox"][3] for block in relevant_blocks)
+        region_bbox = (min_x, min_y, max_x, max_y)
+        region_x_center = (region_bbox[0] + region_bbox[2]) / 2
+        region_y_center = (region_bbox[1] + region_bbox[3]) / 2
+        region_width = region_bbox[2] - region_bbox[0]
+        region_height = region_bbox[3] - region_bbox[1]
+
+    pre_cuts = []
 
     for i, block1 in enumerate(blocks):
         block1.setdefault("title_text", [])
@@ -1111,11 +1159,7 @@ def _get_sub_category(
         block1.setdefault("vision_footnote", [])
         block1.setdefault("sub_label", block1["block_label"])
 
-        if (
-            block1["block_label"] not in title_labels
-            and block1["block_label"] not in sub_title_labels
-            and block1["block_label"] not in vision_labels
-        ):
+        if block1["block_label"] not in all_labels:
             continue
 
         bbox1 = block1["block_bbox"]
@@ -1128,6 +1172,64 @@ def _get_sub_category(
         right_down_title_text_index = -1
         right_down_title_text_direction = None
 
+        # pre-cuts
+        # Condition 1: Length is greater than half of the layout region
+        if is_horizontal_1:
+            block_length = x2 - x1
+            required_length = region_width / 2
+        else:
+            block_length = y2 - y1
+            required_length = region_height / 2
+        length_condition = block_length > required_length
+
+        # Condition 2: Centered check (must be within ±20 in both horizontal and vertical directions)
+        block_x_center = (x1 + x2) / 2
+        block_y_center = (y1 + y2) / 2
+        tolerance_len = block_length // 5
+        is_centered = (
+            abs(block_x_center - region_x_center) <= tolerance_len
+            and abs(block_y_center - region_y_center) <= tolerance_len
+        )
+
+        # Condition 3: Check for surrounding text
+        has_left_text = False
+        has_right_text = False
+        has_above_text = False
+        has_below_text = False
+        for block2 in blocks:
+            if block2["block_label"] != "text":
+                continue
+            bbox2 = block2["block_bbox"]
+            x1_2, y1_2, x2_2, y2_2 = bbox2
+            if is_horizontal_1:
+                if x2_2 <= x1 and not (y2_2 <= y1 or y1_2 >= y2):
+                    has_left_text = True
+                if x1_2 >= x2 and not (y2_2 <= y1 or y1_2 >= y2):
+                    has_right_text = True
+            else:
+                if y2_2 <= y1 and not (x2_2 <= x1 or x1_2 >= x2):
+                    has_above_text = True
+                if y1_2 >= y2 and not (x2_2 <= x1 or x1_2 >= x2):
+                    has_below_text = True
+
+            if (is_horizontal_1 and has_left_text and has_right_text) or (
+                not is_horizontal_1 and has_above_text and has_below_text
+            ):
+                break
+
+        no_text_on_sides = (
+            not (has_left_text or has_right_text)
+            if is_horizontal_1
+            else not (has_above_text or has_below_text)
+        )
+
+        # Add coordinates if all conditions are met
+        if is_centered and length_condition and no_text_on_sides:
+            if is_horizontal_1:
+                pre_cuts.append(y1)
+            else:
+                pre_cuts.append(x1)
+
         for j, block2 in enumerate(blocks):
             if i == j:
                 continue
@@ -1315,7 +1417,7 @@ def _get_sub_category(
             if blocks[i].get("vision_footnote") == []:
                 blocks[i]["vision_footnote"] = vision_footnote
 
-    return blocks
+    return blocks, pre_cuts
 
 
 def get_layout_ordering(
@@ -1343,7 +1445,7 @@ def get_layout_ordering(
         threshold=0.5,
         smaller=True,
     )
-    parsing_res_list = _get_sub_category(parsing_res_list, title_text_labels)
+    parsing_res_list, pre_cuts = _get_sub_category(parsing_res_list, title_text_labels)
 
     doc_flag = False
     median_width = _get_text_median_width(parsing_res_list)
@@ -1414,18 +1516,14 @@ def get_layout_ordering(
             )
             block_bboxes = np.array(block_bboxes)
             sorted_indices = sort_by_xycut(
-                block_bboxes,
-                direction=1,
-                min_gap=1,
+                block_bboxes, direction=1, min_gap=1, pre_cuts={"x": pre_cuts}
             )
         else:
             block_bboxes = [block["block_bbox"] for block in parsing_res_list]
             block_bboxes.sort(key=lambda x: (x[0] // 20, x[1]))
             block_bboxes = np.array(block_bboxes)
             sorted_indices = sort_by_xycut(
-                block_bboxes,
-                direction=0,
-                min_gap=20,
+                block_bboxes, direction=0, min_gap=20, pre_cuts={"y": pre_cuts}
             )
 
         sorted_boxes = block_bboxes[sorted_indices].tolist()