Explorar o código

support image_multilabel_classification pipeline

zhouchangda hai 10 meses
pai
achega
ed03e1a7e7

+ 27 - 0
api_examples/pipelines/test_image_multilabel_classification.py

@@ -0,0 +1,27 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from paddlex import create_pipeline
+
+pipeline = create_pipeline(pipeline="image_multilabel_classification")
+
+output = pipeline.predict("./test_samples/general_image_classification_001.jpg")
+
+# output = pipeline.predict("./test_samples/财报1.pdf")
+
+for res in output:
+    print(res)
+    res.print()  ## 打印预测的结构化输出
+    res.save_to_img("./output/")  ## 保存结果可视化图像
+    res.save_to_json("./output/")  ## 保存预测的结构化输出

+ 9 - 0
paddlex/configs/pipelines/image_multilabel_classification.yaml

@@ -0,0 +1,9 @@
+
+pipeline_name: image_multilabel_classification
+
+SubModules:
+  ImageMultiLabelClassification:
+    module_name: image_multilabel_classification
+    model_name: PP-HGNetV2-B6_ML
+    model_dir: null
+    batch_size: 4    

+ 6 - 3
paddlex/inference/models_new/image_multilabel_classification/predictor.py

@@ -43,10 +43,10 @@ class MLClasPredictor(ClasPredictor):
         super().__init__(*args, **kwargs)
 
     def _get_result_class(self) -> type:
-        """Returns the result class, TopkResult.
+        """Returns the result class, MLClassResult.
 
         Returns:
-            type: The TopkResult class.
+            type: The MLClassResult class.
         """
 
         return MLClassResult
@@ -74,7 +74,10 @@ class MLClasPredictor(ClasPredictor):
         batch_preds = self.infer(x=x)
         batch_class_ids, batch_scores, batch_label_names = self.postprocessors[
             "MultiLabelThreshOutput"
-        ](preds=batch_preds, threshold=threshold or self.threshold)
+        ](
+            preds=batch_preds,
+            threshold=self.threshold if threshold is None else threshold,
+        )
         return {
             "input_path": batch_data,
             "input_img": batch_raw_imgs,

+ 1 - 0
paddlex/inference/pipelines_new/__init__.py

@@ -28,6 +28,7 @@ from .seal_recognition import SealRecognitionPipeline
 from .table_recognition import TableRecognitionPipeline
 from .multilingual_speech_recognition import MultilingualSpeechRecognitionPipeline
 from .formula_recognition import FormulaRecognitionPipeline
+from .image_multilabel_classification import ImageMultiLabelClassificationPipeline
 from .video_classification import VideoClassificationPipeline
 from .anomaly_detection import AnomalyDetectionPipeline
 from .ts_forecasting import TSFcPipeline

+ 15 - 0
paddlex/inference/pipelines_new/image_multilabel_classification/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .pipeline import ImageMultiLabelClassificationPipeline

+ 83 - 0
paddlex/inference/pipelines_new/image_multilabel_classification/pipeline.py

@@ -0,0 +1,83 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional
+import numpy as np
+from ...common.reader import ReadImage
+from ...common.batch_sampler import ImageBatchSampler
+from ...utils.pp_option import PaddlePredictorOption
+from ..base import BasePipeline
+
+# [TODO] 待更新models_new到models
+from ...models_new.image_multilabel_classification.result import MLClassResult
+
+
+class ImageMultiLabelClassificationPipeline(BasePipeline):
+    """Image Multi Label Classification Pipeline"""
+
+    entities = "image_multilabel_classification"
+
+    def __init__(
+        self,
+        config: Dict,
+        device: str = None,
+        pp_option: PaddlePredictorOption = None,
+        use_hpip: bool = False,
+        hpi_params: Optional[Dict[str, Any]] = None,
+    ) -> None:
+        """
+        Initializes the class with given configurations and options.
+
+        Args:
+            config (Dict): Configuration dictionary containing model and other parameters.
+            device (str): The device to run the prediction on. Default is None.
+            pp_option (PaddlePredictorOption): Options for PaddlePaddle predictor. Default is None.
+            use_hpip (bool): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
+            hpi_params (Optional[Dict[str, Any]]): HPIP specific parameters. Default is None.
+        """
+        super().__init__(
+            device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
+        )
+
+        self.threshold = config["SubModules"]["ImageMultiLabelClassification"].get(
+            "threshold", None
+        )
+        image_multilabel_classification_model_config = config["SubModules"][
+            "ImageMultiLabelClassification"
+        ]
+        self.image_multilabel_classification_model = self.create_model(
+            image_multilabel_classification_model_config
+        )
+        batch_size = image_multilabel_classification_model_config["batch_size"]
+
+    def predict(
+        self,
+        input: str | list[str] | np.ndarray | list[np.ndarray],
+        threshold: float | dict | list | None = None,
+        **kwargs
+    ) -> MLClassResult:
+        """Predicts image classification results for the given input.
+
+        Args:
+            input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) or path(s) to the images.
+            **kwargs: Additional keyword arguments that can be passed to the function.
+
+        Returns:
+            TopkResult: The predicted top k results.
+        """
+
+        yield from self.image_multilabel_classification_model(
+            input=input,
+            threshold=self.threshold if threshold is None else threshold,
+        )