Răsfoiți Sursa

yowo for paddlex (#2784)

* yowo for paddlex

* yowo for paddlex

* yowo for paddlex

* Update YOWO.yaml

* Update result.py

* Update YOWO.yaml

* Update trainer.py

* Update config.py

* Update result.py
Sunflower7788 10 luni în urmă
părinte
comite
8de9e2e2b6
30 a modificat fișierele cu 2654 adăugiri și 5 ștergeri
  1. 24 0
      api_examples/pipelines/test_video_detection.py
  2. 40 0
      paddlex/configs/modules/video_detection/YOWO.yaml
  3. 9 0
      paddlex/configs/pipelines/video_detection.yaml
  4. 1 0
      paddlex/inference/models_new/__init__.py
  5. 15 0
      paddlex/inference/models_new/video_detection/__init__.py
  6. 117 0
      paddlex/inference/models_new/video_detection/predictor.py
  7. 459 0
      paddlex/inference/models_new/video_detection/processors.py
  8. 104 0
      paddlex/inference/models_new/video_detection/result.py
  9. 1 0
      paddlex/inference/pipelines_new/__init__.py
  10. 15 0
      paddlex/inference/pipelines_new/video_detection/__init__.py
  11. 67 0
      paddlex/inference/pipelines_new/video_detection/pipeline.py
  12. 22 5
      paddlex/inference/utils/io/readers.py
  13. 1 0
      paddlex/inference/utils/official_models.py
  14. 7 0
      paddlex/modules/__init__.py
  15. 18 0
      paddlex/modules/video_detection/__init__.py
  16. 86 0
      paddlex/modules/video_detection/dataset_checker/__init__.py
  17. 17 0
      paddlex/modules/video_detection/dataset_checker/dataset_src/__init__.py
  18. 101 0
      paddlex/modules/video_detection/dataset_checker/dataset_src/analyse_dataset.py
  19. 134 0
      paddlex/modules/video_detection/dataset_checker/dataset_src/check_dataset.py
  20. 42 0
      paddlex/modules/video_detection/evaluator.py
  21. 22 0
      paddlex/modules/video_detection/exportor.py
  22. 15 0
      paddlex/modules/video_detection/model_list.py
  23. 82 0
      paddlex/modules/video_detection/trainer.py
  24. 1 0
      paddlex/repo_apis/PaddleVideo_api/__init__.py
  25. 144 0
      paddlex/repo_apis/PaddleVideo_api/configs/YOWO.yaml
  26. 19 0
      paddlex/repo_apis/PaddleVideo_api/video_det/__init__.py
  27. 548 0
      paddlex/repo_apis/PaddleVideo_api/video_det/config.py
  28. 298 0
      paddlex/repo_apis/PaddleVideo_api/video_det/model.py
  29. 45 0
      paddlex/repo_apis/PaddleVideo_api/video_det/register.py
  30. 200 0
      paddlex/repo_apis/PaddleVideo_api/video_det/runner.py

+ 24 - 0
api_examples/pipelines/test_video_detection.py

@@ -0,0 +1,24 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from paddlex import create_pipeline
+
+pipeline = create_pipeline(pipeline="video_detection")
+output = pipeline.predict("./test_samples/HorseRiding.avi")
+
+for res in output:
+    print(res)
+    res.print()  ## 打印预测的结构化输出
+    res.save_to_video("./output/")  ## 保存结果可视化视频
+    res.save_to_json("./output/")  ## 保存预测的结构化输出

+ 40 - 0
paddlex/configs/modules/video_detection/YOWO.yaml

@@ -0,0 +1,40 @@
+Global:
+  model: YOWO
+  mode: check_dataset # check_dataset/train/evaluate/predict
+  dataset_dir: "/paddle/dataset/paddlex/video_det/video_det_examples"
+  device: gpu:0
+  output: "output"
+
+CheckDataset:
+  convert: 
+    enable: False
+    src_dataset_type: null
+  split: 
+    enable: False
+    train_percent: null
+    val_percent: null
+
+Train:
+  num_classes: 24
+  epochs_iters: 5
+  batch_size: 8
+  learning_rate: 0.0001
+  pretrain_weight_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/YOWO_pretrain.pdparams"
+  resume_path: null
+  log_interval: 10
+  eval_interval: 1
+  save_interval: 1
+
+Evaluate:
+  weight_path: "output/best_model/best_model.pdparams"
+  log_interval: 1
+
+Export:
+  weight_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/YOWO_pretrain.pdparams"
+
+Predict:
+  batch_size: 1
+  model_dir: "output/best_model/inference"
+  input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/videos/demo_video/HorseRiding.avi"
+  kernel_option:
+    run_mode: paddle

+ 9 - 0
paddlex/configs/pipelines/video_detection.yaml

@@ -0,0 +1,9 @@
+pipeline_name: video_detection
+
+SubModules:
+  VideoDetection:
+    module_name: video_detection
+    model_name: YOWO
+    model_dir: null
+    batch_size: 1    
+    topk: 1

+ 1 - 0
paddlex/inference/models_new/__init__.py

@@ -44,6 +44,7 @@ from .anomaly_detection import UadPredictor
 # from .face_recognition import FaceRecPredictor
 from .multilingual_speech_recognition import WhisperPredictor
 from .video_classification import VideoClasPredictor
+from .video_detection import VideoDetPredictor
 
 
 def _create_hp_predictor(

+ 15 - 0
paddlex/inference/models_new/video_detection/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .predictor import VideoDetPredictor

+ 117 - 0
paddlex/inference/models_new/video_detection/predictor.py

@@ -0,0 +1,117 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Union, Dict, List, Tuple
+from ....utils.func_register import FuncRegister
+from ....modules.video_detection.model_list import MODELS
+from ...common.batch_sampler import VideoBatchSampler
+from ...common.reader import ReadVideo
+from ..common import (
+    ToBatch,
+    StaticInfer,
+)
+from ..base import BasicPredictor
+from .processors import ResizeVideo, Image2Array, NormalizeVideo, DetVideoPostProcess
+from .result import DetVideoResult
+
+
+class VideoDetPredictor(BasicPredictor):
+
+    entities = MODELS
+
+    _FUNC_MAP = {}
+    register = FuncRegister(_FUNC_MAP)
+
+    def __init__(self, topk: Union[int, None] = None, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.pre_tfs, self.infer, self.post_op = self._build()
+
+    def _build_batch_sampler(self):
+        return VideoBatchSampler()
+
+    def _get_result_class(self):
+        return DetVideoResult
+
+    def _build(self):
+        pre_tfs = {}
+        for cfg in self.config["PreProcess"]["transform_ops"]:
+            tf_key = list(cfg.keys())[0]
+            assert tf_key in self._FUNC_MAP
+            func = self._FUNC_MAP[tf_key]
+            args = cfg.get(tf_key, {})
+            name, op = func(self, **args) if args else func(self)
+            if op:
+                pre_tfs[name] = op
+
+        infer = StaticInfer(
+            model_dir=self.model_dir,
+            model_prefix=self.MODEL_FILE_PREFIX,
+            option=self.pp_option,
+        )
+        post_op = {}
+        for cfg in self.config["PostProcess"]["transform_ops"]:
+            tf_key = list(cfg.keys())[0]
+            assert tf_key in self._FUNC_MAP
+            func = self._FUNC_MAP[tf_key]
+            args = cfg.get(tf_key, {})
+            if tf_key == "DetVideoPostProcess":
+                args["label_list"] = self.config["label_list"]
+            name, op = func(self, **args) if args else func(self)
+            if op:
+                post_op[name] = op
+
+        return pre_tfs, infer, post_op
+
+    def process(self, batch_data):
+        batch_raw_videos = self.pre_tfs["ReadVideo"](videos=batch_data)
+        batch_videos = self.pre_tfs["ResizeVideo"](videos=batch_raw_videos)
+        batch_videos = self.pre_tfs["Image2Array"](videos=batch_videos)
+        x = self.pre_tfs["NormalizeVideo"](videos=batch_videos)
+        num_seg = len(x[0])
+        pred_seg = []
+        for i in range(num_seg):
+            batch_preds = self.infer(x=[x[0][i]])
+            pred_seg.append(batch_preds)
+        batch_bboxes = self.post_op["DetVideoPostProcess"](preds=[pred_seg])
+        return {
+            "input_path": batch_data,
+            "result": batch_bboxes,
+        }
+
+    @register("ReadVideo")
+    def build_readvideo(self, num_seg=8):
+        return "ReadVideo", ReadVideo(backend="opencv", num_seg=num_seg)
+
+    @register("ResizeVideo")
+    def build_resize(self, target_size=224):
+        return "ResizeVideo", ResizeVideo(
+            target_size=target_size,
+        )
+
+    @register("Image2Array")
+    def build_image2array(self, data_format="tchw"):
+        return "Image2Array", Image2Array(data_format="tchw")
+
+    @register("NormalizeVideo")
+    def build_normalize(
+        self,
+        scale=255.0,
+    ):
+        return "NormalizeVideo", NormalizeVideo(scale=scale)
+
+    @register("DetVideoPostProcess")
+    def build_postprocess(self, nms_thresh=0.5, score_thresh=0.4, label_list=[]):
+        return "DetVideoPostProcess", DetVideoPostProcess(
+            nms_thresh=nms_thresh, score_thresh=score_thresh, label_list=label_list
+        )

+ 459 - 0
paddlex/inference/models_new/video_detection/processors.py

@@ -0,0 +1,459 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import os.path as osp
+from typing import List, Sequence, Union, Optional, Tuple
+
+import numpy as np
+import cv2
+import lazy_paddle as paddle
+
+
+class ResizeVideo:
+    """Resizes frames of a video to a specified target size.
+
+    This class provides functionality to resize each frame of a video to
+    a specified square dimension (height and width are equal).
+
+    Attributes:
+        target_size (int): The desired size (in pixels) for both the height
+            and width of each frame in the video.
+    """
+
+    def __init__(self, target_size: int = 224) -> None:
+        """Initializes the ResizeVideo with a target size.
+
+        Args:
+            target_size (int): The desired size in pixels for the output
+                frames. Defaults to 224.
+        """
+        super().__init__()
+        self.target_size = target_size
+
+    def resize(self, video: List) -> List:
+        """Resizes all frames of a single video.
+
+        Args:
+            video (list): A list of segments, where each segment is a list
+                of frames represented as numpy arrays.
+
+        Returns:
+            list: The input video with each frame resized to the target size.
+
+        Raises:
+            NotImplementedError: If a frame is not an instance of numpy.ndarray.
+        """
+        num_seg = len(video)
+        seg_len = len(video[0])
+
+        for i in range(num_seg):
+            for j in range(seg_len):
+                img = video[i][j]
+                if isinstance(img, np.ndarray):
+                    h, w, _ = img.shape
+                else:
+                    raise NotImplementedError(
+                        "Currently, only numpy.ndarray frames are supported."
+                    )
+                video[i][j] = cv2.resize(
+                    img,
+                    (self.target_size, self.target_size),
+                    interpolation=cv2.INTER_LINEAR,
+                )
+        return video
+
+    def __call__(self, videos: List) -> List:
+        """Resizes frames of multiple videos.
+
+        Args:
+            videos (list): A list containing multiple videos, where each video
+                is a list of segments, and each segment is a list of frames.
+
+        Returns:
+            list: A list of videos with each frame resized to the target size.
+        """
+        return [self.resize(video) for video in videos]
+
+
+class Image2Array:
+    """Convert a sequence of images to a numpy array with optional transposition."""
+
+    def __init__(self, data_format: str = "tchw") -> None:
+        """
+        Initializes the Image2Array class.
+
+        Args:
+            data_format (str): The format to transpose to, either 'tchw' or 'cthw'.
+
+        Raises:
+            AssertionError: If data_format is not one of the allowed values.
+        """
+        super().__init__()
+        assert data_format in [
+            "tchw",
+            "cthw",
+        ], f"Target format must be in ['tchw', 'cthw'], but got {data_format}"
+        self.data_format = data_format
+
+    def img2array(self, video: List) -> List:
+        """
+        Converts a list of video frames to a numpy array, with frames transposed.
+
+        Args:
+            video (List): A list of frames represented as numpy arrays.
+
+        Returns:
+            List: A numpy array with the video frames transposed and concatenated.
+        """
+        # Transpose each image from HWC to CHW format
+        num_seg = len(video)
+        for i in range(num_seg):
+            video_one = video[i]
+            video_one = [img.transpose([2, 0, 1]) for img in video_one]
+            video_one = np.concatenate(
+                [np.expand_dims(img, axis=1) for img in video_one], axis=1
+            )
+            video[i] = video_one
+        return video
+
+    def __call__(self, videos: List[List[np.ndarray]]) -> List[np.ndarray]:
+        """
+        Process videos by converting each video to a transposed numpy array.
+
+        Args:
+            videos (List[List[np.ndarray]]): A list of videos, where each video is a list
+                of frames represented as numpy arrays.
+
+        Returns:
+            List[np.ndarray]: A list of processed videos with transposed frames.
+        """
+        return [self.img2array(video) for video in videos]
+
+
+class NormalizeVideo:
+    """
+    A class to normalize video frames by scaling the pixel values.
+    """
+
+    def __init__(self, scale: float = 255.0) -> None:
+        """
+        Initializes the NormalizeVideo class.
+
+        Args:
+            scale (float): The scale factor to normalize the frames, usually the max pixel value.
+        """
+        super().__init__()
+        self.scale = scale
+
+    def normalize_video(self, video: List[np.ndarray]) -> List[np.ndarray]:
+        """
+        Normalizes a sequence of images by scaling the pixel values.
+
+        Args:
+            video (List[np.ndarray]): A list of frames, where each frame is a numpy array to be normalized.
+
+        Returns:
+            List[np.ndarray]: The normalized video frames as a list of numpy arrays.
+        """
+        num_seg = len(video)  # Number of frames in the video
+        for i in range(num_seg):
+            # Convert frame to float32 and scale pixel values
+            video[i] = video[i].astype(np.float32) / self.scale
+            # Expand dimensions if needed
+            video[i] = np.expand_dims(video[i], axis=0)
+
+        return video
+
+    def __call__(self, videos: List[List[np.ndarray]]) -> List[List[np.ndarray]]:
+        """
+        Apply normalization to a list of videos.
+
+        Args:
+            videos (List[List[np.ndarray]]): A list of videos, where each video is a list of frames
+                represented as numpy arrays.
+
+        Returns:
+            List[List[np.ndarray]]: A list of normalized videos, each represented as a list of normalized frames.
+        """
+        return [self.normalize_video(video) for video in videos]
+
+
+def convert2cpu(gpu_matrix):
+    float_32_g = gpu_matrix.astype("float32")
+    return float_32_g.cpu()
+
+
+def convert2cpu_long(gpu_matrix):
+    int_64_g = gpu_matrix.astype("int64")
+    return int_64_g.cpu()
+
+
+def get_region_boxes(
+    output,
+    conf_thresh=0.005,
+    num_classes=24,
+    anchors=[
+        0.70458,
+        1.18803,
+        1.26654,
+        2.55121,
+        1.59382,
+        4.08321,
+        2.30548,
+        4.94180,
+        3.52332,
+        5.91979,
+    ],
+    num_anchors=5,
+    only_objectness=1,
+):
+    """
+    Processes the output of a neural network to extract bounding box predictions.
+
+    Args:
+        output (Tensor): The output tensor from the neural network.
+        conf_thresh (float): The confidence threshold for filtering predictions. Default is 0.005.
+        num_classes (int): The number of classes for classification. Default is 24.
+        anchors (List[float]): A list of anchor box dimensions used in the model. Default is a list
+            of 10 predefined anchor values.
+        num_anchors (int): The number of anchor boxes used in the model. Default is 5.
+        only_objectness (int): If set to 1, only objectness scores are considered for filtering. Default is 1.
+    Returns:
+        all_box(List[List[float]]): A list of predicted bounding boxes for each image in the batch.
+    """
+    anchor_step = len(anchors) // num_anchors
+    if output.dim() == 3:
+        output = output.unsqueeze(0)
+    batch = output.shape[0]
+    assert output.shape[1] == (5 + num_classes) * num_anchors
+    h = output.shape[2]
+    w = output.shape[3]
+    all_boxes = []
+    output = paddle.reshape(output, [batch * num_anchors, 5 + num_classes, h * w])
+    output = paddle.transpose(output, (1, 0, 2))
+    output = paddle.reshape(output, [5 + num_classes, batch * num_anchors * h * w])
+
+    grid_x = paddle.linspace(0, w - 1, w)
+    grid_x = paddle.tile(grid_x, [h, 1])
+    grid_x = paddle.tile(grid_x, [batch * num_anchors, 1, 1])
+    grid_x = paddle.reshape(grid_x, [batch * num_anchors * h * w]).cuda()
+
+    grid_y = paddle.linspace(0, h - 1, h)
+    grid_y = paddle.tile(grid_y, [w, 1]).t()
+    grid_y = paddle.tile(grid_y, [batch * num_anchors, 1, 1])
+    grid_y = paddle.reshape(grid_y, [batch * num_anchors * h * w]).cuda()
+
+    sigmoid = paddle.nn.Sigmoid()
+    xs = sigmoid(output[0]) + grid_x
+    ys = sigmoid(output[1]) + grid_y
+
+    anchor_w = paddle.to_tensor(anchors)
+    anchor_w = paddle.reshape(anchor_w, [num_anchors, anchor_step])
+    anchor_w = paddle.index_select(
+        anchor_w, index=paddle.to_tensor(np.array([0]).astype("int32")), axis=1
+    )
+
+    anchor_h = paddle.to_tensor(anchors)
+    anchor_h = paddle.reshape(anchor_h, [num_anchors, anchor_step])
+    anchor_h = paddle.index_select(
+        anchor_h, index=paddle.to_tensor(np.array([1]).astype("int32")), axis=1
+    )
+
+    anchor_w = paddle.tile(anchor_w, [batch, 1])
+    anchor_w = paddle.tile(anchor_w, [1, 1, h * w])
+    anchor_w = paddle.reshape(anchor_w, [batch * num_anchors * h * w]).cuda()
+
+    anchor_h = paddle.tile(anchor_h, [batch, 1])
+    anchor_h = paddle.tile(anchor_h, [1, 1, h * w])
+    anchor_h = paddle.reshape(anchor_h, [batch * num_anchors * h * w]).cuda()
+
+    ws = paddle.exp(output[2]) * anchor_w
+    hs = paddle.exp(output[3]) * anchor_h
+
+    det_confs = sigmoid(output[4])
+
+    cls_confs = paddle.to_tensor(output[5 : 5 + num_classes], stop_gradient=True)
+    cls_confs = paddle.transpose(cls_confs, [1, 0])
+    s = paddle.nn.Softmax()
+    cls_confs = paddle.to_tensor(s(cls_confs))
+
+    cls_max_confs = paddle.max(cls_confs, axis=1)
+    cls_max_ids = paddle.argmax(cls_confs, axis=1)
+
+    cls_max_confs = paddle.reshape(cls_max_confs, [-1])
+    cls_max_ids = paddle.reshape(cls_max_ids, [-1])
+
+    sz_hw = h * w
+    sz_hwa = sz_hw * num_anchors
+
+    det_confs = convert2cpu(det_confs)
+    cls_max_confs = convert2cpu(cls_max_confs)
+    cls_max_ids = convert2cpu_long(cls_max_ids)
+    xs = convert2cpu(xs)
+    ys = convert2cpu(ys)
+    ws = convert2cpu(ws)
+    hs = convert2cpu(hs)
+    for b in range(batch):
+        boxes = []
+        for cy in range(h):
+            for cx in range(w):
+                for i in range(num_anchors):
+                    ind = b * sz_hwa + i * sz_hw + cy * w + cx
+                    det_conf = det_confs[ind]
+                    if only_objectness:
+                        conf = det_confs[ind]
+                    else:
+                        conf = det_confs[ind] * cls_max_confs[ind]
+
+                    if conf > conf_thresh:
+                        bcx = xs[ind]
+                        bcy = ys[ind]
+                        bw = ws[ind]
+                        bh = hs[ind]
+                        cls_max_conf = cls_max_confs[ind]
+                        cls_max_id = cls_max_ids[ind]
+                        box = [
+                            bcx / w,
+                            bcy / h,
+                            bw / w,
+                            bh / h,
+                            det_conf,
+                            cls_max_conf,
+                            cls_max_id,
+                        ]
+                        boxes.append(box)
+        all_boxes.append(boxes)
+    return all_boxes
+
+
+def nms(boxes, nms_thresh):
+    """
+    Performs non-maximum suppression on the input boxes based on their IoUs.
+    """
+    if len(boxes) == 0:
+        return boxes
+    det_confs = paddle.zeros([len(boxes)])
+    for i in range(len(boxes)):
+        det_confs[i] = 1 - boxes[i][4]
+
+    sortIds = paddle.argsort(det_confs)
+    out_boxes = []
+    for i in range(len(boxes)):
+        box_i = boxes[sortIds[i]]
+        if box_i[4] > 0:
+            out_boxes.append(box_i)
+            for j in range(i + 1, len(boxes)):
+                box_j = boxes[sortIds[j]]
+                if bbox_iou(box_i, box_j, x1y1x2y2=False) > nms_thresh:
+                    box_j[4] = 0
+    return out_boxes
+
+
+def bbox_iou(box1, box2, x1y1x2y2=True):
+    """
+    Returns the Intersection over Union (IoU) of two bounding boxes.
+    """
+    if x1y1x2y2:
+        mx = min(box1[0], box2[0])
+        Mx = max(box1[2], box2[2])
+        my = min(box1[1], box2[1])
+        My = max(box1[3], box2[3])
+        w1 = box1[2] - box1[0]
+        h1 = box1[3] - box1[1]
+        w2 = box2[2] - box2[0]
+        h2 = box2[3] - box2[1]
+    else:
+        mx = min(float(box1[0] - box1[2] / 2.0), float(box2[0] - box2[2] / 2.0))
+        Mx = max(float(box1[0] + box1[2] / 2.0), float(box2[0] + box2[2] / 2.0))
+        my = min(float(box1[1] - box1[3] / 2.0), float(box2[1] - box2[3] / 2.0))
+        My = max(float(box1[1] + box1[3] / 2.0), float(box2[1] + box2[3] / 2.0))
+        w1 = box1[2]
+        h1 = box1[3]
+        w2 = box2[2]
+        h2 = box2[3]
+    uw = Mx - mx
+    uh = My - my
+    cw = w1 + w2 - uw
+    ch = h1 + h2 - uh
+    carea = 0
+    if cw <= 0 or ch <= 0:
+        return paddle.to_tensor(0.0)
+
+    area1 = w1 * h1
+    area2 = w2 * h2
+    carea = cw * ch
+    uarea = area1 + area2 - carea
+    return carea / uarea
+
+
+class DetVideoPostProcess:
+    """
+    A class used to perform post-processing on detection results in videos.
+    """
+
+    def __init__(
+        self,
+        nms_thresh: float = 0.5,
+        score_thresh: float = 0.5,
+        label_list: List[str] = [],
+    ) -> None:
+        """
+        Args:
+            nms_thresh : float
+                The IoU (Intersection over Union) threshold used for Non-Maximum Suppression (NMS).
+                Detections with an IoU greater than this threshold will be suppressed.
+            score_thresh : float
+                The threshold for filtering out low-confidence detections.
+                Detections with a confidence score below this threshold will be discarded.
+            labels : List[str]
+                A list of labels or class names associated with the detection results.
+        """
+        super().__init__()
+
+        self.nms_thresh = nms_thresh
+        self.score_thresh = score_thresh
+        self.labels = label_list
+
+    def postprocess(self, pred: List) -> List:
+        font = cv2.FONT_HERSHEY_SIMPLEX
+        num_seg = len(pred)
+        pred_all = []
+        for i in range(num_seg):
+            outputs = pred[i]
+            for out in outputs:
+                preds = []
+                out = paddle.to_tensor(out)
+                all_boxes = get_region_boxes(out, self.score_thresh, len(self.labels))
+                for i in range(out.shape[0]):
+                    boxes = all_boxes[i]
+                    boxes = nms(boxes, self.nms_thresh)
+
+                    for box in boxes:
+                        x1 = round(float(box[0] - box[2] / 2.0) * 320.0)
+                        y1 = round(float(box[1] - box[3] / 2.0) * 240.0)
+                        x2 = round(float(box[0] + box[2] / 2.0) * 320.0)
+                        y2 = round(float(box[1] + box[3] / 2.0) * 240.0)
+
+                        det_conf = float(box[4])
+                        for j in range((len(box) - 5) // 2):
+                            cls_conf = float(box[5 + 2 * j].item())
+                            prob = det_conf * cls_conf
+                        preds.append([[x1, y1, x2, y2], prob, self.labels[int(box[6])]])
+            pred_all.append(preds)
+        return pred_all
+
+    def __call__(self, preds: List) -> List:
+        return [self.postprocess(pred) for pred in preds]

+ 104 - 0
paddlex/inference/models_new/video_detection/result.py

@@ -0,0 +1,104 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import cv2
+import numpy as np
+import random
+import PIL
+from PIL import Image, ImageDraw, ImageFont
+
+from ....utils.fonts import PINGFANG_FONT_FILE_PATH
+from ...utils.color_map import get_colormap
+from ...utils.io import VideoReader
+from ...common.result import BaseVideoResult
+
+
+class DetVideoResult(BaseVideoResult):
+
+    def _to_video(self):
+        """Draw label on image"""
+        video_reader = VideoReader(backend="decord")
+        video = video_reader.read(self["input_path"])
+        video = list(video)
+        write_fps = video_reader.get_fps()
+        label2color = {}
+        catid2fontcolor = {}
+        color_list = get_colormap(rgb=True)
+        video_list = []
+
+        for i in range(len(video)):
+            image = Image.fromarray(video[i].asnumpy())
+            image_size = image.size
+            font_size = int(0.018 * int(image.width)) + 2
+            font = ImageFont.truetype(
+                PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8"
+            )
+            draw_thickness = int(max(image.size) * 0.002)
+            draw = ImageDraw.Draw(image)
+            results = self["result"][i]
+            for result in results:
+                bbox, score, class_name = result
+                if class_name not in label2color:
+                    random_index = random.randint(0, len(color_list) - 1)
+                    label2color[class_name] = color_list[random_index]
+                    catid2fontcolor[class_name] = self._get_font_colormap(random_index)
+                color = tuple(label2color[class_name])
+                font_color = tuple(catid2fontcolor[class_name])
+                xmin, ymin, xmax, ymax = bbox
+                rectangle = [
+                    (xmin, ymin),
+                    (xmin, ymax),
+                    (xmax, ymax),
+                    (xmax, ymin),
+                    (xmin, ymin),
+                ]
+                draw.line(
+                    rectangle,
+                    width=draw_thickness,
+                    fill=color,
+                )
+                text = "{} {:.2f}".format(class_name, score)
+                if tuple(map(int, PIL.__version__.split("."))) <= (10, 0, 0):
+                    tw, th = draw.textsize(text, font=font)
+                else:
+                    left, top, right, bottom = draw.textbbox((0, 0), text, font)
+                    tw, th = right - left, bottom - top + 4
+                if ymin < th:
+                    draw.rectangle(
+                        [(xmin, ymin), (xmin + tw + 4, ymin + th + 1)], fill=color
+                    )
+                    draw.text((xmin + 2, ymin - 2), text, fill=font_color, font=font)
+                else:
+                    draw.rectangle(
+                        [(xmin, ymin - th), (xmin + tw + 4, ymin + 1)], fill=color
+                    )
+                    draw.text(
+                        (xmin + 2, ymin - th - 2), text, fill=font_color, font=font
+                    )
+
+            image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
+            video_list.append(image)
+        return {"res": (np.array(video_list), write_fps)}
+
+    def _get_font_colormap(self, color_index):
+        """
+        Get font colormap
+        """
+        dark = np.array([0x14, 0x0E, 0x35])
+        light = np.array([0xFF, 0xFF, 0xFF])
+        light_indexs = [0, 3, 4, 8, 9, 13, 14, 18, 19]
+        if color_index in light_indexs:
+            return light.astype("int32")
+        else:
+            return dark.astype("int32")

+ 1 - 0
paddlex/inference/pipelines_new/__init__.py

@@ -30,6 +30,7 @@ from .multilingual_speech_recognition import MultilingualSpeechRecognitionPipeli
 from .formula_recognition import FormulaRecognitionPipeline
 from .image_multilabel_classification import ImageMultiLabelClassificationPipeline
 from .video_classification import VideoClassificationPipeline
+from .video_detection import VideoDetectionPipeline
 from .anomaly_detection import AnomalyDetectionPipeline
 from .ts_forecasting import TSFcPipeline
 from .ts_anomaly_detection import TSAnomalyDetPipeline

+ 15 - 0
paddlex/inference/pipelines_new/video_detection/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .pipeline import VideoDetectionPipeline

+ 67 - 0
paddlex/inference/pipelines_new/video_detection/pipeline.py

@@ -0,0 +1,67 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional, Union
+import numpy as np
+from ...utils.pp_option import PaddlePredictorOption
+from ..base import BasePipeline
+
+# [TODO] 待更新models_new到models
+from ...models_new.video_detection.result import DetVideoResult
+
+
+class VideoDetectionPipeline(BasePipeline):
+    """Video detection Pipeline"""
+
+    entities = "video_detection"
+
+    def __init__(
+        self,
+        config: Dict,
+        device: str = None,
+        pp_option: PaddlePredictorOption = None,
+        use_hpip: bool = False,
+        hpi_params: Optional[Dict[str, Any]] = None,
+    ) -> None:
+        """
+        Initializes the class with given configurations and options.
+
+        Args:
+            config (Dict): Configuration dictionary containing model and other parameters.
+            device (str): The device to run the prediction on. Default is None.
+            pp_option (PaddlePredictorOption): Options for PaddlePaddle predictor. Default is None.
+            use_hpip (bool): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
+            hpi_params (Optional[Dict[str, Any]]): HPIP specific parameters. Default is None.
+        """
+        super().__init__(
+            device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
+        )
+
+        video_detection_model_config = config["SubModules"]["VideoDetection"]
+        self.video_detection_model = self.create_model(video_detection_model_config)
+
+    def predict(
+        self, input: str | list[str] | np.ndarray | list[np.ndarray], **kwargs
+    ) -> DetVideoResult:
+        """Predicts video detection results for the given input.
+
+        Args:
+            input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) or path(s) to the images.
+            **kwargs: Additional keyword arguments that can be passed to the function.
+
+        Returns:
+            DetVideoResult: The predicted video detection results.
+        """
+
+        yield from self.video_detection_model(input)

+ 22 - 5
paddlex/inference/utils/io/readers.py

@@ -275,6 +275,7 @@ class OpenCVVideoReaderBackend(_VideoReaderBackend):
     def __init__(self, **bk_args):
         super().__init__()
         self.cap_init_args = bk_args
+        self.num_seg = bk_args.get("num_seg", None)
         self._cap = None
         self._pos = 0
         self._max_num_frames = None
@@ -293,11 +294,27 @@ class OpenCVVideoReaderBackend(_VideoReaderBackend):
 
     def _read_frames(self, cap):
         """read frames"""
-        while True:
-            ret, frame = cap.read()
-            if not ret:
-                break
-            yield frame
+        if self.num_seg:
+            queue = []
+            while True:
+                ret, frame = cap.read()
+                if not ret:
+                    break
+                queue = []
+                if (
+                    len(queue) <= 0
+                ):  # At initialization, populate queue with initial frame
+                    for i in range(self.num_seg):
+                        queue.append(frame)
+                queue.append(frame)
+                queue.pop(0)
+                yield queue.copy()
+        else:
+            while True:
+                ret, frame = cap.read()
+                if not ret:
+                    break
+                yield frame
         self._cap_release()
 
     def _cap_open(self, video_path):

+ 1 - 0
paddlex/inference/utils/official_models.py

@@ -318,6 +318,7 @@ PP-LCNet_x1_0_vehicle_attribute_infer.tar",
     "PP-TSMv2-LCNetV2_16frames_uniform": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0rc0/PP-TSMv2-LCNetV2_16frames_uniform_infer.tar",
     "MaskFormer_tiny": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0rc0/MaskFormer_tiny_infer.tar",
     "MaskFormer_small": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0rc0/MaskFormer_small_infer.tar",
+    "YOWO": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b2/YOWO_infer.tar",
     "PP-TinyPose_128x96": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0rc0/PP-TinyPose_128x96_infer.tar",
     "PP-TinyPose_256x192": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0rc0/PP-TinyPose_256x192_infer.tar",
 }

+ 7 - 0
paddlex/modules/__init__.py

@@ -118,6 +118,13 @@ from .video_classification import (
     VideoClsExportor,
 )
 
+from .video_detection import (
+    VideoDetDatasetChecker,
+    VideoDetTrainer,
+    VideoDetEvaluator,
+    VideoDetExportor,
+)
+
 from .multilingual_speech_recognition import (
     WhisperDatasetChecker,
     WhisperTrainer,

+ 18 - 0
paddlex/modules/video_detection/__init__.py

@@ -0,0 +1,18 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .trainer import VideoDetTrainer
+from .dataset_checker import VideoDetDatasetChecker
+from .evaluator import VideoDetEvaluator
+from .exportor import VideoDetExportor

+ 86 - 0
paddlex/modules/video_detection/dataset_checker/__init__.py

@@ -0,0 +1,86 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pathlib import Path
+
+from ...base import BaseDatasetChecker
+from .dataset_src import check, deep_analyse
+from ..model_list import MODELS
+
+
+class VideoDetDatasetChecker(BaseDatasetChecker):
+    """Dataset Checker for Image Classification Model"""
+
+    entities = MODELS
+    sample_num = 10
+
+    def convert_dataset(self, src_dataset_dir: str) -> str:
+        """convert the dataset from other type to specified type
+
+        Args:
+            src_dataset_dir (str): the root directory of dataset.
+
+        Returns:
+            str: the root directory of converted dataset.
+        """
+        return src_dataset_dir
+
+    def split_dataset(self, src_dataset_dir: str) -> str:
+        """repartition the train and validation dataset
+
+        Args:
+            src_dataset_dir (str): the root directory of dataset.
+
+        Returns:
+            str: the root directory of splited dataset.
+        """
+        return src_dataset_dir
+
+    def check_dataset(self, dataset_dir: str, sample_num: int = sample_num) -> dict:
+        """check if the dataset meets the specifications and get dataset summary
+
+        Args:
+            dataset_dir (str): the root directory of dataset.
+            sample_num (int): the number to be sampled.
+        Returns:
+            dict: dataset summary.
+        """
+        return check(dataset_dir, self.output)
+
+    def analyse(self, dataset_dir: str) -> dict:
+        """deep analyse dataset
+
+        Args:
+            dataset_dir (str): the root directory of dataset.
+
+        Returns:
+            dict: the deep analysis results.
+        """
+        return deep_analyse(dataset_dir, self.output)
+
+    def get_show_type(self) -> str:
+        """get the show type of dataset
+
+        Returns:
+            str: show type
+        """
+        return "video"
+
+    def get_dataset_type(self) -> str:
+        """return the dataset type
+
+        Returns:
+            str: dataset type
+        """
+        return "VideoDetDataset"

+ 17 - 0
paddlex/modules/video_detection/dataset_checker/dataset_src/__init__.py

@@ -0,0 +1,17 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .check_dataset import check
+from .analyse_dataset import deep_analyse

+ 101 - 0
paddlex/modules/video_detection/dataset_checker/dataset_src/analyse_dataset.py

@@ -0,0 +1,101 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import json
+import math
+import platform
+from pathlib import Path
+
+from collections import defaultdict
+from PIL import Image
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib import font_manager
+from matplotlib.backends.backend_agg import FigureCanvasAgg
+
+from .....utils.file_interface import custom_open
+from .....utils.fonts import PINGFANG_FONT_FILE_PATH
+
+
+def deep_analyse(dataset_path, output):
+    """class analysis for dataset"""
+    tags = ["train", "val"]
+    labels_cnt = defaultdict(str)
+    label_path = os.path.join(dataset_path, "label_map.txt")
+    with custom_open(label_path, "r") as f:
+        lines = f.readlines()
+    for line in lines:
+        line = line.strip().split()
+        labels_cnt[line[1]] = " ".join(line[0])
+    for tag in tags:
+        anno_path = os.path.join(dataset_path, f"{tag}.txt")
+        classes_num = defaultdict(int)
+        for i in range(len(labels_cnt)):
+            classes_num[labels_cnt[str(i)]] = 0
+        with custom_open(anno_path, "r") as f:
+            lines = f.readlines()
+        for line in lines:
+            line = line.strip().split()
+            label_file_path = os.path.join(dataset_path, line[0])
+            with custom_open(label_file_path, "r") as f:
+                label_lines = f.readlines()
+                for label_line in label_lines:
+                    label_info = label_line.strip().split(" ")
+                    classes_num[labels_cnt[label_info[0]]] += 1
+
+        if tag == "train":
+            cnts_train = [cat_ids for cat_name, cat_ids in classes_num.items()]
+        elif tag == "val":
+            cnts_val = [cat_ids for cat_name, cat_ids in classes_num.items()]
+
+    classes = [cat_name for cat_name, cat_ids in classes_num.items()]
+    sorted_id = sorted(
+        range(len(cnts_train)), key=lambda k: cnts_train[k], reverse=True
+    )
+    cnts_train_sorted = [cnts_train[index] for index in sorted_id]
+    cnts_val_sorted = [cnts_val[index] for index in sorted_id]
+    classes_sorted = [classes[index] for index in sorted_id]
+    x = np.arange(len(classes))
+    width = 0.5
+
+    # bar
+    os_system = platform.system().lower()
+    if os_system == "windows":
+        plt.rcParams["font.sans-serif"] = "FangSong"
+    else:
+        font = font_manager.FontProperties(fname=PINGFANG_FONT_FILE_PATH, size=10)
+    fig, ax = plt.subplots(figsize=(max(8, int(len(classes) / 5)), 5), dpi=300)
+    ax.bar(x, cnts_train_sorted, width=0.5, label="train")
+    ax.bar(x + width, cnts_val_sorted, width=0.5, label="val")
+    plt.xticks(
+        x + width / 2,
+        classes_sorted,
+        rotation=90,
+        fontproperties=None if os_system == "windows" else font,
+    )
+    ax.set_xlabel(
+        "类别名称", fontproperties=None if os_system == "windows" else font, fontsize=12
+    )
+    ax.set_ylabel(
+        "样本框数量",
+        fontproperties=None if os_system == "windows" else font,
+        fontsize=12,
+    )
+    plt.legend(loc=1)
+    fig.tight_layout()
+    file_path = os.path.join(output, "histogram.png")
+    fig.savefig(file_path, dpi=300)
+
+    return {"histogram": os.path.join("check_dataset", "histogram.png")}

+ 134 - 0
paddlex/modules/video_detection/dataset_checker/dataset_src/check_dataset.py

@@ -0,0 +1,134 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import os.path as osp
+import random
+from PIL import Image, ImageOps
+from collections import defaultdict
+
+from .....utils.file_interface import custom_open
+from .....utils.errors import DatasetFileNotFoundError, CheckFailedError
+
+
+def check(dataset_dir, output, sample_num=10):
+    """check dataset"""
+    dataset_dir = osp.abspath(dataset_dir)
+    # Custom dataset
+    if not osp.exists(dataset_dir) or not osp.isdir(dataset_dir):
+        raise DatasetFileNotFoundError(file_path=dataset_dir)
+
+    tags = ["train", "val"]
+    delim = " "
+    valid_num_parts = 5
+
+    sample_cnts = dict()
+    label_map_dict = dict()
+    sample_paths = defaultdict(list)
+    labels = []
+    image_dir = osp.join(dataset_dir, "rgb-images")
+    label_dir = osp.join(dataset_dir, "labels")
+    if not osp.exists(image_dir):
+        raise DatasetFileNotFoundError(file_path=image_dir)
+    if not osp.exists(label_dir):
+        raise DatasetFileNotFoundError(file_path=label_dir)
+
+    label_map_file = osp.join(dataset_dir, "label_map.txt")
+    if not osp.exists(label_map_file):
+        raise DatasetFileNotFoundError(
+            file_path=label_map_file,
+            solution=f"Ensure that `label_map.txt` exist in {dataset_dir}",
+        )
+    with open(label_map_file, "r", encoding="utf-8") as f:
+        all_lines = f.readlines()
+        for line in all_lines:
+            substr = line.strip("\n").split(" ", 1)
+            try:
+                label_idx = int(substr[1])
+                labels.append(label_idx)
+                label_map_dict[label_idx] = str(substr[0])
+            except:
+                raise CheckFailedError(
+                    f"Ensure that the second number in each line in {label_map_file} should be int."
+                )
+    if min(labels) != 1:
+        raise CheckFailedError(
+            f"Ensure that the index starts from 1 in `{label_map_file}`."
+        )
+
+    for tag in tags:
+        file_list = osp.join(dataset_dir, f"{tag}.txt")
+        if not osp.exists(file_list):
+            if tag in ("train", "val"):
+                # train and val file lists must exist
+                raise DatasetFileNotFoundError(
+                    file_path=file_list,
+                    solution=f"Ensure that both `train.txt` and `val.txt` exist in {dataset_dir}",
+                )
+            else:
+                # tag == 'test'
+                continue
+        else:
+            with open(file_list, "r", encoding="utf-8") as f:
+                all_lines = f.readlines()
+                random.seed(123)
+                random.shuffle(all_lines)
+                sample_cnts[tag] = len(all_lines)
+
+                for line in all_lines:
+                    substr = line.strip("\n")
+                    label_path = osp.join(dataset_dir, substr)
+                    img_path = (
+                        osp.join(dataset_dir, substr)
+                        .replace("labels", "rgb-images")
+                        .replace("txt", "jpg")
+                    )
+
+                    if not osp.exists(img_path):
+                        raise DatasetFileNotFoundError(file_path=img_path)
+                    if not osp.exists(label_path):
+                        raise DatasetFileNotFoundError(file_path=label_path)
+                    with custom_open(label_path, "r") as f:
+                        label_lines = f.readlines()
+                        for label_line in label_lines:
+                            label_info = label_line.strip().split(" ")
+                            try:
+                                label = int(label_info[0])
+                            except (ValueError, TypeError) as e:
+                                raise CheckFailedError(
+                                    f"Ensure that the first number in each line in {label_info} should be int."
+                                ) from e
+                                if len(label_info) != valid_num_parts:
+                                    raise CheckFailedError(
+                                        f"Ensure that each line in {label_line} has exactly two numbers."
+                                    )
+
+                    if len(sample_paths[tag]) < sample_num:
+                        sample_path = osp.join(
+                            "check_dataset", os.path.relpath(img_path, output)
+                        )
+                        sample_paths[tag].append(sample_path)
+
+    num_classes = max(labels)
+
+    attrs = {}
+    attrs["label_file"] = osp.relpath(label_map_file, output)
+    attrs["num_classes"] = num_classes
+    attrs["train_samples"] = sample_cnts["train"]
+    attrs["train_sample_paths"] = sample_paths["train"]
+
+    attrs["val_samples"] = sample_cnts["val"]
+    attrs["val_sample_paths"] = sample_paths["val"]
+
+    return attrs

+ 42 - 0
paddlex/modules/video_detection/evaluator.py

@@ -0,0 +1,42 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..base import BaseEvaluator
+from .model_list import MODELS
+
+
+class VideoDetEvaluator(BaseEvaluator):
+    """Image Classification Model Evaluator"""
+
+    entities = MODELS
+
+    def update_config(self):
+        """update evalution config"""
+        if self.eval_config.log_interval:
+            self.pdx_config.update_log_interval(self.eval_config.log_interval)
+        self.pdx_config.update_dataset(
+            self.global_config.dataset_dir, "VideoDetDataset"
+        )
+        self.pdx_config.update_pretrained_weights(self.eval_config.weight_path)
+
+    def get_eval_kwargs(self) -> dict:
+        """get key-value arguments of model evalution function
+
+        Returns:
+            dict: the arguments of evaluation function.
+        """
+        return {
+            "weight_path": self.eval_config.weight_path,
+            "device": self.get_device(using_device_number=1),
+        }

+ 22 - 0
paddlex/modules/video_detection/exportor.py

@@ -0,0 +1,22 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..base import BaseExportor
+from .model_list import MODELS
+
+
+class VideoDetExportor(BaseExportor):
+    """Image Classification Model Exportor"""
+
+    entities = MODELS

+ 15 - 0
paddlex/modules/video_detection/model_list.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+MODELS = ["YOWO"]

+ 82 - 0
paddlex/modules/video_detection/trainer.py

@@ -0,0 +1,82 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import shutil
+from pathlib import Path
+
+from ..base import BaseTrainer
+from .model_list import MODELS
+from ...utils.config import AttrDict
+
+
+class VideoDetTrainer(BaseTrainer):
+    """Image Classification Model Trainer"""
+
+    entities = MODELS
+
+    def dump_label_dict(self, src_label_dict_path: str):
+        """dump label dict config
+
+        Args:
+            src_label_dict_path (str): path to label dict file to be saved.
+        """
+        dst_label_dict_path = Path(self.global_config.output).joinpath("label_map.txt")
+        shutil.copyfile(src_label_dict_path, dst_label_dict_path)
+
+    def update_config(self):
+        """update training config"""
+        if self.train_config.log_interval:
+            self.pdx_config.update_log_interval(self.train_config.log_interval)
+        if self.train_config.eval_interval:
+            self.pdx_config.update_eval_interval(self.train_config.eval_interval)
+        if self.train_config.save_interval:
+            self.pdx_config.update_save_interval(self.train_config.save_interval)
+        if self.train_config.num_classes is not None:
+            self.pdx_config.update_num_classes(self.train_config.num_classes)
+
+        self.pdx_config.update_dataset(
+            self.global_config.dataset_dir, "VideoDetDataset"
+        )
+        if self.train_config.pretrain_weight_path != "":
+            self.pdx_config.update_pretrained_weights(
+                self.train_config.pretrain_weight_path
+            )
+
+        label_dict_path = Path(self.global_config.dataset_dir).joinpath("label_map.txt")
+        if label_dict_path.exists():
+            self.pdx_config.update_label_list(label_dict_path)
+        if self.train_config.batch_size is not None:
+            self.pdx_config.update_batch_size(self.train_config.batch_size)
+        if self.train_config.learning_rate is not None:
+            self.pdx_config.update_learning_rate(self.train_config.learning_rate)
+        if self.train_config.epochs_iters is not None:
+            self.pdx_config._update_epochs(self.train_config.epochs_iters)
+        if self.global_config.output is not None:
+            self.pdx_config._update_output_dir(self.global_config.output)
+
+    def get_train_kwargs(self) -> dict:
+        """get key-value arguments of model training function
+
+        Returns:
+            dict: the arguments of training function.
+        """
+        train_args = {"device": self.get_device()}
+        if (
+            self.train_config.resume_path is not None
+            and self.train_config.resume_path != ""
+        ):
+            train_args["resume_path"] = self.train_config.resume_path
+        train_args["dy2st"] = self.train_config.get("dy2st", False)
+        return train_args

+ 1 - 0
paddlex/repo_apis/PaddleVideo_api/__init__.py

@@ -14,3 +14,4 @@
 
 
 from .video_cls import VideoClsModel, VideoClsRunner, register
+from .video_det import VideoDetModel, VideoDetRunner, register

+ 144 - 0
paddlex/repo_apis/PaddleVideo_api/configs/YOWO.yaml

@@ -0,0 +1,144 @@
+Global:
+  checkpoints: null
+  pretrained_model: https://videotag.bj.bcebos.com/PaddleVideo-release2.3/YOWO_epoch_00005.pdparams
+  output_dir: ./output/
+  device: gpu
+  use_visualdl: False
+  save_inference_dir: ./inference
+  # training model under @to_static
+  to_static: False
+  algorithm: YOWO
+
+MODEL: #MODEL field
+    framework: "YOWOLocalizer" #Mandatory, indicate the type of network, associate to the 'paddlevideo/modeling/framework/' .
+    backbone: #Mandatory, indicate the type of backbone, associate to the 'paddlevideo/modeling/backbones/' .
+        name: "YOWO" #Mandatory, The name of backbone.
+        num_class: 24
+    loss:
+        name: "RegionLoss"
+        num_classes: 24
+        num_anchors: 5
+        anchors: [0.70458, 1.18803, 1.26654, 2.55121, 1.59382, 4.08321, 2.30548, 4.94180, 3.52332, 5.91979]
+        object_scale: 5
+        noobject_scale: 1
+        class_scale: 1
+        coord_scale: 1
+
+DATASET: #DATASET field
+    batch_size: 8 #Mandatory, bacth size
+    num_workers: 4 #Mandatory, XXX the number of subprocess on each GPU.
+    test_batch_size: 4
+    valid_batch_size: 4
+    train:
+        format: "UCF24Dataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset'
+        image_dir: "data/ucf24" #Mandatory, raw data folder path
+        file_path: "data/ucf24/trainlist.txt" #Mandatory, train data index file path
+    valid:
+        format: "UCF24Dataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset'
+        image_dir: "data/ucf24" #Mandatory, raw data folder path
+        file_path: "data/ucf24/testlist.txt" #Mandatory, test data index file path
+    test:
+        format: "UCF24Dataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset'
+        image_dir: "data/ucf24" 
+        file_path: "data/ucf24/testlist.txt" #Mandatory, test data index file path
+
+PIPELINE: #PIPELINE field TODO.....
+    train: #Mandotary, indicate the pipeline to deal with the training data, associate to the 'paddlevideo/loader/pipelines/'
+        sample:
+            name: "SamplerUCF24"
+            num_frames: 16
+            valid_mode: False
+        transform: #Mandotary, image transform operator.
+            - YowoAug:
+                valid_mode: False
+    valid: #Mandotary, indicate the pipeline to deal with the training data, associate to the 'paddlevideo/loader/pipelines/'
+        sample:
+            name: "SamplerUCF24"
+            num_frames: 16
+            valid_mode: True
+        transform: #Mandotary, image transform operator.
+            - YowoAug:
+                valid_mode: True
+    test:
+        sample:
+            name: "SamplerUCF24"
+            num_frames: 16
+            valid_mode: True
+        transform: #Mandotary, image transform operator.
+            - YowoAug:
+                valid_mode: True
+
+OPTIMIZER: #OPTIMIZER field
+    name: Adam
+    learning_rate:
+        learning_rate: 0.0001
+        name: 'MultiStepDecay'
+        milestones: [1, 2, 3, 4]
+        gamma: 0.5
+    weight_decay:
+        name: "L2"
+        value: 0.0005
+
+GRADIENT_ACCUMULATION:
+    global_batch_size: 128 # Specify the sum of batches to be calculated by all GPUs
+
+METRIC:
+    name: 'YOWOMetric'
+    gt_folder: 'data/ucf24/groundtruths_ucf'
+    result_path: 'output/detections_test'
+    threshold: 0.5
+    log_interval: 100
+    for_paddlex: True
+
+INFERENCE:
+    name: 'YOWO_Inference_helper'
+    num_seg: 16
+    target_size: 224
+
+Infer:
+    transforms:
+        - ReadVideo:
+            num_seg: 16
+        - ResizeVideo:
+            target_size: 224
+        - Image2Array:
+            data_format: 'tchw'
+        - NormalizeVideo:
+            scale: 255.0
+    PostProcess:
+        - DetVideoPostProcess:
+            nms_thresh: 0.5
+            score_thresh: 0.4
+label_list:
+    - Basketball
+    - BasketballDunk
+    - Biking
+    - CliffDiving
+    - CricketBowling
+    - Diving
+    - Fencing
+    - FloorGymnastics
+    - GolfSwing
+    - HorseRiding
+    - IceDancing
+    - LongJump
+    - PoleVault
+    - RopeClimbing
+    - SalsaSpin
+    - SkateBoarding
+    - Skiing
+    - Skijet
+    - SoccerJuggling
+    - Surfing
+    - TennisSwing
+    - TrampolineJumping
+    - VolleyballSpiking
+    - WalkingWithDog
+
+model_name: "YOWO"
+log_interval: 20 #Optional, the interal of logger, default:10
+save_interval: 1
+epochs: 5 #Mandatory, total epoch
+log_level: "INFO" #Optional, the logger level. default: "INFO"
+val_interval: 1
+label_dict_path: data/ucf24/label_map.txt

+ 19 - 0
paddlex/repo_apis/PaddleVideo_api/video_det/__init__.py

@@ -0,0 +1,19 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .model import VideoDetModel
+from .runner import VideoDetRunner
+from .config import VideoDetConfig
+from . import register

+ 548 - 0
paddlex/repo_apis/PaddleVideo_api/video_det/config.py

@@ -0,0 +1,548 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import yaml
+from typing import Union
+
+from ...base import BaseConfig
+from ....utils.misc import abspath
+from ..config_utils import merge_config
+
+
+class VideoDetConfig(BaseConfig):
+    """Image Det Task Config"""
+
+    def update(self, dict_like_obj: list):
+        """update self
+
+        Args:
+            dict_like_obj (list): list of pairs(key0.key1.idx.key2=value)
+        """
+        dict_ = merge_config(self.dict, dict_like_obj)
+        self.reset_from_dict(dict_)
+
+    def load(self, config_file_path: str):
+        """load config from yaml file
+
+        Args:
+            config_file_path (str): the path of yaml file.
+
+        Raises:
+            TypeError: the content of yaml file `config_file_path` error.
+        """
+        dict_ = yaml.load(open(config_file_path, "rb"), Loader=yaml.Loader)
+        if not isinstance(dict_, dict):
+            raise TypeError
+        self.reset_from_dict(dict_)
+
+    def dump(self, config_file_path: str):
+        """dump self to yaml file
+
+        Args:
+            config_file_path (str): the path to save self as yaml file.
+        """
+        with open(config_file_path, "w", encoding="utf-8") as f:
+            yaml.dump(self.dict, f, default_flow_style=False, sort_keys=False)
+
+    def update_dataset(
+        self,
+        dataset_path: str,
+        dataset_type: str = None,
+        *,
+        train_list_path: str = None,
+    ):
+        """update dataset settings
+
+        Args:
+            dataset_path (str): the root path of dataset.
+            dataset_type (str, optional): dataset type. Defaults to None.
+            train_list_path (str, optional): the path of train dataset annotation file . Defaults to None.
+
+        Raises:
+            ValueError: the dataset_type error.
+        """
+        dataset_path = abspath(dataset_path)
+        if dataset_type is None:
+            dataset_type = "VideoDetDataset"
+        if train_list_path:
+            train_list_path = f"{train_list_path}"
+        else:
+            train_list_path = f"{dataset_path}/train.txt"
+
+        if dataset_type in ["VideoDetDataset"]:
+            _cfg = {
+                "DATASET.train.image_dir": dataset_path,
+                "DATASET.train.file_path": os.path.join(dataset_path, "train.txt"),
+                "DATASET.valid.image_dir": dataset_path,
+                "DATASET.valid.file_path": os.path.join(dataset_path, "val.txt"),
+                "DATASET.test.image_dir": dataset_path,
+                "DATASET.test.file_path": os.path.join(dataset_path, "val.txt"),
+                "METRIC.gt_folder": os.path.join(dataset_path, "val.txt"),
+                "label_dict_path": os.path.join(dataset_path, "label_map.txt"),
+            }
+        else:
+            raise ValueError(f"{repr(dataset_type)} is not supported.")
+        self.update(_cfg)
+
+    def update_batch_size(self, batch_size: int, mode: str = "train"):
+        """update batch size setting
+
+        Args:
+            batch_size (int): the batch size number to set.
+            mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'.
+                Defaults to 'train'.
+
+        Raises:
+            ValueError: `mode` error.
+        """
+
+        if mode == "train":
+            _cfg = {"DATASET.batch_size": batch_size}
+        elif mode == "eval":
+            _cfg = {"DATASET.test_batch_size": batch_size}
+        elif mode == "test":
+            _cfg = {"DATASET.test_batch_size": batch_size}
+        else:
+            raise ValueError("The input `mode` should be train, eval or test.")
+        self.update(_cfg)
+
+    def update_learning_rate(self, learning_rate: float):
+        """update learning rate
+
+        Args:
+            learning_rate (float): the learning rate value to set.
+        """
+        if (
+            self._dict["OPTIMIZER"]["learning_rate"].get("learning_rate", None)
+            is not None
+        ):
+            _cfg = {"OPTIMIZER.learning_rate.learning_rate": learning_rate}
+        else:
+            raise ValueError("unsupported lr format")
+        self.update(_cfg)
+
+    def update_num_classes(self, num_classes: int):
+        """update classes number
+
+        Args:
+            num_classes (int): the classes number value to set.
+        """
+
+        update_str_list = {"MODEL.backbone.num_class": num_classes}
+        self.update(update_str_list)
+        update_str_list = {"MODEL.loss.num_classes": num_classes}
+        self.update(update_str_list)
+
+    def update_label_list(self, label_path: str):
+        """update label list
+
+        Args:
+            label_list (str): the path of label list file to set.
+        """
+        with open(label_path, "r") as f:
+            lines = [line.strip().split(" ") for line in f.readlines()]
+            sorted_lines = sorted(lines, key=lambda x: int(x[1]))
+            label_list = [line[0] for line in sorted_lines]
+        f.close()
+
+        self.update({"label_list": label_list})
+
+    def update_pretrained_weights(self, pretrained_model: str):
+        """update pretrained weight path
+
+        Args:
+            pretrained_model (str): the local path or url of pretrained weight file to set.
+        """
+        assert isinstance(
+            pretrained_model, (str, type(None))
+        ), "The 'pretrained_model' should be a string, indicating the path to the '*.pdparams' file, or 'None', \
+indicating that no pretrained model to be used."
+
+        if pretrained_model is None:
+            self.update({"Global.pretrained_model", None})
+        else:
+            if pretrained_model.lower() == "default":
+                self.update({"Global.pretrained_model", None})
+            else:
+                if not pretrained_model.startswith(("http://", "https://")):
+                    pretrained_model = abspath(pretrained_model)
+                self.update({"Global.pretrained_model": pretrained_model})
+
+    def _update_slim_config(self, slim_config_path: str):
+        """update slim settings
+
+        Args:
+            slim_config_path (str): the path to slim config yaml file.
+        """
+        slim_config = yaml.load(open(slim_config_path, "rb"), Loader=yaml.Loader)[
+            "Slim"
+        ]
+        self.update({"Slim": slim_config})
+
+    def _update_amp(self, amp: Union[None, str]):
+        """update AMP settings
+
+        Args:
+            amp (None | str): the AMP settings.
+
+        Raises:
+            ValueError: AMP setting `amp` error, missing field `AMP`.
+        """
+        if amp is None or amp == "OFF":
+            if "AMP" in self.dict:
+                self._dict.pop("AMP")
+        else:
+            if "AMP" not in self.dict:
+                raise ValueError("Config must have AMP information.")
+            _cfg = {"AMP.use_amp": True, "AMP.level": amp}
+            self.update(_cfg)
+
+    def update_num_workers(self, num_workers: int):
+        """update workers number of train and eval dataloader
+
+        Args:
+            num_workers (int): the value of train and eval dataloader workers number to set.
+        """
+        _cfg = {
+            "DATASET.num_workers": num_workers,
+        }
+        self.update(_cfg)
+
+    def update_shared_memory(self, shared_memeory: bool):
+        """update shared memory setting of train and eval dataloader
+
+        Args:
+            shared_memeory (bool): whether or not to use shared memory
+        """
+        assert isinstance(shared_memeory, bool), "shared_memeory should be a bool"
+        _cfg = {
+            "DataLoader.Train.loader.use_shared_memory": shared_memeory,
+            "DataLoader.Eval.loader.use_shared_memory": shared_memeory,
+        }
+        self.update(_cfg)
+
+    def update_shuffle(self, shuffle: bool):
+        """update shuffle setting of train and eval dataloader
+
+        Args:
+            shuffle (bool): whether or not to shuffle the data
+        """
+        assert isinstance(shuffle, bool), "shuffle should be a bool"
+        _cfg = [
+            f"DataLoader.Train.loader.shuffle={shuffle}",
+            f"DataLoader.Eval.loader.shuffle={shuffle}",
+        ]
+        self.update(_cfg)
+
+    def update_dali(self, dali: bool):
+        """enable DALI setting of train and eval dataloader
+
+        Args:
+            dali (bool): whether or not to use DALI
+        """
+        assert isinstance(dali, bool), "dali should be a bool"
+        _cfg = [
+            f"Global.use_dali={dali}",
+            f"Global.use_dali={dali}",
+        ]
+        self.update(_cfg)
+
+    def update_seed(self, seed: int):
+        """update seed
+
+        Args:
+            seed (int): the random seed value to set
+        """
+        _cfg = {"Global.seed": seed}
+        self.update(_cfg)
+
+    def update_device(self, device: str):
+        """update device setting
+
+        Args:
+            device (str): the running device to set
+        """
+        device = device.split(":")[0]
+        _cfg = {"Global.device": device}
+        self.update(_cfg)
+
+    def update_label_dict_path(self, dict_path: str):
+        """update label dict file path
+
+        Args:
+            dict_path (str): the path of label dict file to set
+        """
+        _cfg = {
+            "PostProcess.Topk.class_id_map_file": {abspath(dict_path)},
+        }
+        self.update(_cfg)
+
+    def _update_to_static(self, dy2st: bool):
+        """update config to set dynamic to static mode
+
+        Args:
+            dy2st (bool): whether or not to use the dynamic to static mode.
+        """
+        self.update({"Global.to_static": dy2st})
+
+    def _update_use_vdl(self, use_vdl: bool):
+        """update config to set VisualDL
+
+        Args:
+            use_vdl (bool): whether or not to use VisualDL.
+        """
+        self.update({"Global.use_visuald": use_vdl})
+
+    def _update_epochs(self, epochs: int):
+        """update epochs setting
+
+        Args:
+            epochs (int): the epochs number value to set
+        """
+        self.update({"epochs": epochs})
+
+    def _update_checkpoints(self, resume_path: Union[None, str]):
+        """update checkpoint setting
+
+        Args:
+            resume_path (None | str): the resume training setting. if is `None`, train from scratch, otherwise,
+                train from checkpoint file that path is `.pdparams` file.
+        """
+        if resume_path is not None:
+            resume_path = resume_path.replace(".pdparams", "")
+        self.update({"Global.checkpoints": resume_path})
+
+    def _update_output_dir(self, save_dir: str):
+        """update output directory
+
+        Args:
+            save_dir (str): the path to save outputs.
+        """
+        self.update({"output_dir": abspath(save_dir)})
+        self.update({"METRIC.result_path": abspath(save_dir)})
+
+    def update_log_interval(self, log_interval: int):
+        """update log interval(steps)
+
+        Args:
+            log_interval (int): the log interval value to set.
+        """
+        self.update({"log_interval": log_interval})
+
+    def update_eval_interval(self, eval_interval: int):
+        """update eval interval(epochs)
+
+        Args:
+            eval_interval (int): the eval interval value to set.
+        """
+        self.update({"val_interval": eval_interval})
+
+    def update_save_interval(self, save_interval: int):
+        """update eval interval(epochs)
+
+        Args:
+            save_interval (int): the save interval value to set.
+        """
+        self.update({"save_interval": save_interval})
+
+    def update_log_ranks(self, device):
+        """update log ranks
+
+        Args:
+            device (str): the running device to set
+        """
+        log_ranks = device.split(":")[1]
+        self.update({"Global.log_ranks": log_ranks})
+
+    def update_print_mem_info(self, print_mem_info: bool):
+        """setting print memory info"""
+        assert isinstance(print_mem_info, bool), "print_mem_info should be a bool"
+        self.update({"Global.print_mem_info": print_mem_info})
+
+    def _update_predict_video(self, infer_video: str, infer_list: str = None):
+        """update video to be predicted
+
+        Args:
+            infer_video (str): the path to image that to be predicted.
+            infer_list (str, optional): the path to file that videos. Defaults to None.
+        """
+        if infer_list:
+            self.update({"Infer.infer_list": infer_list})
+        self.update({"Infer.infer_videos": infer_video})
+
+    def _update_save_inference_dir(self, save_inference_dir: str):
+        """update directory path to save inference model files
+
+        Args:
+            save_inference_dir (str): the directory path to set.
+        """
+        self.update({"Global.save_inference_dir": abspath(save_inference_dir)})
+
+    def _update_inference_model_dir(self, model_dir: str):
+        """update inference model directory
+
+        Args:
+            model_dir (str): the directory path of inference model fils that used to predict.
+        """
+        self.update({"Global.inference_model_dir": abspath(model_dir)})
+
+    def _update_infer_video(self, infer_video: str):
+        """update path of image that would be predict
+
+        Args:
+            infer_video (str): the image path.
+        """
+        self.update({"Global.infer_videos": infer_video})
+
+    def _update_infer_device(self, device: str):
+        """update the device used in predicting
+
+        Args:
+            device (str): the running device setting
+        """
+        self.update({"Global.use_gpu": device.split(":")[0] == "gpu"})
+
+    def _update_enable_mkldnn(self, enable_mkldnn: bool):
+        """update whether to enable MKLDNN
+
+        Args:
+            enable_mkldnn (bool): `True` is enable, otherwise is disable.
+        """
+        self.update({"Global.enable_mkldnn": enable_mkldnn})
+
+    def _update_infer_video_shape(self, img_shape: str):
+        """update image cropping shape in the preprocessing
+
+        Args:
+            img_shape (str): the shape of cropping in the preprocessing,
+                i.e. `PreProcess.transform_ops.1.CropImage.size`.
+        """
+        self.update({"INFERENCE.target_size": img_shape})
+
+    def _update_save_predict_result(self, save_dir: str):
+        """update directory that save predicting output
+
+        Args:
+            save_dir (str): the dicrectory path that save predicting output.
+        """
+        self.update({"Infer.save_dir": save_dir})
+
+    def get_epochs_iters(self) -> int:
+        """get epochs
+
+        Returns:
+            int: the epochs value, i.e., `Global.epochs` in config.
+        """
+        return self.dict["Global"]["epochs"]
+
+    def get_log_interval(self) -> int:
+        """get log interval(steps)
+
+        Returns:
+            int: the log interval value, i.e., `Global.print_batch_step` in config.
+        """
+        return self.dict["Global"]["print_batch_step"]
+
+    def get_eval_interval(self) -> int:
+        """get eval interval(epochs)
+
+        Returns:
+            int: the eval interval value, i.e., `Global.eval_interval` in config.
+        """
+        return self.dict["Global"]["eval_interval"]
+
+    def get_save_interval(self) -> int:
+        """get save interval(epochs)
+
+        Returns:
+            int: the save interval value, i.e., `Global.save_interval` in config.
+        """
+        return self.dict["Global"]["save_interval"]
+
+    def get_learning_rate(self) -> float:
+        """get learning rate
+
+        Returns:
+            float: the learning rate value, i.e., `Optimizer.lr.learning_rate` in config.
+        """
+        return self.dict["Optimizer"]["lr"]["learning_rate"]
+
+    def get_warmup_epochs(self) -> int:
+        """get warmup epochs
+
+        Returns:
+            int: the warmup epochs value, i.e., `Optimizer.lr.warmup_epochs` in config.
+        """
+        return self.dict["Optimizer"]["lr"]["warmup_epoch"]
+
+    def get_label_dict_path(self) -> str:
+        """get label dict file path
+
+        Returns:
+            str: the label dict file path, i.e., `PostProcess.Topk.class_id_map_file` in config.
+        """
+        return self.dict["PostProcess"]["Topk"]["class_id_map_file"]
+
+    def get_batch_size(self, mode="train") -> int:
+        """get batch size
+
+        Args:
+            mode (str, optional): the mode that to be get batch size value, must be one of 'train', 'eval', 'test'.
+                Defaults to 'train'.
+
+        Returns:
+            int: the batch size value of `mode`, i.e., `DataLoader.{mode}.sampler.batch_size` in config.
+        """
+        return self.dict["DataLoader"]["Train"]["sampler"]["batch_size"]
+
+    def get_qat_epochs_iters(self) -> int:
+        """get qat epochs
+
+        Returns:
+            int: the epochs value.
+        """
+        return self.get_epochs_iters()
+
+    def get_qat_learning_rate(self) -> float:
+        """get qat learning rate
+
+        Returns:
+            float: the learning rate value.
+        """
+        return self.get_learning_rate()
+
+    def _get_arch_name(self) -> str:
+        """get architecture name of model
+
+        Returns:
+            str: the model arch name, i.e., `Arch.name` in config.
+        """
+        return self.dict["Arch"]["name"]
+
+    def _get_dataset_root(self) -> str:
+        """get root directory of dataset, i.e. `DataLoader.Train.dataset.video_root`
+
+        Returns:
+            str: the root directory of dataset
+        """
+        return self.dict["DataLoader"]["Train"]["dataset"]["video_root"]
+
+    def get_train_save_dir(self) -> str:
+        """get the directory to save output
+
+        Returns:
+            str: the directory to save output
+        """
+        return self["output_dir"]

+ 298 - 0
paddlex/repo_apis/PaddleVideo_api/video_det/model.py

@@ -0,0 +1,298 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+from ...base import BaseModel
+from ...base.utils.arg import CLIArgument
+from ...base.utils.subprocess import CompletedProcess
+from ....utils.misc import abspath
+from ....utils import logging
+from ....utils.errors import raise_unsupported_api_error
+
+
+class VideoDetModel(BaseModel):
+    """Video Det Model"""
+
+    def train(
+        self,
+        batch_size: int = None,
+        learning_rate: float = None,
+        epochs_iters: int = None,
+        ips: str = None,
+        device: str = "gpu",
+        resume_path: str = None,
+        dy2st: bool = False,
+        amp: str = "OFF",
+        num_workers: int = None,
+        use_vdl: bool = True,
+        save_dir: str = None,
+        **kwargs,
+    ) -> CompletedProcess:
+        """train self
+
+        Args:
+            batch_size (int, optional): the train batch size value. Defaults to None.
+            learning_rate (float, optional): the train learning rate value. Defaults to None.
+            epochs_iters (int, optional): the train epochs value. Defaults to None.
+            ips (str, optional): the ip addresses of nodes when using distribution. Defaults to None.
+            device (str, optional): the running device. Defaults to 'gpu'.
+            resume_path (str, optional): the checkpoint file path to resume training. Train from scratch if it is set
+                to None. Defaults to None.
+            dy2st (bool, optional): Enable dynamic to static. Defaults to False.
+            amp (str, optional): the amp settings. Defaults to 'OFF'.
+            num_workers (int, optional): the workers number. Defaults to None.
+            use_vdl (bool, optional): enable VisualDL. Defaults to True.
+            save_dir (str, optional): the directory path to save train output. Defaults to None.
+
+        Returns:
+           CompletedProcess: the result of training subprocess execution.
+        """
+        if resume_path is not None:
+            resume_path = abspath(resume_path)
+
+        with self._create_new_config_file() as config_path:
+            # Update YAML config file
+            config = self.config.copy()
+            config.update_device(device)
+            config._update_to_static(dy2st)
+            config._update_use_vdl(use_vdl)
+
+            if batch_size is not None:
+                config.update_batch_size(batch_size)
+            if learning_rate is not None:
+                config.update_learning_rate(learning_rate)
+            if epochs_iters is not None:
+                config._update_epochs(epochs_iters)
+            config._update_checkpoints(resume_path)
+            if save_dir is not None:
+                save_dir = abspath(save_dir)
+            else:
+                # `save_dir` is None
+                save_dir = abspath(config.get_train_save_dir())
+            config._update_output_dir(save_dir)
+            if num_workers is not None:
+                config.update_num_workers(num_workers)
+
+            cli_args = []
+            do_eval = kwargs.pop("do_eval", True)
+            profile = kwargs.pop("profile", None)
+            if profile is not None:
+                cli_args.append(CLIArgument("--profiler_options", profile))
+
+            # Benchmarking mode settings
+            benchmark = kwargs.pop("benchmark", None)
+            if benchmark is not None:
+                envs = benchmark.get("env", None)
+                seed = benchmark.get("seed", None)
+                do_eval = benchmark.get("do_eval", False)
+                num_workers = benchmark.get("num_workers", None)
+                config.update_log_ranks(device)
+                config._update_amp(benchmark.get("amp", None))
+                config.update_dali(benchmark.get("dali", False))
+                config.update_shuffle(benchmark.get("shuffle", False))
+                config.update_shared_memory(benchmark.get("shared_memory", True))
+                config.update_print_mem_info(benchmark.get("print_mem_info", True))
+                if num_workers is not None:
+                    config.update_num_workers(num_workers)
+                if seed is not None:
+                    config.update_seed(seed)
+                if envs is not None:
+                    for env_name, env_value in envs.items():
+                        os.environ[env_name] = str(env_value)
+            else:
+                config._update_amp(amp)
+            # PDX related settings
+            device_type = device.split(":")[0]
+            uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
+            config.update({"Global.uniform_output_enabled": uniform_output_enabled})
+            config.update({"Global.pdx_model_name": self.name})
+
+            config.dump(config_path)
+            self._assert_empty_kwargs(kwargs)
+            return self.runner.train(
+                config_path, cli_args, device, ips, save_dir, do_eval=do_eval
+            )
+
+    def evaluate(
+        self,
+        weight_path: str,
+        batch_size: int = None,
+        ips: str = None,
+        device: str = "gpu",
+        amp: str = "OFF",
+        num_workers: int = None,
+        **kwargs,
+    ) -> CompletedProcess:
+        """evaluate self using specified weight
+
+        Args:
+            weight_path (str): the path of model weight file to be evaluated.
+            batch_size (int, optional): the batch size value in evaluating. Defaults to None.
+            ips (str, optional): the ip addresses of nodes when using distribution. Defaults to None.
+            device (str, optional): the running device. Defaults to 'gpu'.
+            amp (str, optional): the AMP setting. Defaults to 'OFF'.
+            num_workers (int, optional): the workers number in evaluating. Defaults to None.
+
+        Returns:
+            CompletedProcess: the result of evaluating subprocess execution.
+        """
+
+        with self._create_new_config_file() as config_path:
+            # Update YAML config file
+            config = self.config.copy()
+            config._update_amp(amp)
+            config.update_device(device)
+            config.update_pretrained_weights(weight_path)
+            if batch_size is not None:
+                config.update_batch_size(batch_size)
+            if num_workers is not None:
+                config.update_num_workers(num_workers)
+
+            config.dump(config_path)
+
+            self._assert_empty_kwargs(kwargs)
+
+            cp = self.runner.evaluate(config_path, [], device, ips)
+            return cp
+
+    def predict(
+        self,
+        weight_path: str,
+        input_path: str,
+        input_list_path: str = None,
+        device: str = "gpu",
+        save_dir: str = None,
+        **kwargs,
+    ) -> CompletedProcess:
+        """predict using specified weight
+
+        Args:
+            weight_path (str): the path of model weight file used to predict.
+            input_path (str): the path of image file to be predicted.
+            input_list_path (str, optional): the paths of images to be predicted if is not None. Defaults to None.
+            device (str, optional): the running device. Defaults to 'gpu'.
+            save_dir (str, optional): the directory path to save predict output. Defaults to None.
+
+        Returns:
+            CompletedProcess: the result of predicting subprocess execution.
+        """
+        input_path = abspath(input_path)
+        if input_list_path:
+            input_list_path = abspath(input_list_path)
+
+        with self._create_new_config_file() as config_path:
+            # Update YAML config file
+            config = self.config.copy()
+            config.update_pretrained_weights(weight_path)
+            config._update_predict_img(input_path, input_list_path)
+            config.update_device(device)
+            config._update_save_predict_result(save_dir)
+
+            config.dump(config_path)
+
+            self._assert_empty_kwargs(kwargs)
+
+            return self.runner.predict(config_path, [], device)
+
+    def export(self, weight_path: str, save_dir: str, **kwargs) -> CompletedProcess:
+        """export the dynamic model to static model
+
+        Args:
+            weight_path (str): the model weight file path that used to export.
+            save_dir (str): the directory path to save export output.
+
+        Returns:
+            CompletedProcess: the result of exporting subprocess execution.
+        """
+        if not weight_path.startswith(("http://", "https://")):
+            weight_path = abspath(weight_path)
+        save_dir = abspath(save_dir)
+
+        with self._create_new_config_file() as config_path:
+            # Update YAML config file
+            config = self.config.copy()
+            config.update_pretrained_weights(weight_path)
+            config._update_save_inference_dir(save_dir)
+            device = kwargs.pop("device", None)
+            if device:
+                config.update_device(device)
+            # PDX related settings
+            uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
+            config.update({"Global.uniform_output_enabled": uniform_output_enabled})
+            config.update({"Global.pdx_model_name": self.name})
+
+            config.dump(config_path)
+
+            self._assert_empty_kwargs(kwargs)
+
+            return self.runner.export(config_path, [], None, save_dir)
+
+    def infer(
+        self,
+        model_dir: str,
+        input_path: str,
+        device: str = "gpu",
+        save_dir: str = None,
+        dict_path: str = None,
+        **kwargs,
+    ) -> CompletedProcess:
+        """predict image using infernece model
+
+        Args:
+            model_dir (str): the directory path of inference model files that would use to predict.
+            input_path (str): the path of image that would be predict.
+            device (str, optional): the running device. Defaults to 'gpu'.
+            save_dir (str, optional): the directory path to save output. Defaults to None.
+            dict_path (str, optional): the label dict file path. Defaults to None.
+
+        Returns:
+            CompletedProcess: the result of infering subprocess execution.
+        """
+        model_dir = abspath(model_dir)
+        input_path = abspath(input_path)
+        if save_dir is not None:
+            logging.warning("`save_dir` will not be used.")
+        config_path = os.path.join(model_dir, "inference.yml")
+        config = self.config.copy()
+        config.load(config_path)
+        config._update_inference_model_dir(model_dir)
+        config._update_infer_img(input_path)
+        config._update_infer_device(device)
+        if dict_path is not None:
+            dict_path = abspath(dict_path)
+            config.update_label_dict_path(dict_path)
+        if "enable_mkldnn" in kwargs:
+            config._update_enable_mkldnn(kwargs.pop("enable_mkldnn"))
+
+        with self._create_new_config_file() as config_path:
+            config.dump(config_path)
+
+            self._assert_empty_kwargs(kwargs)
+
+            return self.runner.infer(config_path, [], device)
+
+    def compression(
+        self,
+        weight_path: str,
+        batch_size: int = None,
+        learning_rate: float = None,
+        epochs_iters: int = None,
+        device: str = "gpu",
+        use_vdl: bool = True,
+        save_dir: str = None,
+        **kwargs,
+    ):
+        """compression model"""
+        raise_unsupported_api_error("compression", self.__class__)

+ 45 - 0
paddlex/repo_apis/PaddleVideo_api/video_det/register.py

@@ -0,0 +1,45 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import os.path as osp
+from pathlib import Path
+
+from ...base.register import register_model_info, register_suite_info
+from .model import VideoDetModel
+from .runner import VideoDetRunner
+from .config import VideoDetConfig
+
+REPO_ROOT_PATH = os.environ.get("PADDLE_PDX_PADDLEVIDEO_PATH")
+PDX_CONFIG_DIR = osp.abspath(osp.join(osp.dirname(__file__), "..", "configs"))
+
+register_suite_info(
+    {
+        "suite_name": "VideoDet",
+        "model": VideoDetModel,
+        "runner": VideoDetRunner,
+        "config": VideoDetConfig,
+        "runner_root_path": REPO_ROOT_PATH,
+    }
+)
+
+################ Models Using Universal Config ################
+register_model_info(
+    {
+        "model_name": "YOWO",
+        "suite": "VideoDet",
+        "config_path": osp.join(PDX_CONFIG_DIR, "YOWO.yaml"),
+        "supported_apis": ["train", "evaluate", "predict", "export", "infer"],
+    }
+)

+ 200 - 0
paddlex/repo_apis/PaddleVideo_api/video_det/runner.py

@@ -0,0 +1,200 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import tempfile
+
+from ...base import BaseRunner
+from ...base.utils.subprocess import CompletedProcess
+
+
+class VideoDetRunner(BaseRunner):
+    """Cls Runner"""
+
+    def train(
+        self,
+        config_path: str,
+        cli_args: list,
+        device: str,
+        ips: str,
+        save_dir: str,
+        do_eval=True,
+    ) -> CompletedProcess:
+        """train model
+
+        Args:
+            config_path (str): the config file path used to train.
+            cli_args (list): the additional parameters.
+            device (str): the training device.
+            ips (str): the ip addresses of nodes when using distribution.
+            save_dir (str): the directory path to save training output.
+            do_eval (bool, optional): whether or not to evaluate model during training. Defaults to True.
+
+        Returns:
+            CompletedProcess: the result of training subprocess execution.
+        """
+        args, env = self.distributed(device, ips, log_dir=save_dir)
+        cmd = [*args, "main.py", "--validate", "-c", config_path, *cli_args]
+        cmd.extend(["-o", f"Global.eval_during_train={do_eval}"])
+        return self.run_cmd(
+            cmd,
+            env=env,
+            switch_wdir=True,
+            echo=True,
+            silent=False,
+            capture_output=True,
+            log_path=self._get_train_log_path(save_dir),
+        )
+
+    def evaluate(
+        self, config_path: str, cli_args: list, device: str, ips: str
+    ) -> CompletedProcess:
+        """run model evaluating
+
+        Args:
+            config_path (str): the config file path used to evaluate.
+            cli_args (list): the additional parameters.
+            device (str): the evaluating device.
+            ips (str): the ip addresses of nodes when using distribution.
+
+        Returns:
+            CompletedProcess: the result of evaluating subprocess execution.
+        """
+        args, env = self.distributed(device, ips)
+        cmd = [*args, "main.py", "--test", "-c", config_path, *cli_args]
+        cp = self.run_cmd(
+            cmd, env=env, switch_wdir=True, echo=True, silent=False, capture_output=True
+        )
+
+        if cp.returncode == 0:
+            metric_dict = _extract_eval_metrics(cp.stdout)
+            cp.metrics = metric_dict
+        return cp
+
+    def predict(
+        self, config_path: str, cli_args: list, device: str
+    ) -> CompletedProcess:
+        """run predicting using dynamic mode
+
+        Args:
+            config_path (str): the config file path used to predict.
+            cli_args (list): the additional parameters.
+            device (str): unused.
+
+        Returns:
+            CompletedProcess: the result of predicting subprocess execution.
+        """
+        # `device` unused
+        cmd = [self.python, "tools/infer.py", "-c", config_path, *cli_args]
+        return self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
+
+    def export(
+        self, config_path: str, cli_args: list, device: str, save_dir: str = None
+    ) -> CompletedProcess:
+        """run exporting
+
+        Args:
+            config_path (str): the path of config file used to export.
+            cli_args (list): the additional parameters.
+            device (str): unused.
+            save_dir (str, optional): the directory path to save exporting output. Defaults to None.
+
+        Returns:
+            CompletedProcess: the result of exporting subprocess execution.
+        """
+        # `device` unused
+
+        cmd = [
+            self.python,
+            "tools/export_model.py",
+            "-c",
+            config_path,
+            *cli_args,
+            "-o",
+            save_dir,
+        ]
+
+        cp = self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
+        return cp
+
+    def infer(self, config_path: str, cli_args: list, device: str) -> CompletedProcess:
+        """run predicting using inference model
+
+        Args:
+            config_path (str): the path of config file used to predict.
+            cli_args (list): the additional parameters.
+            device (str): unused.
+
+        Returns:
+            CompletedProcess: the result of infering subprocess execution.
+        """
+        # `device` unused
+        cmd = [self.python, "python/predict_cls.py", "-c", config_path, *cli_args]
+        return self.run_cmd(cmd, switch_wdir="deploy", echo=True, silent=False)
+
+    def compression(
+        self,
+        config_path: str,
+        train_cli_args: list,
+        export_cli_args: list,
+        device: str,
+        train_save_dir: str,
+    ) -> CompletedProcess:
+        """run compression model
+
+        Args:
+            config_path (str): the path of config file used to predict.
+            train_cli_args (list): the additional training parameters.
+            export_cli_args (list): the additional exporting parameters.
+            device (str): the running device.
+            train_save_dir (str): the directory path to save output.
+
+        Returns:
+            CompletedProcess: the result of compression subprocess execution.
+        """
+        # Step 1: Train model
+        cp_train = self.train(config_path, train_cli_args, device, None, train_save_dir)
+
+        # Step 2: Export model
+        weight_path = os.path.join(train_save_dir, "best_model", "model")
+        export_cli_args = [
+            *export_cli_args,
+            "-o",
+            f"Global.pretrained_model={weight_path}",
+        ]
+        cp_export = self.export(config_path, export_cli_args, device)
+
+        return cp_train, cp_export
+
+
+def _extract_eval_metrics(stdout: str) -> dict:
+    """extract evaluation metrics from training log
+
+    Args:
+        stdout (str): the training log
+
+    Returns:
+        dict: the training metric
+    """
+    import re
+
+    pattern = r"mAP:\s*([\d.]+)"
+    compiled_pattern = re.compile(pattern)
+    metric_dict = {}
+    for line in stdout.splitlines():
+        match = compiled_pattern.search(line)
+        if match:
+            fscore_avg_value = float(match.group(1))
+            metric_dict["mAP"] = fscore_avg_value
+    return metric_dict