predictor.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Union, Dict, List, Tuple, Iterator
  15. import shutil
  16. import tempfile
  17. from importlib import import_module
  18. import lazy_paddle
  19. from ....utils import logging
  20. from ....utils.func_register import FuncRegister
  21. module_3d_bev_detection = import_module(".3d_bev_detection", "paddlex.modules")
  22. module_3d_model_list = getattr(module_3d_bev_detection, "model_list")
  23. MODELS = getattr(module_3d_model_list, "MODELS")
  24. from ...common.batch_sampler import Det3DBatchSampler
  25. from ...common.reader import ReadNuscenesData
  26. from ..common import StaticInfer
  27. from ..base import BasicPredictor
  28. from ..base.predictor.base_predictor import PredictionWrap
  29. from .processors import (
  30. LoadPointsFromFile,
  31. LoadPointsFromMultiSweeps,
  32. LoadMultiViewImageFromFiles,
  33. ResizeImage,
  34. NormalizeImage,
  35. PadImage,
  36. SampleFilterByKey,
  37. GetInferInput,
  38. )
  39. from .result import BEV3DDetResult
  40. class BEVDet3DPredictor(BasicPredictor):
  41. """BEVDet3DPredictor that inherits from BasicPredictor."""
  42. entities = MODELS
  43. _FUNC_MAP = {}
  44. register = FuncRegister(_FUNC_MAP)
  45. def __init__(self, *args: List, **kwargs: Dict) -> None:
  46. """Initializes BEVDet3DPredictor.
  47. Args:
  48. *args: Arbitrary positional arguments passed to the superclass.
  49. **kwargs: Arbitrary keyword arguments passed to the superclass.
  50. """
  51. self.temp_dir = tempfile.mkdtemp()
  52. logging.info(
  53. f"infer data will be stored in temporary directory {self.temp_dir}"
  54. )
  55. super().__init__(*args, **kwargs)
  56. self.pre_tfs, self.infer = self._build()
  57. def _build_batch_sampler(self) -> Det3DBatchSampler:
  58. """Builds and returns an Det3DBatchSampler instance.
  59. Returns:
  60. Det3DBatchSampler: An instance of Det3DBatchSampler.
  61. """
  62. return Det3DBatchSampler(temp_dir=self.temp_dir)
  63. def _get_result_class(self) -> type:
  64. """Returns the result class, BEV3DDetResult.
  65. Returns:
  66. type: The BEV3DDetResult class.
  67. """
  68. return BEV3DDetResult
  69. def _build(self) -> Tuple:
  70. """Build the preprocessors and inference engine based on the configuration.
  71. Returns:
  72. tuple: A tuple containing the preprocessors and inference engine.
  73. """
  74. if (
  75. lazy_paddle.is_compiled_with_cuda()
  76. and not lazy_paddle.is_compiled_with_rocm()
  77. ):
  78. from ....ops.voxelize import hard_voxelize
  79. from ....ops.iou3d_nms import nms_gpu
  80. else:
  81. logging.error("3D BEVFusion custom ops only support GPU platform!")
  82. pre_tfs = {"Read": ReadNuscenesData()}
  83. for cfg in self.config["PreProcess"]["transform_ops"]:
  84. tf_key = list(cfg.keys())[0]
  85. func = self._FUNC_MAP[tf_key]
  86. args = cfg.get(tf_key, {})
  87. name, op = func(self, **args) if args else func(self)
  88. if op:
  89. pre_tfs[name] = op
  90. pre_tfs["GetInferInput"] = GetInferInput()
  91. infer = StaticInfer(
  92. model_dir=self.model_dir,
  93. model_prefix=self.MODEL_FILE_PREFIX,
  94. option=self.pp_option,
  95. )
  96. return pre_tfs, infer
  97. def _format_output(
  98. self, infer_input: List[Any], outs: List[Any], img_metas: Dict[str, Any]
  99. ) -> Dict[str, Any]:
  100. """format inference input and output into predict result
  101. Args:
  102. infer_input(List): Model infer inputs with list containing images, points and lidar2img matrix.
  103. outs(List): Model infer output containing bboxes, scores, labels result.
  104. img_metas(Dict): Image metas info of input sample.
  105. Returns:
  106. Dict: A Dict containing formatted inference output results.
  107. """
  108. input_lidar_path = img_metas["input_lidar_path"]
  109. input_img_paths = img_metas["input_img_paths"]
  110. sample_id = img_metas["sample_id"]
  111. results = {}
  112. out_bboxes_3d = []
  113. out_scores_3d = []
  114. out_labels_3d = []
  115. input_imgs = []
  116. input_points = []
  117. input_lidar2imgs = []
  118. input_ids = []
  119. input_lidar_path_list = []
  120. input_img_paths_list = []
  121. out_bboxes_3d.append(outs[0])
  122. out_labels_3d.append(outs[1])
  123. out_scores_3d.append(outs[2])
  124. input_imgs.append(infer_input[1])
  125. input_points.append(infer_input[0])
  126. input_lidar2imgs.append(infer_input[2])
  127. input_ids.append(sample_id)
  128. input_lidar_path_list.append(input_lidar_path)
  129. input_img_paths_list.append(input_img_paths)
  130. results["input_path"] = input_lidar_path_list
  131. results["input_img_paths"] = input_img_paths_list
  132. results["sample_id"] = input_ids
  133. results["boxes_3d"] = out_bboxes_3d
  134. results["labels_3d"] = out_labels_3d
  135. results["scores_3d"] = out_scores_3d
  136. return results
  137. def process(self, batch_data: List[str]) -> Dict[str, Any]:
  138. """
  139. Process a batch of data through the preprocessing and inference.
  140. Args:
  141. batch_data (List[str]): A batch of input data (e.g., sample anno file paths).
  142. Returns:
  143. dict: A dictionary containing the input path, input img, input points, input lidar2img, output bboxes, output labels, output scores and label names. Keys include 'input_path', 'input_img', 'input_points', 'input_lidar2img', 'boxes_3d', 'labels_3d' and 'scores_3d'.
  144. """
  145. sample = self.pre_tfs["Read"](batch_data=batch_data)
  146. sample = self.pre_tfs["LoadPointsFromFile"](results=sample[0])
  147. sample = self.pre_tfs["LoadPointsFromMultiSweeps"](results=sample)
  148. sample = self.pre_tfs["LoadMultiViewImageFromFiles"](sample=sample)
  149. sample = self.pre_tfs["ResizeImage"](results=sample)
  150. sample = self.pre_tfs["NormalizeImage"](results=sample)
  151. sample = self.pre_tfs["PadImage"](results=sample)
  152. sample = self.pre_tfs["SampleFilterByKey"](sample=sample)
  153. infer_input, img_metas = self.pre_tfs["GetInferInput"](sample=sample)
  154. infer_output = self.infer(x=infer_input)
  155. results = self._format_output(infer_input, infer_output, img_metas)
  156. return results
  157. @register("LoadPointsFromFile")
  158. def build_load_img_from_file(
  159. self, load_dim=6, use_dim=[0, 1, 2], shift_height=False, use_color=False
  160. ):
  161. return "LoadPointsFromFile", LoadPointsFromFile(
  162. load_dim=load_dim,
  163. use_dim=use_dim,
  164. shift_height=shift_height,
  165. use_color=use_color,
  166. )
  167. @register("LoadPointsFromMultiSweeps")
  168. def build_load_points_from_multi_sweeps(
  169. self,
  170. sweeps_num=10,
  171. load_dim=5,
  172. use_dim=[0, 1, 2, 4],
  173. pad_empty_sweeps=False,
  174. remove_close=False,
  175. test_mode=False,
  176. point_cloud_angle_range=None,
  177. ):
  178. return "LoadPointsFromMultiSweeps", LoadPointsFromMultiSweeps(
  179. sweeps_num=sweeps_num,
  180. load_dim=load_dim,
  181. use_dim=use_dim,
  182. pad_empty_sweeps=pad_empty_sweeps,
  183. remove_close=remove_close,
  184. test_mode=test_mode,
  185. point_cloud_angle_range=point_cloud_angle_range,
  186. )
  187. @register("LoadMultiViewImageFromFiles")
  188. def build_load_multi_view_image_from_files(
  189. self,
  190. to_float32=False,
  191. project_pts_to_img_depth=False,
  192. cam_depth_range=[4.0, 45.0, 1.0],
  193. constant_std=0.5,
  194. imread_flag=-1,
  195. ):
  196. return "LoadMultiViewImageFromFiles", LoadMultiViewImageFromFiles(
  197. to_float32=to_float32,
  198. project_pts_to_img_depth=project_pts_to_img_depth,
  199. cam_depth_range=cam_depth_range,
  200. constant_std=constant_std,
  201. imread_flag=imread_flag,
  202. )
  203. @register("ResizeImage")
  204. def build_resize_image(
  205. self,
  206. img_scale=None,
  207. multiscale_mode="range",
  208. ratio_range=None,
  209. keep_ratio=True,
  210. bbox_clip_border=True,
  211. backend="cv2",
  212. override=False,
  213. ):
  214. return "ResizeImage", ResizeImage(
  215. img_scale=img_scale,
  216. multiscale_mode=multiscale_mode,
  217. ratio_range=ratio_range,
  218. keep_ratio=keep_ratio,
  219. bbox_clip_border=bbox_clip_border,
  220. backend=backend,
  221. override=override,
  222. )
  223. @register("NormalizeImage")
  224. def build_normalize_image(self, mean, std, to_rgb=True):
  225. return "NormalizeImage", NormalizeImage(mean=mean, std=std, to_rgb=to_rgb)
  226. @register("PadImage")
  227. def build_pad_image(self, size=None, size_divisor=None, pad_val=0):
  228. return "PadImage", PadImage(
  229. size=size, size_divisor=size_divisor, pad_val=pad_val
  230. )
  231. @register("SampleFilterByKey")
  232. def build_sample_filter_by_key(
  233. self,
  234. keys,
  235. meta_keys=(
  236. "filename",
  237. "ori_shape",
  238. "img_shape",
  239. "lidar2img",
  240. "depth2img",
  241. "cam2img",
  242. "pad_shape",
  243. "scale_factor",
  244. "flip",
  245. "pcd_horizontal_flip",
  246. "pcd_vertical_flip",
  247. "box_type_3d",
  248. "img_norm_cfg",
  249. "pcd_trans",
  250. "sample_idx",
  251. "pcd_scale_factor",
  252. "pcd_rotation",
  253. "pts_filename",
  254. "transformation_3d_flow",
  255. ),
  256. ):
  257. return "SampleFilterByKey", SampleFilterByKey(keys=keys, meta_keys=meta_keys)
  258. @register("GetInferInput")
  259. def build_get_infer_input(self):
  260. return "GetInferInput", GetInferInput()
  261. def apply(self, input: Any, **kwargs) -> Iterator[Any]:
  262. """
  263. Do predicting with the input data and yields predictions.
  264. Args:
  265. input (Any): The input data to be predicted.
  266. Yields:
  267. Iterator[Any]: An iterator yielding prediction results.
  268. """
  269. try:
  270. for batch_data in self.batch_sampler(input):
  271. prediction = self.process(batch_data, **kwargs)
  272. prediction = PredictionWrap(prediction, len(batch_data))
  273. for idx in range(len(batch_data)):
  274. yield self.result_class(prediction.get_by_idx(idx))
  275. except Exception as e:
  276. raise e
  277. finally:
  278. shutil.rmtree(self.temp_dir)