瀏覽代碼

fix & polish layout_parsing_v2 (#3235)

* fix & polish layout_parsing_v2

* update layout_parsing_v2 for type hint
cuicheng01 9 月之前
父節點
當前提交
0a93e36d38

+ 0 - 1
api_examples/pipelines/test_layout_parsing_v2.py

@@ -32,4 +32,3 @@ for res in output:
     res.save_to_xlsx("./output")
     res.save_to_xlsx("./output")
     res.save_to_html("./output")
     res.save_to_html("./output")
     res.save_to_markdown("./output")
     res.save_to_markdown("./output")
-    res.save_to_pdf_order("./output")

文件差異過大導致無法顯示
+ 0 - 0
docs/pipeline_usage/tutorials/ocr_pipelines/layout_parsing_v2.md


+ 1 - 12
paddlex/configs/pipelines/layout_parsing_v2.yaml

@@ -51,7 +51,7 @@ SubPipelines:
         
         
       TextRecognition:
       TextRecognition:
         module_name: text_recognition
         module_name: text_recognition
-        model_name: PP-OCRv4_server_rec
+        model_name: PP-OCRv4_server_rec_doc
         model_dir: null
         model_dir: null
         batch_size: 1
         batch_size: 1
         score_thresh: 0.0
         score_thresh: 0.0
@@ -87,17 +87,6 @@ SubPipelines:
         model_name: RT-DETR-L_wireless_table_cell_det
         model_name: RT-DETR-L_wireless_table_cell_det
         model_dir: null
         model_dir: null
 
 
-  # TableRecognition:
-  #   pipeline_name: table_recognition
-  #   use_layout_detection: False
-  #   use_doc_preprocessor: False
-  #   use_ocr_model: False
-  #   SubModules:
-  #     TableStructureRecognition:
-  #       module_name: table_structure_recognition
-  #       model_name: SLANet_plus
-  #       model_dir: null
-
   SealRecognition:
   SealRecognition:
     pipeline_name: seal_recognition
     pipeline_name: seal_recognition
     use_layout_detection: False
     use_layout_detection: False

+ 81 - 26
paddlex/inference/common/result/mixin.py

@@ -12,9 +12,10 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-from typing import Union, Tuple, List, Dict, Any, Iterator
+from typing import Union, Tuple, List, Dict, Any, Iterator, Callable, Optional
 from abc import abstractmethod
 from abc import abstractmethod
 from pathlib import Path
 from pathlib import Path
+import os
 import mimetypes
 import mimetypes
 import json
 import json
 import copy
 import copy
@@ -379,7 +380,6 @@ class CSVMixin:
 
 
         if not _is_csv_file(save_path):
         if not _is_csv_file(save_path):
             fn = Path(self._get_input_fn())
             fn = Path(self._get_input_fn())
-            fn = Path(self._get_input_fn())
             stem = fn.stem
             stem = fn.stem
             base_save_path = Path(save_path)
             base_save_path = Path(save_path)
             for key in self.csv:
             for key in self.csv:
@@ -597,53 +597,108 @@ class VideoMixin:
 
 
 
 
 class MarkdownMixin:
 class MarkdownMixin:
+    """Mixin class for adding Markdown handling capabilities."""
 
 
     def __init__(self, *args: list, **kwargs: dict):
     def __init__(self, *args: list, **kwargs: dict):
+        """Initializes the Markdown writer and appends the save_to_markdown method to the save functions.
+
+        Args:
+            *args: Positional arguments to be passed to the MarkdownWriter constructor.
+            **kwargs: Keyword arguments to be passed to the MarkdownWriter constructor.
+        """
         self._markdown_writer = MarkdownWriter(*args, **kwargs)
         self._markdown_writer = MarkdownWriter(*args, **kwargs)
+        self._img_writer = ImageWriter(*args, **kwargs)
         self._save_funcs.append(self.save_to_markdown)
         self._save_funcs.append(self.save_to_markdown)
-        self.save_path = None
 
 
     @abstractmethod
     @abstractmethod
-    def _to_markdown(self):
+    def _to_markdown(self) -> Dict[str, Union[str, Dict[str, Any]]]:
         """
         """
         Convert the result to markdown format.
         Convert the result to markdown format.
+
         Returns:
         Returns:
-            Dict
+            Dict[str, Union[str, Dict[str, Any]]]: A dictionary containing markdown text and image data.
         """
         """
         raise NotImplementedError
         raise NotImplementedError
 
 
     @property
     @property
-    def markdown(self):
+    def markdown(self) -> Dict[str, Union[str, Dict[str, Any]]]:
+        """Property to access the markdown data.
+
+        Returns:
+            Dict[str, Union[str, Dict[str, Any]]]: A dictionary containing markdown text and image data.
+        """
         return self._to_markdown()
         return self._to_markdown()
 
 
-    def save_to_markdown(self, save_path, *args, **kwargs):
-        save_path = Path(save_path)
-        if not save_path.suffix.lower() == ".md":
-            save_path = save_path / f"layout_parsing_result.md"
+    def save_to_markdown(self, save_path, *args, **kwargs) -> None:
+        """Save the markdown data to a file.
+
+        Args:
+            save_path (Union[str, Path]): The path where the markdown file will be saved.
+            *args: Additional positional arguments for saving.
+            **kwargs: Additional keyword arguments for saving.
+        """
 
 
-        self.save_path = save_path
+        def _is_markdown_file(file_path) -> bool:
+            """Check if a file is a markdown file based on its extension or MIME type.
 
 
-        self._save_list_data(
+            Args:
+                file_path (Union[str, Path]): The path to the file.
+
+            Returns:
+                bool: True if the file is a markdown file, False otherwise.
+            """
+            markdown_extensions = {".md", ".markdown", ".mdown", ".mkd"}
+            _, ext = os.path.splitext(str(file_path))
+            if ext.lower() in markdown_extensions:
+                return True
+            mime_type, _ = mimetypes.guess_type(str(file_path))
+            return mime_type == "text/markdown"
+
+        if not _is_markdown_file(save_path):
+            fn = Path(self._get_input_fn())
+            suffix = fn.suffix if _is_markdown_file(fn) else ".md"
+            stem = fn.stem
+            base_save_path = Path(save_path)
+            save_path = base_save_path / f"{stem}{suffix}"
+            self.save_path = save_path
+        else:
+            self.save_path = save_path
+        self._save_data(
             self._markdown_writer.write,
             self._markdown_writer.write,
-            save_path,
+            self._img_writer.write,
+            self.save_path,
             self.markdown,
             self.markdown,
             *args,
             *args,
             **kwargs,
             **kwargs,
         )
         )
 
 
