Bladeren bron

bugfix: PP-StructureV3 (#3650)

changdazhou 8 maanden geleden
bovenliggende
commit
6b631515c9

+ 0 - 1
paddlex/inference/pipelines/layout_parsing/pipeline_v2.py

@@ -360,7 +360,6 @@ class LayoutParsingPipelineV2(BasePipeline):
             layout_det_res=layout_det_res,
             table_res_list=table_res_list,
             seal_res_list=seal_res_list,
-            imgs_in_doc=imgs_in_doc,
         )
 
         return parsing_res_list

+ 62 - 12
paddlex/inference/pipelines/layout_parsing/result_v2.py

@@ -317,6 +317,60 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
             def format_table():
                 return "\n" + block["block_content"]
 
+            def get_seg_flag(block, prev_block):
+
+                seg_start_flag = True
+                seg_end_flag = True
+
+                block_box = block["block_bbox"]
+                context_left_coordinate = block_box[0]
+                context_right_coordinate = block_box[2]
+                seg_start_coordinate = block.get("seg_start_coordinate")
+                seg_end_coordinate = block.get("seg_end_coordinate")
+
+                if prev_block is not None:
+                    prev_block_bbox = prev_block["block_bbox"]
+                    num_of_prev_lines = prev_block.get("num_of_lines")
+                    pre_block_seg_end_coordinate = prev_block.get("seg_end_coordinate")
+
+                    # update context_left_coordinate and context_right_coordinate
+                    if context_left_coordinate < prev_block_bbox[2]:
+                        context_left_coordinate = min(
+                            prev_block_bbox[0], context_left_coordinate
+                        )
+                        context_right_coordinate = max(
+                            prev_block_bbox[2], context_right_coordinate
+                        )
+
+                    # 判断是否需要分段
+                    prev_end_space_small = (
+                        prev_block_bbox[2] - pre_block_seg_end_coordinate < 10
+                    )
+                    current_start_space_small = (
+                        seg_start_coordinate - context_left_coordinate < 10
+                    )
+                    overlap_blocks = context_left_coordinate < prev_block_bbox[2]
+                    prev_lines_more_than_one = num_of_prev_lines > 1
+
+                    if (
+                        overlap_blocks
+                        and current_start_space_small
+                        and prev_lines_more_than_one
+                    ) or (
+                        prev_end_space_small
+                        and current_start_space_small
+                        and prev_lines_more_than_one
+                    ):
+                        seg_start_flag = False
+                else:
+                    if seg_start_coordinate - context_left_coordinate < 10:
+                        seg_start_flag = False
+
+                if context_right_coordinate - seg_end_coordinate < 10:
+                    seg_end_flag = False
+
+                return seg_start_flag, seg_end_flag
+
             handlers = {
                 "paragraph_title": lambda: format_title(block["block_content"]),
                 "doc_title": lambda: f"# {block['block_content']}".replace(
@@ -333,8 +387,8 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
                     ["摘要", "abstract"], lambda l: f"## {l}\n", " "
                 ),
                 "content": lambda: block["block_content"]
-                .replace("-\n", " ")
-                .replace("\n", " "),
+                .replace("-\n", "  \n")
+                .replace("\n", "  \n"),
                 "image": lambda: format_image("block_image"),
                 "chart": lambda: format_image("block_image"),
                 "formula": lambda: f"$${block['block_content']}$$",
@@ -350,18 +404,17 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
             last_label = None
             seg_start_flag = None
             seg_end_flag = None
+            prev_block = None
             page_first_element_seg_start_flag = None
             page_last_element_seg_end_flag = None
             parsing_res_list = sorted(
                 parsing_res_list,
                 key=lambda x: x.get("sub_index", 999),
             )
-            for block in sorted(
-                parsing_res_list,
-                key=lambda x: x.get("sub_index", 999),
-            ):
+            for block in parsing_res_list:
+                seg_start_flag, seg_end_flag = get_seg_flag(block, prev_block)
+
                 label = block.get("block_label")
-                seg_start_flag = block.get("seg_start_flag")
                 page_first_element_seg_start_flag = (
                     seg_start_flag
                     if (page_first_element_seg_start_flag is None)
@@ -369,10 +422,8 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
                 )
                 handler = handlers.get(label)
                 if handler:
-                    if (
-                        label == last_label == "text"
-                        and seg_start_flag == seg_end_flag == False
-                    ):
+                    prev_block = block
+                    if label == last_label == "text" and seg_start_flag == False:
                         last_char_of_markdown = (
                             markdown_content[-1] if markdown_content else ""
                         )
@@ -396,7 +447,6 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
                             "\n\n" + handler() if markdown_content else handler()
                         )
                     last_label = label
