predictor.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, List, Optional, Sequence, Tuple, Union
  15. import numpy as np
  16. from ....modules.object_detection.model_list import MODELS
  17. from ....utils.func_register import FuncRegister
  18. from ...common.batch_sampler import ImageBatchSampler
  19. from ..base import BasePredictor
  20. from .processors import (
  21. DetPad,
  22. DetPostProcess,
  23. Normalize,
  24. PadStride,
  25. ReadImage,
  26. Resize,
  27. ToBatch,
  28. ToCHWImage,
  29. WarpAffine,
  30. )
  31. from .result import DetResult
  32. from .utils import STATIC_SHAPE_MODEL_LIST
  33. class DetPredictor(BasePredictor):
  34. entities = MODELS
  35. _FUNC_MAP = {}
  36. register = FuncRegister(_FUNC_MAP)
  37. def __init__(
  38. self,
  39. *args,
  40. img_size: Optional[Union[int, Tuple[int, int]]] = None,
  41. threshold: Optional[Union[float, dict]] = None,
  42. layout_nms: Optional[bool] = None,
  43. layout_unclip_ratio: Optional[Union[float, Tuple[float, float], dict]] = None,
  44. layout_merge_bboxes_mode: Optional[Union[str, dict]] = None,
  45. **kwargs,
  46. ):
  47. """Initializes DetPredictor.
  48. Args:
  49. *args: Arbitrary positional arguments passed to the superclass.
  50. img_size (Optional[Union[int, Tuple[int, int]]], optional): The input image size (w, h). Defaults to None.
  51. threshold (Optional[float], optional): The threshold for filtering out low-confidence predictions.
  52. Defaults to None.
  53. layout_nms (bool, optional): Whether to use layout-aware NMS. Defaults to False.
  54. layout_unclip_ratio (Optional[Union[float, Tuple[float, float]]], optional): The ratio of unclipping the bounding box.
  55. Defaults to None.
  56. If it's a single number, then both width and height are used.
  57. If it's a tuple of two numbers, then they are used separately for width and height respectively.
  58. If it's None, then no unclipping will be performed.
  59. layout_merge_bboxes_mode (Optional[Union[str, dict]], optional): The mode for merging bounding boxes. Defaults to None.
  60. **kwargs: Arbitrary keyword arguments passed to the superclass.
  61. """
  62. super().__init__(*args, **kwargs)
  63. if img_size is not None:
  64. assert (
  65. self.model_name not in STATIC_SHAPE_MODEL_LIST
  66. ), f"The model {self.model_name} is not supported set input shape"
  67. if isinstance(img_size, int):
  68. img_size = (img_size, img_size)
  69. elif isinstance(img_size, (tuple, list)):
  70. assert len(img_size) == 2, f"The length of `img_size` should be 2."
  71. else:
  72. raise ValueError(
  73. f"The type of `img_size` must be int or Tuple[int, int], but got {type(img_size)}."
  74. )
  75. if layout_unclip_ratio is not None:
  76. if isinstance(layout_unclip_ratio, float):
  77. layout_unclip_ratio = (layout_unclip_ratio, layout_unclip_ratio)
  78. elif isinstance(layout_unclip_ratio, (tuple, list)):
  79. assert (
  80. len(layout_unclip_ratio) == 2
  81. ), f"The length of `layout_unclip_ratio` should be 2."
  82. elif isinstance(layout_unclip_ratio, dict):
  83. pass
  84. else:
  85. raise ValueError(
  86. f"The type of `layout_unclip_ratio` must be float, Tuple[float, float] or Dict, but got {type(layout_unclip_ratio)}."
  87. )
  88. if layout_merge_bboxes_mode is not None:
  89. if isinstance(layout_merge_bboxes_mode, str):
  90. assert layout_merge_bboxes_mode in [
  91. "union",
  92. "large",
  93. "small",
  94. ], f"The value of `layout_merge_bboxes_mode` must be one of ['union', 'large', 'small'] or a dict, but got {layout_merge_bboxes_mode}"
  95. self.img_size = img_size
  96. self.threshold = threshold
  97. self.layout_nms = layout_nms
  98. self.layout_unclip_ratio = layout_unclip_ratio
  99. self.layout_merge_bboxes_mode = layout_merge_bboxes_mode
  100. self.pre_ops, self.infer, self.post_op = self._build()
  101. def _build_batch_sampler(self):
  102. return ImageBatchSampler()
  103. def _get_result_class(self):
  104. return DetResult
  105. def _build(self) -> Tuple:
  106. """Build the preprocessors, inference engine, and postprocessors based on the configuration.
  107. Returns:
  108. tuple: A tuple containing the preprocessors, inference engine, and postprocessors.
  109. """
  110. # build preprocess ops
  111. pre_ops = [ReadImage(format="RGB")]
  112. for cfg in self.config["Preprocess"]:
  113. tf_key = cfg["type"]
  114. func = self._FUNC_MAP[tf_key]
  115. cfg.pop("type")
  116. args = cfg
  117. op = func(self, **args) if args else func(self)
  118. if op:
  119. pre_ops.append(op)
  120. pre_ops.append(self.build_to_batch())
  121. if self.img_size is not None:
  122. if isinstance(pre_ops[1], Resize):
  123. pre_ops.pop(1)
  124. pre_ops.insert(1, self.build_resize(self.img_size, False, 2))
  125. # build infer
  126. infer = self.create_static_infer()
  127. # build postprocess op
  128. post_op = self.build_postprocess()
  129. return pre_ops, infer, post_op
  130. def _format_output(self, pred: Sequence[Any]) -> List[dict]:
  131. """
  132. Transform batch outputs into a list of single image output.
  133. Args:
  134. pred (Sequence[Any]): The input predictions, which can be either a list of 3 or 4 elements.
  135. - When len(pred) == 4, it is expected to be in the format [boxes, class_ids, scores, masks],
  136. compatible with SOLOv2 output.
  137. - When len(pred) == 3, it is expected to be in the format [boxes, box_nums, masks],
  138. compatible with Instance Segmentation output.
  139. Returns:
  140. List[dict]: A list of dictionaries, each containing either 'class_id' and 'masks' (for SOLOv2),
  141. or 'boxes' and 'masks' (for Instance Segmentation), or just 'boxes' if no masks are provided.
  142. """
  143. box_idx_start = 0
  144. pred_box = []
  145. if len(pred) == 4:
  146. # Adapt to SOLOv2
  147. pred_class_id = []
  148. pred_mask = []
  149. pred_class_id.append([pred[1], pred[2]])
  150. pred_mask.append(pred[3])
  151. return [
  152. {
  153. "class_id": np.array(pred_class_id[i]),
  154. "masks": np.array(pred_mask[i]),
  155. }
  156. for i in range(len(pred_class_id))
  157. ]
  158. if len(pred) == 3:
  159. # Adapt to Instance Segmentation
  160. pred_mask = []
  161. for idx in range(len(pred[1])):
  162. np_boxes_num = pred[1][idx]
  163. box_idx_end = box_idx_start + np_boxes_num
  164. np_boxes = pred[0][box_idx_start:box_idx_end]
  165. pred_box.append(np_boxes)
  166. if len(pred) == 3:
  167. np_masks = pred[2][box_idx_start:box_idx_end]
  168. pred_mask.append(np_masks)
  169. box_idx_start = box_idx_end
  170. if len(pred) == 3:
  171. return [
  172. {"boxes": np.array(pred_box[i]), "masks": np.array(pred_mask[i])}
  173. for i in range(len(pred_box))
  174. ]
  175. else:
  176. return [{"boxes": np.array(res)} for res in pred_box]
  177. def process(
  178. self,
  179. batch_data: List[Any],
  180. threshold: Optional[Union[float, dict]] = None,
  181. layout_nms: bool = False,
  182. layout_unclip_ratio: Optional[Union[float, Tuple[float, float], dict]] = None,
  183. layout_merge_bboxes_mode: Optional[Union[str, dict]] = None,
  184. ):
  185. """
  186. Process a batch of data through the preprocessing, inference, and postprocessing.
  187. Args:
  188. batch_data (List[Union[str, np.ndarray], ...]): A batch of input data (e.g., image file paths).
  189. threshold (Optional[float, dict], optional): The threshold for filtering out low-confidence predictions.
  190. layout_nms (bool, optional): Whether to use layout-aware NMS. Defaults to None.
  191. layout_unclip_ratio (Optional[Union[float, Tuple[float, float]]], optional): The ratio of unclipping the bounding box.
  192. layout_merge_bboxes_mode (Optional[Union[str, dict]], optional): The mode for merging bounding boxes. Defaults to None.
  193. Returns:
  194. dict: A dictionary containing the input path, raw image, class IDs, scores, and label names
  195. for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
  196. """
  197. datas = batch_data.instances
  198. # preprocess
  199. for pre_op in self.pre_ops[:-1]:
  200. datas = pre_op(datas)
  201. # use `ToBatch` format batch inputs
  202. batch_inputs = self.pre_ops[-1](datas)
  203. # do infer
  204. batch_preds = self.infer(batch_inputs)
  205. # process a batch of predictions into a list of single image result
  206. preds_list = self._format_output(batch_preds)
  207. # postprocess
  208. boxes = self.post_op(
  209. preds_list,
  210. datas,
  211. threshold=threshold if threshold is not None else self.threshold,
  212. layout_nms=layout_nms or self.layout_nms,
  213. layout_unclip_ratio=layout_unclip_ratio or self.layout_unclip_ratio,
  214. layout_merge_bboxes_mode=layout_merge_bboxes_mode
  215. or self.layout_merge_bboxes_mode,
  216. )
  217. return {
  218. "input_path": batch_data.input_paths,
  219. "page_index": batch_data.page_indexes,
  220. "input_img": [data["ori_img"] for data in datas],
  221. "boxes": boxes,
  222. }
  223. @register("Resize")
  224. def build_resize(self, target_size, keep_ratio=False, interp=2):
  225. assert target_size
  226. if isinstance(interp, int):
  227. interp = {
  228. 0: "NEAREST",
  229. 1: "LINEAR",
  230. 2: "BICUBIC",
  231. 3: "AREA",
  232. 4: "LANCZOS4",
  233. }[interp]
  234. op = Resize(target_size=target_size[::-1], keep_ratio=keep_ratio, interp=interp)
  235. return op
  236. @register("NormalizeImage")
  237. def build_normalize(
  238. self,
  239. norm_type=None,
  240. mean=[0.485, 0.456, 0.406],
  241. std=[0.229, 0.224, 0.225],
  242. is_scale=True,
  243. ):
  244. if is_scale:
  245. scale = 1.0 / 255.0
  246. else:
  247. scale = 1
  248. if not norm_type or norm_type == "none":
  249. norm_type = "mean_std"
  250. if norm_type != "mean_std":
  251. mean = 0
  252. std = 1
  253. return Normalize(scale=scale, mean=mean, std=std)
  254. @register("Permute")
  255. def build_to_chw(self):
  256. return ToCHWImage()
  257. @register("Pad")
  258. def build_pad(self, fill_value=None, size=None):
  259. if fill_value is None:
  260. fill_value = [127.5, 127.5, 127.5]
  261. if size is None:
  262. size = [3, 640, 640]
  263. return DetPad(size=size, fill_value=fill_value)
  264. @register("PadStride")
  265. def build_pad_stride(self, stride=32):
  266. return PadStride(stride=stride)
  267. @register("WarpAffine")
  268. def build_warp_affine(self, input_h=512, input_w=512, keep_res=True):
  269. return WarpAffine(input_h=input_h, input_w=input_w, keep_res=keep_res)
  270. def build_to_batch(self):
  271. models_required_imgsize = [
  272. "DETR",
  273. "DINO",
  274. "RCNN",
  275. "YOLOv3",
  276. "CenterNet",
  277. "BlazeFace",
  278. "BlazeFace-FPN-SSH",
  279. "PP-DocLayout-L",
  280. "PP-DocLayout_plus-L",
  281. "PP-DocBlockLayout",
  282. ]
  283. if any(name in self.model_name for name in models_required_imgsize):
  284. ordered_required_keys = (
  285. "img_size",
  286. "img",
  287. "scale_factors",
  288. )
  289. else:
  290. ordered_required_keys = ("img", "scale_factors")
  291. return ToBatch(ordered_required_keys=ordered_required_keys)
  292. def build_postprocess(self):
  293. if self.threshold is None:
  294. self.threshold = self.config.get("draw_threshold", 0.5)
  295. if not self.layout_nms:
  296. self.layout_nms = self.config.get("layout_nms", None)
  297. if self.layout_unclip_ratio is None:
  298. self.layout_unclip_ratio = self.config.get("layout_unclip_ratio", None)
  299. if self.layout_merge_bboxes_mode is None:
  300. self.layout_merge_bboxes_mode = self.config.get(
  301. "layout_merge_bboxes_mode", None
  302. )
  303. return DetPostProcess(labels=self.config["label_list"])