Browse Source

expose param "target_size" to users (#2726)

Zhang Zelun 10 months ago
parent
commit
2f22bb890e

+ 49 - 6
paddlex/inference/models_new/semantic_segmentation/predictor.py

@@ -20,13 +20,13 @@ from ....modules.semantic_segmentation.model_list import MODELS
 from ...common.batch_sampler import ImageBatchSampler
 from ...common.reader import ReadImage
 from ..common import (
-    Resize,
     ResizeByShort,
     Normalize,
     ToCHWImage,
     ToBatch,
     StaticInfer,
 )
+from .processors import Resize, SegPostProcess
 from ..base import BasicPredictor
 from .result import SegResult
 
@@ -39,15 +39,22 @@ class SegPredictor(BasicPredictor):
     _FUNC_MAP = {}
     register = FuncRegister(_FUNC_MAP)
 
-    def __init__(self, *args: List, **kwargs: Dict) -> None:
+    def __init__(
+        self,
+        target_size: Union[int, Tuple[int], None] = None,
+        *args: List,
+        **kwargs: Dict,
+    ) -> None:
         """Initializes SegPredictor.
 
         Args:
+            target_size: Image size used for inference.
             *args: Arbitrary positional arguments passed to the superclass.
             **kwargs: Arbitrary keyword arguments passed to the superclass.
         """
         super().__init__(*args, **kwargs)
-        self.preprocessors, self.infer = self._build()
+        self.target_size = target_size
+        self.preprocessors, self.infer, self.postprocessers = self._build()
 
     def _build_batch_sampler(self) -> ImageBatchSampler:
         """Builds and returns an ImageBatchSampler instance.
@@ -80,6 +87,13 @@ class SegPredictor(BasicPredictor):
             name, op = func(self, **args) if args else func(self)
             preprocessors[name] = op
         preprocessors["ToBatch"] = ToBatch()
+        if "Resize" not in preprocessors:
+            _, op = self._FUNC_MAP["Resize"](self, target_size=-1)
+            preprocessors["Resize"] = op
+
+        if self.target_size is not None:
+            _, op = self._FUNC_MAP["Resize"](self, target_size=self.target_size)
+            preprocessors["Resize"] = op
 
         infer = StaticInfer(
             model_dir=self.model_dir,
@@ -87,26 +101,39 @@ class SegPredictor(BasicPredictor):
             option=self.pp_option,
         )
 
-        return preprocessors, infer
+        postprocessers = SegPostProcess()
+
+        return preprocessors, infer, postprocessers
 
-    def process(self, batch_data: List[Union[str, np.ndarray]]) -> Dict[str, Any]:
+    def process(
+        self,
+        batch_data: List[Union[str, np.ndarray]],
+        target_size: Union[int, Tuple[int], None] = None,
+    ) -> Dict[str, Any]:
         """
         Process a batch of data through the preprocessing, inference, and postprocessing.
 
         Args:
             batch_data (List[Union[str, np.ndarray], ...]): A batch of input data (e.g., image file paths).
+            target_size: Image size used for inference.
 
         Returns:
             dict: A dictionary containing the input path, raw image, and predicted segmentation maps for every instance of the batch. Keys include 'input_path', 'input_img', and 'pred'.
         """
         batch_raw_imgs = self.preprocessors["Read"](imgs=batch_data)
-        batch_imgs = self.preprocessors["ToCHW"](imgs=batch_raw_imgs)
+        batch_imgs = self.preprocessors["Resize"](
+            imgs=batch_raw_imgs, target_size=target_size
+        )
+        batch_imgs = self.preprocessors["ToCHW"](imgs=batch_imgs)
         batch_imgs = self.preprocessors["Normalize"](imgs=batch_imgs)
         x = self.preprocessors["ToBatch"](imgs=batch_imgs)
         batch_preds = self.infer(x=x)
         if len(batch_data) > 1:
             batch_preds = np.split(batch_preds[0], len(batch_data), axis=0)
 
+        # postprocess
+        batch_preds = self.postprocessers(batch_preds, batch_raw_imgs)
+
         return {
             "input_path": batch_data,
             "input_img": batch_raw_imgs,
@@ -121,3 +148,19 @@ class SegPredictor(BasicPredictor):
     ):
         op = Normalize(mean=mean, std=std)
         return "Normalize", op
+
+    @register("Resize")
+    def build_resize(
+        self,
+        target_size=-1,
+        keep_ratio=True,
+        size_divisor=32,
+        interp="LINEAR",
+    ):
+        op = Resize(
+            target_size=target_size,
+            keep_ratio=keep_ratio,
+            size_divisor=size_divisor,
+            interp=interp,
+        )
+        return "Resize", op

+ 114 - 0
paddlex/inference/models_new/semantic_segmentation/processors.py

@@ -0,0 +1,114 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Tuple, Union
+import os
+import sys
+import cv2
+import copy
+import math
+import pyclipper
+import numpy as np
+from ..common.vision.processors import _BaseResize
+
+from ..common.vision import funcs as F
+
+
+class Resize(_BaseResize):
+    """Resize the image."""
+
+    def __init__(
+        self, target_size=-1, keep_ratio=False, size_divisor=None, interp="LINEAR"
+    ):
+        """
+        Initialize the instance.
+
+        Args:
+            target_size (list|tuple|int, optional): Target width and height. -1 will return the images directly without resizing.
+            keep_ratio (bool, optional): Whether to keep the aspect ratio of resized
+                image. Default: False.
+            size_divisor (int|None, optional): Divisor of resized image size.
+                Default: None.
+            interp (str, optional): Interpolation method. Choices are 'NEAREST',
+                'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
+        """
+        super().__init__(size_divisor=size_divisor, interp=interp)
+
+        if isinstance(target_size, int):
+            target_size = (target_size, target_size)
+        F.check_image_size(target_size)
+        self.target_size = target_size
+
+        self.keep_ratio = keep_ratio
+
+    def __call__(self, imgs, target_size=None):
+        """apply"""
+        target_size = self.target_size if target_size is None else target_size
+        if isinstance(target_size, int):
+            target_size = (target_size, target_size)
+        F.check_image_size(target_size)
+        return [self.resize(img, target_size) for img in imgs]
+
+    def resize(self, img, target_size):
+
+        if target_size == (-1, -1):
+            # If the final target_size == (-1, -1), it means use the source input image directly.
+            return img
+        original_size = img.shape[:2][::-1]
+        assert target_size[0] > 0 and target_size[1] > 0
+
+        if self.keep_ratio:
+            h, w = img.shape[0:2]
+            target_size, _ = self._rescale_size((w, h), target_size)
+
+        if self.size_divisor:
+            target_size = [
+                math.ceil(i / self.size_divisor) * self.size_divisor
+                for i in target_size
+            ]
+        img = F.resize(img, target_size, interp=self.interp)
+        return img
+
+
+class SegPostProcess:
+    """Semantic Segmentation PostProcess
+
+    This class is responsible for post-processing detection results, only including
+    restoring the prediction segmentation map to the original image size for now.
+    """
+
+    def __call__(self, imgs, src_images):
+        assert len(imgs) == len(src_images)
+
+        src_sizes = [src_image.shape[:2][::-1] for src_image in src_images]
+        return [
+            self.reverse_resize(img, src_size) for img, src_size in zip(imgs, src_sizes)
+        ]
+
+    def reverse_resize(self, img, src_size):
+        """Restore the prediction map to source image size using nearest interpolation.
+
+        Args:
+             img (np.ndarray): prediction map with shape of (1, width, height)
+             src_size (Tuple[int, int]): source size of the input image, with format of (width, height).
+        """
+        assert isinstance(src_size, (tuple, list)) and len(src_size) == 2
+        assert src_size[0] > 0 and src_size[1] > 0
+        assert img.ndim == 3
+
+        reversed_img = cv2.resize(
+            img[0], dsize=src_size, interpolation=cv2.INTER_NEAREST
+        )
+        reversed_img = np.expand_dims(reversed_img, axis=0)
+        return reversed_img