瀏覽代碼

Add fast infer for inst seg; support thr setting (#2696)

* add fast infer for semantic seg; by zzl

* fix sth

* fix sth oh seg

* add platform judge

* add platform judge

* Re-trigger CI

* add inst seg; fix sth

* fix sth
Zhang Zelun 10 月之前
父節點
當前提交
44d7207d51

+ 1 - 1
paddlex/inference/models_new/__init__.py

@@ -27,7 +27,7 @@ from .text_recognition import TextRecPredictor
 
 # from .table_recognition import TablePredictor
 # from .object_detection import DetPredictor
-# from .instance_segmentation import InstanceSegPredictor
+from .instance_segmentation import InstanceSegPredictor
 from .semantic_segmentation import SegPredictor
 from .image_feature import ImageFeaturePredictor
 

+ 15 - 0
paddlex/inference/models_new/instance_segmentation/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .predictor import InstanceSegPredictor

+ 209 - 0
paddlex/inference/models_new/instance_segmentation/predictor.py

@@ -0,0 +1,209 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Union, Dict, List, Tuple, Sequence, Optional
+import numpy as np
+
+from ....modules.instance_segmentation.model_list import MODELS
+from ...common.batch_sampler import ImageBatchSampler
+from ..common import StaticInfer
+from ..object_detection.processors import (
+    ReadImage,
+    ToBatch,
+)
+from .processors import InstanceSegPostProcess
+from ..object_detection import DetPredictor
+from .result import InstanceSegResult
+from ....utils import logging
+
+
+class InstanceSegPredictor(DetPredictor):
+    """InstanceSegPredictor that inherits from DetPredictor."""
+
+    entities = MODELS
+
+    def __init__(self, *args, threshold: Optional[float] = None, **kwargs):
+        """Initializes InstanceSegPredictor.
+        Args:
+            *args: Arbitrary positional arguments passed to the superclass.
+            threshold (Optional[float], optional): The threshold for filtering out low-confidence predictions.
+                Defaults to None, in which case will use default from the config file.
+            **kwargs: Arbitrary keyword arguments passed to the superclass.
+        """
+        super().__init__(*args, **kwargs)
+
+        self.model_names_only_supports_batchsize_of_one = {
+            "SOLOv2",
+            "PP-YOLOE_seg-S",
+            "Cascade-MaskRCNN-ResNet50-vd-SSLDv2-FPN",
+            "Cascade-MaskRCNN-ResNet50-FPN",
+        }
+        if self.model_name in self.model_names_only_supports_batchsize_of_one:
+            logging.warning(
+                f"Instance Segmentation Models: \"{', '.join(list(self.model_names_only_supports_batchsize_of_one))}\" only supports prediction with a batch_size of one, "
+                "if you set the predictor with a batch_size larger than one, no error will occur, however, it will actually inference with a batch_size of one, "
+                f"which will lead to a slower inference speed. You are now using {self.config['Global']['model_name']}."
+            )
+
+        self.threshold = threshold
+
+    def _get_result_class(self) -> type:
+        """Returns the result class, InstanceSegResult.
+
+        Returns:
+            type: The InstanceSegResult class.
+        """
+        return InstanceSegResult
+
+    def _build(self) -> Tuple:
+        """Build the preprocessors, inference engine, and postprocessors based on the configuration.
+
+        Returns:
+            tuple: A tuple containing the preprocessors, inference engine, and postprocessors.
+        """
+        # build preprocess ops
+        pre_ops = [ReadImage(format="RGB")]
+        for cfg in self.config["Preprocess"]:
+            tf_key = cfg["type"]
+            func = self._FUNC_MAP[tf_key]
+            cfg.pop("type")
+            args = cfg
+            op = func(self, **args) if args else func(self)
+            if op:
+                pre_ops.append(op)
+        pre_ops.append(self.build_to_batch())
+
+        # build infer
+        infer = StaticInfer(
+            model_dir=self.model_dir,
+            model_prefix=self.MODEL_FILE_PREFIX,
+            option=self.pp_option,
+        )
+
+        # build postprocess op
+        post_op = self.build_postprocess()
+
+        return pre_ops, infer, post_op
+
+    def build_to_batch(self):
+
+        ordered_required_keys = (
+            "img_size",
+            "img",
+            "scale_factors",
+        )
+
+        return ToBatch(ordered_required_keys=ordered_required_keys)
+
+    def process(self, batch_data: List[Any], threshold: Optional[float] = None):
+        """
+        Process a batch of data through the preprocessing, inference, and postprocessing.
+
+        Args:
+            batch_data (List[Union[str, np.ndarray], ...]): A batch of input data (e.g., image file paths).
+
+        Returns:
+            dict: A dictionary containing the input path, raw image, box and mask
+                for every instance of the batch. Keys include 'input_path', 'input_img', 'boxes' and 'masks'.
+        """
+        datas = batch_data
+        # preprocess
+        for pre_op in self.pre_ops[:-1]:
+            datas = pre_op(datas)
+
+        # use `ToBatch` format batch inputs
+        batch_inputs = self.pre_ops[-1](datas)
+
+        # do infer
+        if self.model_name in self.model_names_only_supports_batchsize_of_one:
+            batch_preds = []
+            for i in range(batch_inputs[0].shape[0]):
+                batch_inputs_ = [
+                    batch_input_[i][None, ...] for batch_input_ in batch_inputs
+                ]
+                batch_pred_ = self.infer(batch_inputs_)
+                batch_preds.append(batch_pred_)
+        else:
+            batch_preds = self.infer(batch_inputs)
+
+        # process a batch of predictions into a list of single image result
+        preds_list = self._format_output(batch_preds)
+
+        # postprocess
+        boxes_masks = self.post_op(
+            preds_list, datas, threshold if threshold is not None else self.threshold
+        )
+
+        return {
+            "input_path": [data.get("img_path", None) for data in datas],
+            "input_img": [data["ori_img"] for data in datas],
+            "boxes": [result["boxes"] for result in boxes_masks],
+            "masks": [result["masks"] for result in boxes_masks],
+        }
+
+    def _format_output(self, pred: Sequence[Any]) -> List[dict]:
+        """
+        Transform batch outputs into a list of single image output.
+
+        Args:
+            pred (Sequence[Any]): The input predictions, which can be either a list of 3 or 4 elements.
+                - When len(pred) == 4, it is expected to be in the format [boxes, class_ids, scores, masks],
+                  compatible with SOLOv2 output.
+                - When len(pred) == 3, it is expected to be in the format [boxes, box_nums, masks],
+                  compatible with Instance Segmentation output.
+
+        Returns:
+            List[dict]: A list of dictionaries, each containing either 'class_id' and 'masks' (for SOLOv2),
+                or 'boxes' and 'masks' (for Instance Segmentation), or just 'boxes' if no masks are provided.
+        """
+        box_idx_start = 0
+        pred_box = []
+
+        if isinstance(pred, list) and len(pred[0]) == 4:
+            # Adapt to SOLOv2, which only support prediction with a batch_size of 1.
+            pred_class_id = [[pred_[1], pred_[2]] for pred_ in pred]
+            pred_mask = [pred_[3] for pred_ in pred]
+            return [
+                {
+                    "class_id": np.array(pred_class_id[i]),
+                    "masks": np.array(pred_mask[i]),
+                }
+                for i in range(len(pred_class_id))
+            ]
+        if isinstance(pred, list) and len(pred[0]) == 3:
+            # Adapt to PP-YOLOE_seg-S, which only support prediction with a batch_size of 1.
+            return [
+                {"boxes": np.array(pred[i][0]), "masks": np.array(pred[i][2])}
+                for i in range(len(pred))
+            ]
+
+        pred_mask = []
+        for idx in range(len(pred[1])):
+            np_boxes_num = pred[1][idx]
+            box_idx_end = box_idx_start + np_boxes_num
+            np_boxes = pred[0][box_idx_start:box_idx_end]
+            pred_box.append(np_boxes)
+            np_masks = pred[2][box_idx_start:box_idx_end]
+            pred_mask.append(np_masks)
+            box_idx_start = box_idx_end
+
+        return [
+            {"boxes": np.array(pred_box[i]), "masks": np.array(pred_mask[i])}
+            for i in range(len(pred_box))
+        ]
+
+    def build_postprocess(self):
+        return InstanceSegPostProcess(
+            threshold=self.config["draw_threshold"], labels=self.config["label_list"]
+        )

+ 105 - 0
paddlex/inference/models_new/instance_segmentation/processors.py

@@ -0,0 +1,105 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from typing import List, Sequence, Tuple, Union, Optional
+
+import numpy as np
+from ....utils import logging
+from ..object_detection.processors import restructured_boxes
+
+import cv2
+
+
+def extract_masks_from_boxes(boxes, masks):
+    """
+    Extracts the portion of each mask that is within the corresponding box.
+    """
+    new_masks = []
+
+    for i, box in enumerate(boxes):
+        x_min, y_min, x_max, y_max = box["coordinate"]
+        x_min, y_min, x_max, y_max = map(
+            lambda x: int(round(x)), [x_min, y_min, x_max, y_max]
+        )
+
+        cropped_mask = masks[i][y_min:y_max, x_min:x_max]
+        new_masks.append(cropped_mask)
+
+    return new_masks
+
+
+class InstanceSegPostProcess(object):
+    """Save Result Transform"""
+
+    def __init__(self, threshold=0.5, labels=None):
+        super().__init__()
+        self.threshold = threshold
+        self.labels = labels
+
+    def apply(self, masks, img_size, boxes=None, class_id=None, threshold=None):
+        """apply"""
+        if boxes is not None:
+            expect_boxes = (boxes[:, 1] > threshold) & (boxes[:, 0] > -1)
+            boxes = boxes[expect_boxes, :]
+            boxes = restructured_boxes(boxes, self.labels, img_size)
+            masks = masks[expect_boxes, :, :]
+            masks = extract_masks_from_boxes(boxes, masks)
+            result = {"boxes": boxes, "masks": masks}
+        else:
+            mask_info = []
+            class_id = [list(item) for item in zip(class_id[0], class_id[1])]
+
+            selected_masks = []
+            for i, info in enumerate(class_id):
+                label_id = int(info[0])
+                if info[1] < threshold:
+                    continue
+                mask_info.append(
+                    {
+                        "label": self.labels[label_id],
+                        "score": info[1],
+                        "class_id": label_id,
+                    }
+                )
+                selected_masks.append(masks[i])
+            result = {"boxes": mask_info, "masks": selected_masks}
+
+        return result
+
+    def __call__(
+        self,
+        batch_outputs: List[dict],
+        datas: List[dict],
+        threshold: Optional[float] = None,
+    ):
+        """Apply the post-processing to a batch of outputs.
+
+        Args:
+            batch_outputs (List[dict]): The list of detection outputs.
+            datas (List[dict]): The list of input data.
+            threshold: Optional[float]: object score threshold for postprocess.
+
+        Returns:
+            List[Boxes]: The list of post-processed detection boxes.
+        """
+        outputs = []
+        for data, output in zip(datas, batch_outputs):
+            boxes_masks = self.apply(
+                img_size=data["ori_img_size"],
+                **output,
+                threshold=threshold if threshold is not None else self.threshold
+            )
+            outputs.append(boxes_masks)
+        return outputs

+ 155 - 0
paddlex/inference/models_new/instance_segmentation/result.py

@@ -0,0 +1,155 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import cv2
+import numpy as np
+import copy
+import PIL
+from PIL import Image, ImageDraw, ImageFont
+
+from ...utils.color_map import get_colormap, font_colormap
+from ...common.result import BaseCVResult
+from ....utils.fonts import PINGFANG_FONT_FILE_PATH
+from ..object_detection.result import draw_box
+
+
+def draw_segm(im, masks, mask_info, alpha=0.7):
+    """
+    Draw segmentation on image
+    """
+    mask_color_id = 0
+    w_ratio = 0.4
+    color_list = get_colormap(rgb=True)
+    im = np.array(im).astype("float32")
+    clsid2color = {}
+    masks = np.array(masks)
+    masks = masks.astype(np.uint8)
+    for i in range(masks.shape[0]):
+        mask, score, clsid = masks[i], mask_info[i]["score"], mask_info[i]["class_id"]
+
+        if clsid not in clsid2color:
+            color_index = i % len(color_list)
+            clsid2color[clsid] = color_list[color_index]
+        color_mask = clsid2color[clsid]
+        for c in range(3):
+            color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
+        idx = np.nonzero(mask)
+        color_mask = np.array(color_mask)
+        idx0 = np.minimum(idx[0], im.shape[0] - 1)
+        idx1 = np.minimum(idx[1], im.shape[1] - 1)
+        im[idx0, idx1, :] *= 1.0 - alpha
+        im[idx0, idx1, :] += alpha * color_mask
+        sum_x = np.sum(mask, axis=0)
+        x = np.where(sum_x > 0.5)[0]
+        sum_y = np.sum(mask, axis=1)
+        y = np.where(sum_y > 0.5)[0]
+        x0, x1, y0, y1 = x[0], x[-1], y[0], y[-1]
+        cv2.rectangle(
+            im, (x0, y0), (x1, y1), tuple(color_mask.astype("int32").tolist()), 1
+        )
+        bbox_text = "%s %.2f" % (mask_info[i]["label"], score)
+        t_size = cv2.getTextSize(bbox_text, 0, 0.3, thickness=1)[0]
+        cv2.rectangle(
+            im,
+            (x0, y0),
+            (x0 + t_size[0], y0 - t_size[1] - 3),
+            tuple(color_mask.astype("int32").tolist()),
+            -1,
+        )
+        cv2.putText(
+            im,
+            bbox_text,
+            (x0, y0 - 2),
+            cv2.FONT_HERSHEY_SIMPLEX,
+            0.3,
+            (0, 0, 0),
+            1,
+            lineType=cv2.LINE_AA,
+        )
+    return Image.fromarray(im.astype("uint8"))
+
+
+def restore_to_draw_masks(img_size, boxes, masks):
+    """
+    Restores extracted masks to the original shape and draws them on a blank image.
+
+    """
+
+    restored_masks = []
+
+    for i, (box, mask) in enumerate(zip(boxes, masks)):
+        restored_mask = np.zeros(img_size, dtype=np.uint8)
+        x_min, y_min, x_max, y_max = map(lambda x: int(round(x)), box["coordinate"])
+        restored_mask[y_min:y_max, x_min:x_max] = mask
+        restored_masks.append(restored_mask)
+
+    return np.array(restored_masks)
+
+
+def draw_mask(im, boxes, np_masks, img_size):
+    """
+    Args:
+        im (PIL.Image.Image): PIL image
+        boxes (list): a list of dictionaries representing detection box information.
+        np_masks (np.ndarray): shape:[N, im_h, im_w]
+    Returns:
+        im (PIL.Image.Image): visualized image
+    """
+    color_list = get_colormap(rgb=True)
+    w_ratio = 0.4
+    alpha = 0.7
+    im = np.array(im).astype("float32")
+    clsid2color = {}
+    np_masks = restore_to_draw_masks(img_size, boxes, np_masks)
+    im_h, im_w = im.shape[:2]
+    np_masks = np_masks[:, :im_h, :im_w]
+    for i in range(len(np_masks)):
+        clsid, score = int(boxes[i]["cls_id"]), boxes[i]["score"]
+        mask = np_masks[i]
+        if clsid not in clsid2color:
+            color_index = i % len(color_list)
+            clsid2color[clsid] = color_list[color_index]
+        color_mask = clsid2color[clsid]
+        for c in range(3):
+            color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
+        idx = np.nonzero(mask)
+        color_mask = np.array(color_mask)
+        im[idx[0], idx[1], :] *= 1.0 - alpha
+        im[idx[0], idx[1], :] += alpha * color_mask
+    return Image.fromarray(im.astype("uint8"))
+
+
+class InstanceSegResult(BaseCVResult):
+    """Save Result Transform"""
+
+    def _to_img(self):
+        """apply"""
+        # image = self._img_reader.read(self["input_path"])
+        image = Image.fromarray(self._input_img)
+        ori_img_size = list(image.size)[::-1]
+        boxes = self["boxes"]
+        masks = self["masks"]
+        if next((True for item in self["boxes"] if "coordinate" in item), False):
+            image = draw_mask(image, boxes, masks, ori_img_size)
+            image = draw_box(image, boxes)
+        else:
+            image = draw_segm(image, masks, boxes)
+
+        return image
+
+    def _to_str(self, _, *args, **kwargs):
+        data = copy.deepcopy(self)
+        data["masks"] = "..."
+        return super()._to_str(data, *args, **kwargs)

+ 3 - 5
paddlex/inference/models_new/semantic_segmentation/predictor.py

@@ -47,7 +47,7 @@ class SegPredictor(BasicPredictor):
             **kwargs: Arbitrary keyword arguments passed to the superclass.
         """
         super().__init__(*args, **kwargs)
-        self.preprocessors, self.infer, self.postprocessors = self._build()
+        self.preprocessors, self.infer = self._build()
 
     def _build_batch_sampler(self) -> ImageBatchSampler:
         """Builds and returns an ImageBatchSampler instance.
@@ -87,9 +87,7 @@ class SegPredictor(BasicPredictor):
             option=self.pp_option,
         )
 
