Selaa lähdekoodia

support object detection pipeline

leo-q8 10 kuukautta sitten
vanhempi
commit
382ee3c229

+ 10 - 0
paddlex/configs/pipelines/object_detection.yaml

@@ -0,0 +1,10 @@
+pipeline_name: object_detection
+
+SubModules:
+  ObjectDetection:
+    module_name: object_detection
+    model_name: PicoDet-S
+    model_dir: null
+    batch_size: 1
+    imgsz: null
+    threshold: null

+ 1 - 0
paddlex/inference/pipelines_new/__init__.py

@@ -23,6 +23,7 @@ from .doc_preprocessor import DocPreprocessorPipeline
 from .layout_parsing import LayoutParsingPipeline
 from .pp_chatocr import PP_ChatOCRv3_Pipeline, PP_ChatOCRv4_Pipeline
 from .image_classification import ImageClassificationPipeline
+from .object_detection import ObjectDetectionPipeline
 from .seal_recognition import SealRecognitionPipeline
 from .table_recognition import TableRecognitionPipeline
 from .formula_recognition import FormulaRecognitionPipeline

+ 15 - 0
paddlex/inference/pipelines_new/object_detection/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .pipeline import ObjectDetectionPipeline

+ 74 - 0
paddlex/inference/pipelines_new/object_detection/pipeline.py

@@ -0,0 +1,74 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional
+import numpy as np
+
+from ...utils.pp_option import PaddlePredictorOption
+from ..base import BasePipeline
+
+# [TODO] 待更新models_new到models
+from ...models_new.object_detection.result import DetResult
+
+
+class ObjectDetectionPipeline(BasePipeline):
+    """Object Detection Pipeline"""
+
+    entities = "object_detection"
+
+    def __init__(
+        self,
+        config: Dict,
+        device: str = None,
+        pp_option: PaddlePredictorOption = None,
+        use_hpip: bool = False,
+        hpi_params: Optional[Dict[str, Any]] = None,
+    ) -> None:
+        """
+        Initializes the class with given configurations and options.
+
+        Args:
+            config (Dict): Configuration dictionary containing model and other parameters.
+            device (str): The device to run the prediction on. Default is None.
+            pp_option (PaddlePredictorOption): Options for PaddlePaddle predictor. Default is None.
+            use_hpip (bool): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
+            hpi_params (Optional[Dict[str, Any]]): HPIP specific parameters. Default is None.
+        """
+        super().__init__(
+            device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
+        )
+        model_cfg = config["SubModules"]["ObjectDetection"]
+        model_kwargs = {}
+        if "threshold" in model_cfg:
+            model_kwargs["threshold"] = model_cfg["threshold"]
+        if "imgsz" in model_cfg:
+            model_kwargs["imgsz"] = model_cfg["imgsz"]
+        self.det_model = self.create_model(model_cfg, **model_kwargs)
+
+    def predict(
+        self,
+        input: str | list[str] | np.ndarray | list[np.ndarray],
+        threshold: Optional[float] = None,
+        **kwargs,
+    ) -> DetResult:
+        """Predicts object detection results for the given input.
+
+        Args:
+            input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) or path(s) to the images.
+            threshold (Optional[float]): The threshold value to filter out low-confidence predictions. Default is None.
+            **kwargs: Additional keyword arguments that can be passed to the function.
+        Returns:
+            DetResult: The predicted detection results.
+        """
+        yield from self.det_model(input, threshold=threshold, **kwargs)