Просмотр исходного кода

sort formulas and text in a line && bug fix

zhouchangda 8 месяцев назад
Родитель
Сommit
7b7c4599f9

+ 3 - 2
docs/pipeline_usage/tutorials/ocr_pipelines/PP-StructureV3.en.md

@@ -700,8 +700,9 @@ markdown_texts = ""
 markdown_images = []
 
 for res in output:
-    markdown_texts += res.markdown["markdown_texts"]
-    markdown_images.append(res.markdown["markdown_images"])
+    md_info = res.markdown
+    markdown_list.append(md_info)
+    markdown_images.append(md_info.get("markdown_images", {}))
 
 mkd_file_path = output_path / f"{Path(input_file).stem}.md"
 mkd_file_path.parent.mkdir(parents=True, exist_ok=True)

+ 3 - 2
docs/pipeline_usage/tutorials/ocr_pipelines/PP-StructureV3.md

@@ -647,8 +647,9 @@ markdown_list = []
 markdown_images = []
 
 for res in output:
-    markdown_list.append(res.markdown)
-    markdown_images.append(res.get("markdown_images", {}))
+    md_info = res.markdown
+    markdown_list.append(md_info)
+    markdown_images.append(md_info.get("markdown_images", {}))
 
 markdown_texts = pipeline.concatenate_markdown_pages(markdown_list)
 

+ 2 - 1
paddlex/configs/pipelines/PP-StructureV3.yaml

@@ -21,6 +21,7 @@ SubModules:
     layout_merge_bboxes_mode: 
       1: "large"  # image
       18: "large" # chart
+      7: "large"  # formula
 
 SubPipelines:
   DocPreprocessor:
@@ -45,7 +46,7 @@ SubPipelines:
     SubModules:
       TextDetection:
         module_name: text_detection
-        model_name: PP-OCRv4_mobile_det
+        model_name: PP-OCRv4_server_det
         model_dir: null
         limit_side_len: 960
         limit_type: max

+ 9 - 2
paddlex/inference/pipelines/layout_parsing/pipeline_v2.py

@@ -309,7 +309,9 @@ class LayoutParsingPipelineV2(BasePipeline):
                     del overall_ocr_res["rec_polys"][matched_idx]
                     del overall_ocr_res["rec_scores"][matched_idx]
 
-                if sub_ocr_res["rec_boxes"] is not []:
+                if sub_ocr_res["rec_boxes"].size > 0:
+                    sub_ocr_res["rec_labels"] = ["text"] * len(sub_ocr_res["rec_texts"])
+
                     overall_ocr_res["dt_polys"].extend(sub_ocr_res["dt_polys"])
                     overall_ocr_res["rec_texts"].extend(sub_ocr_res["rec_texts"])
                     overall_ocr_res["rec_boxes"] = np.concatenate(
@@ -317,6 +319,7 @@ class LayoutParsingPipelineV2(BasePipeline):
                     )
                     overall_ocr_res["rec_polys"].extend(sub_ocr_res["rec_polys"])
                     overall_ocr_res["rec_scores"].extend(sub_ocr_res["rec_scores"])
+                    overall_ocr_res["rec_labels"].extend(sub_ocr_res["rec_labels"])
 
         for formula_res in formula_res_list:
             x_min, y_min, x_max, y_max = list(map(int, formula_res["dt_polys"]))
@@ -331,10 +334,12 @@ class LayoutParsingPipelineV2(BasePipeline):
             overall_ocr_res["rec_boxes"] = np.vstack(
                 (overall_ocr_res["rec_boxes"], [formula_res["dt_polys"]])
             )
