Kaynağa Gözat

structurev3 support to render imgs in table when save markdown

gaotingquan 8 ay önce
ebeveyn
işleme
1f5b4d90f7

+ 30 - 10
paddlex/inference/pipelines/layout_parsing/pipeline_v2.py

@@ -26,8 +26,7 @@ from ...utils.pp_option import PaddlePredictorOption
 from ..base import BasePipeline
 from ..ocr.result import OCRResult
 from .result_v2 import LayoutParsingResultV2
-from .utils import get_single_block_parsing_res
-from .utils import get_sub_regions_ocr_res
+from .utils import get_single_block_parsing_res, get_sub_regions_ocr_res, gather_imgs
 
 
 class LayoutParsingPipelineV2(BasePipeline):
@@ -498,6 +497,7 @@ class LayoutParsingPipelineV2(BasePipeline):
                     layout_merge_bboxes_mode=layout_merge_bboxes_mode,
                 )
             )
+            imgs_in_doc = gather_imgs(doc_preprocessor_image, layout_det_res["boxes"])
 
             if model_settings["use_formula_recognition"]:
                 formula_res_all = next(
@@ -537,7 +537,7 @@ class LayoutParsingPipelineV2(BasePipeline):
                 overall_ocr_res = {}
 
             if model_settings["use_table_recognition"]:
-                table_overall_ocr_res = copy.deepcopy(overall_ocr_res)
+                table_contents = copy.deepcopy(overall_ocr_res)
                 for formula_res in formula_res_list:
                     x_min, y_min, x_max, y_max = list(map(int, formula_res["dt_polys"]))
                     poly_points = [
@@ -546,15 +546,34 @@ class LayoutParsingPipelineV2(BasePipeline):
                         (x_max, y_max),
                         (x_min, y_max),
                     ]
-                    table_overall_ocr_res["dt_polys"].append(poly_points)
-                    table_overall_ocr_res["rec_texts"].append(
+                    table_contents["dt_polys"].append(poly_points)
+                    table_contents["rec_texts"].append(
                         f"${formula_res['rec_formula']}$"
                     )
-                    table_overall_ocr_res["rec_boxes"] = np.vstack(
-                        (table_overall_ocr_res["rec_boxes"], [formula_res["dt_polys"]])
+                    table_contents["rec_boxes"] = np.vstack(
+                        (table_contents["rec_boxes"], [formula_res["dt_polys"]])
                     )
-                    table_overall_ocr_res["rec_polys"].append(poly_points)
-                    table_overall_ocr_res["rec_scores"].append(1)
+                    table_contents["rec_polys"].append(poly_points)
+                    table_contents["rec_scores"].append(1)
+
+                for img in imgs_in_doc:
+                    img_path = img["path"]
+                    x_min, y_min, x_max, y_max = img["coordinate"]
+                    poly_points = [
+                        (x_min, y_min),
+                        (x_max, y_min),
+                        (x_max, y_max),
+                        (x_min, y_max),
+                    ]
+                    table_contents["dt_polys"].append(poly_points)
+                    table_contents["rec_texts"].append(
+                        f'<div style="text-align: center;"><img src="{img_path}" alt="Image" /></div>'
+                    )
+                    table_contents["rec_boxes"] = np.vstack(
+                        (table_contents["rec_boxes"], img["coordinate"])
+                    )
+                    table_contents["rec_polys"].append(poly_points)
+                    table_contents["rec_scores"].append(img["score"])
 
                 table_res_all = next(
                     self.table_recognition_pipeline(
@@ -563,7 +582,7 @@ class LayoutParsingPipelineV2(BasePipeline):
                         use_doc_unwarping=False,
                         use_layout_detection=False,
                         use_ocr_model=False,
-                        overall_ocr_res=table_overall_ocr_res,
+                        overall_ocr_res=table_contents,
                         layout_det_res=layout_det_res,
                         cell_sort_by_y_projection=True,
                     ),
@@ -623,6 +642,7 @@ class LayoutParsingPipelineV2(BasePipeline):
                 "seal_res_list": seal_res_list,
                 "formula_res_list": formula_res_list,
                 "parsing_res_list": parsing_res_list,
+                "imgs_in_doc": imgs_in_doc,
                 "model_settings": model_settings,
             }
             yield LayoutParsingResultV2(single_img_res)

+ 4 - 7
paddlex/inference/pipelines/layout_parsing/result_v2.py

@@ -30,7 +30,6 @@ from ...common.result import (
     XlsxMixin,
 )
 from .utils import get_layout_ordering
-from .utils import recursive_img_array2path
 from .utils import get_show_color
 
 
@@ -238,7 +237,6 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
         Returns:
             Dict
         """
-        recursive_img_array2path(self["parsing_res_list"], labels=["block_image"])
 
         def _format_data(obj):
 
@@ -382,10 +380,9 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
             page_first_element_seg_start_flag,
             page_last_element_seg_end_flag,
         )
-        markdown_info["markdown_images"] = dict()
-        for block in self["parsing_res_list"]:
-            if block["block_label"] in ["image", "chart"]:
-                image_path, image_value = next(iter(block["block_image"].items()))
-                markdown_info["markdown_images"][image_path] = image_value
+
+        markdown_info["markdown_images"] = {}
+        for img in self["imgs_in_doc"]:
+            markdown_info["markdown_images"][img["path"]] = img["img"]
 
         return markdown_info

+ 16 - 51
paddlex/inference/pipelines/layout_parsing/utils.py

@@ -16,7 +16,6 @@ __all__ = [
     "get_sub_regions_ocr_res",
     "get_layout_ordering",
     "get_single_block_parsing_res",
-    "recursive_img_array2path",
     "get_show_color",
     "sorted_layout_boxes",
 ]
@@ -797,56 +796,22 @@ def sort_by_xycut(
     return res
 
 
-def _img_array2path(data: np.ndarray) -> str:
-    """
-    Save an image array to disk and return the relative file path.
-
-    Args:
-        data (np.ndarray): An image represented as a numpy array with 3 dimensions (H, W, C).
-
-    Returns:
-        dict: A dictionary with a single key-value pair formatted as:
-              {"imgs/image_{uuid4_hex}.png": PIL.Image.Image}
-
-    Raises:
-        ValueError: If the input data is not a valid image array.
-    """
-    if isinstance(data, np.ndarray) and data.ndim == 3:
-        # Generate a unique filename using UUID
-        img_name = f"image_{uuid.uuid4().hex}.png"
-
-        return {f"imgs/{img_name}": Image.fromarray(data[:, :, ::-1])}
-    else:
-        raise ValueError(
-            "Input data must be a 3-dimensional numpy array representing an image."
-        )
-
-
-def recursive_img_array2path(
-    data: Union[Dict[str, Any], List[Any]],
-    labels: List[str] = [],
-) -> None:
-    """
-    Recursively process a dictionary or list to save image arrays to disk
-    and replace them with file paths.
-
-    Args:
-        data (Union[Dict[str, Any], List[Any]]): The data structure that may contain image arrays.
-        save_path (Union[str, Path]): The base path where images should be saved.
-        labels (List[str]): List of keys to check for image arrays in dictionaries.
-
-    Returns:
-        None: This function modifies the input data structure in place.
-    """
-    if isinstance(data, dict):
-        for k, v in data.items():
-            if k in labels and isinstance(v, np.ndarray) and v.ndim == 3:
-                data[k] = _img_array2path(v)
-            else:
-                recursive_img_array2path(v, labels)
-    elif isinstance(data, list):
-        for item in data:
-            recursive_img_array2path(item, labels)
+def gather_imgs(original_img, layout_det_objs):
+    imgs_in_doc = []
+    for det_obj in layout_det_objs:
+        if det_obj["label"] in ("image", "chart"):
+            x_min, y_min, x_max, y_max = list(map(int, det_obj["coordinate"]))
+            img_path = f"imgs/img_in_table_box_{x_min}_{y_min}_{x_max}_{y_max}.jpg"
+            img = Image.fromarray(original_img[y_min:y_max, x_min:x_max, ::-1])
+            imgs_in_doc.append(
+                {
+                    "path": img_path,
+                    "img": img,
+                    "coordinate": (x_min, y_min, x_max, y_max),
+                    "score": det_obj["score"],
+                }
+            )
+    return imgs_in_doc
 
 
 def _get_minbox_if_overlap_by_ratio(