Procházet zdrojové kódy

support FaceRecPipeline (#2834)

学卿 před 10 měsíci
rodič
revize
f1325124ff

+ 18 - 0
paddlex/configs/pipelines/face_recognition.yaml

@@ -0,0 +1,18 @@
+pipeline_name: face_recognition
+
+index: None
+det_threshold: 0.6
+rec_threshold: 0.4
+rec_topk: 5
+
+SubModules:
+  Detection:
+    module_name: face_detection
+    model_name: PP-YOLOE_plus-S_face
+    model_dir: null
+    batch_size: 1 
+  Recognition:
+    module_name: face_feature
+    model_name: ResNet50_face
+    model_dir: null
+    batch_size: 1

+ 1 - 0
paddlex/inference/models_new/__init__.py

@@ -35,6 +35,7 @@ from .ts_anomaly_detection import TSAdPredictor
 from .ts_classification import TSClsPredictor
 from .image_unwarping import WarpPredictor
 from .image_multilabel_classification import MLClasPredictor
+from .face_feature import FaceFeaturePredictor
 from .open_vocabulary_detection import OVDetPredictor
 from .open_vocabulary_segmentation import OVSegPredictor
 

+ 15 - 0
paddlex/inference/models_new/face_feature/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .predictor import FaceFeaturePredictor

+ 64 - 0
paddlex/inference/models_new/face_feature/predictor.py

@@ -0,0 +1,64 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, List, Union
+
+import numpy as np
+from ....modules.face_recognition.model_list import MODELS
+from ..image_feature import ImageFeaturePredictor
+
+
+class FaceFeaturePredictor(ImageFeaturePredictor):
+    """FaceFeaturePredictor that inherits from ImageFeaturePredictor."""
+
+    entities = MODELS
+
+    def __init__(self, *args: List, flip: bool = False, **kwargs: Dict) -> None:
+        """Initializes ClasPredictor.
+
+        Args:
+            *args: Arbitrary positional arguments passed to the superclass.
+            flip: Whether to perform face flipping during inference. Default is False.
+            **kwargs: Arbitrary keyword arguments passed to the superclass.
+        """
+        super().__init__(*args, **kwargs)
+        self.flip = flip
+
+    def process(self, batch_data: List[Union[str, np.ndarray]]) -> Dict[str, Any]:
+        """
+        Process a batch of data through the preprocessing, inference, and postprocessing.
+
+        Args:
+            batch_data (List[Union[str, np.ndarray], ...]): A batch of input data (e.g., image file paths).
+
+        Returns:
+            dict: A dictionary containing the input path, raw image, class IDs, scores, and label names for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
+        """
+        batch_raw_imgs = self.preprocessors["Read"](imgs=batch_data)
+        batch_imgs = self.preprocessors["Resize"](imgs=batch_raw_imgs)
+        batch_imgs = self.preprocessors["Normalize"](imgs=batch_imgs)
+        batch_imgs = self.preprocessors["ToCHW"](imgs=batch_imgs)
+        x = self.preprocessors["ToBatch"](imgs=batch_imgs)
+        batch_preds = self.infer(x=x)
+        if self.flip:
+            batch_preds_flipped = self.infer(x=[np.flip(data, axis=3) for data in x])
+            for i in range(len(batch_preds)):
+                batch_preds[i] = batch_preds[i] + batch_preds_flipped[i]
+        features = self.postprocessors["NormalizeFeatures"](batch_preds)
+
+        return {
+            "input_path": batch_data,
+            "input_img": batch_raw_imgs,
+            "feature": features,
+        }

+ 1 - 0
paddlex/inference/pipelines_new/__init__.py

@@ -36,6 +36,7 @@ from .ts_forecasting import TSFcPipeline
 from .ts_anomaly_detection import TSAnomalyDetPipeline
 from .ts_classification import TSClsPipeline
 from .pp_shitu_v2 import ShiTuV2Pipeline
+from .face_recognition import FaceRecPipeline
 from .attribute_recognition import (
     PedestrianAttributeRecPipeline,
     VehicleAttributeRecPipeline,

+ 15 - 0
paddlex/inference/pipelines_new/face_recognition/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .pipeline import FaceRecPipeline

+ 57 - 0
paddlex/inference/pipelines_new/face_recognition/pipeline.py

@@ -0,0 +1,57 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..pp_shitu_v2 import ShiTuV2Pipeline
+from .result import FaceRecResult
+
+
+class FaceRecPipeline(ShiTuV2Pipeline):
+    """Face Recognition Pipeline"""
+
+    entities = "face_recognition"
+
+    def get_rec_result(
+        self, raw_img, det_res, indexer, rec_threshold, hamming_radius, topk
+    ):
+        if len(det_res["boxes"]) == 0:
+            return {"label": [], "score": []}
+        subs_of_img = list(self.crop_by_boxes(raw_img, det_res["boxes"]))
+        img_list = [img["img"] for img in subs_of_img]
+        all_rec_res = list(self.rec_model(img_list))
+        all_rec_res = indexer(
+            [rec_res["feature"] for rec_res in all_rec_res],
+            score_thres=rec_threshold,
+            hamming_radius=hamming_radius,
+            topk=topk,
+        )
+        output = {"label": [], "score": []}
+        for res in all_rec_res:
+            output["label"].append(res["label"])
+            output["score"].append(res["score"])
+        return output
+
+    def get_final_result(self, input_data, raw_img, det_res, rec_res):
+        single_img_res = {"input_path": input_data, "input_img": raw_img, "boxes": []}
+        for i, obj in enumerate(det_res["boxes"]):
+            rec_scores = rec_res["score"][i]
+            labels = rec_res["label"][i]
+            single_img_res["boxes"].append(
+                {
+                    "labels": labels,
+                    "rec_scores": rec_scores,
+                    "det_score": obj["score"],
+                    "coordinate": obj["coordinate"],
+                }
+            )
+        return FaceRecResult(single_img_res)

+ 32 - 0
paddlex/inference/pipelines_new/face_recognition/result.py

@@ -0,0 +1,32 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...common.result import BaseCVResult
+from ..pp_shitu_v2.result import draw_box
+
+
+class FaceRecResult(BaseCVResult):
+
+    def _to_img(self):
+        """apply"""
+        boxes = [
+            {
+                "coordinate": box["coordinate"],
+                "label": box["labels"][0] if box["labels"] is not None else "Unknown",
+                "score": box["rec_scores"][0] if box["rec_scores"] is not None else 0,
+            }
+            for box in self["boxes"]
+        ]
+        image = draw_box(self["input_img"][..., ::-1], boxes)
+        return {"res": image}