-        postprocessors = {}  # Empty for Semantic Segmentation for now
-
-        return preprocessors, infer, postprocessors
+        return preprocessors, infer
 
     def process(self, batch_data: List[Union[str, np.ndarray]]) -> Dict[str, Any]:
         """
@@ -108,7 +106,7 @@ class SegPredictor(BasicPredictor):
         batch_preds = self.infer(x=x)
         if len(batch_data) > 1:
             batch_preds = np.split(batch_preds[0], len(batch_data), axis=0)
-        # postprocessors is empty for static infer of semantic segmentation
+
         return {
             "input_path": batch_data,
             "input_img": batch_raw_imgs,

+ 5 - 4
paddlex/repo_manager/utils.py

@@ -111,7 +111,8 @@ def install_external_deps(repo_name, repo_root):
         if os.path.exists(os.path.join(repo_root, "ppdet", "ext_op")):
             """Install custom op for rotated object detection"""
             if (
-                _compare_version(gcc_version, "8.2.0") >= 0
+                PLATFORM == "Linux"
+                and _compare_version(gcc_version, "8.2.0") >= 0
                 and "gpu" in get_device_type()
                 and (
                     paddle.is_compiled_with_cuda()
@@ -123,9 +124,9 @@ def install_external_deps(repo_name, repo_root):
                     _check_call(args)
             else:
                 logging.warning(
-                    "The custom operators in PaddleDetection for Rotated Object Detection is only supported when using CUDA, GCC>=8.2.0 and Paddle>=2.0.1, \
-                        your environment does not meet these requirements, so we will skip the installation of custom operators under PaddleDetection/ppdet/ext_ops, \
-                            which means you can not train the Rotated Object Detection models."
+                    "The custom operators in PaddleDetection for Rotated Object Detection is only supported when using CUDA, GCC>=8.2.0 and Paddle>=2.0.1, "
+                    "your environment does not meet these requirements, so we will skip the installation of custom operators under PaddleDetection/ppdet/ext_ops, "
+                    "which means you can not train the Rotated Object Detection models."
                 )