-    def _save_list_data(self, save_func, save_path, data, *args, **kwargs):
+    def _save_data(
+        self,
+        save_mkd_func: Callable,
+        save_img_func: Callable,
+        save_path: Union[str, Path],
+        data: Optional[Dict[str, Union[str, Dict[str, Any]]]],
+        *args,
+        **kwargs,
+    ) -> None:
+        """Internal method to save markdown and image data.
+
+        Args:
+            save_mkd_func (Callable): Function to save markdown text.
+            save_img_func (Callable): Function to save image data.
+            save_path (Union[str, Path]): The base path where the data will be saved.
+            data (Optional[Dict[str, Union[str, Dict[str, Any]]]]): The markdown data to save.
+            *args: Additional positional arguments for saving.
+            **kwargs: Additional keyword arguments for saving.
+        """
         save_path = Path(save_path)
         save_path = Path(save_path)
         if data is None:
         if data is None:
             return
             return
-        if isinstance(data, list):
-            for idx, single in enumerate(data):
-                save_func(
-                    (
-                        save_path.parent / f"{save_path.stem}_{idx}{save_path.suffix}"
-                    ).as_posix(),
-                    single,
-                    *args,
-                    **kwargs,
-                )
-        save_func(save_path.as_posix(), data, *args, **kwargs)
-        logging.info(f"The result has been saved in {save_path}.")
+        for key, value in data.items():
+            if isinstance(value, str):
+                save_mkd_func(save_path.as_posix(), value, *args, **kwargs)
+            if isinstance(value, dict):
+                base_save_path = save_path.parent
+                for img_path, img_data in value.items():
+                    save_path = base_save_path / img_path
+                    save_img_func(save_path.as_posix(), img_data, *args, **kwargs)

+ 63 - 97
paddlex/inference/pipelines/layout_parsing/result_v2.py

@@ -48,11 +48,22 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
         JsonMixin.__init__(self)
         JsonMixin.__init__(self)
         self.already_sorted = False
         self.already_sorted = False
 
 
+    def _get_input_fn(self):
+        fn = super()._get_input_fn()
+        if (page_idx := self["page_index"]) is not None:
+            fp = Path(fn)
+            stem, suffix = fp.stem, fp.suffix
+            return f"{stem}_{page_idx}{suffix}"
+        else:
+            return fn
+
     def _to_img(self) -> dict[str, np.ndarray]:
     def _to_img(self) -> dict[str, np.ndarray]:
         res_img_dict = {}
         res_img_dict = {}
         model_settings = self["model_settings"]
         model_settings = self["model_settings"]
