Jelajahi Sumber

refactor: format result to markdown

gaotingquan 6 bulan lalu
induk
melakukan
cc8d7b6fda
1 mengubah file dengan 247 tambahan dan 244 penghapusan
  1. 247 244
      paddlex/inference/pipelines/layout_parsing/result_v2.py

+ 247 - 244
paddlex/inference/pipelines/layout_parsing/result_v2.py

@@ -16,6 +16,7 @@ from __future__ import annotations
 import copy
 import math
 import re
+from functools import partial
 from pathlib import Path
 from typing import List
 
@@ -33,6 +34,168 @@ from ...common.result import (
 from .setting import BLOCK_LABEL_MAP
 
 
+def compile_title_pattern():
+    # Precompiled regex pattern for matching numbering at the beginning of the title
+    numbering_pattern = (
+        r"(?:" + r"[1-9][0-9]*(?:\.[1-9][0-9]*)*[\.、]?|" + r"[\(\(](?:[1-9][0-9]*|["
+        r"一二三四五六七八九十百千万亿零壹贰叁肆伍陆柒捌玖拾]+)[\)\)]|" + r"["
+        r"一二三四五六七八九十百千万亿零壹贰叁肆伍陆柒捌玖拾]+"
+        r"[、\.]?|" + r"(?:I|II|III|IV|V|VI|VII|VIII|IX|X)\.?" + r")"
+    )
+    return re.compile(r"^\s*(" + numbering_pattern + r")(\s*)(.*)$")
+
+
+TITLE_RE_PATTERN = compile_title_pattern()
+
+
+def format_title_func(block):
+    """
+    Normalize chapter title.
+    Add the '#' to indicate the level of the title.
+    If numbering exists, ensure there's exactly one space between it and the title content.
+    If numbering does not exist, return the original title unchanged.
+
+    :param title: Original chapter title string.
+    :return: Normalized chapter title string.
+    """
+    title = block.content
+    match = TITLE_RE_PATTERN.match(title)
+    if match:
+        numbering = match.group(1).strip()
+        title_content = match.group(3).lstrip()
+        # Return numbering and title content separated by one space
+        title = numbering + " " + title_content
+
+    title = title.rstrip(".")
+    level = (
+        title.count(
+            ".",
+        )
+        + 1
+        if "." in title
+        else 1
+    )
+    return f"#{'#' * level} {title}".replace("-\n", "").replace(
+        "\n",
+        " ",
+    )
+
+
+def format_text_centered_by_html_func(block):
+    return (
+        f'<div style="text-align: center;">{block.content}</div>'.replace(
+            "-\n",
+            "",
+        ).replace("\n", " ")
+        + "\n"
+    )
+
+
+def format_image_centered_by_html_func(block, original_image_width):
+    img_tags = []
+    image_path = "".join(block.image.keys())
+    image_width = block.image[image_path].width
+    scale = int(image_width / original_image_width * 100)
+    img_tags.append(
+        '<div style="text-align: center;"><img src="{}" alt="Image" width="{}%" /></div>'.format(
+            image_path.replace("-\n", "").replace("\n", " "), scale
+        ),
+    )
+    return "\n".join(img_tags)
+
+
+def format_image_plain_func(block):
+    img_tags = []
+    image_path = "".join(block.image.keys())
+    img_tags.append("![]({})".format(image_path.replace("-\n", "").replace("\n", " ")))
+    return "\n".join(img_tags)
+
+
+def format_chart_func(block):
+    lines_list = block.content.split("\n")
+    column_num = len(lines_list[0].split("|"))
+    lines_list.insert(1, "|".join(["---"] * column_num))
+    lines_list = [f"|{line}|" for line in lines_list]
+    return "\n".join(lines_list)
+
+
+def simplify_table_func(table_code):
+    return "\n" + table_code.replace("<html>", "").replace(
+        "</html>", ""
+    ).replace("<body>", "").replace("</body>", "")
+
+
+def format_first_line_func(block, templates, format_func, spliter):
+    lines = block.content.split(spliter)
+    for idx in range(len(lines)):
+        line = lines[idx]
+        if line.strip() == "":
+            continue
+        if line.lower() in templates:
+            lines[idx] = format_func(line)
+        break
+    return spliter.join(lines)
+
+
+def compose_funcs(block, funcs):
+    res = ""
+    for func in funcs:
+        res += func(block)
+    return res
+
+
+def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock):
+
+    seg_start_flag = True
+    seg_end_flag = True
+
+    block_box = block.bbox
+    context_left_coordinate = block_box[0]
+    context_right_coordinate = block_box[2]
+    seg_start_coordinate = block.seg_start_coordinate
+    seg_end_coordinate = block.seg_end_coordinate
+
+    if prev_block is not None:
+        prev_block_bbox = prev_block.bbox
+        num_of_prev_lines = prev_block.num_of_lines
+        pre_block_seg_end_coordinate = prev_block.seg_end_coordinate
+        prev_end_space_small = (
+            abs(prev_block_bbox[2] - pre_block_seg_end_coordinate) < 10
+        )
+        prev_lines_more_than_one = num_of_prev_lines > 1
+
+        overlap_blocks = context_left_coordinate < prev_block_bbox[2]
+
+        # update context_left_coordinate and context_right_coordinate
+        if overlap_blocks:
+            context_left_coordinate = min(prev_block_bbox[0], context_left_coordinate)
+            context_right_coordinate = max(prev_block_bbox[2], context_right_coordinate)
+            prev_end_space_small = (
+                abs(context_right_coordinate - pre_block_seg_end_coordinate) < 10
+            )
+            edge_distance = 0
+        else:
+            edge_distance = abs(block_box[0] - prev_block_bbox[2])
+
+        current_start_space_small = seg_start_coordinate - context_left_coordinate < 10
+
+        if (
+            prev_end_space_small
+            and current_start_space_small
+            and prev_lines_more_than_one
+            and edge_distance < max(prev_block.width, block.width)
+        ):
+            seg_start_flag = False
+    else:
+        if seg_start_coordinate - context_left_coordinate < 10:
+            seg_start_flag = False
+
+    if context_right_coordinate - seg_end_coordinate < 10:
+        seg_end_flag = False
+
+    return seg_start_flag, seg_end_flag
+
+
 class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
     """Layout Parsing Result V2"""
 
