predictor.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Union, Dict, List, Tuple, Sequence, Optional
  15. import numpy as np
  16. from ....modules.instance_segmentation.model_list import MODELS
  17. from ...common.batch_sampler import ImageBatchSampler
  18. from ..common import StaticInfer
  19. from ..object_detection.processors import (
  20. ReadImage,
  21. ToBatch,
  22. )
  23. from .processors import InstanceSegPostProcess
  24. from ..object_detection import DetPredictor
  25. from .result import InstanceSegResult
  26. from ....utils import logging
  27. class InstanceSegPredictor(DetPredictor):
  28. """InstanceSegPredictor that inherits from DetPredictor."""
  29. entities = MODELS
  30. def __init__(self, *args, threshold: Optional[float] = None, **kwargs):
  31. """Initializes InstanceSegPredictor.
  32. Args:
  33. *args: Arbitrary positional arguments passed to the superclass.
  34. threshold (Optional[float], optional): The threshold for filtering out low-confidence predictions.
  35. Defaults to None, in which case will use default from the config file.
  36. **kwargs: Arbitrary keyword arguments passed to the superclass.
  37. """
  38. super().__init__(*args, **kwargs)
  39. self.model_names_only_supports_batchsize_of_one = {
  40. "SOLOv2",
  41. "PP-YOLOE_seg-S",
  42. "Cascade-MaskRCNN-ResNet50-vd-SSLDv2-FPN",
  43. "Cascade-MaskRCNN-ResNet50-FPN",
  44. }
  45. if self.model_name in self.model_names_only_supports_batchsize_of_one:
  46. logging.warning(
  47. f"Instance Segmentation Models: \"{', '.join(list(self.model_names_only_supports_batchsize_of_one))}\" only supports prediction with a batch_size of one, "
  48. "if you set the predictor with a batch_size larger than one, no error will occur, however, it will actually inference with a batch_size of one, "
  49. f"which will lead to a slower inference speed. You are now using {self.config['Global']['model_name']}."
  50. )
  51. self.threshold = threshold
  52. def _get_result_class(self) -> type:
  53. """Returns the result class, InstanceSegResult.
  54. Returns:
  55. type: The InstanceSegResult class.
  56. """
  57. return InstanceSegResult
  58. def _build(self) -> Tuple:
  59. """Build the preprocessors, inference engine, and postprocessors based on the configuration.
  60. Returns:
  61. tuple: A tuple containing the preprocessors, inference engine, and postprocessors.
  62. """
  63. # build preprocess ops
  64. pre_ops = [ReadImage(format="RGB")]
  65. for cfg in self.config["Preprocess"]:
  66. tf_key = cfg["type"]
  67. func = self._FUNC_MAP[tf_key]
  68. cfg.pop("type")
  69. args = cfg
  70. op = func(self, **args) if args else func(self)
  71. if op:
  72. pre_ops.append(op)
  73. pre_ops.append(self.build_to_batch())
  74. # build infer
  75. infer = StaticInfer(
  76. model_dir=self.model_dir,
  77. model_prefix=self.MODEL_FILE_PREFIX,
  78. option=self.pp_option,
  79. )
  80. # build postprocess op
  81. post_op = self.build_postprocess()
  82. return pre_ops, infer, post_op
  83. def build_to_batch(self):
  84. ordered_required_keys = (
  85. "img_size",
  86. "img",
  87. "scale_factors",
  88. )
  89. return ToBatch(ordered_required_keys=ordered_required_keys)
  90. def process(self, batch_data: List[Any], threshold: Optional[float] = None):
  91. """
  92. Process a batch of data through the preprocessing, inference, and postprocessing.
  93. Args:
  94. batch_data (List[Union[str, np.ndarray], ...]): A batch of input data (e.g., image file paths).
  95. Returns:
  96. dict: A dictionary containing the input path, raw image, box and mask
  97. for every instance of the batch. Keys include 'input_path', 'input_img', 'boxes' and 'masks'.
  98. """
  99. datas = batch_data.instances
  100. # preprocess
  101. for pre_op in self.pre_ops[:-1]:
  102. datas = pre_op(datas)
  103. # use `ToBatch` format batch inputs
  104. batch_inputs = self.pre_ops[-1](datas)
  105. # do infer
  106. if self.model_name in self.model_names_only_supports_batchsize_of_one:
  107. batch_preds = []
  108. for i in range(batch_inputs[0].shape[0]):
  109. batch_inputs_ = [
  110. batch_input_[i][None, ...] for batch_input_ in batch_inputs
  111. ]
  112. batch_pred_ = self.infer(batch_inputs_)
  113. batch_preds.append(batch_pred_)
  114. else:
  115. batch_preds = self.infer(batch_inputs)
  116. # process a batch of predictions into a list of single image result
  117. preds_list = self._format_output(batch_preds)
  118. # postprocess
  119. boxes_masks = self.post_op(
  120. preds_list, datas, threshold if threshold is not None else self.threshold
  121. )
  122. return {
  123. "input_path": batch_data.input_paths,
  124. "page_index": batch_data.page_indexes,
  125. "input_img": [data["ori_img"] for data in datas],
  126. "boxes": [result["boxes"] for result in boxes_masks],
  127. "masks": [result["masks"] for result in boxes_masks],
  128. }
  129. def _format_output(self, pred: Sequence[Any]) -> List[dict]:
  130. """
  131. Transform batch outputs into a list of single image output.
  132. Args:
  133. pred (Sequence[Any]): The input predictions, which can be either a list of 3 or 4 elements.
  134. - When len(pred) == 4, it is expected to be in the format [boxes, class_ids, scores, masks],
  135. compatible with SOLOv2 output.
  136. - When len(pred) == 3, it is expected to be in the format [boxes, box_nums, masks],
  137. compatible with Instance Segmentation output.
  138. Returns:
  139. List[dict]: A list of dictionaries, each containing either 'class_id' and 'masks' (for SOLOv2),
  140. or 'boxes' and 'masks' (for Instance Segmentation), or just 'boxes' if no masks are provided.
  141. """
  142. box_idx_start = 0
  143. pred_box = []
  144. if isinstance(pred[0], list) and len(pred[0]) == 4:
  145. # Adapt to SOLOv2, which only support prediction with a batch_size of 1.
  146. pred_class_id = [[pred_[1], pred_[2]] for pred_ in pred]
  147. pred_mask = [pred_[3] for pred_ in pred]
  148. return [
  149. {
  150. "class_id": np.array(pred_class_id[i]),
  151. "masks": np.array(pred_mask[i]),
  152. }
  153. for i in range(len(pred_class_id))
  154. ]
  155. if isinstance(pred[0], list) and len(pred[0]) == 3:
  156. # Adapt to PP-YOLOE_seg-S, which only support prediction with a batch_size of 1.
  157. return [
  158. {"boxes": np.array(pred[i][0]), "masks": np.array(pred[i][2])}
  159. for i in range(len(pred))
  160. ]
  161. pred_mask = []
  162. for idx in range(len(pred[1])):
  163. np_boxes_num = pred[1][idx]
  164. box_idx_end = box_idx_start + np_boxes_num
  165. np_boxes = pred[0][box_idx_start:box_idx_end]
  166. pred_box.append(np_boxes)
  167. np_masks = pred[2][box_idx_start:box_idx_end]
  168. pred_mask.append(np_masks)
  169. box_idx_start = box_idx_end
  170. return [
  171. {"boxes": np.array(pred_box[i]), "masks": np.array(pred_mask[i])}
  172. for i in range(len(pred_box))
  173. ]
  174. def build_postprocess(self):
  175. return InstanceSegPostProcess(
  176. threshold=self.config["draw_threshold"], labels=self.config["label_list"]
  177. )