Bläddra i källkod

fix & polish layout_parsing_v2 (#3235)

* fix & polish layout_parsing_v2

* update layout_parsing_v2 for type hint
cuicheng01 9 månader sedan
förälder
incheckning
0a93e36d38

+ 0 - 1
api_examples/pipelines/test_layout_parsing_v2.py

@@ -32,4 +32,3 @@ for res in output:
     res.save_to_xlsx("./output")
     res.save_to_html("./output")
     res.save_to_markdown("./output")
-    res.save_to_pdf_order("./output")

Filskillnaden har hållts tillbaka eftersom den är för stor
+ 0 - 0
docs/pipeline_usage/tutorials/ocr_pipelines/layout_parsing_v2.md


+ 1 - 12
paddlex/configs/pipelines/layout_parsing_v2.yaml

@@ -51,7 +51,7 @@ SubPipelines:
         
       TextRecognition:
         module_name: text_recognition
-        model_name: PP-OCRv4_server_rec
+        model_name: PP-OCRv4_server_rec_doc
         model_dir: null
         batch_size: 1
         score_thresh: 0.0
@@ -87,17 +87,6 @@ SubPipelines:
         model_name: RT-DETR-L_wireless_table_cell_det
         model_dir: null
 
-  # TableRecognition:
-  #   pipeline_name: table_recognition
-  #   use_layout_detection: False
-  #   use_doc_preprocessor: False
-  #   use_ocr_model: False
-  #   SubModules:
-  #     TableStructureRecognition:
-  #       module_name: table_structure_recognition
-  #       model_name: SLANet_plus
-  #       model_dir: null
-
   SealRecognition:
     pipeline_name: seal_recognition
     use_layout_detection: False

+ 81 - 26
paddlex/inference/common/result/mixin.py

@@ -12,9 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Union, Tuple, List, Dict, Any, Iterator
+from typing import Union, Tuple, List, Dict, Any, Iterator, Callable, Optional
 from abc import abstractmethod
 from pathlib import Path
+import os
 import mimetypes
 import json
 import copy
@@ -379,7 +380,6 @@ class CSVMixin:
 
         if not _is_csv_file(save_path):
             fn = Path(self._get_input_fn())
-            fn = Path(self._get_input_fn())
             stem = fn.stem
             base_save_path = Path(save_path)
             for key in self.csv:
@@ -597,53 +597,108 @@ class VideoMixin:
 
 
 class MarkdownMixin:
+    """Mixin class for adding Markdown handling capabilities."""
 
     def __init__(self, *args: list, **kwargs: dict):
+        """Initializes the Markdown writer and appends the save_to_markdown method to the save functions.
+
+        Args:
+            *args: Positional arguments to be passed to the MarkdownWriter constructor.
+            **kwargs: Keyword arguments to be passed to the MarkdownWriter constructor.
+        """
         self._markdown_writer = MarkdownWriter(*args, **kwargs)
+        self._img_writer = ImageWriter(*args, **kwargs)
         self._save_funcs.append(self.save_to_markdown)
-        self.save_path = None
 
     @abstractmethod
-    def _to_markdown(self):
+    def _to_markdown(self) -> Dict[str, Union[str, Dict[str, Any]]]:
         """
         Convert the result to markdown format.
+
         Returns:
-            Dict
+            Dict[str, Union[str, Dict[str, Any]]]: A dictionary containing markdown text and image data.
         """
         raise NotImplementedError
 
     @property
-    def markdown(self):
+    def markdown(self) -> Dict[str, Union[str, Dict[str, Any]]]:
+        """Property to access the markdown data.
+
+        Returns:
+            Dict[str, Union[str, Dict[str, Any]]]: A dictionary containing markdown text and image data.
+        """
         return self._to_markdown()
 
-    def save_to_markdown(self, save_path, *args, **kwargs):
-        save_path = Path(save_path)
-        if not save_path.suffix.lower() == ".md":
-            save_path = save_path / f"layout_parsing_result.md"
+    def save_to_markdown(self, save_path, *args, **kwargs) -> None:
+        """Save the markdown data to a file.
+
+        Args:
+            save_path (Union[str, Path]): The path where the markdown file will be saved.
+            *args: Additional positional arguments for saving.
+            **kwargs: Additional keyword arguments for saving.
+        """
 
-        self.save_path = save_path
+        def _is_markdown_file(file_path) -> bool:
+            """Check if a file is a markdown file based on its extension or MIME type.
 
-        self._save_list_data(
+            Args:
+                file_path (Union[str, Path]): The path to the file.
+
+            Returns:
+                bool: True if the file is a markdown file, False otherwise.
+            """
+            markdown_extensions = {".md", ".markdown", ".mdown", ".mkd"}
+            _, ext = os.path.splitext(str(file_path))
+            if ext.lower() in markdown_extensions:
+                return True
+            mime_type, _ = mimetypes.guess_type(str(file_path))
+            return mime_type == "text/markdown"
+
+        if not _is_markdown_file(save_path):
+            fn = Path(self._get_input_fn())
+            suffix = fn.suffix if _is_markdown_file(fn) else ".md"
+            stem = fn.stem
+            base_save_path = Path(save_path)
+            save_path = base_save_path / f"{stem}{suffix}"
+            self.save_path = save_path
+        else:
+            self.save_path = save_path
+        self._save_data(
             self._markdown_writer.write,
-            save_path,
+            self._img_writer.write,
+            self.save_path,
             self.markdown,
             *args,
             **kwargs,
         )
 
-    def _save_list_data(self, save_func, save_path, data, *args, **kwargs):
+    def _save_data(
+        self,
+        save_mkd_func: Callable,
+        save_img_func: Callable,
+        save_path: Union[str, Path],
+        data: Optional[Dict[str, Union[str, Dict[str, Any]]]],
+        *args,
+        **kwargs,
+    ) -> None:
+        """Internal method to save markdown and image data.
+
+        Args:
+            save_mkd_func (Callable): Function to save markdown text.
+            save_img_func (Callable): Function to save image data.
+            save_path (Union[str, Path]): The base path where the data will be saved.
+            data (Optional[Dict[str, Union[str, Dict[str, Any]]]]): The markdown data to save.
+            *args: Additional positional arguments for saving.
+            **kwargs: Additional keyword arguments for saving.
+        """
         save_path = Path(save_path)
         if data is None:
             return
-        if isinstance(data, list):
-            for idx, single in enumerate(data):
-                save_func(
-                    (
-                        save_path.parent / f"{save_path.stem}_{idx}{save_path.suffix}"
-                    ).as_posix(),
-                    single,
-                    *args,
-                    **kwargs,
-                )
-        save_func(save_path.as_posix(), data, *args, **kwargs)
-        logging.info(f"The result has been saved in {save_path}.")
+        for key, value in data.items():
+            if isinstance(value, str):
+                save_mkd_func(save_path.as_posix(), value, *args, **kwargs)
+            if isinstance(value, dict):
+                base_save_path = save_path.parent
+                for img_path, img_data in value.items():
+                    save_path = base_save_path / img_path
+                    save_img_func(save_path.as_posix(), img_data, *args, **kwargs)

+ 63 - 97
paddlex/inference/pipelines/layout_parsing/result_v2.py

@@ -48,11 +48,22 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
         JsonMixin.__init__(self)
         self.already_sorted = False
 
+    def _get_input_fn(self):
+        fn = super()._get_input_fn()
+        if (page_idx := self["page_index"]) is not None:
+            fp = Path(fn)
+            stem, suffix = fp.stem, fp.suffix
+            return f"{stem}_{page_idx}{suffix}"
+        else:
+            return fn
+
     def _to_img(self) -> dict[str, np.ndarray]:
         res_img_dict = {}
         model_settings = self["model_settings"]
+        page_index = self["page_index"]
         if model_settings["use_doc_preprocessor"]:
-            res_img_dict.update(**self["doc_preprocessor_res"].img)
+            for key, value in self["doc_preprocessor_res"].img.items():
+                res_img_dict[key] = value
         res_img_dict["layout_det_res"] = self["layout_det_res"].img["res"]
 
         if model_settings["use_general_ocr"] or model_settings["use_table_recognition"]:
@@ -92,16 +103,39 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
                 key = f"seal_res_region{seal_region_id}"
                 res_img_dict[key] = sub_seal_res_dict["ocr_res_img"]
 
-        # if (
-        #     model_settings["use_formula_recognition"]
-        #     and len(self["formula_res_list"]) > 0
-        # ):
-        #     for sno in range(len(self["formula_res_list"])):
-        #         formula_res = self["formula_res_list"][sno]
-        #         formula_region_id = formula_res["formula_region_id"]
-        #         sub_formula_res_dict = formula_res.img
-        #         key = f"formula_res_region{formula_region_id}"
-        #         res_img_dict[key] = sub_formula_res_dict["res"]
+        # for layout ordering image
+        image = Image.fromarray(self["doc_preprocessor_res"]["output_img"])
+        draw = ImageDraw.Draw(image, "RGBA")
+        parsing_result = self["parsing_res_list"]
+
+        for block in parsing_result:
+            if self.already_sorted == False:
+                block = get_layout_ordering(
+                    block,
+                    no_mask_labels=[
+                        "text",
+                        "formula",
+                        "algorithm",
+                        "reference",
+                        "content",
+                        "abstract",
+                    ],
+                    already_sorted=self.already_sorted,
+                )
+
+            sub_blocks = block["sub_blocks"]
+            for sub_block in sub_blocks:
+                bbox = sub_block["layout_bbox"]
+                index = sub_block.get("index", None)
+                label = sub_block["sub_label"]
+                fill_color = get_show_color(label)
+                draw.rectangle(bbox, fill=fill_color)
+                if index is not None:
+                    text_position = (bbox[2] + 2, bbox[1] - 10)
+                    draw.text(text_position, str(index), fill="red")
+
+        self.already_sorted = True
+        res_img_dict["layout_order_res"] = image
 
         return res_img_dict
 
@@ -117,6 +151,7 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
         """
         data = {}
         data["input_path"] = self["input_path"]
+        data["page_index"] = self["page_index"]
         model_settings = self["model_settings"]
         data["model_settings"] = model_settings
         if self["model_settings"]["use_doc_preprocessor"]:
@@ -167,6 +202,7 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
         """
         data = {}
         data["input_path"] = self["input_path"]
+        data["page_index"] = self["page_index"]
         model_settings = self["model_settings"]
         data["model_settings"] = model_settings
         if self["model_settings"]["use_doc_preprocessor"]:
@@ -235,73 +271,6 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
                 res_xlsx_dict[key] = table_res.xlsx["pred"]
         return res_xlsx_dict
 
-    def save_to_pdf_order(self, save_path: str) -> None:
-        """
-        Save the layout ordering to an image file.
-
-        Args:
-            save_path (str): The path where the image should be saved.
-
-        Returns:
-            None
-        """
-        input_path = Path(self["input_path"])
-        page_index = self["page_index"]
-        save_path = Path(save_path)
-        if save_path.suffix.lower() not in (".jpg", ".png"):
-            if input_path.suffix.lower() == ".pdf":
-                save_path = save_path / f"page_{page_index}.jpg"
-            else:
-                save_path = save_path / f"{input_path.stem}.jpg"
-        else:
-            save_path = save_path.with_suffix("")
-
-        ordering_image_path = (
-            save_path.parent / f"{save_path.stem}_layout_order_res.jpg"
-        )
-
-        try:
-            image = Image.fromarray(self["doc_preprocessor_res"]["output_img"])
-        except OSError as e:
-            print(f"Error opening image: {e}")
-            return
-
-        draw = ImageDraw.Draw(image, "RGBA")
-        parsing_result = self["parsing_res_list"]
-
-        for block in parsing_result:
-            if self.already_sorted == False:
-                block = get_layout_ordering(
-                    block,
-                    no_mask_labels=[
-                        "text",
-                        "formula",
-                        "algorithm",
-                        "reference",
-                        "content",
-                        "abstract",
-                    ],
-                    already_sorted=self.already_sorted,
-                )
-
-            sub_blocks = block["sub_blocks"]
-            for sub_block in sub_blocks:
-                bbox = sub_block["layout_bbox"]
-                index = sub_block.get("index", None)
-                label = sub_block["sub_label"]
-                fill_color = get_show_color(label)
-                draw.rectangle(bbox, fill=fill_color)
-                if index is not None:
-                    text_position = (bbox[2] + 2, bbox[1] - 10)
-                    draw.text(text_position, str(index), fill="red")
-
-        self.already_sorted = True
-
-        # Ensure the directory exists and save the image
-        ordering_image_path.parent.mkdir(parents=True, exist_ok=True)
-        print(f"Saving ordering image to {ordering_image_path}")
-        image.save(str(ordering_image_path))
-
     def _to_markdown(self) -> dict:
         """
         Save the parsing result to a Markdown file.
@@ -309,14 +278,8 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
         Returns:
             Dict
         """
-        if self.save_path == None:
-            is_save_mk_img = False
-        else:
-            is_save_mk_img = True
-            save_path = Path(self.save_path)
 
         parsing_result = self["parsing_res_list"]
-
         for block in parsing_result:
             if self.already_sorted == False:
                 block = get_layout_ordering(
@@ -333,12 +296,7 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
                 )
         self.already_sorted == True
 
-        if is_save_mk_img:
-            recursive_img_array2path(
-                self["parsing_res_list"],
-                save_path.parent,
-                labels=["img"],
-            )
+        recursive_img_array2path(self["parsing_res_list"], labels=["img"])
 
         def _format_data(obj):
 
@@ -367,16 +325,12 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
                 )
 
             def format_image(label):
