Browse Source

repair video class to support dict

liuhongen1234567 10 months ago
parent
commit
3b84718fee

+ 1 - 0
api_examples/pipelines/test_video_classification.py

@@ -26,5 +26,6 @@ output = pipeline.predict("./test_samples/general_video_classification_001.mp4",
 for res in output:
     print(res)
     res.print()  ## 打印预测的结构化输出
+    res.save_to_video("./output/1.mp4")  ## 保存结果可视化视频
     res.save_to_video("./output/")  ## 保存结果可视化视频
     res.save_to_json("./output/")  ## 保存预测的结构化输出

+ 1 - 1
paddlex/inference/common/batch_sampler/video_batch_sampler.py

@@ -37,7 +37,7 @@ class VideoBatchSampler(BaseBatchSampler):
     def _get_files_list(self, fp):
         file_list = []
         if fp is None or not os.path.exists(fp):
-            raise Exception(f"Not found any img file in path: {fp}")
+            raise Exception(f"Not found any video file in path: {fp}")
         if os.path.isfile(fp) and fp.split(".")[-1] in self.SUFFIX:
             file_list.append(fp)
         elif os.path.isdir(fp):

+ 53 - 7
paddlex/inference/common/result/mixin.py

@@ -453,21 +453,67 @@ class XlsxMixin:
 
 
 class VideoMixin:
-    def __init__(self, backend="opencv", *args, **kwargs):
+    """Mixin class for adding Video handling capabilities."""
+
+    def __init__(self, backend: str = "opencv", *args: List, **kwargs: Dict) -> None:
+        """Initializes VideoMixin.
+
+        Args:
+            backend (str): The backend to use for video processing. Defaults to "opencv".
+            *args: Additional positional arguments to pass to the VideoWriter.
+            **kwargs: Additional keyword arguments to pass to the VideoWriter.
+        """
         self._backend = backend
         self._save_funcs.append(self.save_to_video)
 
     @abstractmethod
-    def _to_video(self):
+    def _to_video(self) -> Dict[str, np.array]:
+        """Abstract method to convert the result to a video.
+
+        Returns:
+            Dict[str, np.array]: The video representation result.
+        """
         raise NotImplementedError
 
     @property
-    def video(self):
+    def video(self) -> Dict[str, np.array]:
+        """Property to get the video representation of the result.
+
+        Returns:
+            Dict[str, np.array]: The video representation of the result.
+        """
         return self._to_video()
 
-    def save_to_video(self, save_path, *args, **kwargs):
+    def save_to_video(self, save_path: str, *args: List, **kwargs: Dict) -> None:
+        """Saves the video representation of the result to the specified path.
+
+        Args:
+            save_path (str): The path to save the video. If the save path does not end with .mp4 or .avi, it appends the input path's stem and suffix to the save path.
+            *args: Additional positional arguments that will be passed to the video writer.
+            **kwargs: Additional keyword arguments that will be passed to the video writer.
+        """
+
+        def _is_video_file(file_path):
+            mime_type, _ = mimetypes.guess_type(file_path)
+            return mime_type is not None and mime_type.startswith("video/")
+
         video_writer = VideoWriter(backend=self._backend, *args, **kwargs)
-        if not str(save_path).lower().endswith((".mp4", ".avi", ".mkv", ".webm")):
+
+        if not _is_video_file(save_path):
             fp = Path(self["input_path"])
-            save_path = Path(save_path) / f"{fp.stem}{fp.suffix}"
-        video_writer.write(save_path.as_posix(), self.video["video"], *args, **kwargs)
+            stem = fp.stem
+            suffix = fp.suffix
+            base_save_path = Path(save_path)
+            for key in self.video:
+                save_path = base_save_path / f"{stem}_{key}{suffix}"
+                video_writer.write(
+                    save_path.as_posix(), self.video[key], *args, **kwargs
+                )
+        else:
+            if len(self.video) > 1:
+                logging.warning(
+                    f"The result has multiple video files need to be saved. But the `save_path` has been specfied as `{save_path}`!"
+                )
+            video_writer.write(
+                save_path, self.video[list(self.video.keys())[0]], *args, **kwargs
+            )

+ 1 - 1
paddlex/inference/models_new/video_classification/result.py

@@ -77,7 +77,7 @@ class TopkVideoResult(BaseVideoResult):
             draw.text((text_x, text_y), label_str, fill=font_color, font=font)
             image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
             video_list.append(image)
-        return np.array(video_list), write_fps
+        return {"res": (np.array(video_list), write_fps)}
 
     def _get_font_colormap(self, color_index):
         """