Browse Source

add image_feature module for new inference interface

cuicheng01 11 months ago
parent
commit
01be9f3e37

+ 1 - 0
paddlex/inference/models_new/__init__.py

@@ -29,6 +29,7 @@ from .text_recognition import TextRecPredictor
 # from .object_detection import DetPredictor
 # from .instance_segmentation import InstanceSegPredictor
 from .semantic_segmentation import SegPredictor
+from .image_feature import ImageFeaturePredictor
 
 # from .general_recognition import ShiTuRecPredictor
 

+ 15 - 0
paddlex/inference/models_new/image_feature/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .predictor import ImageFeaturePredictor

+ 155 - 0
paddlex/inference/models_new/image_feature/predictor.py

@@ -0,0 +1,155 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Union, Dict, List, Tuple
+import numpy as np
+
+from ....utils.func_register import FuncRegister
+from ....modules.general_recognition.model_list import MODELS
+from ...common.batch_sampler import ImageBatchSampler
+from ...common.reader import ReadImage
+from ..common import (
+    Resize,
+    ResizeByShort,
+    Normalize,
+    ToCHWImage,
+    ToBatch,
+    StaticInfer,
+)
+from ..base import BasicPredictor
+from .processors import NormalizeFeatures
+from .result import IdentityResult
+
+
+class ImageFeaturePredictor(BasicPredictor):
+    """ImageFeaturePredictor that inherits from BasicPredictor."""
+
+    entities = MODELS
+
+    _FUNC_MAP = {}
+    register = FuncRegister(_FUNC_MAP)
+
+    def __init__(self, *args: List, **kwargs: Dict) -> None:
+        """Initializes ClasPredictor.
+
+        Args:
+            *args: Arbitrary positional arguments passed to the superclass.
+            **kwargs: Arbitrary keyword arguments passed to the superclass.
+        """
+        super().__init__(*args, **kwargs)
+        self.preprocessors, self.infer, self.postprocessors = self._build()
+
+    def _build_batch_sampler(self) -> ImageBatchSampler:
+        """Builds and returns an ImageBatchSampler instance.
+
+        Returns:
+            ImageBatchSampler: An instance of ImageBatchSampler.
+        """
+        return ImageBatchSampler()
+
+    def _get_result_class(self) -> type:
+        """Returns the result class, IdentityResult.
+
+        Returns:
+            type: The IdentityResult class.
+        """
+        return IdentityResult
+
+    def _build(self) -> Tuple:
+        """Build the preprocessors, inference engine, and postprocessors based on the configuration.
+
+        Returns:
+            tuple: A tuple containing the preprocessors, inference engine, and postprocessors.
+        """
+        preprocessors = {"Read": ReadImage(format="RGB")}
+        for cfg in self.config["PreProcess"]["transform_ops"]:
+            tf_key = list(cfg.keys())[0]
+            func = self._FUNC_MAP[tf_key]
+            args = cfg.get(tf_key, {})
+            if args is not None and "return_numpy" in args:
+                args.pop("return_numpy")
+            name, op = func(self, **args) if args else func(self)
+            preprocessors[name] = op
+        preprocessors["ToBatch"] = ToBatch()
+
+        infer = StaticInfer(
+            model_dir=self.model_dir,
+            model_prefix=self.MODEL_FILE_PREFIX,
+            option=self.pp_option,
+        )
+
+        postprocessors = {}
+        for key in self.config["PostProcess"]:
+            func = self._FUNC_MAP.get(key)
+            args = self.config["PostProcess"].get(key, {})
+            name, op = func(self, **args) if args else func(self)
+            postprocessors[name] = op
+        return preprocessors, infer, postprocessors
+
+    def process(self, batch_data: List[Union[str, np.ndarray]]) -> Dict[str, Any]:
+        """
+        Process a batch of data through the preprocessing, inference, and postprocessing.
+
+        Args:
+            batch_data (List[Union[str, np.ndarray], ...]): A batch of input data (e.g., image file paths).
+
+        Returns:
+            dict: A dictionary containing the input path, raw image, class IDs, scores, and label names for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
+        """
+        batch_raw_imgs = self.preprocessors["Read"](imgs=batch_data)
+        batch_imgs = self.preprocessors["Resize"](imgs=batch_raw_imgs)
+        batch_imgs = self.preprocessors["Normalize"](imgs=batch_imgs)
+        batch_imgs = self.preprocessors["ToCHW"](imgs=batch_imgs)
+        x = self.preprocessors["ToBatch"](imgs=batch_imgs)
+        batch_preds = self.infer(x=x)
+        features = self.postprocessors["NormalizeFeatures"](batch_preds)
+        return {
+            "input_path": batch_data,
+            "input_img": batch_raw_imgs,
+        }
+
+    @register("ResizeImage")
+    # TODO(gaotingquan): backend & interpolation
+    def build_resize(
+        self, resize_short=None, size=None, backend="cv2", interpolation="LINEAR"
+    ):
+        assert resize_short or size
+        if resize_short:
+            op = ResizeByShort(
+                target_short_edge=resize_short, size_divisor=None, interp="LINEAR"
+            )
+        else:
+            op = Resize(target_size=size)
+        return "Resize", op
+
+    @register("NormalizeImage")
+    def build_normalize(
+        self,
+        mean=[0.485, 0.456, 0.406],
+        std=[0.229, 0.224, 0.225],
+        scale=1 / 255,
+        order="",
+        channel_num=3,
+    ):
+        assert channel_num == 3
+        assert order == "hwc"
+        return "Normalize", Normalize(scale=scale, mean=mean, std=std)
+
+    @register("ToCHWImage")
+    def build_to_chw(self):
+        return "ToCHW", ToCHWImage()
+
+    @register("NormalizeFeatures")
+    def build_normalize_features(self):
+        return "NormalizeFeatures", NormalizeFeatures()

+ 29 - 0
paddlex/inference/models_new/image_feature/processors.py

@@ -0,0 +1,29 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+
+class NormalizeFeatures:
+    """Normalize Features Transform"""
+
+    def _normalize(self, preds):
+        """normalize"""
+        feas_norm = np.sqrt(np.sum(np.square(preds[0][0]), axis=0, keepdims=True))
+        features = np.divide(preds[0][0], feas_norm)
+        return features
+
+    def __call__(self, preds):
+        features = self._normalize(preds)
+        return features

+ 25 - 0
paddlex/inference/models_new/image_feature/result.py

@@ -0,0 +1,25 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from PIL import Image
+
+from ...common.result import BaseCVResult
+
+
+class IdentityResult(BaseCVResult):
+
+    def _to_img(self):
+        """This module does not support visualization; it simply outputs the input images"""
+        image = Image.fromarray(self._input_img)
+        return image