Sfoglia il codice sorgente

support sort region

zhouchangda 6 mesi fa
parent
commit
9a2eeeaed3

+ 2 - 2
paddlex/configs/pipelines/PP-StructureV3.yaml

@@ -6,8 +6,8 @@ use_general_ocr: True
 use_seal_recognition: True
 use_table_recognition: True
 use_formula_recognition: True
-use_chart_recognition: True
-use_region_detection: True
+use_chart_recognition: False
+use_region_detection: False
 pretty_markdown: True
 
 SubModules:

+ 75 - 84
paddlex/inference/models/formula_recognition/processors.py

@@ -15,9 +15,7 @@
 
 import json
 import math
-import os
 import re
-import tempfile
 from typing import Any, Dict, List, Optional, Tuple, Union
 
 import numpy as np
@@ -325,14 +323,9 @@ class LaTeXOCRDecode(object):
             **kwargs: Additional keyword arguments for initialization.
         """
         super(LaTeXOCRDecode, self).__init__()
-        temp_path = tempfile.gettempdir()
-        rec_char_dict_path = os.path.join(temp_path, "latexocr_tokenizer.json")
-        try:
-            with open(rec_char_dict_path, "w") as f:
-                json.dump(character_list, f)
-        except Exception as e:
-            print(f"创建 latexocr_tokenizer.json 文件失败, 原因{str(e)}")
-        self.tokenizer = TokenizerFast.from_file(rec_char_dict_path)
+        fast_tokenizer_str = json.dumps(character_list)
+        fast_tokenizer_buffer = fast_tokenizer_str.encode("utf-8")
+        self.tokenizer = TokenizerFast.from_buffer(fast_tokenizer_buffer)
 
     def post_process(self, s: str) -> str:
         """Post-processes the decoded LaTeX string.