+            overall_ocr_res["rec_labels"].append("formula")
             overall_ocr_res["rec_polys"].append(poly_points)
             overall_ocr_res["rec_scores"].append(1)
 
         parsing_res_list = get_single_block_parsing_res(
+            self.general_ocr_pipeline,
             overall_ocr_res=overall_ocr_res,
             layout_det_res=layout_det_res,
             table_res_list=table_res_list,
@@ -472,7 +477,7 @@ class LayoutParsingPipelineV2(BasePipeline):
         if not self.check_model_settings_valid(model_settings):
             yield {"error": "the input params for model settings are invalid!"}
 
-        for img_id, batch_data in enumerate(self.batch_sampler(input)):
+        for batch_data in self.batch_sampler(input):
             image_array = self.img_reader(batch_data.instances)[0]
 
             if model_settings["use_doc_preprocessor"]:
@@ -536,6 +541,8 @@ class LayoutParsingPipelineV2(BasePipeline):
             else:
                 overall_ocr_res = {}
 
+            overall_ocr_res["rec_labels"] = ["text"] * len(overall_ocr_res["rec_texts"])
+
             if model_settings["use_table_recognition"]:
                 table_contents = copy.deepcopy(overall_ocr_res)
                 for formula_res in formula_res_list:

+ 1 - 1
paddlex/inference/pipelines/layout_parsing/result_v2.py

@@ -310,7 +310,7 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
                 "table": format_table,
                 "reference": lambda: block["block_content"],
                 "algorithm": lambda: block["block_content"].strip("\n"),
-                "seal": lambda: format_image("block_content"),
+                "seal": lambda: f"Words of Seals:\n{block['block_content']}",
             }
             parsing_res_list = obj["parsing_res_list"]
             markdown_content = ""

+ 159 - 18
paddlex/inference/pipelines/layout_parsing/utils.py

@@ -25,6 +25,7 @@ from PIL import Image
 import uuid
 import re
 from pathlib import Path
+from copy import deepcopy
 from typing import Optional, Union, List, Tuple, Dict, Any
 from ..ocr.result import OCRResult
 from ...models.object_detection.result import DetResult
