瀏覽代碼

Fix layout pipeline (#2969)

* fix layout_parse

* fix layout api_test
shuai.liu 9 月之前
父節點
當前提交
61268b93f0

+ 48 - 28
paddlex/inference/pipelines/layout_parsing/pipeline_v2.py

@@ -227,20 +227,15 @@ class LayoutParsingPipelineV2(BasePipeline):
         seal_res_list: list,
         seal_res_list: list,
     ) -> list:
     ) -> list:
         """
         """
-        Retrieves the layout parsing result based on the layout detection result, OCR result, and other recognition results.
+        Get the layout parsing result based on the layout detection result, OCR result, and other recognition results.
+
         Args:
         Args:
             image (list): The input image.
             image (list): The input image.
-            overall_ocr_res (OCRResult): An object containing the overall OCR results, including detected text boxes and recognized text. The structure is expected to have:
-            - "input_img": The image on which OCR was performed.
-            - "dt_boxes": A list of detected text box coordinates.
-            - "rec_texts": A list of recognized text corresponding to the detected boxes.
-
-            layout_det_res (DetResult): An object containing the layout detection results, including detected layout boxes and their labels. The structure is expected to have:
-                - "boxes": A list of dictionaries with keys "coordinate" for box coordinates and "label" for the type of content.
+            layout_det_res (DetResult): The layout detection results.
+            overall_ocr_res (OCRResult): The overall OCR results.
+            table_res_list (list): A list of table detection results.
+            seal_res_list (list): A list of seal detection results.
 
 
-            table_res_list (list): A list of table detection results, where each item is a dictionary containing:
-                - "layout_bbox": The bounding box of the table layout.
-                - "pred_html": The predicted HTML representation of the table.
         Returns:
         Returns:
             list: A list of dictionaries representing the layout parsing result.
             list: A list of dictionaries representing the layout parsing result.
         """
         """
@@ -274,14 +269,16 @@ class LayoutParsingPipelineV2(BasePipeline):
         Get the model settings based on the provided parameters or default values.
         Get the model settings based on the provided parameters or default values.
 
 
         Args:
         Args:
-            use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
-            use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
-            use_general_ocr (Optional[bool]): Whether to use general OCR.
-            use_seal_recognition (Optional[bool]): Whether to use seal recognition.
-            use_table_recognition (Optional[bool]): Whether to use table recognition.
+            use_doc_orientation_classify (Union[bool, None]): Enables document orientation classification if True. Defaults to system setting if None.
+            use_doc_unwarping (Union[bool, None]): Enables document unwarping if True. Defaults to system setting if None.
+            use_general_ocr (Union[bool, None]): Enables general OCR if True. Defaults to system setting if None.
+            use_seal_recognition (Union[bool, None]): Enables seal recognition if True. Defaults to system setting if None.
+            use_table_recognition (Union[bool, None]): Enables table recognition if True. Defaults to system setting if None.
+            use_formula_recognition (Union[bool, None]): Enables formula recognition if True. Defaults to system setting if None.
 
 
         Returns:
         Returns:
             dict: A dictionary containing the model settings.
             dict: A dictionary containing the model settings.
+
         """
         """
         if use_doc_orientation_classify is None and use_doc_unwarping is None:
         if use_doc_orientation_classify is None and use_doc_unwarping is None:
             use_doc_preprocessor = self.use_doc_preprocessor
             use_doc_preprocessor = self.use_doc_preprocessor
@@ -316,10 +313,15 @@ class LayoutParsingPipelineV2(BasePipeline):
         input: Union[str, list[str], np.ndarray, list[np.ndarray]],
         input: Union[str, list[str], np.ndarray, list[np.ndarray]],
         use_doc_orientation_classify: Union[bool, None] = None,
         use_doc_orientation_classify: Union[bool, None] = None,
         use_doc_unwarping: Union[bool, None] = None,
         use_doc_unwarping: Union[bool, None] = None,