+        page_index = self["page_index"]
         if model_settings["use_doc_preprocessor"]:
         if model_settings["use_doc_preprocessor"]:
-            res_img_dict.update(**self["doc_preprocessor_res"].img)
+            for key, value in self["doc_preprocessor_res"].img.items():
+                res_img_dict[key] = value
         res_img_dict["layout_det_res"] = self["layout_det_res"].img["res"]
         res_img_dict["layout_det_res"] = self["layout_det_res"].img["res"]
 
 
         if model_settings["use_general_ocr"] or model_settings["use_table_recognition"]:
         if model_settings["use_general_ocr"] or model_settings["use_table_recognition"]:
@@ -92,16 +103,39 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
                 key = f"seal_res_region{seal_region_id}"
                 key = f"seal_res_region{seal_region_id}"
                 res_img_dict[key] = sub_seal_res_dict["ocr_res_img"]
                 res_img_dict[key] = sub_seal_res_dict["ocr_res_img"]
 
 
-        # if (
-        #     model_settings["use_formula_recognition"]
-        #     and len(self["formula_res_list"]) > 0
-        # ):
-        #     for sno in range(len(self["formula_res_list"])):
-        #         formula_res = self["formula_res_list"][sno]
-        #         formula_region_id = formula_res["formula_region_id"]
-        #         sub_formula_res_dict = formula_res.img
-        #         key = f"formula_res_region{formula_region_id}"
-        #         res_img_dict[key] = sub_formula_res_dict["res"]
+        # for layout ordering image
+        image = Image.fromarray(self["doc_preprocessor_res"]["output_img"])
+        draw = ImageDraw.Draw(image, "RGBA")
+        parsing_result = self["parsing_res_list"]
+
+        for block in parsing_result:
+            if self.already_sorted == False:
+                block = get_layout_ordering(
+                    block,
+                    no_mask_labels=[
+                        "text",
+                        "formula",
+                        "algorithm",
+                        "reference",
+                        "content",
+                        "abstract",
+                    ],
+                    already_sorted=self.already_sorted,
+                )
+
+            sub_blocks = block["sub_blocks"]
+            for sub_block in sub_blocks:
+                bbox = sub_block["layout_bbox"]
+                index = sub_block.get("index", None)
+                label = sub_block["sub_label"]
+                fill_color = get_show_color(label)
+                draw.rectangle(bbox, fill=fill_color)
+                if index is not None:
+                    text_position = (bbox[2] + 2, bbox[1] - 10)
+                    draw.text(text_position, str(index), fill="red")
+
+        self.already_sorted = True
+        res_img_dict["layout_order_res"] = image
 
 
         return res_img_dict
         return res_img_dict
 
 
@@ -117,6 +151,7 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
         """
         """
         data = {}
         data = {}
         data["input_path"] = self["input_path"]
         data["input_path"] = self["input_path"]
+        data["page_index"] = self["page_index"]
         model_settings = self["model_settings"]
         model_settings = self["model_settings"]
         data["model_settings"] = model_settings
         data["model_settings"] = model_settings
         if self["model_settings"]["use_doc_preprocessor"]:
         if self["model_settings"]["use_doc_preprocessor"]:
@@ -167,6 +202,7 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
         """
         """
         data = {}
         data = {}
         data["input_path"] = self["input_path"]
         data["input_path"] = self["input_path"]
+        data["page_index"] = self["page_index"]
         model_settings = self["model_settings"]
         model_settings = self["model_settings"]
         data["model_settings"] = model_settings
         data["model_settings"] = model_settings
         if self["model_settings"]["use_doc_preprocessor"]:
         if self["model_settings"]["use_doc_preprocessor"]:
@@ -235,73 +271,6 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
                 res_xlsx_dict[key] = table_res.xlsx["pred"]
                 res_xlsx_dict[key] = table_res.xlsx["pred"]
         return res_xlsx_dict
         return res_xlsx_dict
 
 
