Преглед изворни кода

add fast infer for semantic seg; by zzl

zhangzelun пре 11 месеци
родитељ
комит
120916b7f8

+ 1 - 1
paddlex/inference/models_new/__init__.py

@@ -27,7 +27,7 @@ from .text_recognition import TextRecPredictor
 # from .table_recognition import TablePredictor
 # from .object_detection import DetPredictor
 # from .instance_segmentation import InstanceSegPredictor
-# from .semantic_segmentation import SegPredictor
+from .semantic_segmentation import SegPredictor
 # from .general_recognition import ShiTuRecPredictor
 # from .ts_fc import TSFcPredictor
 # from .ts_ad import TSAdPredictor

+ 15 - 0
paddlex/inference/models_new/semantic_segmentation/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .predictor import SegPredictor

+ 154 - 0
paddlex/inference/models_new/semantic_segmentation/predictor.py

@@ -0,0 +1,154 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Union, Dict, List, Tuple
+import numpy as np
+
+from ....utils.func_register import FuncRegister
+from ....modules.semantic_segmentation.model_list import MODELS
+from ...common.batch_sampler import ImageBatchSampler
+from ...common.reader import ReadImage
+from ..common import (
+    Resize,
+    ResizeByShort,
+    Normalize,
+    ToCHWImage,
+    ToBatch,
+    StaticInfer,
+)
+from ..base import BasicPredictor
+from .result import SegResult
+
+
+class SegPredictor(BasicPredictor):
+    """SegPredictor that inherits from BasicPredictor."""
+
+    entities = MODELS
+
+    _FUNC_MAP = {}
+    register = FuncRegister(_FUNC_MAP)
+
+    def __init__(self, *args: List, **kwargs: Dict) -> None:
+        """Initializes SegPredictor.
+
+        Args:
+            *args: Arbitrary positional arguments passed to the superclass.
+            **kwargs: Arbitrary keyword arguments passed to the superclass.
+        """
+        super().__init__(*args, **kwargs)
+        self.preprocessors, self.infer, self.postprocessors = self._build()
+
+    def _build_batch_sampler(self) -> ImageBatchSampler:
+        """Builds and returns an ImageBatchSampler instance.
+
+        Returns:
+            ImageBatchSampler: An instance of ImageBatchSampler.
+        """
+        return ImageBatchSampler()
+
+    def _get_result_class(self) -> type:
+        """Returns the result class, SegResult.
+
+        Returns:
+            type: The SegResult class.
+        """
+        return SegResult
+
+    def _build(self) -> Tuple:
+        """Build the preprocessors, inference engine, and postprocessors based on the configuration.
+
+        Returns:
+            tuple: A tuple containing the preprocessors, inference engine, and postprocessors.
+        """
+        preprocessors = {"Read": ReadImage(format="RGB")}
+        preprocessors['ToCHW'] = ToCHWImage()
+        for cfg in self.config["Deploy"]["transforms"]:
+            tf_key = cfg.pop('type')
+            func = self._FUNC_MAP[tf_key]
+            args = cfg
+            name, op = func(self, **args) if args else func(self)
+            preprocessors[name] = op
+        preprocessors["ToBatch"] = ToBatch()
+
+        infer = StaticInfer(
+            model_dir=self.model_dir,
+            model_prefix=self.MODEL_FILE_PREFIX,
+            option=self.pp_option,
+        )
+
+        postprocessors = {} # Empty for Semantic Segmentation for now
+
+        return preprocessors, infer, postprocessors
+
+    def process(self, batch_data: List[Union[str, np.ndarray]]) -> Dict[str, Any]:
+        """
+        Process a batch of data through the preprocessing, inference, and postprocessing.
+
+        Args:
+            batch_data (List[Union[str, np.ndarray], ...]): A batch of input data (e.g., image file paths).
+
+        Returns:
+            dict: A dictionary containing the input path, raw image, and predicted segmentation maps for every instance of the batch. Keys include 'input_path', 'input_img', and 'pred'.
+        """
+        batch_raw_imgs = self.preprocessors["Read"](imgs=batch_data)
+        batch_imgs = self.preprocessors["ToCHW"](imgs=batch_raw_imgs)
+        batch_imgs = self.preprocessors["Normalize"](imgs=batch_imgs)
+        x = self.preprocessors["ToBatch"](imgs=batch_imgs)
+        batch_preds = self.infer(x=x)
+        if len(batch_data) > 1:
+            batch_preds = np.split(batch_preds[0], len(batch_data), axis = 0)
+        # postprocessors is empty for static infer of semantic segmentation
+        return {
+            "input_path": batch_data,
+            "input_img": batch_raw_imgs,
+            "pred": batch_preds,
+        }
+
+    @register("Resize")
+    def build_resize(
+        self, target_size, keep_ratio=False, size_divisor=None, interp="LINEAR"
+    ):
+        assert target_size
+        op = Resize(
+            target_size=target_size,
+            keep_ratio=keep_ratio,
+            size_divisor=size_divisor,
+            interp=interp,
+        )
+        return "Resize", op
+
+    @register("ResizeByLong")
+    def build_resizebylong(self, long_size):
+        assert long_size
+        op = ResizeByLong(
+            target_long_edge=long_size, size_divisor=size_divisor, interp=interp
+        )
+        return "ResizeByLong", op
+
+    @register("ResizeByShort")
+    def build_resizebylong(self, short_size):
+        assert short_size
+        op = ResizeByLong(
+            target_long_edge=short_size, size_divisor=size_divisor, interp=interp
+        )
+        return "ResizeByShort", op
+
+    @register("Normalize")
+    def build_normalize(
+        self,
+        mean=0.5,
+        std=0.5,
+    ):
+        op = Normalize(mean=mean, std=std)
+        return "Normalize", op

+ 64 - 0
paddlex/inference/models_new/semantic_segmentation/result.py

@@ -0,0 +1,64 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from PIL import Image
+import copy
+
+from ...common.result import BaseCVResult
+
+
+class SegResult(BaseCVResult):
+    """Save Result Transform"""
+
+    def _to_img(self):
+        """apply"""
+        seg_map = self["pred"]
+        pc_map = self.get_pseudo_color_map(seg_map[0])
+        return pc_map
+
+    def get_pseudo_color_map(self, pred):
+        """get_pseudo_color_map"""
+        if pred.min() < 0 or pred.max() > 255:
+            raise ValueError("`pred` cannot be cast to uint8.")
+        pred = pred.astype(np.uint8)
+        pred_mask = Image.fromarray(pred, mode="P")
+        color_map = self._get_color_map_list(256)
+        pred_mask.putpalette(color_map)
+        return pred_mask
+
+    @staticmethod
+    def _get_color_map_list(num_classes, custom_color=None):
+        """_get_color_map_list"""
+        num_classes += 1
+        color_map = num_classes * [0, 0, 0]
+        for i in range(0, num_classes):
+            j = 0
+            lab = i
+            while lab:
+                color_map[i * 3] |= ((lab >> 0) & 1) << (7 - j)
+                color_map[i * 3 + 1] |= ((lab >> 1) & 1) << (7 - j)
+                color_map[i * 3 + 2] |= ((lab >> 2) & 1) << (7 - j)
+                j += 1
+                lab >>= 3
+        color_map = color_map[3:]
+
+        if custom_color:
+            color_map[: len(custom_color)] = custom_color
+        return color_map
+
+    def _to_str(self, _, *args, **kwargs):
+        data = copy.deepcopy(self)
+        data["pred"] = "..."
+        return super()._to_str(data, *args, **kwargs)