| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202 |
- # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from typing import Any, List, Optional, Sequence, Tuple
- import numpy as np
- from ....modules.instance_segmentation.model_list import MODELS
- from ....utils import logging
- from ..object_detection import DetPredictor
- from ..object_detection.processors import ReadImage, ToBatch
- from .processors import InstanceSegPostProcess
- from .result import InstanceSegResult
- class InstanceSegPredictor(DetPredictor):
- """InstanceSegPredictor that inherits from DetPredictor."""
- entities = MODELS
- def __init__(self, *args, threshold: Optional[float] = None, **kwargs):
- """Initializes InstanceSegPredictor.
- Args:
- *args: Arbitrary positional arguments passed to the superclass.
- threshold (Optional[float], optional): The threshold for filtering out low-confidence predictions.
- Defaults to None, in which case will use default from the config file.
- **kwargs: Arbitrary keyword arguments passed to the superclass.
- """
- super().__init__(*args, **kwargs)
- self.model_names_only_supports_batchsize_of_one = {
- "SOLOv2",
- "PP-YOLOE_seg-S",
- "Cascade-MaskRCNN-ResNet50-vd-SSLDv2-FPN",
- "Cascade-MaskRCNN-ResNet50-FPN",
- }
- if self.model_name in self.model_names_only_supports_batchsize_of_one:
- logging.warning(
- f"Instance Segmentation Models: \"{', '.join(list(self.model_names_only_supports_batchsize_of_one))}\" only supports prediction with a batch_size of one, "
- "if you set the predictor with a batch_size larger than one, no error will occur, however, it will actually inference with a batch_size of one, "
- f"which will lead to a slower inference speed. You are now using {self.config['Global']['model_name']}."
- )
- self.threshold = threshold
- def _get_result_class(self) -> type:
- """Returns the result class, InstanceSegResult.
- Returns:
- type: The InstanceSegResult class.
- """
- return InstanceSegResult
- def _build(self) -> Tuple:
- """Build the preprocessors, inference engine, and postprocessors based on the configuration.
- Returns:
- tuple: A tuple containing the preprocessors, inference engine, and postprocessors.
- """
- # build preprocess ops
- pre_ops = [ReadImage(format="RGB")]
- for cfg in self.config["Preprocess"]:
- tf_key = cfg["type"]
- func = self._FUNC_MAP[tf_key]
- cfg.pop("type")
- args = cfg
- op = func(self, **args) if args else func(self)
- if op:
- pre_ops.append(op)
- pre_ops.append(self.build_to_batch())
- # build infer
- infer = self.create_static_infer()
- # build postprocess op
- post_op = self.build_postprocess()
- return pre_ops, infer, post_op
- def build_to_batch(self):
- ordered_required_keys = (
- "img_size",
- "img",
- "scale_factors",
- )
- return ToBatch(ordered_required_keys=ordered_required_keys)
- def process(self, batch_data: List[Any], threshold: Optional[float] = None):
- """
- Process a batch of data through the preprocessing, inference, and postprocessing.
- Args:
- batch_data (List[Union[str, np.ndarray], ...]): A batch of input data (e.g., image file paths).
- Returns:
- dict: A dictionary containing the input path, raw image, box and mask
- for every instance of the batch. Keys include 'input_path', 'input_img', 'boxes' and 'masks'.
- """
- datas = batch_data.instances
- # preprocess
- for pre_op in self.pre_ops[:-1]:
- datas = pre_op(datas)
- # use `ToBatch` format batch inputs
- batch_inputs = self.pre_ops[-1](datas)
- # do infer
- if self.model_name in self.model_names_only_supports_batchsize_of_one:
- batch_preds = []
- for i in range(batch_inputs[0].shape[0]):
- batch_inputs_ = [
- batch_input_[i][None, ...] for batch_input_ in batch_inputs
- ]
- batch_pred_ = self.infer(batch_inputs_)
- batch_preds.append(batch_pred_)
- else:
- batch_preds = self.infer(batch_inputs)
- # process a batch of predictions into a list of single image result
- preds_list = self._format_output(batch_preds)
- # postprocess
- boxes_masks = self.post_op(
- preds_list, datas, threshold if threshold is not None else self.threshold
- )
- return {
- "input_path": batch_data.input_paths,
- "page_index": batch_data.page_indexes,
- "input_img": [data["ori_img"] for data in datas],
- "boxes": [result["boxes"] for result in boxes_masks],
- "masks": [result["masks"] for result in boxes_masks],
- }
- def _format_output(self, pred: Sequence[Any]) -> List[dict]:
- """
- Transform batch outputs into a list of single image output.
- Args:
- pred (Sequence[Any]): The input predictions, which can be either a list of 3 or 4 elements.
- - When len(pred) == 4, it is expected to be in the format [boxes, class_ids, scores, masks],
- compatible with SOLOv2 output.
- - When len(pred) == 3, it is expected to be in the format [boxes, box_nums, masks],
- compatible with Instance Segmentation output.
- Returns:
- List[dict]: A list of dictionaries, each containing either 'class_id' and 'masks' (for SOLOv2),
- or 'boxes' and 'masks' (for Instance Segmentation), or just 'boxes' if no masks are provided.
- """
- box_idx_start = 0
- pred_box = []
- if isinstance(pred[0], list) and len(pred[0]) == 4:
- # Adapt to SOLOv2, which only support prediction with a batch_size of 1.
- pred_class_id = [[pred_[1], pred_[2]] for pred_ in pred]
- pred_mask = [pred_[3] for pred_ in pred]
- return [
- {
- "class_id": np.array(pred_class_id[i]),
- "masks": np.array(pred_mask[i]),
- }
- for i in range(len(pred_class_id))
- ]
- if isinstance(pred[0], list) and len(pred[0]) == 3:
- # Adapt to PP-YOLOE_seg-S, which only support prediction with a batch_size of 1.
- return [
- {"boxes": np.array(pred[i][0]), "masks": np.array(pred[i][2])}
- for i in range(len(pred))
- ]
- pred_mask = []
- for idx in range(len(pred[1])):
- np_boxes_num = pred[1][idx]
- box_idx_end = box_idx_start + np_boxes_num
- np_boxes = pred[0][box_idx_start:box_idx_end]
- pred_box.append(np_boxes)
- np_masks = pred[2][box_idx_start:box_idx_end]
- pred_mask.append(np_masks)
- box_idx_start = box_idx_end
- return [
- {"boxes": np.array(pred_box[i]), "masks": np.array(pred_mask[i])}
- for i in range(len(pred_box))
- ]
- def build_postprocess(self):
- return InstanceSegPostProcess(
- threshold=self.config["draw_threshold"], labels=self.config["label_list"]
- )
|