-    def save_to_pdf_order(self, save_path: str) -> None:
-        """
-        Save the layout ordering to an image file.
-
-        Args:
-            save_path (str): The path where the image should be saved.
-
-        Returns:
-            None
-        """
-        input_path = Path(self["input_path"])
-        page_index = self["page_index"]
-        save_path = Path(save_path)
-        if save_path.suffix.lower() not in (".jpg", ".png"):
-            if input_path.suffix.lower() == ".pdf":
-                save_path = save_path / f"page_{page_index}.jpg"
-            else:
-                save_path = save_path / f"{input_path.stem}.jpg"
-        else:
-            save_path = save_path.with_suffix("")
-
-        ordering_image_path = (
-            save_path.parent / f"{save_path.stem}_layout_order_res.jpg"
-        )
-
-        try:
-            image = Image.fromarray(self["doc_preprocessor_res"]["output_img"])
-        except OSError as e:
-            print(f"Error opening image: {e}")
-            return
-
-        draw = ImageDraw.Draw(image, "RGBA")
-        parsing_result = self["parsing_res_list"]
-
-        for block in parsing_result:
-            if self.already_sorted == False:
-                block = get_layout_ordering(
-                    block,
-                    no_mask_labels=[
-                        "text",
-                        "formula",
-                        "algorithm",
-                        "reference",
-                        "content",
-                        "abstract",
-                    ],
-                    already_sorted=self.already_sorted,
-                )
-
-            sub_blocks = block["sub_blocks"]
-            for sub_block in sub_blocks:
-                bbox = sub_block["layout_bbox"]
-                index = sub_block.get("index", None)
-                label = sub_block["sub_label"]
-                fill_color = get_show_color(label)
-                draw.rectangle(bbox, fill=fill_color)
-                if index is not None:
-                    text_position = (bbox[2] + 2, bbox[1] - 10)
-                    draw.text(text_position, str(index), fill="red")
-
-        self.already_sorted = True
-
-        # Ensure the directory exists and save the image
-        ordering_image_path.parent.mkdir(parents=True, exist_ok=True)
-        print(f"Saving ordering image to {ordering_image_path}")
-        image.save(str(ordering_image_path))
-
     def _to_markdown(self) -> dict:
     def _to_markdown(self) -> dict:
         """
         """
         Save the parsing result to a Markdown file.
         Save the parsing result to a Markdown file.
@@ -309,14 +278,8 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
         Returns:
         Returns:
             Dict
             Dict
         """
         """
-        if self.save_path == None:
-            is_save_mk_img = False
-        else:
-            is_save_mk_img = True
-            save_path = Path(self.save_path)
 
 
         parsing_result = self["parsing_res_list"]
         parsing_result = self["parsing_res_list"]
-
         for block in parsing_result:
         for block in parsing_result:
             if self.already_sorted == False:
             if self.already_sorted == False:
                 block = get_layout_ordering(
                 block = get_layout_ordering(
@@ -333,12 +296,7 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
                 )
                 )
         self.already_sorted == True
         self.already_sorted == True
 
 
-        if is_save_mk_img:
-            recursive_img_array2path(
-                self["parsing_res_list"],
-                save_path.parent,
-                labels=["img"],
-            )
+        recursive_img_array2path(self["parsing_res_list"], labels=["img"])
 
 
         def _format_data(obj):
         def _format_data(obj):
 
 
@@ -367,16 +325,12 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
                 )
                 )
 
 
             def format_image(label):
             def format_image(label):
-                if is_save_mk_img is False:
-                    return ""
-
                 img_tags = []
                 img_tags = []
                 if "img" in sub_block[label]:
                 if "img" in sub_block[label]:
+                    image_path = "".join(sub_block[label]["img"].keys())
                     img_tags.append(
                     img_tags.append(
                         '<div style="text-align: center;"><img src="{}" alt="Image" /></div>'.format(
                         '<div style="text-align: center;"><img src="{}" alt="Image" /></div>'.format(
-                            sub_block[label]["img"]
-                            .replace("-\n", "")
-                            .replace("\n", " "),
+                            image_path.replace("-\n", "").replace("\n", " "),
                         ),
                         ),
                     )
                     )
                 if "image_text" in sub_block[label]:
                 if "image_text" in sub_block[label]:
@@ -456,4 +410,16 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
 
 
             return markdown_content
             return markdown_content
 
 
-        return _format_data(self)
+        markdown_info = dict()
+        markdown_info["markdown_texts"] = _format_data(self)
+        markdown_info["markdown_images"] = dict()
+        for block in self["parsing_res_list"]:
+            sub_blocks = block["sub_blocks"]
+            for sub_block in sub_blocks:
+                if sub_block["label"] == "image":
+                    image_path, image_value = next(
+                        iter(sub_block["image"]["img"].items())
+                    )
+                    markdown_info["markdown_images"][image_path] = image_value
+
+        return markdown_info