-                    seg_end_flag = block.get("seg_end_flag")
             page_last_element_seg_end_flag = seg_end_flag
 
             return markdown_content, (

+ 26 - 26
paddlex/inference/pipelines/layout_parsing/utils.py

@@ -473,14 +473,18 @@ def _sort_ocr_res_by_y_projection(
             line = _sort_line_by_x_projection(input_img, general_ocr_pipeline, line)
         if label == "reference":
             line = _format_line(line, inline_x_min, inline_x_max, is_reference=True)
-        else:
+        elif label != "content":
             line = _format_line(line, x_min, x_max)
         new_lines.append(line)
 
     ocr_res["boxes"] = [span[0] for line in new_lines for span in line]
-    ocr_res["rec_texts"] = [span[1] + " " for line in new_lines for span in line]
-
-    return ocr_res
+    if label == "content":
+        ocr_res["rec_texts"] = [
+            "".join(f"{span[1]} " for span in line).rstrip() for line in new_lines
+        ]
+    else:
+        ocr_res["rec_texts"] = [span[1] + " " for line in new_lines for span in line]
+    return ocr_res, len(new_lines)
 
 
 def _process_text(input_text: str) -> str:
@@ -542,7 +546,6 @@ def get_single_block_parsing_res(
     layout_det_res: DetResult,
     table_res_list: list,
     seal_res_list: list,
-    imgs_in_doc: list,
 ) -> OCRResult:
     """
     Extract structured information from OCR and layout detection results.
@@ -583,8 +586,9 @@ def get_single_block_parsing_res(
         block_bbox = box_info["coordinate"]
         label = box_info["label"]
         rec_res = {"boxes": [], "rec_texts": [], "rec_labels": [], "flag": False}
-        seg_start_flag = True
-        seg_end_flag = True
+        seg_start_coordinate = float("inf")
+        seg_end_coordinate = float("-inf")
+        num_of_lines = 1
 
         if label == "table":
             for table_res in table_res_list:
@@ -599,8 +603,6 @@ def get_single_block_parsing_res(
                             "block_label": label,
                             "block_content": table_res["pred_html"],
                             "block_bbox": block_bbox,
-                            "seg_start_flag": seg_start_flag,
-                            "seg_end_flag": seg_end_flag,
                         },
                     )
                     break
@@ -613,8 +615,6 @@ def get_single_block_parsing_res(
                             ", ".join(seal_res_list[seal_index]["rec_texts"])
                         ),
                         "block_bbox": block_bbox,
-                        "seg_start_flag": seg_start_flag,
-                        "seg_end_flag": seg_end_flag,
                     },
                 )
                 seal_index += 1
@@ -637,15 +637,11 @@ def get_single_block_parsing_res(
                     rec_res["flag"] = True
 
             if rec_res["flag"]:
-                rec_res = _sort_ocr_res_by_y_projection(
+                rec_res, num_of_lines = _sort_ocr_res_by_y_projection(
                     input_img, general_ocr_pipeline, label, block_bbox, rec_res, 0.7
                 )
-                rec_res_first_bbox = rec_res["boxes"][0]
-                rec_res_end_bbox = rec_res["boxes"][-1]
-                if rec_res_first_bbox[0] - block_bbox[0] < 10:
-                    seg_start_flag = False
-                if block_bbox[2] - rec_res_end_bbox[2] < 10:
-                    seg_end_flag = False
+                seg_start_coordinate = rec_res["boxes"][0][0]
+                seg_end_coordinate = rec_res["boxes"][-1][2]
                 if label == "formula":
                     rec_res["rec_texts"] = [
                         rec_res_text.replace("$", "")
@@ -662,13 +658,13 @@ def get_single_block_parsing_res(
                         "block_content": _process_text("".join(rec_res["rec_texts"])),
                         "block_image": {img_path: img},
                         "block_bbox": block_bbox,
-                        "seg_start_flag": seg_start_flag,
-                        "seg_end_flag": seg_end_flag,
                     },
                 )
             else:
                 if label in ["doc_title"]:
                     content = " ".join(rec_res["rec_texts"])
+                elif label in ["content"]:
+                    content = "\n".join(rec_res["rec_texts"])
                 else:
                     content = "".join(rec_res["rec_texts"])
                     if label != "reference":
@@ -678,8 +674,9 @@ def get_single_block_parsing_res(
                         "block_label": label,
                         "block_content": content,
                         "block_bbox": block_bbox,
-                        "seg_start_flag": seg_start_flag,
-                        "seg_end_flag": seg_end_flag,
+                        "seg_start_coordinate": seg_start_coordinate,
+                        "seg_end_coordinate": seg_end_coordinate,
+                        "num_of_lines": num_of_lines,
                     },
                 )
 
@@ -692,8 +689,8 @@ def get_single_block_parsing_res(
                     "block_label": "text",
                     "block_content": ocr_rec_text,
                     "block_bbox": ocr_rec_box,
-                    "seg_start_flag": True,
-                    "seg_end_flag": True,
+                    "seg_start_coordinate": ocr_rec_box[0],
+                    "seg_end_coordinate": ocr_rec_box[2],
                 },
             )
 
@@ -1935,11 +1932,14 @@ def get_layout_ordering(
             "block_content": parsing_res["block_content"],
             "block_bbox": parsing_res["block_bbox"],
             "block_image": parsing_res.get("block_image", None),
-            "seg_start_flag": parsing_res["seg_start_flag"],
-            "seg_end_flag": parsing_res["seg_end_flag"],
             "sub_label": parsing_res["sub_label"],
             "sub_index": parsing_res["sub_index"],
             "index": parsing_res.get("index", None),
+            "seg_start_coordinate": parsing_res.get(
+                "seg_start_coordinate", float("inf")
+            ),
+            "seg_end_coordinate": parsing_res.get("seg_end_coordinate", float("-inf")),
+            "num_of_lines": parsing_res.get("num_of_lines", 1),
         }
         for parsing_res in final_parsing_res_list
     ]