-                if is_save_mk_img is False:
-                    return ""
-
                 img_tags = []
                 if "img" in sub_block[label]:
+                    image_path = "".join(sub_block[label]["img"].keys())
                     img_tags.append(
                         '<div style="text-align: center;"><img src="{}" alt="Image" /></div>'.format(
-                            sub_block[label]["img"]
-                            .replace("-\n", "")
-                            .replace("\n", " "),
+                            image_path.replace("-\n", "").replace("\n", " "),
                         ),
                     )
                 if "image_text" in sub_block[label]:
@@ -456,4 +410,16 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
 
             return markdown_content
 
-        return _format_data(self)
+        markdown_info = dict()
+        markdown_info["markdown_texts"] = _format_data(self)
+        markdown_info["markdown_images"] = dict()
+        for block in self["parsing_res_list"]:
+            sub_blocks = block["sub_blocks"]
+            for sub_block in sub_blocks:
+                if sub_block["label"] == "image":
+                    image_path, image_value = next(
+                        iter(sub_block["image"]["img"].items())
+                    )
+                    markdown_info["markdown_images"][image_path] = image_value
+
+        return markdown_info

+ 8 - 17
paddlex/inference/pipelines/layout_parsing/utils.py

@@ -24,6 +24,7 @@ __all__ = [
 import numpy as np
 import copy
 import cv2
+from PIL import Image
 import uuid
 from pathlib import Path
 from typing import Optional, Union, List, Tuple, Dict, Any
@@ -724,16 +725,16 @@ def sort_by_xycut(
     return res
 
 
-def _img_array2path(data: np.ndarray, save_path: Union[str, Path]) -> str:
+def _img_array2path(data: np.ndarray) -> str:
     """
     Save an image array to disk and return the relative file path.
 
     Args:
         data (np.ndarray): An image represented as a numpy array with 3 dimensions (H, W, C).
-        save_path (Union[str, Path]): The base path where images should be saved.
 
     Returns:
-        str: The relative path of the saved image file.
+        dict: A dictionary with a single key-value pair formatted as:
+              {"imgs/image_{uuid4_hex}.png": PIL.Image.Image}
 
     Raises:
         ValueError: If the input data is not a valid image array.
@@ -741,17 +742,8 @@ def _img_array2path(data: np.ndarray, save_path: Union[str, Path]) -> str:
     if isinstance(data, np.ndarray) and data.ndim == 3:
         # Generate a unique filename using UUID
         img_name = f"image_{uuid.uuid4().hex}.png"
-        img_path = Path(save_path) / "imgs" / img_name
-        img_path.parent.mkdir(
-            parents=True, exist_ok=True
-        )  # Ensure the directory exists
-
-        # Save the image using OpenCV
-        success = cv2.imwrite(str(img_path), data)
-        if not success:
-            raise IOError(f"Failed to save image to {img_path}")
 
-        return f"imgs/{img_name}"
+        return {f"imgs/{img_name}": Image.fromarray(data[:, :, ::-1])}
     else:
         raise ValueError(
             "Input data must be a 3-dimensional numpy array representing an image."
@@ -760,7 +752,6 @@ def _img_array2path(data: np.ndarray, save_path: Union[str, Path]) -> str:
 
 def recursive_img_array2path(
     data: Union[Dict[str, Any], List[Any]],
-    save_path: Union[str, Path],
     labels: List[str] = [],
 ) -> None:
     """
@@ -778,12 +769,12 @@ def recursive_img_array2path(
     if isinstance(data, dict):
         for k, v in data.items():
             if k in labels and isinstance(v, np.ndarray) and v.ndim == 3:
-                data[k] = _img_array2path(v, save_path)
+                data[k] = _img_array2path(v)
             else:
-                recursive_img_array2path(v, save_path, labels)
+                recursive_img_array2path(v, labels)
     elif isinstance(data, list):
         for item in data:
-            recursive_img_array2path(item, save_path, labels)
+            recursive_img_array2path(item, labels)
 
 
 def _get_minbox_if_overlap_by_ratio(

+ 1 - 1
paddlex/inference/utils/io/writers.py

@@ -454,5 +454,5 @@ class MarkdownWriterBackend(_BaseWriterBackend):
 
     def _write_obj(self, out_path, obj):
         """write markdown obj"""
-        with open(out_path, mode="a", encoding="utf-8", errors="replace") as f:
+        with open(out_path, mode="w", encoding="utf-8", errors="replace") as f:
             f.write(obj)

Vissa filer visades inte eftersom för många filer har ändrats