+        use_textline_orientation: Optional[bool] = None,
         use_general_ocr: Union[bool, None] = None,
         use_general_ocr: Union[bool, None] = None,
         use_seal_recognition: Union[bool, None] = None,
         use_seal_recognition: Union[bool, None] = None,
         use_table_recognition: Union[bool, None] = None,
         use_table_recognition: Union[bool, None] = None,
         use_formula_recognition: Union[bool, None] = None,
         use_formula_recognition: Union[bool, None] = None,
+        layout_threshold: Optional[Union[float, dict]] = None,
+        layout_nms: Optional[bool] = None,
+        layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]] = None,
+        layout_merge_bboxes_mode: Optional[str] = None,
         text_det_limit_side_len: Union[int, None] = None,
         text_det_limit_side_len: Union[int, None] = None,
         text_det_limit_type: Union[str, None] = None,
         text_det_limit_type: Union[str, None] = None,
         text_det_thresh: Union[float, None] = None,
         text_det_thresh: Union[float, None] = None,
@@ -332,23 +334,40 @@ class LayoutParsingPipelineV2(BasePipeline):
         seal_det_box_thresh: Union[float, None] = None,
         seal_det_box_thresh: Union[float, None] = None,
         seal_det_unclip_ratio: Union[float, None] = None,
         seal_det_unclip_ratio: Union[float, None] = None,
         seal_rec_score_thresh: Union[float, None] = None,
         seal_rec_score_thresh: Union[float, None] = None,
-        layout_threshold: Optional[Union[float, dict]] = None,
-        layout_nms: Optional[bool] = None,
-        layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]] = None,
-        layout_merge_bboxes_mode: Optional[str] = None,
         **kwargs,
         **kwargs,
     ) -> LayoutParsingResultV2:
     ) -> LayoutParsingResultV2:
         """
         """
-        This function predicts the layout parsing result for the given input.
+        Predicts the layout parsing result for the given input.
 
 
         Args:
         Args:
-            input (Union[str, list[str], np.ndarray, list[np.ndarray]]): The input image(s) or pdf(s) to be processed.
-            use_doc_orientation_classify (bool): Whether to use document orientation classification.
-            use_doc_unwarping (bool): Whether to use document unwarping.
-            use_general_ocr (bool): Whether to use general OCR.
-            use_seal_recognition (bool): Whether to use seal recognition.
-            use_table_recognition (bool): Whether to use table recognition.
-            **kwargs: Additional keyword arguments.
+            use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
+            use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
+            use_textline_orientation (Optional[bool]): Whether to use textline orientation prediction.
+            use_general_ocr (Optional[bool]): Whether to use general OCR.
+            use_seal_recognition (Optional[bool]): Whether to use seal recognition.
+            use_table_recognition (Optional[bool]): Whether to use table recognition.
+            use_formula_recognition (Optional[bool]): Whether to use formula recognition.
+            layout_threshold (Optional[float]): The threshold value to filter out low-confidence predictions. Default is None.
+            layout_nms (bool, optional): Whether to use layout-aware NMS. Defaults to False.
+            layout_unclip_ratio (Optional[Union[float, Tuple[float, float]]], optional): The ratio of unclipping the bounding box.
+                Defaults to None.
+                If it's a single number, then both width and height are used.
+                If it's a tuple of two numbers, then they are used separately for width and height respectively.
+                If it's None, then no unclipping will be performed.
+            layout_merge_bboxes_mode (Optional[str], optional): The mode for merging bounding boxes. Defaults to None.
+            text_det_limit_side_len (Optional[int]): Maximum side length for text detection.
+            text_det_limit_type (Optional[str]): Type of limit to apply for text detection.
+            text_det_thresh (Optional[float]): Threshold for text detection.
+            text_det_box_thresh (Optional[float]): Threshold for text detection boxes.
+            text_det_unclip_ratio (Optional[float]): Ratio for unclipping text detection boxes.
+            text_rec_score_thresh (Optional[float]): Score threshold for text recognition.
+            seal_det_limit_side_len (Optional[int]): Maximum side length for seal detection.
+            seal_det_limit_type (Optional[str]): Type of limit to apply for seal detection.
+            seal_det_thresh (Optional[float]): Threshold for seal detection.
+            seal_det_box_thresh (Optional[float]): Threshold for seal detection boxes.
+            seal_det_unclip_ratio (Optional[float]): Ratio for unclipping seal detection boxes.
+            seal_rec_score_thresh (Optional[float]): Score threshold for seal recognition.
+            **kwargs (Any): Additional settings to extend functionality.
 
 
         Returns:
         Returns:
             LayoutParsingResultV2: The predicted layout parsing result.
             LayoutParsingResultV2: The predicted layout parsing result.
