# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import shutil import tempfile from typing import Any, Dict, Iterator, List, Tuple from ....modules.m_3d_bev_detection.model_list import MODELS from ....utils import logging from ....utils.func_register import FuncRegister from ...common.batch_sampler import Det3DBatchSampler from ...common.reader import ReadNuscenesData from ..base import BasePredictor from ..base.predictor.base_predictor import PredictionWrap from .processors import ( GetInferInput, LoadMultiViewImageFromFiles, LoadPointsFromFile, LoadPointsFromMultiSweeps, NormalizeImage, PadImage, ResizeImage, SampleFilterByKey, ) from .result import BEV3DDetResult class BEVDet3DPredictor(BasePredictor): """BEVDet3DPredictor that inherits from BasePredictor.""" entities = MODELS _FUNC_MAP = {} register = FuncRegister(_FUNC_MAP) def __init__(self, *args: List, **kwargs: Dict) -> None: """Initializes BEVDet3DPredictor. Args: *args: Arbitrary positional arguments passed to the superclass. **kwargs: Arbitrary keyword arguments passed to the superclass. """ self.temp_dir = tempfile.mkdtemp() logging.info( f"infer data will be stored in temporary directory {self.temp_dir}" ) super().__init__(*args, **kwargs) self.pre_tfs, self.infer = self._build() def _build_batch_sampler(self) -> Det3DBatchSampler: """Builds and returns an Det3DBatchSampler instance. Returns: Det3DBatchSampler: An instance of Det3DBatchSampler. """ return Det3DBatchSampler(temp_dir=self.temp_dir) def _get_result_class(self) -> type: """Returns the result class, BEV3DDetResult. Returns: type: The BEV3DDetResult class. """ return BEV3DDetResult def _build(self) -> Tuple: """Build the preprocessors and inference engine based on the configuration. Returns: tuple: A tuple containing the preprocessors and inference engine. """ import paddle if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm(): from ....ops.iou3d_nms import nms_gpu # noqa: F401 from ....ops.voxelize import hard_voxelize # noqa: F401 else: logging.error("3D BEVFusion custom ops only support GPU platform!") pre_tfs = {"Read": ReadNuscenesData()} for cfg in self.config["PreProcess"]["transform_ops"]: tf_key = list(cfg.keys())[0] func = self._FUNC_MAP[tf_key] args = cfg.get(tf_key, {}) name, op = func(self, **args) if args else func(self) if op: pre_tfs[name] = op pre_tfs["GetInferInput"] = GetInferInput() infer = self.create_static_infer() return pre_tfs, infer def _format_output( self, infer_input: List[Any], outs: List[Any], img_metas: Dict[str, Any] ) -> Dict[str, Any]: """format inference input and output into predict result Args: infer_input(List): Model infer inputs with list containing images, points and lidar2img matrix. outs(List): Model infer output containing bboxes, scores, labels result. img_metas(Dict): Image metas info of input sample. Returns: Dict: A Dict containing formatted inference output results. """ input_lidar_path = img_metas["input_lidar_path"] input_img_paths = img_metas["input_img_paths"] sample_id = img_metas["sample_id"] results = {} out_bboxes_3d = [] out_scores_3d = [] out_labels_3d = [] input_imgs = [] input_points = [] input_lidar2imgs = [] input_ids = [] input_lidar_path_list = [] input_img_paths_list = [] out_bboxes_3d.append(outs[0]) out_labels_3d.append(outs[1]) out_scores_3d.append(outs[2]) input_imgs.append(infer_input[1]) input_points.append(infer_input[0]) input_lidar2imgs.append(infer_input[2]) input_ids.append(sample_id) input_lidar_path_list.append(input_lidar_path) input_img_paths_list.append(input_img_paths) results["input_path"] = input_lidar_path_list results["input_img_paths"] = input_img_paths_list results["sample_id"] = input_ids results["boxes_3d"] = out_bboxes_3d results["labels_3d"] = out_labels_3d results["scores_3d"] = out_scores_3d return results def process(self, batch_data: List[str]) -> Dict[str, Any]: """ Process a batch of data through the preprocessing and inference. Args: batch_data (List[str]): A batch of input data (e.g., sample anno file paths). Returns: dict: A dictionary containing the input path, input img, input points, input lidar2img, output bboxes, output labels, output scores and label names. Keys include 'input_path', 'input_img', 'input_points', 'input_lidar2img', 'boxes_3d', 'labels_3d' and 'scores_3d'. """ sample = self.pre_tfs["Read"](batch_data=batch_data) sample = self.pre_tfs["LoadPointsFromFile"](results=sample[0]) sample = self.pre_tfs["LoadPointsFromMultiSweeps"](results=sample) sample = self.pre_tfs["LoadMultiViewImageFromFiles"](sample=sample) sample = self.pre_tfs["ResizeImage"](results=sample) sample = self.pre_tfs["NormalizeImage"](results=sample) sample = self.pre_tfs["PadImage"](results=sample) sample = self.pre_tfs["SampleFilterByKey"](sample=sample) infer_input, img_metas = self.pre_tfs["GetInferInput"](sample=sample) infer_output = self.infer(x=infer_input) results = self._format_output(infer_input, infer_output, img_metas) return results @register("LoadPointsFromFile") def build_load_img_from_file( self, load_dim=6, use_dim=[0, 1, 2], shift_height=False, use_color=False ): return "LoadPointsFromFile", LoadPointsFromFile( load_dim=load_dim, use_dim=use_dim, shift_height=shift_height, use_color=use_color, ) @register("LoadPointsFromMultiSweeps") def build_load_points_from_multi_sweeps( self, sweeps_num=10, load_dim=5, use_dim=[0, 1, 2, 4], pad_empty_sweeps=False, remove_close=False, test_mode=False, point_cloud_angle_range=None, ): return "LoadPointsFromMultiSweeps", LoadPointsFromMultiSweeps( sweeps_num=sweeps_num, load_dim=load_dim, use_dim=use_dim, pad_empty_sweeps=pad_empty_sweeps, remove_close=remove_close, test_mode=test_mode, point_cloud_angle_range=point_cloud_angle_range, ) @register("LoadMultiViewImageFromFiles") def build_load_multi_view_image_from_files( self, to_float32=False, project_pts_to_img_depth=False, cam_depth_range=[4.0, 45.0, 1.0], constant_std=0.5, imread_flag=-1, ): return "LoadMultiViewImageFromFiles", LoadMultiViewImageFromFiles( to_float32=to_float32, project_pts_to_img_depth=project_pts_to_img_depth, cam_depth_range=cam_depth_range, constant_std=constant_std, imread_flag=imread_flag, ) @register("ResizeImage") def build_resize_image( self, img_scale=None, multiscale_mode="range", ratio_range=None, keep_ratio=True, bbox_clip_border=True, backend="cv2", override=False, ): return "ResizeImage", ResizeImage( img_scale=img_scale, multiscale_mode=multiscale_mode, ratio_range=ratio_range, keep_ratio=keep_ratio, bbox_clip_border=bbox_clip_border, backend=backend, override=override, ) @register("NormalizeImage") def build_normalize_image(self, mean, std, to_rgb=True): return "NormalizeImage", NormalizeImage(mean=mean, std=std, to_rgb=to_rgb) @register("PadImage") def build_pad_image(self, size=None, size_divisor=None, pad_val=0): return "PadImage", PadImage( size=size, size_divisor=size_divisor, pad_val=pad_val ) @register("SampleFilterByKey") def build_sample_filter_by_key( self, keys, meta_keys=( "filename", "ori_shape", "img_shape", "lidar2img", "depth2img", "cam2img", "pad_shape", "scale_factor", "flip", "pcd_horizontal_flip", "pcd_vertical_flip", "box_type_3d", "img_norm_cfg", "pcd_trans", "sample_idx", "pcd_scale_factor", "pcd_rotation", "pts_filename", "transformation_3d_flow", ), ): return "SampleFilterByKey", SampleFilterByKey(keys=keys, meta_keys=meta_keys) @register("GetInferInput") def build_get_infer_input(self): return "GetInferInput", GetInferInput() def apply(self, input: Any, **kwargs) -> Iterator[Any]: """ Do predicting with the input data and yields predictions. Args: input (Any): The input data to be predicted. Yields: Iterator[Any]: An iterator yielding prediction results. """ try: for batch_data in self.batch_sampler(input): prediction = self.process(batch_data, **kwargs) prediction = PredictionWrap(prediction, len(batch_data)) for idx in range(len(batch_data)): yield self.result_class(prediction.get_by_idx(idx)) except Exception as e: raise e finally: shutil.rmtree(self.temp_dir)