+ 8 - 17
paddlex/inference/pipelines/layout_parsing/utils.py

@@ -24,6 +24,7 @@ __all__ = [
 import numpy as np
 import numpy as np
 import copy
 import copy
 import cv2
 import cv2
+from PIL import Image
 import uuid
 import uuid
 from pathlib import Path
 from pathlib import Path
 from typing import Optional, Union, List, Tuple, Dict, Any
 from typing import Optional, Union, List, Tuple, Dict, Any
@@ -724,16 +725,16 @@ def sort_by_xycut(
     return res
     return res
 
 
 
 
-def _img_array2path(data: np.ndarray, save_path: Union[str, Path]) -> str:
+def _img_array2path(data: np.ndarray) -> str:
     """
     """
     Save an image array to disk and return the relative file path.
     Save an image array to disk and return the relative file path.
 
 
     Args:
     Args:
         data (np.ndarray): An image represented as a numpy array with 3 dimensions (H, W, C).
         data (np.ndarray): An image represented as a numpy array with 3 dimensions (H, W, C).
-        save_path (Union[str, Path]): The base path where images should be saved.
 
 
     Returns:
     Returns:
-        str: The relative path of the saved image file.
+        dict: A dictionary with a single key-value pair formatted as:
+              {"imgs/image_{uuid4_hex}.png": PIL.Image.Image}
 
 
     Raises:
     Raises:
         ValueError: If the input data is not a valid image array.
         ValueError: If the input data is not a valid image array.
@@ -741,17 +742,8 @@ def _img_array2path(data: np.ndarray, save_path: Union[str, Path]) -> str:
     if isinstance(data, np.ndarray) and data.ndim == 3:
     if isinstance(data, np.ndarray) and data.ndim == 3:
         # Generate a unique filename using UUID
         # Generate a unique filename using UUID
         img_name = f"image_{uuid.uuid4().hex}.png"
         img_name = f"image_{uuid.uuid4().hex}.png"
-        img_path = Path(save_path) / "imgs" / img_name
-        img_path.parent.mkdir(
-            parents=True, exist_ok=True
-        )  # Ensure the directory exists
-
-        # Save the image using OpenCV
-        success = cv2.imwrite(str(img_path), data)
-        if not success:
-            raise IOError(f"Failed to save image to {img_path}")
 
 
-        return f"imgs/{img_name}"
+        return {f"imgs/{img_name}": Image.fromarray(data[:, :, ::-1])}
     else:
     else:
         raise ValueError(
         raise ValueError(
             "Input data must be a 3-dimensional numpy array representing an image."
             "Input data must be a 3-dimensional numpy array representing an image."
@@ -760,7 +752,6 @@ def _img_array2path(data: np.ndarray, save_path: Union[str, Path]) -> str:
 
 
 def recursive_img_array2path(
 def recursive_img_array2path(
     data: Union[Dict[str, Any], List[Any]],
     data: Union[Dict[str, Any], List[Any]],
-    save_path: Union[str, Path],
     labels: List[str] = [],
     labels: List[str] = [],
 ) -> None:
 ) -> None:
     """
     """
@@ -778,12 +769,12 @@ def recursive_img_array2path(
     if isinstance(data, dict):
     if isinstance(data, dict):
         for k, v in data.items():
         for k, v in data.items():
             if k in labels and isinstance(v, np.ndarray) and v.ndim == 3:
             if k in labels and isinstance(v, np.ndarray) and v.ndim == 3:
-                data[k] = _img_array2path(v, save_path)
+                data[k] = _img_array2path(v)
             else:
             else:
-                recursive_img_array2path(v, save_path, labels)
+                recursive_img_array2path(v, labels)
     elif isinstance(data, list):
     elif isinstance(data, list):
         for item in data:
         for item in data:
-            recursive_img_array2path(item, save_path, labels)
+            recursive_img_array2path(item, labels)
 
 
 
 
 def _get_minbox_if_overlap_by_ratio(
 def _get_minbox_if_overlap_by_ratio(

+ 1 - 1
paddlex/inference/utils/io/writers.py

@@ -454,5 +454,5 @@ class MarkdownWriterBackend(_BaseWriterBackend):
 
 
     def _write_obj(self, out_path, obj):
     def _write_obj(self, out_path, obj):
         """write markdown obj"""
         """write markdown obj"""
-        with open(out_path, mode="a", encoding="utf-8", errors="replace") as f:
+        with open(out_path, mode="w", encoding="utf-8", errors="replace") as f:
             f.write(obj)
             f.write(obj)

部分文件因文件數量過多而無法顯示