@@ -417,6 +436,7 @@ class LayoutParsingPipelineV2(BasePipeline):
                 overall_ocr_res = next(
                 overall_ocr_res = next(
                     self.general_ocr_pipeline(
                     self.general_ocr_pipeline(
                         doc_preprocessor_image,
                         doc_preprocessor_image,
+                        use_textline_orientation=use_textline_orientation,
                         text_det_limit_side_len=text_det_limit_side_len,
                         text_det_limit_side_len=text_det_limit_side_len,
                         text_det_limit_type=text_det_limit_type,
                         text_det_limit_type=text_det_limit_type,
                         text_det_thresh=text_det_thresh,
                         text_det_thresh=text_det_thresh,

+ 15 - 35
paddlex/inference/pipelines/layout_parsing/result_v2.py

@@ -15,7 +15,6 @@ from __future__ import annotations
 
 
 import copy
 import copy
 from pathlib import Path
 from pathlib import Path
-from PIL import Image, ImageDraw
 from typing import Dict
 from typing import Dict
 
 
 import cv2
 import cv2
@@ -236,13 +235,12 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
                 res_xlsx_dict[key] = table_res.xlsx["pred"]
                 res_xlsx_dict[key] = table_res.xlsx["pred"]
         return res_xlsx_dict
         return res_xlsx_dict
 
 
-    def save_to_pdf_order(self, save_path):
+    def save_to_pdf_order(self, save_path: str) -> None:
         """
         """
         Save the layout ordering to an image file.
         Save the layout ordering to an image file.
 
 
         Args:
         Args:
-            save_path (str or Path): The path where the image should be saved.
-            font_path (str): Path to the font file used for drawing text.
+            save_path (str): The path where the image should be saved.
 
 
         Returns:
         Returns:
             None
             None
@@ -257,6 +255,7 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
                 save_path = save_path / f"{input_path.stem}.jpg"
                 save_path = save_path / f"{input_path.stem}.jpg"
         else:
         else:
             save_path = save_path.with_suffix("")
             save_path = save_path.with_suffix("")
+
         ordering_image_path = (
         ordering_image_path = (
             save_path.parent / f"{save_path.stem}_layout_order_res.jpg"
             save_path.parent / f"{save_path.stem}_layout_order_res.jpg"
         )
         )
@@ -268,8 +267,8 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
             return
             return
 
 
         draw = ImageDraw.Draw(image, "RGBA")
         draw = ImageDraw.Draw(image, "RGBA")
-
         parsing_result = self["parsing_res_list"]
         parsing_result = self["parsing_res_list"]
+
         for block in parsing_result:
         for block in parsing_result:
             if self.already_sorted == False:
             if self.already_sorted == False:
                 block = get_layout_ordering(
                 block = get_layout_ordering(
@@ -295,14 +294,15 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
                 if index is not None:
                 if index is not None:
                     text_position = (bbox[2] + 2, bbox[1] - 10)
                     text_position = (bbox[2] + 2, bbox[1] - 10)
                     draw.text(text_position, str(index), fill="red")
                     draw.text(text_position, str(index), fill="red")
-        self.already_sorted == True
+
+        self.already_sorted = True
 
 
         # Ensure the directory exists and save the image
         # Ensure the directory exists and save the image
         ordering_image_path.parent.mkdir(parents=True, exist_ok=True)
         ordering_image_path.parent.mkdir(parents=True, exist_ok=True)
         print(f"Saving ordering image to {ordering_image_path}")
         print(f"Saving ordering image to {ordering_image_path}")
         image.save(str(ordering_image_path))
         image.save(str(ordering_image_path))
 
 
-    def _to_markdown(self):
+    def _to_markdown(self) -> dict:
         """
         """
         Save the parsing result to a Markdown file.
         Save the parsing result to a Markdown file.
 
 
@@ -366,43 +366,23 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
                     + "\n"
                     + "\n"
                 )
                 )
 
 
-            def format_image():
+            def format_image(label):
                 if is_save_mk_img is False:
                 if is_save_mk_img is False:
                     return ""
                     return ""
 
 
                 img_tags = []
                 img_tags = []
-                if "img" in sub_block["image"]:
-                    img_tags.append(
-                        '<div style="text-align: center;"><img src="{}" alt="Image" /></div>'.format(
-                            sub_block["image"]["img"]
-                            .replace("-\n", "")
-                            .replace("\n", " "),
-                        ),
-                    )
-                if "image_text" in sub_block["image"]:
-                    img_tags.append(
-                        '<div style="text-align: center;">{}</div>'.format(
-                            sub_block["image"]["image_text"]
-                            .replace("-\n", "")
-                            .replace("\n", " "),
-                        ),
-                    )
-                return "\n".join(img_tags)
-
-            def format_chart():
-                img_tags = []
-                if "img" in sub_block["chart"]:
+                if "img" in sub_block[label]:
                     img_tags.append(
                     img_tags.append(
                         '<div style="text-align: center;"><img src="{}" alt="Image" /></div>'.format(
                         '<div style="text-align: center;"><img src="{}" alt="Image" /></div>'.format(
-                            sub_block["chart"]["img"]
+                            sub_block[label]["img"]
                             .replace("-\n", "")
                             .replace("-\n", "")
                             .replace("\n", " "),
                             .replace("\n", " "),
                         ),
                         ),
                     )
                     )
-                if "image_text" in sub_block["chart"]:
+                if "image_text" in sub_block[label]:
                     img_tags.append(
                     img_tags.append(
                         '<div style="text-align: center;">{}</div>'.format(
                         '<div style="text-align: center;">{}</div>'.format(
-                            sub_block["chart"]["image_text"]
+                            sub_block[label]["image_text"]
                             .replace("-\n", "")
                             .replace("-\n", "")
                             .replace("\n", " "),
                             .replace("\n", " "),
                         ),
                         ),
@@ -440,14 +420,14 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
                 "content": lambda: sub_block["content"]
                 "content": lambda: sub_block["content"]
                 .replace("-\n", " ")
                 .replace("-\n", " ")
                 .replace("\n", " "),
                 .replace("\n", " "),
-                "image": format_image,
-                "chart": format_chart,
+                "image": lambda: format_image("image"),
+                "chart": lambda: format_image("chart"),
                 "formula": lambda: f"$${sub_block['formula']}$$",
                 "formula": lambda: f"$${sub_block['formula']}$$",
                 "table": format_table,
                 "table": format_table,
                 # "reference": format_reference,
                 # "reference": format_reference,
                 "reference": lambda: sub_block["reference"],
                 "reference": lambda: sub_block["reference"],
                 "algorithm": lambda: sub_block["algorithm"].strip("\n"),
                 "algorithm": lambda: sub_block["algorithm"].strip("\n"),
-                "seal": lambda: sub_block["seal"].strip("\n"),
+                "seal": lambda: format_image("seal"),
             }
             }
             parsing_result = obj["parsing_res_list"]
             parsing_result = obj["parsing_res_list"]
             markdown_content = ""
             markdown_content = ""

文件差異過大導致無法顯示
+ 327 - 253
paddlex/inference/pipelines/layout_parsing/utils.py


部分文件因文件數量過多而無法顯示