Преглед изворни кода

Support saving JSON results for different PDF pages & Support JSON without ndarray type content

cuicheng01 пре 10 месеци
родитељ
комит
fd91a79716

+ 0 - 1
paddlex/inference/pipelines_new/ocr/pipeline.py

@@ -267,7 +267,6 @@ class OCRPipeline(BasePipeline):
             dt_polys = self._sort_boxes(dt_polys)
 
             single_img_res = {
-                "input_img": image_array,
                 "doc_preprocessor_image": doc_preprocessor_image,
                 "doc_preprocessor_res": doc_preprocessor_res,
                 "dt_polys": dt_polys,

+ 52 - 1
paddlex/inference/pipelines_new/ocr/result.py

@@ -12,7 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import os
 from pathlib import Path
+import copy
 import math
 import random
 import numpy as np
@@ -41,7 +43,9 @@ class OCRResult(BaseCVResult):
         input_params = self["input_params"]
         img_id = self["img_id"]
         if input_params["use_doc_preprocessor"]:
-            save_img_path = Path(save_path) / f"doc_preprocessor_result_img_{img_id}.jpg"
+            save_img_path = (
+                Path(save_path) / f"doc_preprocessor_result_img_{img_id}.jpg"
+            )
             self["doc_preprocessor_res"].save_to_img(save_img_path)
 
         if not str(save_path).lower().endswith((".jpg", ".png")):
@@ -49,6 +53,53 @@ class OCRResult(BaseCVResult):
 
         super().save_to_img(save_path, *args, **kwargs)
 
+    def save_to_json(
+        self,
+        save_path: str,
+        indent: int = 4,
+        ensure_ascii: bool = False,
+        save_ndarray: bool = False,
+        *args,
+        **kwargs,
+    ) -> None:
+        """Save the JSON representation of the object to a file.
+
+        Args:
+            save_path (str): The path to save the JSON file. If the save path does not end with '.json', it appends the base name and suffix of the input path.
+            indent (int): The number of spaces to indent for pretty printing. Default is 4.
+            ensure_ascii (bool): If False, non-ASCII characters will be included in the output. Default is False.
+            save_ndarray (bool): If True, save the numpy arrays in the result. Default is False.
+            *args: Additional positional arguments to pass to the underlying writer.
+            **kwargs: Additional keyword arguments to pass to the underlying writer.
+        """
+        img_id = self["img_id"]
+        base_name, ext = os.path.splitext(save_path)
+        save_path = f"{base_name}_{img_id}{ext}"
+
+        def remove_ndarray(d):
+            """
+            Remove all keys from the dictionary whose values are numpy arrays.
+            """
+            keys_to_delete = []
+            for key, value in d.items():
+                if isinstance(value, dict):
+                    remove_ndarray(value)
+                    if all(isinstance(v, np.ndarray) for v in value.values()):
+                        keys_to_delete.append(key)
+                elif isinstance(value, np.ndarray):
+                    keys_to_delete.append(key)
+            for key in keys_to_delete:
+                del d[key]
+
+        if not save_ndarray:
+            self_copy = copy.deepcopy(self)
+            remove_ndarray(self_copy)
+            super(type(self_copy), self_copy).save_to_json(
+                save_path, indent, ensure_ascii, *args, **kwargs
+            )
+        else:
+            super().save_to_json(save_path, indent, ensure_ascii, *args, **kwargs)
+
     def get_minarea_rect(self, points: np.ndarray) -> np.ndarray:
         """
         Get the minimum area rectangle for the given points using OpenCV.