predictor.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, List, Sequence, Optional
  15. import numpy as np
  16. from ....utils.func_register import FuncRegister
  17. from ....modules.object_detection.model_list import MODELS
  18. from ...common.batch_sampler import ImageBatchSampler
  19. from ..common import StaticInfer
  20. from ..base import BasicPredictor
  21. from .processors import (
  22. DetPad,
  23. DetPostProcess,
  24. Normalize,
  25. PadStride,
  26. ReadImage,
  27. Resize,
  28. ToBatch,
  29. ToCHWImage,
  30. WarpAffine,
  31. )
  32. from .result import DetResult
  33. class DetPredictor(BasicPredictor):
  34. entities = MODELS
  35. _FUNC_MAP = {}
  36. register = FuncRegister(_FUNC_MAP)
  37. def __init__(self, *args, threshold: Optional[float] = None, **kwargs):
  38. """Initializes DetPredictor.
  39. Args:
  40. *args: Arbitrary positional arguments passed to the superclass.
  41. threshold (Optional[float], optional): The threshold for filtering out low-confidence predictions.
  42. Defaults to None.
  43. **kwargs: Arbitrary keyword arguments passed to the superclass.
  44. """
  45. super().__init__(*args, **kwargs)
  46. self.threshold = threshold
  47. self.pre_ops, self.infer, self.post_op = self._build()
  48. def _build_batch_sampler(self):
  49. return ImageBatchSampler()
  50. def _get_result_class(self):
  51. return DetResult
  52. def _build(self):
  53. # build preprocess ops
  54. pre_ops = [ReadImage(format="RGB")]
  55. for cfg in self.config["Preprocess"]:
  56. tf_key = cfg["type"]
  57. func = self._FUNC_MAP[tf_key]
  58. cfg.pop("type")
  59. args = cfg
  60. op = func(self, **args) if args else func(self)
  61. if op:
  62. pre_ops.append(op)
  63. pre_ops.append(self.build_to_batch())
  64. # build infer
  65. infer = StaticInfer(
  66. model_dir=self.model_dir,
  67. model_prefix=self.MODEL_FILE_PREFIX,
  68. option=self.pp_option,
  69. )
  70. # build postprocess op
  71. post_op = self.build_postprocess()
  72. return pre_ops, infer, post_op
  73. def _format_output(self, pred: Sequence[Any]) -> List[dict]:
  74. """
  75. Transform batch outputs into a list of single image output.
  76. Args:
  77. pred (Sequence[Any]): The input predictions, which can be either a list of 3 or 4 elements.
  78. - When len(pred) == 4, it is expected to be in the format [boxes, class_ids, scores, masks],
  79. compatible with SOLOv2 output.
  80. - When len(pred) == 3, it is expected to be in the format [boxes, box_nums, masks],
  81. compatible with Instance Segmentation output.
  82. Returns:
  83. List[dict]: A list of dictionaries, each containing either 'class_id' and 'masks' (for SOLOv2),
  84. or 'boxes' and 'masks' (for Instance Segmentation), or just 'boxes' if no masks are provided.
  85. """
  86. box_idx_start = 0
  87. pred_box = []
  88. if len(pred) == 4:
  89. # Adapt to SOLOv2
  90. pred_class_id = []
  91. pred_mask = []
  92. pred_class_id.append([pred[1], pred[2]])
  93. pred_mask.append(pred[3])
  94. return [
  95. {
  96. "class_id": np.array(pred_class_id[i]),
  97. "masks": np.array(pred_mask[i]),
  98. }
  99. for i in range(len(pred_class_id))
  100. ]
  101. if len(pred) == 3:
  102. # Adapt to Instance Segmentation
  103. pred_mask = []
  104. for idx in range(len(pred[1])):
  105. np_boxes_num = pred[1][idx]
  106. box_idx_end = box_idx_start + np_boxes_num
  107. np_boxes = pred[0][box_idx_start:box_idx_end]
  108. pred_box.append(np_boxes)
  109. if len(pred) == 3:
  110. np_masks = pred[2][box_idx_start:box_idx_end]
  111. pred_mask.append(np_masks)
  112. box_idx_start = box_idx_end
  113. if len(pred) == 3:
  114. return [
  115. {"boxes": np.array(pred_box[i]), "masks": np.array(pred_mask[i])}
  116. for i in range(len(pred_box))
  117. ]
  118. else:
  119. return [{"boxes": np.array(res)} for res in pred_box]
  120. def process(self, batch_data: List[Any], threshold: Optional[float] = None):
  121. """
  122. Process a batch of data through the preprocessing, inference, and postprocessing.
  123. Args:
  124. batch_data (List[Union[str, np.ndarray], ...]): A batch of input data (e.g., image file paths).
  125. Returns:
  126. dict: A dictionary containing the input path, raw image, class IDs, scores, and label names
  127. for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
  128. """
  129. datas = batch_data
  130. # preprocess
  131. for pre_op in self.pre_ops[:-1]:
  132. datas = pre_op(datas)
  133. # use `ToBatch` format batch inputs
  134. batch_inputs = self.pre_ops[-1](datas)
  135. # do infer
  136. batch_preds = self.infer(batch_inputs)
  137. # process a batch of predictions into a list of single image result
  138. preds_list = self._format_output(batch_preds)
  139. # postprocess
  140. boxes = self.post_op(
  141. preds_list, datas, threshold if threshold is not None else self.threshold
  142. )
  143. return {
  144. "input_path": [data.get("img_path", None) for data in datas],
  145. "input_img": [data["ori_img"] for data in datas],
  146. "boxes": boxes,
  147. }
  148. @register("Resize")
  149. def build_resize(self, target_size, keep_ratio=False, interp=2):
  150. assert target_size
  151. if isinstance(interp, int):
  152. interp = {
  153. 0: "NEAREST",
  154. 1: "LINEAR",
  155. 2: "CUBIC",
  156. 3: "AREA",
  157. 4: "LANCZOS4",
  158. }[interp]
  159. op = Resize(target_size=target_size[::-1], keep_ratio=keep_ratio, interp=interp)
  160. return op
  161. @register("NormalizeImage")
  162. def build_normalize(
  163. self,
  164. norm_type=None,
  165. mean=[0.485, 0.456, 0.406],
  166. std=[0.229, 0.224, 0.225],
  167. is_scale=True,
  168. ):
  169. if is_scale:
  170. scale = 1.0 / 255.0
  171. else:
  172. scale = 1
  173. if not norm_type or norm_type == "none":
  174. norm_type = "mean_std"
  175. if norm_type != "mean_std":
  176. mean = 0
  177. std = 1
  178. return Normalize(scale=scale, mean=mean, std=std)
  179. @register("Permute")
  180. def build_to_chw(self):
  181. return ToCHWImage()
  182. @register("Pad")
  183. def build_pad(self, fill_value=None, size=None):
  184. if fill_value is None:
  185. fill_value = [127.5, 127.5, 127.5]
  186. if size is None:
  187. size = [3, 640, 640]
  188. return DetPad(size=size, fill_value=fill_value)
  189. @register("PadStride")
  190. def build_pad_stride(self, stride=32):
  191. return PadStride(stride=stride)
  192. @register("WarpAffine")
  193. def build_warp_affine(self, input_h=512, input_w=512, keep_res=True):
  194. return WarpAffine(input_h=input_h, input_w=input_w, keep_res=keep_res)
  195. def build_to_batch(self):
  196. model_names_required_imgsize = [
  197. "DETR",
  198. "RCNN",
  199. "YOLOv3",
  200. "CenterNet",
  201. "BlazeFace",
  202. "BlazeFace-FPN-SSH",
  203. ]
  204. if any(name in self.model_name for name in model_names_required_imgsize):
  205. ordered_required_keys = (
  206. "img_size",
  207. "img",
  208. "scale_factors",
  209. )
  210. else:
  211. ordered_required_keys = ("img", "scale_factors")
  212. return ToBatch(ordered_required_keys=ordered_required_keys)
  213. def build_postprocess(self):
  214. return DetPostProcess(
  215. threshold=self.config["draw_threshold"],
  216. labels=self.config["label_list"],
  217. layout_postprocess=self.config.get("layout_postprocess", False),
  218. )