Răsfoiți Sursa

change to dict type and support to save multi result files

gaotingquan 10 luni în urmă
părinte
comite
07969ad6fb

+ 109 - 68
paddlex/inference/common/result/mixin.py

@@ -39,11 +39,11 @@ class StrMixin:
     """Mixin class for adding string conversion capabilities."""
 
     @property
-    def str(self) -> str:
+    def str(self) -> Dict[str, str]:
         """Property to get the string representation of the result.
 
         Returns:
-            str: The str type string representation of the result.
+            Dict[str, str]: The string representation of the result.
         """
 
         return self._to_str(self)
@@ -54,7 +54,7 @@ class StrMixin:
         json_format: bool = False,
         indent: int = 4,
         ensure_ascii: bool = False,
-    ) -> str:
+    ):
         """Convert the given result data to a string representation.
 
         Args:
@@ -64,12 +64,14 @@ class StrMixin:
             ensure_ascii (bool): If True, ensure all characters are ASCII. Default is False.
 
         Returns:
-            str: The string representation of the data.
+            Dict[str, str]: The string representation of the result.
         """
         if json_format:
-            return json.dumps(data.json, indent=indent, ensure_ascii=ensure_ascii)
+            return {
+                "res": json.dumps(data.json, indent=indent, ensure_ascii=ensure_ascii)
+            }
         else:
