Bläddra i källkod

fix duplicate calculation (#3236)

Co-authored-by: cuicheng01 <45199522+cuicheng01@users.noreply.github.com>
Tingquan Gao 9 månader sedan
förälder
incheckning
9b4ff88eb8
1 ändrade filer med 37 tillägg och 43 borttagningar
  1. 37 43
      paddlex/inference/common/result/mixin.py

+ 37 - 43
paddlex/inference/common/result/mixin.py

@@ -68,7 +68,7 @@ class StrMixin:
 
     def print(self) -> None:
         """Print the string representation of the result."""
-        logging.info(self.str)
+        logging.info(self._to_str())
 
 
 def _format_data(obj):
@@ -144,23 +144,24 @@ class JsonMixin:
             mime_type, _ = mimetypes.guess_type(file_path)
             return mime_type is not None and mime_type == "application/json"
 
+        json = self._to_json()
         if not _is_json_file(save_path):
             fn = Path(self._get_input_fn())
             stem = fn.stem
             base_save_path = Path(save_path)
-            for key in self.json:
+            for key in json:
                 save_path = base_save_path / f"{stem}_{key}.json"
                 self._json_writer.write(
-                    save_path.as_posix(), self.json[key], *args, **kwargs
+                    save_path.as_posix(), json[key], *args, **kwargs
                 )
         else:
-            if len(self.json) > 1:
+            if len(json) > 1:
                 logging.warning(
                     f"The result has multiple json files need to be saved. But the `save_path` has been specfied as `{save_path}`!"
                 )
             self._json_writer.write(
                 save_path,
-                self.json[list(self.json.keys())[0]],
+                json[list(json.keys())[0]],
                 indent=indent,
                 ensure_ascii=ensure_ascii,
                 *args,
@@ -246,22 +247,23 @@ class Base64Mixin:
             *args: Additional positional arguments that will be passed to the base64 writer.
             **kwargs: Additional keyword arguments that will be passed to the base64 writer.
         """
+        base64 = self._to_base64()
         if not str(save_path).lower().endswith((".b64")):
             fn = Path(self._get_input_fn())
             stem = fn.stem
             base_save_path = Path(save_path)
-            for key in self.base64:
+            for key in base64:
                 save_path = base_save_path / f"{stem}_{key}.b64"
                 self._base64_writer.write(
-                    save_path.as_posix(), self.base64[key], *args, **kwargs
+                    save_path.as_posix(), base64[key], *args, **kwargs
                 )
         else:
-            if len(self.base64) > 1:
+            if len(base64) > 1:
                 logging.warning(
                     f"The result has multiple base64 files need to be saved. But the `save_path` has been specfied as `{save_path}`!"
                 )
             self._base64_writer.write(
-                save_path, self.base64[list(self.base64.keys())[0]], *args, **kwargs
+                save_path, base64[list(base64.keys())[0]], *args, **kwargs
             )
 
 
@@ -310,24 +312,21 @@ class ImgMixin:
             mime_type, _ = mimetypes.guess_type(file_path)
             return mime_type is not None and mime_type.startswith("image/")
 
+        img = self._to_img()
         if not _is_image_file(save_path):
             fn = Path(self._get_input_fn())
             suffix = fn.suffix if _is_image_file(fn) else ".png"
             stem = fn.stem
             base_save_path = Path(save_path)
-            for key in self.img:
+            for key in img:
                 save_path = base_save_path / f"{stem}_{key}{suffix}"
-                self._img_writer.write(
-                    save_path.as_posix(), self.img[key], *args, **kwargs
-                )
+                self._img_writer.write(save_path.as_posix(), img[key], *args, **kwargs)
         else:
-            if len(self.img) > 1:
+            if len(img) > 1:
                 logging.warning(
                     f"The result has multiple img files need to be saved. But the `save_path` has been specfied as `{save_path}`!"
                 )
-            self._img_writer.write(
-                save_path, self.img[list(self.img.keys())[0]], *args, **kwargs
-            )
+            self._img_writer.write(save_path, img[list(img.keys())[0]], *args, **kwargs)
 
 
 class CSVMixin:
@@ -378,23 +377,20 @@ class CSVMixin:
             mime_type, _ = mimetypes.guess_type(file_path)
             return mime_type is not None and mime_type == "text/csv"
 
+        csv = self._to_csv()
         if not _is_csv_file(save_path):
             fn = Path(self._get_input_fn())
             stem = fn.stem
             base_save_path = Path(save_path)
-            for key in self.csv:
+            for key in csv:
                 save_path = base_save_path / f"{stem}_{key}.csv"
-                self._csv_writer.write(
-                    save_path.as_posix(), self.csv[key], *args, **kwargs
-                )
+                self._csv_writer.write(save_path.as_posix(), csv[key], *args, **kwargs)
         else:
-            if len(self.csv) > 1:
+            if len(csv) > 1:
                 logging.warning(
                     f"The result has multiple csv files need to be saved. But the `save_path` has been specfied as `{save_path}`!"
                 )
-            self._csv_writer.write(
-                save_path, self.csv[list(self.csv.keys())[0]], *args, **kwargs
-            )
+            self._csv_writer.write(save_path, csv[list(csv.keys())[0]], *args, **kwargs)
 
 
 class HtmlMixin:
@@ -442,22 +438,23 @@ class HtmlMixin:
             mime_type, _ = mimetypes.guess_type(file_path)
             return mime_type is not None and mime_type == "text/html"
 
+        html = self._to_html()
         if not _is_html_file(save_path):
             fn = Path(self._get_input_fn())
             stem = fn.stem
             base_save_path = Path(save_path)
-            for key in self.html:
+            for key in html:
                 save_path = base_save_path / f"{stem}_{key}.html"
                 self._html_writer.write(
-                    save_path.as_posix(), self.html[key], *args, **kwargs
+                    save_path.as_posix(), html[key], *args, **kwargs
                 )
         else:
-            if len(self.html) > 1:
+            if len(html) > 1:
                 logging.warning(
                     f"The result has multiple html files need to be saved. But the `save_path` has been specfied as `{save_path}`!"
                 )
             self._html_writer.write(
-                save_path, self.html[list(self.html.keys())[0]], *args, **kwargs
+                save_path, html[list(html.keys())[0]], *args, **kwargs
             )
 
 
@@ -510,22 +507,23 @@ class XlsxMixin:
                 == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
             )
 
+        xlsx = self._to_xlsx()
         if not _is_xlsx_file(save_path):
             fn = Path(self._get_input_fn())
             stem = fn.stem
             base_save_path = Path(save_path)
-            for key in self.xlsx:
+            for key in xlsx:
                 save_path = base_save_path / f"{stem}_{key}.xlsx"
                 self._xlsx_writer.write(
-                    save_path.as_posix(), self.xlsx[key], *args, **kwargs
+                    save_path.as_posix(), xlsx[key], *args, **kwargs
                 )
         else:
-            if len(self.xlsx) > 1:
+            if len(xlsx) > 1:
                 logging.warning(
                     f"The result has multiple xlsx files need to be saved. But the `save_path` has been specfied as `{save_path}`!"
                 )
             self._xlsx_writer.write(
-                save_path, self.xlsx[list(self.xlsx.keys())[0]], *args, **kwargs
+                save_path, xlsx[list(xlsx.keys())[0]], *args, **kwargs
             )
 
 
@@ -575,25 +573,21 @@ class VideoMixin:
             return mime_type is not None and mime_type.startswith("video/")
 
         video_writer = VideoWriter(backend=self._backend, *args, **kwargs)
-
+        video = self._to_video()
         if not _is_video_file(save_path):
             fn = Path(self._get_input_fn())
             stem = fn.stem
             suffix = fn.suffix if _is_video_file(fn) else ".mp4"
             base_save_path = Path(save_path)
-            for key in self.video:
+            for key in video:
                 save_path = base_save_path / f"{stem}_{key}{suffix}"
-                video_writer.write(
-                    save_path.as_posix(), self.video[key], *args, **kwargs
-                )
+                video_writer.write(save_path.as_posix(), video[key], *args, **kwargs)
         else:
-            if len(self.video) > 1:
+            if len(video) > 1:
                 logging.warning(
                     f"The result has multiple video files need to be saved. But the `save_path` has been specfied as `{save_path}`!"
                 )
-            video_writer.write(
-                save_path, self.video[list(self.video.keys())[0]], *args, **kwargs
-            )
+            video_writer.write(save_path, video[list(video.keys())[0]], *args, **kwargs)
 
 
 class MarkdownMixin:
@@ -667,7 +661,7 @@ class MarkdownMixin:
             self._markdown_writer.write,
             self._img_writer.write,
             self.save_path,
-            self.markdown,
+            self._to_markdown(),
             *args,
             **kwargs,
         )