浏览代码

add anomaly detection predict (#2743)

Sunflower7788 10 月之前
父节点
当前提交
3360516111

+ 2 - 1
paddlex/inference/models_new/__init__.py

@@ -34,9 +34,10 @@ from .ts_classify import TSClsPredictor
 from .image_unwarping import WarpPredictor
 from .image_multilabel_classification import MLClasPredictor
 
+
 # from .table_recognition import TablePredictor
 # from .general_recognition import ShiTuRecPredictor
-# from .anomaly_detection import UadPredictor
+from .anomaly_detection import UadPredictor
 # from .face_recognition import FaceRecPredictor
 from .video_classification import VideoClasPredictor
 

+ 15 - 0
paddlex/inference/models_new/anomaly_detection/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .predictor import UadPredictor

+ 144 - 0
paddlex/inference/models_new/anomaly_detection/predictor.py

@@ -0,0 +1,144 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Union, Dict, List, Tuple
+import numpy as np
+
+from ....utils.func_register import FuncRegister
+from ....modules.anomaly_detection.model_list import MODELS
+from ...common.batch_sampler import ImageBatchSampler
+from ...common.reader import ReadImage
+from ..common import (
+    Resize,
+    ResizeByShort,
+    Normalize,
+    ToCHWImage,
+    ToBatch,
+    StaticInfer,
+)
+from .processors import MapToMask
+from ..base import BasicPredictor
+from .result import UadResult
+
+
+class UadPredictor(BasicPredictor):
+    """UadPredictor that inherits from BasicPredictor."""
+
+    entities = MODELS
+
+    _FUNC_MAP = {}
+    register = FuncRegister(_FUNC_MAP)
+
+    def __init__(self, *args: List, **kwargs: Dict) -> None:
+        """Initializes UadPredictor.
+
+        Args:
+            *args: Arbitrary positional arguments passed to the superclass.
+            **kwargs: Arbitrary keyword arguments passed to the superclass.
+        """
+        super().__init__(*args, **kwargs)
+        self.preprocessors, self.infer, self.postprocessors = self._build()
+
+    def _build_batch_sampler(self) -> ImageBatchSampler:
+        """Builds and returns an ImageBatchSampler instance.
+
+        Returns:
+            ImageBatchSampler: An instance of ImageBatchSampler.
+        """
+        return ImageBatchSampler()
+
+    def _get_result_class(self) -> type:
+        """Returns the result class, UadResult.
+
+        Returns:
+            type: The UadResult class.
+        """
+        return UadResult
+
+    def _build(self) -> Tuple:
+        """Build the preprocessors, inference engine, and postprocessors based on the configuration.
+
+        Returns:
+            tuple: A tuple containing the preprocessors, inference engine, and postprocessors.
+        """
+        preprocessors = {"Read": ReadImage(format="RGB")}
+        preprocessors["ToCHW"] = ToCHWImage()
+        for cfg in self.config["Deploy"]["transforms"]:
+            tf_key = cfg.pop("type")
+            func = self._FUNC_MAP[tf_key]
+            args = cfg
+            name, op = func(self, **args) if args else func(self)
+            preprocessors[name] = op
+        preprocessors["ToBatch"] = ToBatch()
+
+        infer = StaticInfer(
+            model_dir=self.model_dir,
+            model_prefix=self.MODEL_FILE_PREFIX,
+            option=self.pp_option,
+        )
+        postprocessors = {"Map_to_mask": MapToMask()}
+        return preprocessors, infer, postprocessors
+
+    def process(self, batch_data: List[Union[str, np.ndarray]]) -> Dict[str, Any]:
+        """
+        Process a batch of data through the preprocessing, inference, and postprocessing.
+
+        Args:
+            batch_data (List[Union[str, np.ndarray], ...]): A batch of input data (e.g., image file paths).
+
+        Returns:
+            dict: A dictionary containing the input path, raw image, and predicted segmentation maps for every instance of the batch. Keys include 'input_path', 'input_img', and 'pred'.
+        """
+        batch_raw_imgs = self.preprocessors["Read"](imgs=batch_data)
+        batch_imgs = self.preprocessors["Resize"](imgs=batch_raw_imgs)
+        batch_imgs = self.preprocessors["Normalize"](imgs=batch_imgs)
+        batch_imgs = self.preprocessors["ToCHW"](imgs=batch_imgs)
+        x = self.preprocessors["ToBatch"](imgs=batch_imgs)
+        batch_preds = self.infer(x=x)
+        batch_preds = self.postprocessors["Map_to_mask"](preds=batch_preds)
+        if len(batch_data) > 1:
+            batch_preds = np.split(batch_preds[0], len(batch_data), axis=0)
+
+        return {
+            "input_path": batch_data,
+            "input_img": batch_raw_imgs,
+            "pred": batch_preds,
+        }
+
+    @register("Resize")
+    def build_resize(
+        self, target_size, keep_ratio=False, size_divisor=None, interp="LINEAR"
+    ):
+        assert target_size
+        op = Resize(
+            target_size=target_size,
+            keep_ratio=keep_ratio,
+            size_divisor=size_divisor,
+            interp=interp,
+        )
+        return "Resize", op
+
+    @register("Normalize")
+    def build_normalize(
+        self,
+        mean=0.5,
+        std=0.5,
+    ):
+        op = Normalize(mean=mean, std=std)
+        return "Normalize", op
+
+    @register("Map_to_mask")
+    def map_to_mask(self, mask_map):
+        op = MapToMask()
+        return "Map_to_mask", op

+ 46 - 0
paddlex/inference/models_new/anomaly_detection/processors.py

@@ -0,0 +1,46 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from skimage import measure, morphology
+
+
+class MapToMask:
+    """Map_to_mask"""
+
+    def __init__(self):
+        """
+        Initialize the instance.
+        """
+        super().__init__()
+
+    def __call__(self, preds, *args):
+        """apply"""
+        return [self.apply(pred) for pred in preds]
+
+    def apply(
+        self,
+        pred,
+    ):
+        """apply"""
+        score_map = pred[0]
+        thred = 0.01
+        mask = score_map[0]
+        mask[mask > thred] = 255
+        mask[mask <= thred] = 0
+        kernel = morphology.disk(4)
+        mask = morphology.opening(mask, kernel)
+        mask = mask.astype(np.uint8)
+
+        return mask[None, :, :]

+ 64 - 0
paddlex/inference/models_new/anomaly_detection/result.py

@@ -0,0 +1,64 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from PIL import Image
+import copy
+
+from ...common.result import BaseCVResult
+
+
+class UadResult(BaseCVResult):
+    """Save Result Transform"""
+
+    def _to_img(self):
+        """apply"""
+        seg_map = self["pred"]
+        pc_map = self.get_pseudo_color_map(seg_map[0])
+        return pc_map
+
+    def get_pseudo_color_map(self, pred):
+        """get_pseudo_color_map"""
+        if pred.min() < 0 or pred.max() > 255:
+            raise ValueError("`pred` cannot be cast to uint8.")
+        pred = pred.astype(np.uint8)
+        pred_mask = Image.fromarray(pred, mode="P")
+        color_map = self._get_color_map_list(256)
+        pred_mask.putpalette(color_map)
+        return pred_mask
+
+    @staticmethod
+    def _get_color_map_list(num_classes, custom_color=None):
+        """_get_color_map_list"""
+        num_classes += 1
+        color_map = num_classes * [0, 0, 0]
+        for i in range(0, num_classes):
+            j = 0
+            lab = i
+            while lab:
+                color_map[i * 3] |= ((lab >> 0) & 1) << (7 - j)
+                color_map[i * 3 + 1] |= ((lab >> 1) & 1) << (7 - j)
+                color_map[i * 3 + 2] |= ((lab >> 2) & 1) << (7 - j)
+                j += 1
+                lab >>= 3
+        color_map = color_map[3:]
+
+        if custom_color:
+            color_map[: len(custom_color)] = custom_color
+        return color_map
+
+    def _to_str(self, _, *args, **kwargs):
+        data = copy.deepcopy(self)
+        data["pred"] = "..."
+        return super()._to_str(data, *args, **kwargs)