-            return str(data)
+            return {"res": str(data)}
 
     def print(
         self, json_format: bool = False, indent: int = 4, ensure_ascii: bool = False
@@ -84,7 +86,7 @@ class StrMixin:
         str_ = self._to_str(
             self, json_format=json_format, indent=indent, ensure_ascii=ensure_ascii
         )
-        logging.info(str_)
+        logging.info(str_["res"])
 
 
 class JsonMixin:
@@ -94,11 +96,11 @@ class JsonMixin:
         self._json_writer = JsonWriter()
         self._save_funcs.append(self.save_to_json)
 
-    def _to_json(self) -> Dict[str, Any]:
+    def _to_json(self) -> Dict[str, Dict[str, Any]]:
         """Convert the object to a JSON-serializable format.
 
         Returns:
-            Dict[str, Any]: A dictionary representation of the object that is JSON-serializable.
+            Dict[str, Dict[str, Any]]: A dictionary representation of the object that is JSON-serializable.
         """
 
         def _format_data(obj):
@@ -125,14 +127,14 @@ class JsonMixin:
             else:
                 return obj
 
-        return _format_data(copy.deepcopy(self))
+        return {"res": _format_data(copy.deepcopy(self))}
 
     @property
-    def json(self) -> Dict[str, Any]:
+    def json(self) -> Dict[str, Dict[str, Any]]:
         """Property to get the JSON representation of the result.
 
         Returns:
-            Dict[str, Any]: The dict type JSON representation of the result.
+            Dict[str, Dict[str, Any]]: The dict type JSON representation of the result.
         """
 
         return self._to_json()
@@ -160,16 +162,28 @@ class JsonMixin:
             return mime_type is not None and mime_type == "application/json"
 
         if not _is_json_file(save_path):
-            save_path = Path(save_path) / f"{Path(self['input_path']).stem}.json"
-            save_path = save_path.as_posix()
-        self._json_writer.write(
-            save_path,
-            self.json,
-            indent=indent,
-            ensure_ascii=ensure_ascii,
-            *args,
-            **kwargs,
-        )
+            fp = Path(self["input_path"])
+            stem = fp.stem
+            suffix = fp.suffix
+            base_save_path = Path(save_path)
+            for key in self.json:
+                save_path = base_save_path / f"{stem}_{key}.json"
+                self._json_writer.write(
+                    save_path.as_posix(), self.json[key], *args, **kwargs
+                )
+        else:
+            if len(self.json) > 1:
+                logging.warning(
+                    f"The result has multiple json files need to be saved. But the `save_path` has been specfied as `{save_path}`!"
+                )
+            self._json_writer.write(
+                save_path,
+                self.json[list(self.json.keys())[0]],
+                indent=indent,
+                ensure_ascii=ensure_ascii,
+                *args,
+                **kwargs,
+            )
 
 
 class Base64Mixin:
@@ -186,21 +200,21 @@ class Base64Mixin:
         self._save_funcs.append(self.save_to_base64)
 
     @abstractmethod
-    def _to_base64(self) -> str:
+    def _to_base64(self) -> Dict[str, str]:
         """Abstract method to convert the result to Base64.
 
         Returns:
-        str: The str type Base64 representation result.
+            Dict[str, str]: The str type Base64 representation result.
         """
         raise NotImplementedError
 
     @property
-    def base64(self) -> str:
+    def base64(self) -> Dict[str, str]:
         """
         Property that returns the Base64 encoded content.
 
         Returns:
-            str: The base64 representation of the result.
+            Dict[str, str]: The base64 representation of the result.
         """
         return self._to_base64()
 
@@ -213,13 +227,24 @@ class Base64Mixin:
             *args: Additional positional arguments that will be passed to the base64 writer.
             **kwargs: Additional keyword arguments that will be passed to the base64 writer.
         """
-
         if not str(save_path).lower().endswith((".b64")):
             fp = Path(self["input_path"])
-            save_path = Path(save_path) / f"{fp.stem}{fp.suffix}"
+            stem = fp.stem
+            suffix = fp.suffix
+            base_save_path = Path(save_path)
+            for key in self.base64:
+                save_path = base_save_path / f"{stem}_{key}.b64"
+                self._base64_writer.write(
+                    save_path.as_posix(), self.base64[key], *args, **kwargs
+                )
         else:
-            save_path = Path(save_path)
-        self._base64_writer.write(save_path.as_posix(), self.base64, *args, **kwargs)
+            if len(self.base64) > 1:
+                logging.warning(
+                    f"The result has multiple base64 files need to be saved. But the `save_path` has been specfied as `{save_path}`!"
+                )
+            self._base64_writer.write(
+                save_path, self.base64[list(self.base64.keys())[0]], *args, **kwargs
+            )
 
 
 class ImgMixin:
@@ -237,20 +262,20 @@ class ImgMixin:
         self._save_funcs.append(self.save_to_img)
 
     @abstractmethod
-    def _to_img(self) -> Union[Image.Image, Dict[str, Image.Image]]:
+    def _to_img(self) -> Dict[str, Image.Image]:
         """Abstract method to convert the result to an image.
 
         Returns:
-        Union[Image.Image, Dict[str, Image.Image]]: The image representation result.
+            Dict[str, Image.Image]: The image representation result.
         """
         raise NotImplementedError
 
     @property
-    def img(self) -> Union[Image.Image, Dict[str, Image.Image]]:
+    def img(self) -> Dict[str, Image.Image]:
         """Property to get the image representation of the result.
 
         Returns:
-            Union[Image.Image, Dict[str, Image.Image]]: The image representation of the result.
+            Dict[str, Image.Image]: The image representation of the result.
         """
         return self._to_img()
 
@@ -267,24 +292,24 @@ class ImgMixin:
             mime_type, _ = mimetypes.guess_type(file_path)
             return mime_type is not None and mime_type.startswith("image/")
 
-        img = self.img
-        if isinstance(img, dict):
-            if not _is_image_file(save_path):
-                fp = Path(self["input_path"])
-                stem = fp.stem
-                suffix = fp.suffix
-            else:
-                stem = save_path.stem
-                suffix = save_path.suffix
+        if not _is_image_file(save_path):
+            fp = Path(self["input_path"])
+            stem = fp.stem
+            suffix = fp.suffix
             base_save_path = Path(save_path)
-            for key in img:
+            for key in self.img:
                 save_path = base_save_path / f"{stem}_{key}{suffix}"
-                self._img_writer.write(save_path.as_posix(), img[key], *args, **kwargs)
+                self._img_writer.write(
+                    save_path.as_posix(), self.img[key], *args, **kwargs
+                )
         else:
-            if not _is_image_file(save_path):
-                fp = Path(self["input_path"])
-                save_path = Path(save_path) / f"{fp.stem}{fp.suffix}"
-            self._img_writer.write(save_path.as_posix(), img, *args, **kwargs)
+            if len(self.img) > 1:
+                logging.warning(
+                    f"The result has multiple img files need to be saved. But the `save_path` has been specfied as `{save_path}`!"
+                )
+            self._img_writer.write(
+                save_path, self.img[list(self.img.keys())[0]], *args, **kwargs
+            )
 
 
 class CSVMixin:
@@ -304,20 +329,20 @@ class CSVMixin:
         self._save_funcs.append(self.save_to_csv)
 
     @property
-    def csv(self) -> pd.DataFrame:
+    def csv(self) -> Dict[str, pd.DataFrame]:
         """Property to get the pandas Dataframe representation of the result.
 
         Returns:
-            pandas.DataFrame: The pandas.DataFrame representation of the result.
+            Dict[str, pd.DataFrame]: The pandas.DataFrame representation of the result.
         """
         return self._to_csv()
 
     @abstractmethod
-    def _to_csv(self) -> pd.DataFrame:
+    def _to_csv(self) -> Dict[str, pd.DataFrame]:
         """Abstract method to convert the result to pandas.DataFrame.
 
         Returns:
-        pandas.DataFrame: The pandas.DataFrame representation result.
+            Dict[str, pd.DataFrame]: The pandas.DataFrame representation result.
         """
         raise NotImplementedError
 
@@ -330,11 +355,28 @@ class CSVMixin:
             *args: Optional positional arguments to pass to the CSV writer's write method.
             **kwargs: Optional keyword arguments to pass to the CSV writer's write method.
         """
-        if not str(save_path).endswith(".csv"):
-            save_path = Path(save_path) / f"{Path(self['input_path']).stem}.csv"
+
+        def _is_csv_file(file_path):
+            mime_type, _ = mimetypes.guess_type(file_path)
+            return mime_type is not None and mime_type == "text/csv"
+
+        if not _is_csv_file(save_path):
+            fp = Path(self["input_path"])
+            stem = fp.stem
+            base_save_path = Path(save_path)
+            for key in self.csv:
+                save_path = base_save_path / f"{stem}_{key}.csv"
+                self._csv_writer.write(
+                    save_path.as_posix(), self.csv[key], *args, **kwargs
+                )
         else:
-            save_path = Path(save_path)
-        self._csv_writer.write(save_path.as_posix(), self.csv, *args, **kwargs)
+            if len(self.csv) > 1:
+                logging.warning(
+                    f"The result has multiple csv files need to be saved. But the `save_path` has been specfied as `{save_path}`!"
+                )
+            self._csv_writer.write(
+                save_path, self.csv[list(self.csv.keys())[0]], *args, **kwargs
+            )
 
 
 class HtmlMixin:
@@ -352,7 +394,7 @@ class HtmlMixin:
         self._save_funcs.append(self.save_to_html)
 
     @property
-    def html(self) -> str:
+    def html(self) -> Dict[str, str]:
         """Property to get the HTML representation of the result.
 
         Returns:
@@ -361,11 +403,11 @@ class HtmlMixin:
         return self._to_html()
 
     @abstractmethod
-    def _to_html(self) -> str:
+    def _to_html(self) -> Dict[str, str]:
         """Abstract method to convert the result to str type HTML representation.
 
         Returns:
-        str: The str type HTML representation result.
+            Dict[str, str]: The str type HTML representation result.
         """
         raise NotImplementedError
 
@@ -381,7 +423,7 @@ class HtmlMixin:
             save_path = Path(save_path) / f"{Path(self['input_path']).stem}.html"
         else:
             save_path = Path(save_path)
-        self._html_writer.write(save_path.as_posix(), self.html, *args, **kwargs)
+        self._html_writer.write(save_path.as_posix(), self.html["res"], *args, **kwargs)
 
 
 class XlsxMixin:
@@ -398,20 +440,20 @@ class XlsxMixin:
         self._save_funcs.append(self.save_to_xlsx)
 
     @property
-    def xlsx(self) -> str:
+    def xlsx(self) -> Dict[str, str]:
         """Property to get the XLSX representation of the result.
 
         Returns:
-            str: The str type XLSX representation of the result.
+            Dict[str, str]: The str type XLSX representation of the result.
         """
         return self._to_xlsx()
 
     @abstractmethod
-    def _to_xlsx(self) -> str:
+    def _to_xlsx(self) -> Dict[str, str]:
         """Abstract method to convert the result to str type XLSX representation.
 
         Returns:
-        str: The str type HTML representation result.
+            Dict[str, str]: The str type HTML representation result.
         """
         raise NotImplementedError
 
@@ -442,12 +484,11 @@ class VideoMixin:
 
     @property
     def video(self):
-        video = self._to_video()
-        return video
+        return self._to_video()
 
     def save_to_video(self, save_path, *args, **kwargs):
         video_writer = VideoWriter(backend=self._backend, *args, **kwargs)
         if not str(save_path).lower().endswith((".mp4", ".avi", ".mkv", ".webm")):
             fp = Path(self["input_path"])
             save_path = Path(save_path) / f"{fp.stem}{fp.suffix}"
-        _save_list_data(video_writer.write, save_path, self.video, *args, **kwargs)
+        video_writer.write(save_path.as_posix(), self.video["video"], *args, **kwargs)

+ 1 - 1
paddlex/inference/models_new/anomaly_detection/result.py

@@ -26,7 +26,7 @@ class UadResult(BaseCVResult):
         """apply"""
         seg_map = self["pred"]
         pc_map = self.get_pseudo_color_map(seg_map[0])
-        return pc_map
+        return {"res": pc_map}
 
     def get_pseudo_color_map(self, pred):
         """get_pseudo_color_map"""

+ 3 - 3
paddlex/inference/models_new/formula_recognition/result.py

@@ -56,7 +56,7 @@ class FormulaRecResult(BaseCVResult):
             logging.warning(
                 "Please refer to 2.3 Formula Recognition Pipeline Visualization in Formula Recognition Pipeline Tutorial to install the LaTeX rendering engine at first."
             )
-            return image
+            return {"res": image}
 
         rec_formula = str(self["rec_formula"])
         image = np.array(image.convert("RGB"))
@@ -83,10 +83,10 @@ class FormulaRecResult(BaseCVResult):
             )
             new_image.paste(image, (0, 0))
             new_image.paste(img_formula, (image.width + 10, 0))
-            return new_image
+            return {"res": new_image}
         except subprocess.CalledProcessError as e:
             logging.warning("Syntax error detected in formula, rendering failed.")
-            return image
+            return {"res": image}
 
 
 def get_align_equation(equation: str) -> str:

+ 1 - 1
paddlex/inference/models_new/image_classification/result.py

@@ -66,7 +66,7 @@ class TopkResult(BaseCVResult):
         text_x = rect_left + 3
         text_y = rect_top
         draw.text((text_x, text_y), label_str, fill=font_color, font=font)
-        return image
+        return {"res": image}
 
     def _get_font_colormap(self, color_index):
         """

+ 3 - 7
paddlex/inference/models_new/image_feature/result.py

@@ -14,12 +14,8 @@
 
 from PIL import Image
 
-from ...common.result import BaseCVResult
+from ...common.result import BaseResult
 
 
-class IdentityResult(BaseCVResult):
-
-    def _to_img(self):
-        """This module does not support visualization; it simply outputs the input images"""
-        image = Image.fromarray(self["input_img"])
-        return image
+class IdentityResult(BaseResult):
+    pass

+ 1 - 1
paddlex/inference/models_new/image_multilabel_classification/result.py

@@ -70,7 +70,7 @@ class MLClassResult(BaseCVResult):
                 fill=font_color,
                 font=font,
             )