@@ -372,7 +365,7 @@ class LaTeXOCRDecode(object):
         dec = [self.tokenizer.decode(tok) for tok in tokens]
         dec_str_list = [
             "".join(detok.split(" "))
-            .replace("Ġ", " ")
+            .replace("", " ")
             .replace("[EOS]", "")
             .replace("[BOS]", "")
             .replace("[PAD]", "")
@@ -631,80 +624,65 @@ class UniMERNetDecode(object):
         self.pad_token_type_id = 0
         self.pad_to_multiple_of = None
 
-        with tempfile.NamedTemporaryFile(
-            mode="w", suffix=".json", delete=True
-        ) as temp_file1, tempfile.NamedTemporaryFile(
-            mode="w", suffix=".json", delete=True
-        ) as temp_file2:
-            fast_tokenizer_file = temp_file1.name
-            tokenizer_config_file = temp_file2.name
-            try:
-                with open(fast_tokenizer_file, "w") as f:
-                    json.dump(character_list["fast_tokenizer_file"], f)
-                with open(tokenizer_config_file, "w") as f:
-                    json.dump(character_list["tokenizer_config_file"], f)
-            except Exception as e:
-                print(
-                    f"创建 tokenizer.json 和 tokenizer_config.json 文件失败, 原因{str(e)}"
-                )
-
-            self.tokenizer = TokenizerFast.from_file(fast_tokenizer_file)
-            added_tokens_decoder = {}
-            added_tokens_map = {}
-            if tokenizer_config_file is not None:
-                with open(
-                    tokenizer_config_file, encoding="utf-8"
-                ) as tokenizer_config_handle:
-                    init_kwargs = json.load(tokenizer_config_handle)
-                    if "added_tokens_decoder" in init_kwargs:
-                        for idx, token in init_kwargs["added_tokens_decoder"].items():
-                            if isinstance(token, dict):
-                                token = AddedToken(**token)
-                            if isinstance(token, AddedToken):
-                                added_tokens_decoder[int(idx)] = token
-                                added_tokens_map[str(token)] = token
-                            else:
-                                raise ValueError(
-                                    f"Found a {token.__class__} in the saved `added_tokens_decoder`, should be a dictionary or an AddedToken instance"
-                                )
-                    init_kwargs["added_tokens_decoder"] = added_tokens_decoder
-                    added_tokens_decoder = init_kwargs.pop("added_tokens_decoder", {})
-                    tokens_to_add = [
-                        token
-                        for index, token in sorted(
-                            added_tokens_decoder.items(), key=lambda x: x[0]
+        fast_tokenizer_str = json.dumps(character_list["fast_tokenizer_file"])
+        fast_tokenizer_buffer = fast_tokenizer_str.encode("utf-8")
+        self.tokenizer = TokenizerFast.from_buffer(fast_tokenizer_buffer)
+        tokenizer_config = (
+            character_list["tokenizer_config_file"]
+            if "tokenizer_config_file" in character_list
+            else None
+        )
+        added_tokens_decoder = {}
+        added_tokens_map = {}
+        if tokenizer_config is not None:
+            init_kwargs = tokenizer_config
+            if "added_tokens_decoder" in init_kwargs:
+                for idx, token in init_kwargs["added_tokens_decoder"].items():
+                    if isinstance(token, dict):
+                        token = AddedToken(**token)
+                    if isinstance(token, AddedToken):
+                        added_tokens_decoder[int(idx)] = token
+                        added_tokens_map[str(token)] = token
+                    else:
+                        raise ValueError(
+                            f"Found a {token.__class__} in the saved `added_tokens_decoder`, should be a dictionary or an AddedToken instance"
                         )
-                        if token not in added_tokens_decoder
-                    ]
-                    added_tokens_encoder = self.added_tokens_encoder(
-                        added_tokens_decoder
+            init_kwargs["added_tokens_decoder"] = added_tokens_decoder
+            added_tokens_decoder = init_kwargs.pop("added_tokens_decoder", {})
+            tokens_to_add = [
+                token
+                for index, token in sorted(
+                    added_tokens_decoder.items(), key=lambda x: x[0]
+                )
+                if token not in added_tokens_decoder
+            ]
+            added_tokens_encoder = self.added_tokens_encoder(added_tokens_decoder)
+            encoder = list(added_tokens_encoder.keys()) + [
+                str(token) for token in tokens_to_add
+            ]
+            tokens_to_add += [
+                token
+                for token in self.all_special_tokens_extended
+                if token not in encoder and token not in tokens_to_add
+            ]
+            if len(tokens_to_add) > 0:
+                is_last_special = None
+                tokens = []
+                special_tokens = self.all_special_tokens
+                for token in tokens_to_add:
+                    is_special = (
+                        (token.special or str(token) in special_tokens)
+                        if isinstance(token, AddedToken)
+                        else str(token) in special_tokens
                     )
-                    encoder = list(added_tokens_encoder.keys()) + [
-                        str(token) for token in tokens_to_add
-                    ]
-                    tokens_to_add += [
-                        token
-                        for token in self.all_special_tokens_extended
-                        if token not in encoder and token not in tokens_to_add
-                    ]
-                    if len(tokens_to_add) > 0:
-                        is_last_special = None
-                        tokens = []
-                        special_tokens = self.all_special_tokens
-                        for token in tokens_to_add:
-                            is_special = (
-                                (token.special or str(token) in special_tokens)
-                                if isinstance(token, AddedToken)
-                                else str(token) in special_tokens
-                            )
-                            if is_last_special is None or is_last_special == is_special:
-                                tokens.append(token)
-                            else:
-                                self._add_tokens(tokens, special_tokens=is_last_special)
-                                tokens = [token]
-                            is_last_special = is_special
-                        if tokens:
-                            self._add_tokens(tokens, special_tokens=is_last_special)
+                    if is_last_special is None or is_last_special == is_special:
+                        tokens.append(token)
+                    else:
+                        self._add_tokens(tokens, special_tokens=is_last_special)
+                        tokens = [token]
+                    is_last_special = is_special
+                if tokens:
+                    self._add_tokens(tokens, special_tokens=is_last_special)
 
     def _add_tokens(
         self, new_tokens: "List[Union[AddedToken, str]]", special_tokens: bool = False
@@ -820,7 +798,7 @@ class UniMERNetDecode(object):
             for i in reversed(range(len(toks[b]))):
                 if toks[b][i] is None:
                     toks[b][i] = ""
-                toks[b][i] = toks[b][i].replace("Ġ", " ").strip()
+                toks[b][i] = toks[b][i].replace("", " ").strip()
                 if toks[b][i] in (
                     [
                         self.tokenizer.bos_token,
@@ -876,6 +854,15 @@ class UniMERNetDecode(object):
                 break
         return s
 
+    def remove_chinese_text_wrapping(self, formula):
+        pattern = re.compile(r"\\text\s*{\s*([^}]*?[\u4e00-\u9fff]+[^}]*?)\s*}")
+
+        def replacer(match):
+            return match.group(1)
+
+        replaced_formula = pattern.sub(replacer, formula)
+        return replaced_formula.replace('"', "")
+
     def post_process(self, text: str) -> str:
         """Post-processes a string by fixing text and normalizing it.
 
@@ -887,8 +874,12 @@ class UniMERNetDecode(object):
         """
         from ftfy import fix_text
 
+        text = self.remove_chinese_text_wrapping(text)
         text = fix_text(text)
+        print("=" * 100)
+        print(text)
         text = self.normalize(text)
+        print(text)
         return text
 
     def __call__(

+ 6 - 1
paddlex/inference/models/table_structure_recognition/processors.py

@@ -130,7 +130,12 @@ class TableLabelDecode:
             structure_probs, bbox_preds, img_size, ori_img_size
         )
         structure_str_list = [
-            (["<table>"] + structure + ["</table>"]) for structure in structure_str_list
+            (
+                ["<html>", "<body>", "<table>"]
+                + structure
+                + ["</table>", "</body>", "</html>"]
+            )
+            for structure in structure_str_list
         ]
         return [
             {"bbox": bbox, "structure": structure, "structure_score": structure_score}

+ 36 - 27
paddlex/inference/pipelines/layout_parsing/pipeline_v2.py

@@ -111,7 +111,7 @@ class LayoutParsingPipelineV2(BasePipeline):
         )
         self.use_chart_recognition = config.get(
             "use_chart_recognition",
-            True,
+            False,
         )
 
         self.pretty_markdown = config.get(
@@ -642,38 +642,42 @@ class LayoutParsingPipelineV2(BasePipeline):
             block.content = ""
             return block
 
-        lines, text_direction = group_boxes_into_lines(
+        lines, text_direction, text_line_height = group_boxes_into_lines(
             ocr_rec_res,
             LINE_SETTINGS.get("line_height_iou_threshold", 0.8),
         )
 
+        # format line
+        text_lines = []
+        need_new_line_num = 0
+        # words start coordinate and stop coordinate in the line
+        words_start_index = 0 if text_direction == "horizontal" else 1
+        words_stop_index = words_start_index + 2
+        lines_start_index = 1 if text_direction == "horizontal" else 3
+        line_width_list = []
+
         if block.label == "reference":
             rec_boxes = ocr_rec_res["boxes"]
-            block_right_coordinate = max([box[2] for box in rec_boxes])
+            block_start_coordinate = min([box[words_start_index] for box in rec_boxes])
+            block_stop_coordinate = max([box[words_stop_index] for box in rec_boxes])
         else:
-            block_right_coordinate = block.bbox[2]
+            block_start_coordinate = block.bbox[words_start_index]
+            block_stop_coordinate = block.bbox[words_stop_index]
 
-        # format line
-        text_lines = []
-        need_new_line_num = 0
-        start_index = 0 if text_direction == "horizontal" else 1
-        secondary_direction_start_index = 1 if text_direction == "horizontal" else 0
-        line_height_list, line_width_list = [], []
         for idx, line in enumerate(lines):
-            line.sort(key=lambda span: span[0][start_index])
-
-            text_bboxes_height = [
-                span[0][secondary_direction_start_index + 2]
-                - span[0][secondary_direction_start_index]
-                for span in line
-            ]
-            text_bboxes_width = [
-                span[0][start_index + 2] - span[0][start_index] for span in line
-            ]
+            line.sort(
+                key=lambda span: (
+                    span[0][words_start_index] // 2,
+                    (
+                        span[0][lines_start_index]
+                        if text_direction == "horizontal"
+                        else -span[0][lines_start_index]
+                    ),
+                )
+            )
 
-            line_height = np.mean(text_bboxes_height)
-            line_height_list.append(line_height)
-            line_width_list.append(np.mean(text_bboxes_width))
+            line_width = line[-1][0][words_stop_index] - line[0][0][words_start_index]
+            line_width_list.append(line_width)
             # merge formula and text
             ocr_labels = [span[2] for span in line]
             if "formula" in ocr_labels:
@@ -683,8 +687,11 @@ class LayoutParsingPipelineV2(BasePipeline):
 
             line_text, need_new_line = format_line(
                 line,
-                block_right_coordinate,
-                last_line_span_limit=line_height * 1.5,
+                text_direction,
+                np.max(line_width_list),
+                block_start_coordinate,
+                block_stop_coordinate,
+                line_gap_limit=text_line_height * 1.5,
                 block_label=block.label,
             )
             if need_new_line:
@@ -699,12 +706,13 @@ class LayoutParsingPipelineV2(BasePipeline):
 
         delim = LINE_SETTINGS["delimiter_map"].get(block.label, "")
         if need_new_line_num > len(text_lines) * 0.5 and delim == "":
+            text_lines = [text.replace("\n", "") for text in text_lines]
             delim = "\n"
         content = delim.join(text_lines)
         block.content = content
         block.num_of_lines = len(text_lines)
         block.direction = text_direction
-        block.text_line_height = np.mean(line_height_list)
+        block.text_line_height = text_line_height
         block.text_line_width = np.mean(line_width_list)
 
         return block
@@ -816,12 +824,13 @@ class LayoutParsingPipelineV2(BasePipeline):
             region = LayoutParsingRegion(
                 bbox=region_bbox,
                 blocks=region_blocks,
+                image_shape=image.shape[:2],
             )
             region_list.append(region)
 
         region_list = sorted(
             region_list,
-            key=lambda r: (r.euclidean_distance // 50, r.center_euclidean_distance),
+            key=lambda r: (r.weighted_distance),
         )
 
         return region_list

+ 23 - 11
paddlex/inference/pipelines/layout_parsing/result_v2.py

@@ -115,8 +115,8 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
         for block in parsing_result:
             bbox = block.bbox
             index = block.order_index
-            label = block.order_label
-            fill_color = get_show_color(label, True)
+            label = block.label
+            fill_color = get_show_color(label, False)
             draw.rectangle(bbox, fill=fill_color)
             if index is not None:
                 text_position = (bbox[2] + 2, bbox[1] - font_size // 2)
@@ -348,9 +348,6 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
                     break
                 return spliter.join(lines)
 
-            def format_table():
-                return "\n" + block.content
-
             def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock):
 
                 seg_start_flag = True
@@ -409,12 +406,22 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
 
                 return seg_start_flag, seg_end_flag
 
+            def format_table_with_html_body():
+                return "\n" + block.content
+
+            def format_table_wo_html_body():
+                return "\n" + block.content.replace("<html>", "").replace(
+                    "</html>", ""
+                ).replace("<body>", "").replace("</body>", "")
+
             if self["model_settings"].get("pretty_markdown", True):
                 format_text = format_text_centered_by_html
                 format_image = format_image_centered_by_html
+                format_table = format_table_with_html_body
             else:
                 format_text = format_text_plain
                 format_image = format_image_plain
+                format_table = format_table_wo_html_body
 
             handlers = {
                 "paragraph_title": lambda: format_title(block.content),
@@ -482,7 +489,7 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
             )
 
         markdown_info = dict()
-        original_image_width = self["doc_preprocessor_res"]["input_img"].shape[1]
+        original_image_width = self["doc_preprocessor_res"]["output_img"].shape[1]
         markdown_info["markdown_texts"], (
             page_first_element_seg_start_flag,
             page_last_element_seg_end_flag,
@@ -532,8 +539,6 @@ class LayoutParsingBlock:
         return self.__dict__
 
     def update_direction_info(self) -> None:
-        if self.order_label == "vision":
-            self.direction = "horizontal"
         if self.direction == "horizontal":
             self.secondary_direction = "vertical"
             self.short_side_length = self.height
@@ -598,11 +603,13 @@ class LayoutParsingBlock:
 
 class LayoutParsingRegion:
 
-    def __init__(self, bbox, blocks: List[LayoutParsingBlock] = []) -> None:
+    def __init__(
+        self, bbox, blocks: List[LayoutParsingBlock] = [], image_shape=None
+    ) -> None:
         self.bbox = bbox
         self.block_map = {}
         self.direction = "horizontal"
-        self.calculate_bbox_metrics()
+        self.calculate_bbox_metrics(image_shape)
         self.doc_title_block_idxes = []
         self.paragraph_title_block_idxes = []
         self.vision_block_idxes = []
@@ -678,12 +685,17 @@ class LayoutParsingRegion:
             + self.bbox[self.secondary_direction_end_index]
         ) / 2
 
-    def calculate_bbox_metrics(self):
+    def calculate_bbox_metrics(self, image_shape):
         x1, y1, x2, y2 = self.bbox
+        width = x2 - x1
+        image_height, image_width = image_shape
         x_center, y_center = (x1 + x2) / 2, (y1 + y2) / 2
         self.euclidean_distance = math.sqrt(((x1) ** 2 + (y1) ** 2))
         self.center_euclidean_distance = math.sqrt(((x_center) ** 2 + (y_center) ** 2))
         self.angle_rad = math.atan2(y_center, x_center)
+        self.weighted_distance = (
+            y1 + width + (x1 // (image_width // 10)) * (image_width // 10) * 1.5
+        )
 
     def sort_normal_blocks(self, blocks):
         if self.direction == "horizontal":

+ 0 - 8
paddlex/inference/pipelines/layout_parsing/setting.py

@@ -74,16 +74,8 @@ BLOCK_LABEL_MAP = {
         "abstract",
         "paragraph_title",
         "doc_title",
-        "table_title",
-        "chart_title",
-        "figure_title",
-        "image",
-        "table",
-        "chart",
-        "figure",
         "abstract_title",
         "refer_title",
         "content_title",
-        "flowchart",
     ],
 }

+ 92 - 28
paddlex/inference/pipelines/layout_parsing/utils.py

@@ -274,6 +274,9 @@ def group_boxes_into_lines(ocr_rec_res, line_height_iou_threshold):
 
     match_direction = "vertical" if text_orientation == "horizontal" else "horizontal"
 
+    line_start_index = 1 if text_orientation == "horizontal" else 0
+    line_end_index = 3 if text_orientation == "horizontal" else 2
+
     spans = list(zip(rec_boxes, rec_texts, rec_labels))
     sort_index = 1
     reverse = False
@@ -286,7 +289,7 @@ def group_boxes_into_lines(ocr_rec_res, line_height_iou_threshold):
     lines = []
     line = [spans[0]]
     line_region_box = spans[0][0].copy()
-
+    line_heights = []
     # merge line
     for span in spans[1:]:
         rec_bbox = span[0]
@@ -297,15 +300,36 @@ def group_boxes_into_lines(ocr_rec_res, line_height_iou_threshold):
             >= line_height_iou_threshold
         ):
             line.append(span)
-            line_region_box[1] = min(line_region_box[1], rec_bbox[1])
-            line_region_box[3] = max(line_region_box[3], rec_bbox[3])
+            line_region_box[line_start_index] = min(
+                line_region_box[line_start_index], rec_bbox[line_start_index]
+            )
+            line_region_box[line_end_index] = max(
+                line_region_box[line_end_index], rec_bbox[line_end_index]
+            )
         else:
+            line_heights.append(
+                line_region_box[line_end_index] - line_region_box[line_start_index]
+            )
             lines.append(line)
             line = [span]
             line_region_box = rec_bbox.copy()
 
     lines.append(line)
-    return lines, text_orientation
+    line_heights.append(
+        line_region_box[line_end_index] - line_region_box[line_start_index]
+    )
+
+    min_height = min(line_heights) if line_heights else 0
+    max_height = max(line_heights) if line_heights else 0
+
+    if max_height > min_height * 2 and text_orientation == "vertical":
+        line_heights = np.array(line_heights)
+        min_height_num = np.sum(line_heights < min_height * 1.1)
+        if min_height_num < len(lines) * 0.4:
+            condition = line_heights > min_height * 1.1
+            lines = [value for value, keep in zip(lines, condition) if keep]
+
+    return lines, text_orientation, np.mean(line_heights)
 
 
 def calculate_minimum_enclosing_bbox(bboxes):
@@ -381,6 +405,7 @@ def is_non_breaking_punctuation(char):
         ";",  # 全角分号
         ":",  # 半角冒号
         ":",  # 全角冒号
+        "-",  # 连字符
     }
 
     return char in non_breaking_punctuations
@@ -388,8 +413,11 @@ def is_non_breaking_punctuation(char):
 
 def format_line(
     line: List[List[Union[List[int], str]]],
-    block_right_coordinate: int,
-    last_line_span_limit: int = 10,
+    text_direction: int,
+    block_width: int,
+    block_start_coordinate: int,
+    block_stop_coordinate: int,
+    line_gap_limit: int = 10,
     block_label: str = "text",
 ) -> None:
     """
@@ -397,14 +425,15 @@ def format_line(
 
     Args:
         line (list): A list of spans, where each span is a list containing a bounding box and text.
-        block_left_coordinate (int): The minimum x-coordinate of the layout bounding box.
-        block_right_coordinate (int): The maximum x-coordinate of the layout bounding box.
+        block_left_coordinate (int): The text line directional minimum coordinate of the layout bounding box.
+        block_stop_coordinate (int): The text line directional maximum x-coordinate of the layout bounding box.
         first_line_span_limit (int): The limit for the number of pixels before the first span that should be considered part of the first line. Default is 10.
-        last_line_span_limit (int): The limit for the number of pixels after the last span that should be considered part of the last line. Default is 10.
+        line_gap_limit (int): The limit for the number of pixels after the last span that should be considered part of the last line. Default is 10.
         block_label (str): The label associated with the entire block. Default is 'text'.
     Returns:
         None: The function modifies the line in place.
     """
+    first_span_box = line[0][0]
     last_span_box = line[-1][0]
 
     for span in line:
@@ -414,17 +443,37 @@ def format_line(
             else:
                 span[1] = f"\n${span[1]}$"
 
-    line_text = " ".join([span[1] for span in line])
+    line_text = ""
+    for span in line:
+        _, text, label = span
+        line_text += text
+        if len(text) > 0 and is_english_letter(line_text[-1]) or label == "formula":
+            line_text += " "
+
+    if text_direction == "horizontal":
+        text_start_index = 0
+        text_stop_index = 2
+    else:
+        text_start_index = 1
+        text_stop_index = 3
 
     need_new_line = False
     if (
-        block_right_coordinate - last_span_box[2] > last_line_span_limit
-        and not line_text.endswith("-")
-        and len(line_text) > 0
+        len(line_text) > 0
         and not is_english_letter(line_text[-1])
         and not is_non_breaking_punctuation(line_text[-1])
     ):
-        need_new_line = True
+        if (
+            text_direction == "horizontal"
+            and block_stop_coordinate - last_span_box[text_stop_index] > line_gap_limit
+        ) or (
+            text_direction == "vertical"
+            and (
+                block_stop_coordinate - last_span_box[text_stop_index] > line_gap_limit
+                or first_span_box[1] - block_start_coordinate > line_gap_limit
+            )
+        ):
+            need_new_line = True
 
     if line_text.endswith("-"):
         line_text = line_text[:-1]
@@ -432,6 +481,18 @@ def format_line(
         len(line_text) > 0 and is_english_letter(line_text[-1])
     ) or line_text.endswith("$"):
         line_text += " "
+    else:
+        if (
+            block_stop_coordinate - last_span_box[text_stop_index] > block_width * 0.3
+            and block_label != "formula"
+        ):
+            line_text += "\n"
+        if (
+            first_span_box[text_start_index] - block_start_coordinate
+            > block_width * 0.3
+            and block_label != "formula"
+        ):
+            line_text = "\n" + line_text
 
     return line_text, need_new_line
 
@@ -612,6 +673,7 @@ def remove_overlap_blocks(
     """
     dropped_indexes = set()
     blocks = deepcopy(blocks)
+    overlap_image_blocks = []
     # Iterate over each pair of blocks to find overlaps
     for i, block1 in enumerate(blocks["boxes"]):
         for j in range(i + 1, len(blocks["boxes"])):
@@ -627,20 +689,17 @@ def remove_overlap_blocks(
                 smaller=smaller,
             )
             if overlap_box_index is not None:
-                if block1["label"] == "image" and block2["label"] == "image":
-                    # Determine which block to remove based on overlap_box_index
-                    if overlap_box_index == 1:
-                        drop_index = i
-                    else:
-                        drop_index = j
-                elif block1["label"] == "image" and block2["label"] != "image":
-                    drop_index = i
-                elif block1["label"] != "image" and block2["label"] == "image":
-                    drop_index = j
-                elif overlap_box_index == 1:
-                    drop_index = i
+                is_block1_image = block1["label"] == "image"
+                is_block2_image = block2["label"] == "image"
+
+                if is_block1_image != is_block2_image:
+                    # 如果只有一个块在视觉标签中,删除在视觉标签中的那个块
+                    drop_index = i if is_block1_image else j
+                    overlap_image_blocks.append(blocks["boxes"][drop_index])
                 else:
-                    drop_index = j
+                    # 如果两个块都在或都不在视觉标签中,根据 overlap_box_index 决定删除哪个块
+                    drop_index = i if overlap_box_index == 1 else j
+
                 dropped_indexes.add(drop_index)
 
     # Remove marked blocks from the original list
@@ -815,7 +874,12 @@ def convert_formula_res_to_ocr_format(formula_res_list: List, ocr_res: dict):
             (x_min, y_max),
         ]
         ocr_res["dt_polys"].append(poly_points)
-        ocr_res["rec_texts"].append(f"{formula_res['rec_formula']}")
+        formula_res_text: str = formula_res["rec_formula"]
+        if formula_res_text.startswith("$$") and formula_res_text.endswith("$$"):
+            formula_res_text = formula_res_text[2:-2]
+        elif formula_res_text.startswith("$") and formula_res_text.endswith("$"):
+            formula_res_text = formula_res_text[1:-1]
+        ocr_res["rec_texts"].append(formula_res_text)
         if ocr_res["rec_boxes"].size == 0:
             ocr_res["rec_boxes"] = np.array(formula_res["dt_polys"])
         else:

+ 57 - 44
paddlex/inference/pipelines/layout_parsing/xycut_enhanced/utils.py

@@ -68,11 +68,17 @@ def _projection_by_bboxes(boxes: np.ndarray, axis: int) -> np.ndarray:
         A 1D numpy array representing the projection histogram based on bounding box intervals.
     """
     assert axis in [0, 1]
+
     max_length = np.max(boxes[:, axis::2])
+    if max_length < 0:
+        max_length = abs(np.min(boxes[:, axis::2]))
+
     projection = np.zeros(max_length, dtype=int)
 
     # Increment projection histogram over the interval defined by each bounding box
     for start, end in boxes[:, axis::2]:
+        start = abs(start)
+        end = abs(end)
         projection[start:end] += 1
 
     return projection
@@ -170,10 +176,12 @@ def recursive_yx_cut(
             res.extend(x_sorted_indices_chunk)
             continue
 
+        if np.min(x_sorted_boxes_chunk[:, 0]) < 0:
+            x_intervals = np.flip(x_intervals, axis=1)
         # Recursively process each segment defined by X-axis projection
         for x_start, x_end in zip(*x_intervals):
-            x_interval_indices = (x_start <= x_sorted_boxes_chunk[:, 0]) & (
-                x_sorted_boxes_chunk[:, 0] < x_end
+            x_interval_indices = (x_start <= abs(x_sorted_boxes_chunk[:, 0])) & (
+                abs(x_sorted_boxes_chunk[:, 0]) < x_end
             )
             recursive_yx_cut(
                 x_sorted_boxes_chunk[x_interval_indices],
@@ -214,11 +222,13 @@ def recursive_xy_cut(
     if not x_intervals:
         return
 
+    if np.min(x_sorted_boxes[:, 0]) < 0:
+        x_intervals = np.flip(x_intervals, axis=1)
     # Process each segment defined by X-axis projection
     for x_start, x_end in zip(*x_intervals):
         # Select boxes within the current x interval
-        x_interval_indices = (x_start <= x_sorted_boxes[:, 0]) & (
-            x_sorted_boxes[:, 0] < x_end
+        x_interval_indices = (x_start <= abs(x_sorted_boxes[:, 0])) & (
+            abs(x_sorted_boxes[:, 0]) < x_end
         )
         x_boxes_chunk = x_sorted_boxes[x_interval_indices]
         x_indices_chunk = x_sorted_indices[x_interval_indices]
@@ -413,7 +423,7 @@ def insert_child_blocks(
     if block.child_blocks:
         sub_blocks = block.get_child_blocks()
         sub_blocks.append(block)
-        sub_blocks = sort_child_blocks(sub_blocks, block.direction)
+        sub_blocks = sort_child_blocks(sub_blocks, sub_blocks[0].direction)
         sorted_blocks[block_idx] = sub_blocks[0]
         for block in sub_blocks[1:]:
             block_idx += 1
@@ -439,17 +449,15 @@ def sort_child_blocks(blocks, direction="horizontal") -> List[LayoutParsingBlock
                 x.bbox[0],  # x_min
                 x.bbox[1] ** 2 + x.bbox[0] ** 2,  # distance with (0,0)
             ),
-            reverse=False,
         )
     else:
         # from right to left
         blocks.sort(
             key=lambda x: (
-                x.bbox[0],  # x_min
+                -x.bbox[0],  # x_min
                 x.bbox[1],  # y_min
-                x.bbox[1] ** 2 + x.bbox[0] ** 2,  # distance with (0,0)
+                x.bbox[1] ** 2 - x.bbox[0] ** 2,  # distance with (max,0)
             ),
-            reverse=True,
         )
     return blocks
 
@@ -495,28 +503,23 @@ def _manhattan_distance(
     return weight_x * abs(point1[0] - point2[0]) + weight_y * abs(point1[1] - point2[1])
 
 
-def sort_blocks(blocks, median_width=None, reverse=False):
-    """
-    Sort blocks based on their y_min, x_min and distance with (0,0).
-
-    Args:
-        blocks (list): list of blocks to be sorted.
-        median_width (int): the median width of the text blocks.
-        reverse (bool, optional): whether to sort in descending order. Default is False.
-
-    Returns:
-        list: a list of sorted blocks.
-    """
-    if median_width is None:
-        median_width = 1
-    blocks.sort(
-        key=lambda x: (
-            x.bbox[1] // 10,  # y_min
-            x.bbox[0] // median_width,  # x_min
-            x.bbox[1] ** 2 + x.bbox[0] ** 2,  # distance with (0,0)
-        ),
-        reverse=reverse,
-    )
+def sort_normal_blocks(blocks, text_line_height, text_line_width, region_direction):
+    if region_direction == "horizontal":
+        blocks.sort(
+            key=lambda x: (
+                x.bbox[1] // text_line_height,
+                x.bbox[0] // text_line_width,
+                x.bbox[1] ** 2 + x.bbox[0] ** 2,
+            ),
+        )
+    else:
+        blocks.sort(
+            key=lambda x: (
+                -x.bbox[0] // text_line_width,
+                x.bbox[1] // text_line_height,
+                x.bbox[1] ** 2 - x.bbox[2] ** 2,  # distance with (max,0)
+            ),
+        )
     return blocks
 
 
@@ -920,7 +923,10 @@ def update_vision_child_blocks(
             )
             block_center = block.get_centroid()
             ref_block_center = ref_block.get_centroid()
-            if ref_block.label in BLOCK_LABEL_MAP["vision_title_labels"]:
+            if (
+                ref_block.label in BLOCK_LABEL_MAP["vision_title_labels"]
+                and nearest_edge_distance <= ref_block.text_line_height * 2
+            ):
                 has_vision_title = True
                 ref_block.order_label = "vision_title"
                 block.append_child_block(ref_block)
@@ -928,12 +934,17 @@ def update_vision_child_blocks(
             if ref_block.label in BLOCK_LABEL_MAP["text_labels"]:
                 if (
                     not has_vision_footnote
-                    and nearest_edge_distance <= block.text_line_height * 2
-                    and ref_block.short_side_length < block.short_side_length
-                    and ref_block.long_side_length < 0.5 * block.long_side_length
                     and ref_block.direction == block.direction
-                    and (
-                        abs(block_center[0] - ref_block_center[0]) < 10
+                    and ref_block.long_side_length < block.long_side_length
+                ):
+                    if (
+                        (
+                            nearest_edge_distance <= block.text_line_height * 2
+                            and ref_block.short_side_length < block.short_side_length
+                            and ref_block.long_side_length
+                            < 0.5 * block.long_side_length
+                            and abs(block_center[0] - ref_block_center[0]) < 10
+                        )
                         or (
                             block.bbox[0] - ref_block.bbox[0] < 10
                             and ref_block.num_of_lines == 1
@@ -942,12 +953,11 @@ def update_vision_child_blocks(
                             block.bbox[2] - ref_block.bbox[2] < 10
                             and ref_block.num_of_lines == 1
                         )
-                    )
-                ):
-                    has_vision_footnote = True
-                    ref_block.order_label = "vision_footnote"
-                    block.append_child_block(ref_block)
-                    region.normal_text_block_idxes.remove(ref_block.index)
+                    ):
+                        has_vision_footnote = True
+                        ref_block.order_label = "vision_footnote"
+                        block.append_child_block(ref_block)
+                        region.normal_text_block_idxes.remove(ref_block.index)
                 break
         for ref_block in post_blocks:
             if (
@@ -960,7 +970,10 @@ def update_vision_child_blocks(
             )
             block_center = block.get_centroid()
             ref_block_center = ref_block.get_centroid()
-            if ref_block.label in BLOCK_LABEL_MAP["vision_title_labels"]:
+            if (
+                ref_block.label in BLOCK_LABEL_MAP["vision_title_labels"]
+                and nearest_edge_distance <= ref_block.text_line_height * 2
+            ):
                 has_vision_title = True
                 ref_block.order_label = "vision_title"
                 block.append_child_block(ref_block)

+ 34 - 38
paddlex/inference/pipelines/layout_parsing/xycut_enhanced/xycuts.py

@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from copy import deepcopy
 from typing import Dict, List, Tuple
 
 import numpy as np
@@ -24,7 +25,6 @@ from .utils import (
     get_cut_blocks,
     get_nearest_edge_distance,
     insert_child_blocks,
-    is_projection_consistent,
     manhattan_insert,
     recursive_xy_cut,
     recursive_yx_cut,
@@ -76,7 +76,10 @@ def pre_process(
         else:
             tolerance_len = block.short_side_length // 10
 
-        block_center = (block.start_coordinate + block.end_coordinate) / 2
+        block_center = (
+            block.bbox[region.direction_start_index]
+            + block.bbox[region.direction_end_index]
+        ) / 2
         center_offset = abs(block_center - region.direction_center_coordinate)
         is_centered = center_offset <= tolerance_len
 
@@ -183,6 +186,7 @@ def update_region_label(
     elif block.label in BLOCK_LABEL_MAP["vision_labels"]:
         block.order_label = "vision"
         block.num_of_lines = 1
+        block.direction = region.direction
         block.update_direction_info()
     elif block.label in BLOCK_LABEL_MAP["footer_labels"]:
         block.order_label = "footer"
@@ -280,19 +284,25 @@ def get_layout_structure(
                         second_ref_block.bbox,
                         region_direction,
                     )
-                    ref_match_projection_iou_ = calculate_projection_overlap_ratio(
-                        ref_block.bbox,
-                        second_ref_block.bbox,
-                        region_secondary_direction,
+                    secondary_direction_ref_match_projection_overlap_ratio = (
+                        calculate_projection_overlap_ratio(
+                            ref_block.bbox,
+                            second_ref_block.bbox,
+                            region_secondary_direction,
+                        )
                     )
                     if (
                         second_match_projection_iou > 0
                         and ref_match_projection_iou == 0
-                        and ref_match_projection_iou_ > 0
+                        and secondary_direction_ref_match_projection_overlap_ratio > 0
                     ):
                         if block.order_label == "vision" or (
                             ref_block.order_label == "normal_text"
                             and second_ref_block.order_label == "normal_text"
+                            and ref_block.text_line_width
+                            > ref_block.text_line_height * 5
+                            and second_ref_block.text_line_width
+                            > second_ref_block.text_line_height * 5
                         ):
                             block.order_label = (
                                 "cross_reference"
@@ -462,60 +472,46 @@ def xycut_enhanced(
             )
             if len(discontinuous) > 1:
                 xy_cut_blocks = [block for block in xy_cut_blocks]
-            # if len(discontinuous) == 1 or max(block_text_lines) == 1 or (not is_projection_consistent(xy_cut_blocks, discontinuous, direction=region.direction) and len(discontinuous) > 2 and max(block_text_lines) - min(block_text_lines) < 3):
-            if len(discontinuous) == 1 or max(block_text_lines) == 1:
-                xy_cut_blocks.sort(
-                    key=lambda x: (
-                        x.bbox[region.secondary_direction_start_index]
-                        // (region.text_line_height // 2),
-                        x.bbox[region.direction_start_index],
+            blocks_to_sort = deepcopy(xy_cut_blocks)
+            if region.direction == "vertical":
+                for block in blocks_to_sort:
+                    block.bbox = np.array(
+                        [-block.bbox[0], block.bbox[1], -block.bbox[2], block.bbox[3]]
                     )
-                )
-                xy_cut_blocks = shrink_overlapping_boxes(
-                    xy_cut_blocks, region.secondary_direction
-                )
-            if (
-                len(discontinuous) == 1
-                or max(block_text_lines) == 1
-                or (
-                    not is_projection_consistent(
-                        xy_cut_blocks, discontinuous, direction=region.direction
-                    )
-                    and len(discontinuous) > 2
-                    and max(block_text_lines) - min(block_text_lines) < 3
-                )
-            ):
-                xy_cut_blocks.sort(
+            if len(discontinuous) == 1 or max(block_text_lines) == 1:
+                blocks_to_sort.sort(
                     key=lambda x: (
                         x.bbox[region.secondary_direction_start_index]
                         // (region.text_line_height // 2),
                         x.bbox[region.direction_start_index],
                     )
                 )
-                xy_cut_blocks = shrink_overlapping_boxes(
-                    xy_cut_blocks, region.secondary_direction
+                blocks_to_sort = shrink_overlapping_boxes(
+                    blocks_to_sort, region.secondary_direction
                 )
-                block_bboxes = np.array([block.bbox for block in xy_cut_blocks])
+                block_bboxes = np.array([block.bbox for block in blocks_to_sort])
                 sorted_indexes = sort_by_xycut(
                     block_bboxes, direction=region.secondary_direction, min_gap=1
                 )
             else:
-                xy_cut_blocks.sort(
+                blocks_to_sort.sort(
                     key=lambda x: (
                         x.bbox[region.direction_start_index]
                         // (region.text_line_width // 2),
                         x.bbox[region.secondary_direction_start_index],
                     )
                 )
-                xy_cut_blocks = shrink_overlapping_boxes(
-                    xy_cut_blocks, region.direction
+                blocks_to_sort = shrink_overlapping_boxes(
+                    blocks_to_sort, region.direction
                 )
-                block_bboxes = np.array([block.bbox for block in xy_cut_blocks])
+                block_bboxes = np.array([block.bbox for block in blocks_to_sort])
                 sorted_indexes = sort_by_xycut(
                     block_bboxes, direction=region.direction, min_gap=1
                 )
 
-            sorted_blocks = [xy_cut_blocks[i] for i in sorted_indexes]
+            sorted_blocks = [
+                region.block_map[blocks_to_sort[i].index] for i in sorted_indexes
+            ]
 
         sorted_blocks = match_unsorted_blocks(
             sorted_blocks,