@@ -252,6 +253,7 @@ def _adjust_span_text(span: List[str], prepend: bool = False, append: bool = Fal
         span[1] = "\n" + span[1]
     if append:
         span[1] = span[1] + "\n"
+    return span
 
 
 def _format_line(
@@ -277,17 +279,127 @@ def _format_line(
 
     if not is_reference:
         if first_span[0][0] - layout_min > 10:
-            _adjust_span_text(first_span, prepend=True)
+            first_span = _adjust_span_text(first_span, prepend=True)
         if layout_max - end_span[0][2] > 10:
-            _adjust_span_text(end_span, append=True)
+            end_span = _adjust_span_text(end_span, append=True)
     else:
         if first_span[0][0] - layout_min < 5:
-            _adjust_span_text(first_span, prepend=True)
+            first_span = _adjust_span_text(first_span, prepend=True)
         if layout_max - end_span[0][2] > 20:
-            _adjust_span_text(end_span, append=True)
+            end_span = _adjust_span_text(end_span, append=True)
+
+    line[0] = first_span
+    line[-1] = end_span
+
+    return line
+
+
+def split_boxes_if_x_contained(boxes, offset=1e-5):
+    """
+    Check if there is any complete containment in the x-direction
+    between the bounding boxes and split the containing box accordingly.
+
+    Args:
+        boxes (list of lists): Each element is a list containing an ndarray of length 4, a description, and a label.
+        offset (float): A small offset value to ensure that the split boxes are not too close to the original boxes.
+    Returns:
+        A new list of boxes, including split boxes, with the same `rec_text` and `label` attributes.
+    """
+
+    def is_x_contained(box_a, box_b):
+        """Check if box_a completely contains box_b in the x-direction."""
+        return box_a[0][0] <= box_b[0][0] and box_a[0][2] >= box_b[0][2]
+
+    new_boxes = []
+
+    for i in range(len(boxes)):
+        box_a = boxes[i]
+        is_split = False
+        for j in range(len(boxes)):
+            if i == j:
+                continue
+            box_b = boxes[j]
+            if is_x_contained(box_a, box_b):
+                is_split = True
+                # Split box_a based on the x-coordinates of box_b
+                if box_a[0][0] < box_b[0][0]:
+                    w = box_b[0][0] - offset - box_a[0][0]
+                    if w > 1:
+                        new_boxes.append(
+                            [
+                                np.array(
+                                    [
+                                        box_a[0][0],
+                                        box_a[0][1],
+                                        box_b[0][0] - offset,
+                                        box_a[0][3],
+                                    ]
+                                ),
+                                box_a[1],
+                                box_a[2],
+                            ]
+                        )
+                if box_a[0][2] > box_b[0][2]:
+                    w = box_a[0][2] - box_b[0][2] + offset
+                    if w > 1:
+                        box_a = [
+                            np.array(
+                                [
+                                    box_b[0][2] + offset,
+                                    box_a[0][1],
+                                    box_a[0][2],
+                                    box_a[0][3],
+                                ]
+                            ),
+                            box_a[1],
+                            box_a[2],
+                        ]
+            if j == len(boxes) - 1 and is_split:
+                new_boxes.append(box_a)
+        if not is_split:
+            new_boxes.append(box_a)
+
+    return new_boxes
+
+
+def _sort_line_by_x_projection(
+    input_img: np.ndarray,
+    general_ocr_pipeline: Any,
+    line: List[List[Union[List[int], str]]],
+) -> None:
+    """
+    Sort a line of text spans based on their vertical position within the layout bounding box.
+
+    Args:
+        input_img (ndarray): The input image used for OCR.
+        general_ocr_pipeline (Any): The general OCR pipeline used for text recognition.
+        line (list): A list of spans, where each span is a list containing a bounding box and text.
+
+    Returns:
+        list: The sorted line of text spans.
+    """
+    splited_boxes = split_boxes_if_x_contained(line)
+    splited_lines = []
+    if len(line) != len(splited_boxes):
+        splited_boxes.sort(key=lambda span: span[0][0])
+        text_rec_model = general_ocr_pipeline.text_rec_model
+        for span in splited_boxes:
+            if span[2] == "text":
+                crop_img = input_img[
+                    int(span[0][1]) : int(span[0][3]),
+                    int(span[0][0]) : int(span[0][2]),
+                ]
+                span[1] = next(text_rec_model([crop_img]))["rec_text"]
+            splited_lines.append(span)
+    else:
+        splited_lines = line
+
+    return splited_lines
 
 
 def _sort_ocr_res_by_y_projection(
+    input_img: np.ndarray,
+    general_ocr_pipeline: Any,
     label: Any,
     block_bbox: Tuple[int, int, int, int],
     ocr_res: Dict[str, List[Any]],
@@ -297,6 +409,8 @@ def _sort_ocr_res_by_y_projection(
     Sorts OCR results based on their spatial arrangement, grouping them into lines and blocks.
 
     Args:
+        input_img (ndarray): The input image used for OCR.
+        general_ocr_pipeline (Any): The general OCR pipeline used for text recognition.
         label (Any): The label associated with the OCR results. It's not used in the function but might be
                      relevant for other parts of the calling context.
         block_bbox (Tuple[int, int, int, int]): A tuple representing the layout bounding box, defined as
@@ -317,12 +431,13 @@ def _sort_ocr_res_by_y_projection(
 
     boxes = ocr_res["boxes"]
     rec_texts = ocr_res["rec_texts"]
+    rec_labels = ocr_res["rec_labels"]
 
     x_min, _, x_max, _ = block_bbox
     inline_x_min = min([box[0] for box in boxes])
     inline_x_max = max([box[2] for box in boxes])
 
-    spans = list(zip(boxes, rec_texts))
+    spans = list(zip(boxes, rec_texts, rec_labels))
 
     spans.sort(key=lambda span: span[0][1])
     spans = [list(span) for span in spans]
@@ -349,16 +464,21 @@ def _sort_ocr_res_by_y_projection(
     if current_line:
         lines.append(current_line)
 
+    new_lines = []
     for line in lines:
         line.sort(key=lambda span: span[0][0])
+
+        ocr_labels = [span[2] for span in line]
+        if "formula" in ocr_labels:
+            line = _sort_line_by_x_projection(input_img, general_ocr_pipeline, line)
         if label == "reference":
             line = _format_line(line, inline_x_min, inline_x_max, is_reference=True)
         else:
             line = _format_line(line, x_min, x_max)
+        new_lines.append(line)
 
-    # Flatten lines back into a single list for boxes and texts
-    ocr_res["boxes"] = [span[0] for line in lines for span in line]
-    ocr_res["rec_texts"] = [span[1] + " " for line in lines for span in line]
+    ocr_res["boxes"] = [span[0] for line in new_lines for span in line]
+    ocr_res["rec_texts"] = [span[1] + " " for line in new_lines for span in line]
 
     return ocr_res
 
@@ -417,6 +537,7 @@ def _process_text(input_text: str) -> str:
 
 
 def get_single_block_parsing_res(
+    general_ocr_pipeline: Any,
     overall_ocr_res: OCRResult,
     layout_det_res: DetResult,
     table_res_list: list,
@@ -451,10 +572,16 @@ def get_single_block_parsing_res(
     input_img = overall_ocr_res["doc_preprocessor_res"]["output_img"]
     seal_index = 0
 
-    for box_info in layout_det_res["boxes"]:
+    layout_det_res_list, _ = _remove_overlap_blocks(
+        deepcopy(layout_det_res["boxes"]),
+        threshold=0.5,
+        smaller=True,
+    )
+
+    for box_info in layout_det_res_list:
         block_bbox = box_info["coordinate"]
         label = box_info["label"]
-        rec_res = {"boxes": [], "rec_texts": [], "flag": False}
+        rec_res = {"boxes": [], "rec_texts": [], "rec_labels": [], "flag": False}
         seg_start_flag = True
         seg_end_flag = True
 
@@ -503,10 +630,15 @@ def get_single_block_parsing_res(
                     rec_res["rec_texts"].append(
                         overall_ocr_res["rec_texts"][box_no],
                     )
+                    rec_res["rec_labels"].append(
+                        overall_ocr_res["rec_labels"][box_no],
+                    )
                     rec_res["flag"] = True
 
             if rec_res["flag"]:
-                rec_res = _sort_ocr_res_by_y_projection(label, block_bbox, rec_res, 0.7)
+                rec_res = _sort_ocr_res_by_y_projection(
+                    input_img, general_ocr_pipeline, label, block_bbox, rec_res, 0.7
+                )
                 rec_res_first_bbox = rec_res["boxes"][0]
                 rec_res_end_bbox = rec_res["boxes"][-1]
                 if rec_res_first_bbox[0] - block_bbox[0] < 10:
@@ -547,6 +679,20 @@ def get_single_block_parsing_res(
                     },
                 )
 
+    if len(layout_det_res_list) == 0:
+        for ocr_rec_box, ocr_rec_text in zip(
+            overall_ocr_res["rec_boxes"], overall_ocr_res["rec_texts"]
+        ):
+            single_block_layout_parsing_res.append(
+                {
+                    "block_label": "text",
+                    "block_content": ocr_rec_text,
+                    "block_bbox": ocr_rec_box,
+                    "seg_start_flag": True,
+                    "seg_end_flag": True,
+                },
+            )
+
     single_block_layout_parsing_res = get_layout_ordering(
         single_block_layout_parsing_res,
         no_mask_labels=[
@@ -875,8 +1021,8 @@ def _remove_overlap_blocks(
                 continue
             # Check for overlap and determine which block to remove
             overlap_box_index = _get_minbox_if_overlap_by_ratio(
-                block1["block_bbox"],
-                block2["block_bbox"],
+                block1["coordinate"],
+                block2["coordinate"],
                 threshold,
                 smaller=smaller,
             )
@@ -1384,11 +1530,6 @@ def get_layout_ordering(
     vision_labels = ["image", "table", "seal", "chart", "figure"]
     vision_title_labels = ["table_title", "chart_title", "figure_title"]
 
-    parsing_res_list, _ = _remove_overlap_blocks(
-        parsing_res_list,
-        threshold=0.5,
-        smaller=True,
-    )
     parsing_res_list, pre_cuts = _get_sub_category(parsing_res_list, title_text_labels)
 
     parsing_res_by_pre_cuts_list = []