-        return new_image
+        return {"res": new_image}
 
     def _get_font_colormap(self, color_index):
         """

+ 1 - 1
paddlex/inference/models_new/image_unwarping/result.py

@@ -31,7 +31,7 @@ class DocTrResult(BaseCVResult):
 
     def _to_img(self) -> np.ndarray:
         result = np.array(self["doctr_img"])
-        return result
+        return {"res": result}
 
     def _to_str(self, _, *args, **kwargs):
         data = copy.deepcopy(self)

+ 1 - 1
paddlex/inference/models_new/instance_segmentation/result.py

@@ -147,7 +147,7 @@ class InstanceSegResult(BaseCVResult):
         else:
             image = draw_segm(image, masks, boxes)
 
-        return image
+        return {"res": image}
 
     def _to_str(self, _, *args, **kwargs):
         data = copy.deepcopy(self)

+ 1 - 1
paddlex/inference/models_new/object_detection/result.py

@@ -100,4 +100,4 @@ class DetResult(BaseCVResult):
         """apply"""
         boxes = self["boxes"]
         image = Image.fromarray(self["input_img"])
-        return draw_box(image, boxes)
+        return {"res": draw_box(image, boxes)}

+ 1 - 1
paddlex/inference/models_new/semantic_segmentation/result.py

@@ -28,7 +28,7 @@ class SegResult(BaseCVResult):
         pc_map = self.get_pseudo_color_map(seg_map[0])
         if pc_map.mode == "P":
             pc_map = pc_map.convert("RGB")
-        return pc_map
+        return {"res": pc_map}
 
     def get_pseudo_color_map(self, pred):
         """get_pseudo_color_map"""

+ 1 - 1
paddlex/inference/models_new/table_structure_recognition/result.py

@@ -34,7 +34,7 @@ class TableRecResult(BaseCVResult):
             vis_img = self.draw_rectangle(image, bbox_res)
         else:
             vis_img = self.draw_bbox(image, bbox_res)
-        return vis_img
+        return {"res": vis_img}
 
     def draw_rectangle(self, image, boxes):
         """draw_rectangle"""

+ 1 - 1
paddlex/inference/models_new/text_detection/result.py

@@ -30,4 +30,4 @@ class TextDetResult(BaseCVResult):
         for box in boxes:
             box = np.reshape(np.array(box).astype(int), [-1, 1, 2]).astype(np.int64)
             cv2.polylines(image, [box], True, (0, 0, 255), 2)
-        return image[:, :, ::-1]
+        return {"res": image[:, :, ::-1]}

+ 1 - 1
paddlex/inference/models_new/text_recognition/result.py

@@ -42,7 +42,7 @@ class TextRecResult(BaseCVResult):
             fill=(0, 0, 0),
             font=font,
         )
-        return new_image
+        return {"res": new_image}
 
     def adjust_font_size(self, image_width, text, font_path):
         font_size = int(image_width * 0.06)

+ 1 - 1
paddlex/inference/models_new/ts_anomaly_detection/result.py

@@ -26,4 +26,4 @@ class TSAdResult(BaseTSResult):
         Returns:
             Any: The anomaly data formatted for CSV output, typically a DataFrame or similar structure.
         """
-        return self["anomaly"]
+        return {"res": self["anomaly"]}

+ 1 - 1
paddlex/inference/models_new/ts_classification/result.py

@@ -26,4 +26,4 @@ class TSClsResult(BaseTSResult):
         Returns:
             Any: The classification data formatted for CSV output, typically a DataFrame or similar structure.
         """
-        return self["classification"]
+        return {"res": self["classification"]}

+ 1 - 1
paddlex/inference/models_new/ts_forecasting/result.py

@@ -26,4 +26,4 @@ class TSFcResult(BaseTSResult):
         Returns:
             Any: The forecast data formatted for CSV output, typically a DataFrame or similar structure.
         """
-        return self["forecast"]
+        return {"res": self["forecast"]}