@@ -43,19 +206,6 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
         XlsxMixin.__init__(self)
         MarkdownMixin.__init__(self)
         JsonMixin.__init__(self)
-        self.title_pattern = self._build_title_pattern()
-
-    def _build_title_pattern(self):
-        # Precompiled regex pattern for matching numbering at the beginning of the title
-        numbering_pattern = (
-            r"(?:"
-            + r"[1-9][0-9]*(?:\.[1-9][0-9]*)*[\.、]?|"
-            + r"[\(\(](?:[1-9][0-9]*|["
-            r"一二三四五六七八九十百千万亿零壹贰叁肆伍陆柒捌玖拾]+)[\)\)]|" + r"["
-            r"一二三四五六七八九十百千万亿零壹贰叁肆伍陆柒捌玖拾]+"
-            r"[、\.]?|" + r"(?:I|II|III|IV|V|VI|VII|VIII|IX|X)\.?" + r")"
-        )
-        return re.compile(r"^\s*(" + numbering_pattern + r")(\s*)(.*)$")
 
     def _get_input_fn(self):
         fn = super()._get_input_fn()
@@ -265,243 +415,96 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
         Returns:
             Dict
         """
+        original_image_width = self["doc_preprocessor_res"]["output_img"].shape[1]
 
-        def _format_data(obj):
-
-            def format_title(title):
-                """
-                Normalize chapter title.
-                Add the '#' to indicate the level of the title.
-                If numbering exists, ensure there's exactly one space between it and the title content.
-                If numbering does not exist, return the original title unchanged.
-
-                :param title: Original chapter title string.
-                :return: Normalized chapter title string.
-                """
-                match = self.title_pattern.match(title)
-                if match:
-                    numbering = match.group(1).strip()
-                    title_content = match.group(3).lstrip()
-                    # Return numbering and title content separated by one space
-                    title = numbering + " " + title_content
-
-                title = title.rstrip(".")
-                level = (
-                    title.count(
-                        ".",
-                    )
-                    + 1
-                    if "." in title
-                    else 1
-                )
-                return f"#{'#' * level} {title}".replace("-\n", "").replace(
-                    "\n",
-                    " ",
-                )
-
-            def format_text_centered_by_html():
-                return (
-                    f'<div style="text-align: center;">{block.content}</div>'.replace(
-                        "-\n",
-                        "",
-                    ).replace("\n", " ")
-                    + "\n"
-                )
-
-            def format_text_plain():
-                return block.content
-
-            def format_image_centered_by_html():
-                img_tags = []
-                image_path = "".join(block.image.keys())
-                image_width = block.image[image_path].width
-                scale = int(image_width / original_image_width * 100)
-                img_tags.append(
-                    '<div style="text-align: center;"><img src="{}" alt="Image" width="{}%" /></div>'.format(
-                        image_path.replace("-\n", "").replace("\n", " "), scale
-                    ),
-                )
-                return "\n".join(img_tags)
-
-            def format_image_plain():
-                img_tags = []
-                image_path = "".join(block.image.keys())
-                img_tags.append(
-                    "![]({})".format(image_path.replace("-\n", "").replace("\n", " "))
-                )
-                return "\n".join(img_tags)
-
-            def format_chart():
-                if not self["model_settings"].get("use_chart_recognition", False):
-                    return format_image()
-                lines_list = block.content.split("\n")
-                column_num = len(lines_list[0].split("|"))
-                lines_list.insert(1, "|".join(["---"] * column_num))
-                lines_list = [f"|{line}|" for line in lines_list]
-                return "\n".join(lines_list)
-
-            def format_first_line(templates, format_func, spliter):
-                lines = block.content.split(spliter)
-                for idx in range(len(lines)):
-                    line = lines[idx]
-                    if line.strip() == "":
-                        continue
-                    if line.lower() in templates:
-                        lines[idx] = format_func(line)
-                    break
-                return spliter.join(lines)
-
-            def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock):
-
-                seg_start_flag = True
-                seg_end_flag = True
-
-                block_box = block.bbox
-                context_left_coordinate = block_box[0]
-                context_right_coordinate = block_box[2]
-                seg_start_coordinate = block.seg_start_coordinate
-                seg_end_coordinate = block.seg_end_coordinate
-
-                if prev_block is not None:
-                    prev_block_bbox = prev_block.bbox
-                    num_of_prev_lines = prev_block.num_of_lines
-                    pre_block_seg_end_coordinate = prev_block.seg_end_coordinate
-                    prev_end_space_small = (
-                        abs(prev_block_bbox[2] - pre_block_seg_end_coordinate) < 10
-                    )
-                    prev_lines_more_than_one = num_of_prev_lines > 1
-
-                    overlap_blocks = context_left_coordinate < prev_block_bbox[2]
-
-                    # update context_left_coordinate and context_right_coordinate
-                    if overlap_blocks:
-                        context_left_coordinate = min(
-                            prev_block_bbox[0], context_left_coordinate
-                        )
-                        context_right_coordinate = max(
-                            prev_block_bbox[2], context_right_coordinate
-                        )
-                        prev_end_space_small = (
-                            abs(context_right_coordinate - pre_block_seg_end_coordinate)
-                            < 10
-                        )
-                        edge_distance = 0
-                    else:
-                        edge_distance = abs(block_box[0] - prev_block_bbox[2])
-
-                    current_start_space_small = (
-                        seg_start_coordinate - context_left_coordinate < 10
-                    )
-
-                    if (
-                        prev_end_space_small
-                        and current_start_space_small
-                        and prev_lines_more_than_one
-                        and edge_distance < max(prev_block.width, block.width)
-                    ):
-                        seg_start_flag = False
-                else:
-                    if seg_start_coordinate - context_left_coordinate < 10:
-                        seg_start_flag = False
-
-                if context_right_coordinate - seg_end_coordinate < 10:
-                    seg_end_flag = False
-
-                return seg_start_flag, seg_end_flag
+        if pretty_markdown:
+            format_text_func = format_text_centered_by_html_func
+            format_image_func = partial(
+                format_image_centered_by_html_func,
+                original_image_width=original_image_width,
+            )
+            format_table = lambda block: "\n" + format_text_func(block)
+        else:
+            format_text_func = lambda block: block.content
+            format_image_func = format_image_plain_func
+            format_table = lambda block: simplify_table_func("\n" + block.content)
+
+        handle_funcs_dict = {
+            "paragraph_title": format_title_func,
+            "abstract_title": format_title_func,
+            "reference_title": format_title_func,
+            "content_title": format_title_func,
+            "doc_title": lambda block: f"# {block.content}".replace(
+                "-\n",
+                "",
+            ).replace("\n", " "),
+            "table_title": format_text_func,
+            "figure_title": format_text_func,
+            "chart_title": format_text_func,
+            "text": lambda block: block.content.replace("\n\n", "\n").replace(
+                "\n", "\n\n"
+            ),
+            "abstract": partial(
+                format_first_line_func,
+                templates=["摘要", "abstract"],
+                format_func=lambda l: f"## {l}\n",
+                spliter=" ",
+            ),
+            "content": lambda block: block.content.replace("-\n", "  \n").replace(
+                "\n", "  \n"
+            ),
+            "image": format_image_func,
+            "chart": format_chart_func,
+            "formula": lambda block: f"$${block.content}$$",
+            "table": format_table,
+            "reference": partial(
+                format_first_line_func,
+                templates=["参考文献", "references"],
+                format_func=lambda l: f"## {l}",
+                spliter="\n",
+            ),
+            "algorithm": lambda block: block.content.strip("\n"),
+            "seal": partial(compose_funcs, funcs=[format_image_func, format_text_func]),
+        }
+
+        markdown_content = ""
+        last_label = None
+        seg_start_flag = None
+        seg_end_flag = None
+        prev_block = None
+        page_first_element_seg_start_flag = None
+        page_last_element_seg_end_flag = None
+        for block in self["parsing_res_list"]:
+            seg_start_flag, seg_end_flag = get_seg_flag(block, prev_block)
 
-            def format_table_with_html_body():
-                return "\n" + block.content
+            label = block.label
+            page_first_element_seg_start_flag = (
+                seg_start_flag
+                if (page_first_element_seg_start_flag is None)
+                else page_first_element_seg_start_flag
+            )
 
-            def format_table_wo_html_body():
-                return "\n" + block.content.replace("<html>", "").replace(
-                    "</html>", ""
-                ).replace("<body>", "").replace("</body>", "")
+            handle_func = handle_funcs_dict.get(label, None)
+            if handle_func:
+                prev_block = block
+                if label == last_label == "text" and seg_start_flag == False:
+                    markdown_content += handle_func(block)
+                else:
+                    markdown_content += (
+                        "\n\n" + handle_func(block)
+                        if markdown_content
+                        else handle_func(block)
+                    )
+                last_label = label
+        page_last_element_seg_end_flag = seg_end_flag
 
-            if pretty_markdown:
-                format_text = format_text_centered_by_html
-                format_image = format_image_centered_by_html
-                format_table = format_table_with_html_body
-            else:
-                format_text = format_text_plain
-                format_image = format_image_plain
-                format_table = format_table_wo_html_body
-
-            handlers = {
-                "paragraph_title": lambda: format_title(block.content),
-                "abstract_title": lambda: format_title(block.content),
-                "reference_title": lambda: format_title(block.content),
-                "content_title": lambda: format_title(block.content),
-                "doc_title": lambda: f"# {block.content}".replace(
-                    "-\n",
-                    "",
-                ).replace("\n", " "),
-                "table_title": lambda: format_text(),
-                "figure_title": lambda: format_text(),
-                "chart_title": lambda: format_text(),
-                "text": lambda: block.content.replace("\n\n", "\n").replace(
-                    "\n", "\n\n"
-                ),
-                "abstract": lambda: format_first_line(
-                    ["摘要", "abstract"], lambda l: f"## {l}\n", " "
-                ),
-                "content": lambda: block.content.replace("-\n", "  \n").replace(
-                    "\n", "  \n"
-                ),
-                "image": lambda: format_image(),
-                "chart": lambda: format_chart(),
-                "formula": lambda: f"$${block.content}$$",
-                "table": format_table,
-                "reference": lambda: format_first_line(
-                    ["参考文献", "references"], lambda l: f"## {l}", "\n"
-                ),
-                "algorithm": lambda: block.content.strip("\n"),
-                "seal": lambda: f"Words of Seals:\n{block.content}",
-            }
-            parsing_res_list = obj["parsing_res_list"]
-            markdown_content = ""
-            last_label = None
-            seg_start_flag = None
-            seg_end_flag = None
-            prev_block = None
-            page_first_element_seg_start_flag = None
-            page_last_element_seg_end_flag = None
-            for block in parsing_res_list:
-                seg_start_flag, seg_end_flag = get_seg_flag(block, prev_block)
-
-                label = block.label
-                page_first_element_seg_start_flag = (
-                    seg_start_flag
-                    if (page_first_element_seg_start_flag is None)
-                    else page_first_element_seg_start_flag
-                )
-                handler = handlers.get(label)
-                if handler:
-                    prev_block = block
-                    if label == last_label == "text" and seg_start_flag == False:
-                        markdown_content += handler()
-                    else:
-                        markdown_content += (
-                            "\n\n" + handler() if markdown_content else handler()
-                        )
-                    last_label = label
-            page_last_element_seg_end_flag = seg_end_flag
-
-            return markdown_content, (
+        markdown_info = {
+            "markdown_texts": markdown_content,
+            "page_continuation_flags": (
                 page_first_element_seg_start_flag,
                 page_last_element_seg_end_flag,
-            )
-
-        markdown_info = dict()
-        original_image_width = self["doc_preprocessor_res"]["output_img"].shape[1]
-        markdown_info["markdown_texts"], (
-            page_first_element_seg_start_flag,
-            page_last_element_seg_end_flag,
-        ) = _format_data(self)
-        markdown_info["page_continuation_flags"] = (
-            page_first_element_seg_start_flag,
-            page_last_element_seg_end_flag,
-        )
-
+            ),
+        }
         markdown_info["markdown_images"] = {}
         for img in self["imgs_in_doc"]:
             markdown_info["markdown_images"][img["path"]] = img["img"]