|
|
@@ -0,0 +1,209 @@
|
|
|
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
|
|
|
+#
|
|
|
+# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
+# you may not use this file except in compliance with the License.
|
|
|
+# You may obtain a copy of the License at
|
|
|
+#
|
|
|
+# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
+#
|
|
|
+# Unless required by applicable law or agreed to in writing, software
|
|
|
+# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
+# See the License for the specific language governing permissions and
|
|
|
+# limitations under the License.
|
|
|
+
|
|
|
+from typing import Any, Union, Dict, List, Tuple, Sequence, Optional
|
|
|
+import numpy as np
|
|
|
+
|
|
|
+from ....modules.instance_segmentation.model_list import MODELS
|
|
|
+from ...common.batch_sampler import ImageBatchSampler
|
|
|
+from ..common import StaticInfer
|
|
|
+from ..object_detection.processors import (
|
|
|
+ ReadImage,
|
|
|
+ ToBatch,
|
|
|
+)
|
|
|
+from .processors import InstanceSegPostProcess
|
|
|
+from ..object_detection import DetPredictor
|
|
|
+from .result import InstanceSegResult
|
|
|
+from ....utils import logging
|
|
|
+
|
|
|
+
|
|
|
+class InstanceSegPredictor(DetPredictor):
|
|
|
+ """InstanceSegPredictor that inherits from DetPredictor."""
|
|
|
+
|
|
|
+ entities = MODELS
|
|
|
+
|
|
|
+ def __init__(self, *args, threshold: Optional[float] = None, **kwargs):
|
|
|
+ """Initializes InstanceSegPredictor.
|
|
|
+ Args:
|
|
|
+ *args: Arbitrary positional arguments passed to the superclass.
|
|
|
+ threshold (Optional[float], optional): The threshold for filtering out low-confidence predictions.
|
|
|
+ Defaults to None, in which case will use default from the config file.
|
|
|
+ **kwargs: Arbitrary keyword arguments passed to the superclass.
|
|
|
+ """
|
|
|
+ super().__init__(*args, **kwargs)
|
|
|
+
|
|
|
+ self.model_names_only_supports_batchsize_of_one = {
|
|
|
+ "SOLOv2",
|
|
|
+ "PP-YOLOE_seg-S",
|
|
|
+ "Cascade-MaskRCNN-ResNet50-vd-SSLDv2-FPN",
|
|
|
+ "Cascade-MaskRCNN-ResNet50-FPN",
|
|
|
+ }
|
|
|
+ if self.model_name in self.model_names_only_supports_batchsize_of_one:
|
|
|
+ logging.warning(
|
|
|
+ f"Instance Segmentation Models: \"{', '.join(list(self.model_names_only_supports_batchsize_of_one))}\" only supports prediction with a batch_size of one, "
|
|
|
+ "if you set the predictor with a batch_size larger than one, no error will occur, however, it will actually inference with a batch_size of one, "
|
|
|
+ f"which will lead to a slower inference speed. You are now using {self.config['Global']['model_name']}."
|
|
|
+ )
|
|
|
+
|
|
|
+ self.threshold = threshold
|
|
|
+
|
|
|
+ def _get_result_class(self) -> type:
|
|
|
+ """Returns the result class, InstanceSegResult.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ type: The InstanceSegResult class.
|
|
|
+ """
|
|
|
+ return InstanceSegResult
|
|
|
+
|
|
|
+ def _build(self) -> Tuple:
|
|
|
+ """Build the preprocessors, inference engine, and postprocessors based on the configuration.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ tuple: A tuple containing the preprocessors, inference engine, and postprocessors.
|
|
|
+ """
|
|
|
+ # build preprocess ops
|
|
|
+ pre_ops = [ReadImage(format="RGB")]
|
|
|
+ for cfg in self.config["Preprocess"]:
|
|
|
+ tf_key = cfg["type"]
|
|
|
+ func = self._FUNC_MAP[tf_key]
|
|
|
+ cfg.pop("type")
|
|
|
+ args = cfg
|
|
|
+ op = func(self, **args) if args else func(self)
|
|
|
+ if op:
|
|
|
+ pre_ops.append(op)
|
|
|
+ pre_ops.append(self.build_to_batch())
|
|
|
+
|
|
|
+ # build infer
|
|
|
+ infer = StaticInfer(
|
|
|
+ model_dir=self.model_dir,
|
|
|
+ model_prefix=self.MODEL_FILE_PREFIX,
|
|
|
+ option=self.pp_option,
|
|
|
+ )
|
|
|
+
|
|
|
+ # build postprocess op
|
|
|
+ post_op = self.build_postprocess()
|
|
|
+
|
|
|
+ return pre_ops, infer, post_op
|
|
|
+
|
|
|
+ def build_to_batch(self):
|
|
|
+
|
|
|
+ ordered_required_keys = (
|
|
|
+ "img_size",
|
|
|
+ "img",
|
|
|
+ "scale_factors",
|
|
|
+ )
|
|
|
+
|
|
|
+ return ToBatch(ordered_required_keys=ordered_required_keys)
|
|
|
+
|
|
|
+ def process(self, batch_data: List[Any], threshold: Optional[float] = None):
|
|
|
+ """
|
|
|
+ Process a batch of data through the preprocessing, inference, and postprocessing.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ batch_data (List[Union[str, np.ndarray], ...]): A batch of input data (e.g., image file paths).
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ dict: A dictionary containing the input path, raw image, box and mask
|
|
|
+ for every instance of the batch. Keys include 'input_path', 'input_img', 'boxes' and 'masks'.
|
|
|
+ """
|
|
|
+ datas = batch_data
|
|
|
+ # preprocess
|
|
|
+ for pre_op in self.pre_ops[:-1]:
|
|
|
+ datas = pre_op(datas)
|
|
|
+
|
|
|
+ # use `ToBatch` format batch inputs
|
|
|
+ batch_inputs = self.pre_ops[-1](datas)
|
|
|
+
|
|
|
+ # do infer
|
|
|
+ if self.model_name in self.model_names_only_supports_batchsize_of_one:
|
|
|
+ batch_preds = []
|
|
|
+ for i in range(batch_inputs[0].shape[0]):
|
|
|
+ batch_inputs_ = [
|
|
|
+ batch_input_[i][None, ...] for batch_input_ in batch_inputs
|
|
|
+ ]
|
|
|
+ batch_pred_ = self.infer(batch_inputs_)
|
|
|
+ batch_preds.append(batch_pred_)
|
|
|
+ else:
|
|
|
+ batch_preds = self.infer(batch_inputs)
|
|
|
+
|
|
|
+ # process a batch of predictions into a list of single image result
|
|
|
+ preds_list = self._format_output(batch_preds)
|
|
|
+
|
|
|
+ # postprocess
|
|
|
+ boxes_masks = self.post_op(
|
|
|
+ preds_list, datas, threshold if threshold is not None else self.threshold
|
|
|
+ )
|
|
|
+
|
|
|
+ return {
|
|
|
+ "input_path": [data.get("img_path", None) for data in datas],
|
|
|
+ "input_img": [data["ori_img"] for data in datas],
|
|
|
+ "boxes": [result["boxes"] for result in boxes_masks],
|
|
|
+ "masks": [result["masks"] for result in boxes_masks],
|
|
|
+ }
|
|
|
+
|
|
|
+ def _format_output(self, pred: Sequence[Any]) -> List[dict]:
|
|
|
+ """
|
|
|
+ Transform batch outputs into a list of single image output.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ pred (Sequence[Any]): The input predictions, which can be either a list of 3 or 4 elements.
|
|
|
+ - When len(pred) == 4, it is expected to be in the format [boxes, class_ids, scores, masks],
|
|
|
+ compatible with SOLOv2 output.
|
|
|
+ - When len(pred) == 3, it is expected to be in the format [boxes, box_nums, masks],
|
|
|
+ compatible with Instance Segmentation output.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ List[dict]: A list of dictionaries, each containing either 'class_id' and 'masks' (for SOLOv2),
|
|
|
+ or 'boxes' and 'masks' (for Instance Segmentation), or just 'boxes' if no masks are provided.
|
|
|
+ """
|
|
|
+ box_idx_start = 0
|
|
|
+ pred_box = []
|
|
|
+
|
|
|
+ if isinstance(pred, list) and len(pred[0]) == 4:
|
|
|
+ # Adapt to SOLOv2, which only support prediction with a batch_size of 1.
|
|
|
+ pred_class_id = [[pred_[1], pred_[2]] for pred_ in pred]
|
|
|
+ pred_mask = [pred_[3] for pred_ in pred]
|
|
|
+ return [
|
|
|
+ {
|
|
|
+ "class_id": np.array(pred_class_id[i]),
|
|
|
+ "masks": np.array(pred_mask[i]),
|
|
|
+ }
|
|
|
+ for i in range(len(pred_class_id))
|
|
|
+ ]
|
|
|
+ if isinstance(pred, list) and len(pred[0]) == 3:
|
|
|
+ # Adapt to PP-YOLOE_seg-S, which only support prediction with a batch_size of 1.
|
|
|
+ return [
|
|
|
+ {"boxes": np.array(pred[i][0]), "masks": np.array(pred[i][2])}
|
|
|
+ for i in range(len(pred))
|
|
|
+ ]
|
|
|
+
|
|
|
+ pred_mask = []
|
|
|
+ for idx in range(len(pred[1])):
|
|
|
+ np_boxes_num = pred[1][idx]
|
|
|
+ box_idx_end = box_idx_start + np_boxes_num
|
|
|
+ np_boxes = pred[0][box_idx_start:box_idx_end]
|
|
|
+ pred_box.append(np_boxes)
|
|
|
+ np_masks = pred[2][box_idx_start:box_idx_end]
|
|
|
+ pred_mask.append(np_masks)
|
|
|
+ box_idx_start = box_idx_end
|
|
|
+
|
|
|
+ return [
|
|
|
+ {"boxes": np.array(pred_box[i]), "masks": np.array(pred_mask[i])}
|
|
|
+ for i in range(len(pred_box))
|
|
|
+ ]
|
|
|
+
|
|
|
+ def build_postprocess(self):
|
|
|
+ return InstanceSegPostProcess(
|
|
|
+ threshold=self.config["draw_threshold"], labels=self.config["label_list"]
|
|
|
+ )
|