| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303 |
- # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import shutil
- import tempfile
- from typing import Any, Dict, Iterator, List, Tuple
- from ....modules.m_3d_bev_detection.model_list import MODELS
- from ....utils import logging
- from ....utils.func_register import FuncRegister
- from ...common.batch_sampler import Det3DBatchSampler
- from ...common.reader import ReadNuscenesData
- from ..base import BasePredictor
- from ..base.predictor.base_predictor import PredictionWrap
- from .processors import (
- GetInferInput,
- LoadMultiViewImageFromFiles,
- LoadPointsFromFile,
- LoadPointsFromMultiSweeps,
- NormalizeImage,
- PadImage,
- ResizeImage,
- SampleFilterByKey,
- )
- from .result import BEV3DDetResult
- class BEVDet3DPredictor(BasePredictor):
- """BEVDet3DPredictor that inherits from BasePredictor."""
- entities = MODELS
- _FUNC_MAP = {}
- register = FuncRegister(_FUNC_MAP)
- def __init__(self, *args: List, **kwargs: Dict) -> None:
- """Initializes BEVDet3DPredictor.
- Args:
- *args: Arbitrary positional arguments passed to the superclass.
- **kwargs: Arbitrary keyword arguments passed to the superclass.
- """
- self.temp_dir = tempfile.mkdtemp()
- logging.info(
- f"infer data will be stored in temporary directory {self.temp_dir}"
- )
- super().__init__(*args, **kwargs)
- self.pre_tfs, self.infer = self._build()
- def _build_batch_sampler(self) -> Det3DBatchSampler:
- """Builds and returns an Det3DBatchSampler instance.
- Returns:
- Det3DBatchSampler: An instance of Det3DBatchSampler.
- """
- return Det3DBatchSampler(temp_dir=self.temp_dir)
- def _get_result_class(self) -> type:
- """Returns the result class, BEV3DDetResult.
- Returns:
- type: The BEV3DDetResult class.
- """
- return BEV3DDetResult
- def _build(self) -> Tuple:
- """Build the preprocessors and inference engine based on the configuration.
- Returns:
- tuple: A tuple containing the preprocessors and inference engine.
- """
- import paddle
- if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm():
- from ....ops.iou3d_nms import nms_gpu # noqa: F401
- from ....ops.voxelize import hard_voxelize # noqa: F401
- else:
- logging.error("3D BEVFusion custom ops only support GPU platform!")
- pre_tfs = {"Read": ReadNuscenesData()}
- for cfg in self.config["PreProcess"]["transform_ops"]:
- tf_key = list(cfg.keys())[0]
- func = self._FUNC_MAP[tf_key]
- args = cfg.get(tf_key, {})
- name, op = func(self, **args) if args else func(self)
- if op:
- pre_tfs[name] = op
- pre_tfs["GetInferInput"] = GetInferInput()
- infer = self.create_static_infer()
- return pre_tfs, infer
- def _format_output(
- self, infer_input: List[Any], outs: List[Any], img_metas: Dict[str, Any]
- ) -> Dict[str, Any]:
- """format inference input and output into predict result
- Args:
- infer_input(List): Model infer inputs with list containing images, points and lidar2img matrix.
- outs(List): Model infer output containing bboxes, scores, labels result.
- img_metas(Dict): Image metas info of input sample.
- Returns:
- Dict: A Dict containing formatted inference output results.
- """
- input_lidar_path = img_metas["input_lidar_path"]
- input_img_paths = img_metas["input_img_paths"]
- sample_id = img_metas["sample_id"]
- results = {}
- out_bboxes_3d = []
- out_scores_3d = []
- out_labels_3d = []
- input_imgs = []
- input_points = []
- input_lidar2imgs = []
- input_ids = []
- input_lidar_path_list = []
- input_img_paths_list = []
- out_bboxes_3d.append(outs[0])
- out_labels_3d.append(outs[1])
- out_scores_3d.append(outs[2])
- input_imgs.append(infer_input[1])
- input_points.append(infer_input[0])
- input_lidar2imgs.append(infer_input[2])
- input_ids.append(sample_id)
- input_lidar_path_list.append(input_lidar_path)
- input_img_paths_list.append(input_img_paths)
- results["input_path"] = input_lidar_path_list
- results["input_img_paths"] = input_img_paths_list
- results["sample_id"] = input_ids
- results["boxes_3d"] = out_bboxes_3d
- results["labels_3d"] = out_labels_3d
- results["scores_3d"] = out_scores_3d
- return results
- def process(self, batch_data: List[str]) -> Dict[str, Any]:
- """
- Process a batch of data through the preprocessing and inference.
- Args:
- batch_data (List[str]): A batch of input data (e.g., sample anno file paths).
- Returns:
- dict: A dictionary containing the input path, input img, input points, input lidar2img, output bboxes, output labels, output scores and label names. Keys include 'input_path', 'input_img', 'input_points', 'input_lidar2img', 'boxes_3d', 'labels_3d' and 'scores_3d'.
- """
- sample = self.pre_tfs["Read"](batch_data=batch_data)
- sample = self.pre_tfs["LoadPointsFromFile"](results=sample[0])
- sample = self.pre_tfs["LoadPointsFromMultiSweeps"](results=sample)
- sample = self.pre_tfs["LoadMultiViewImageFromFiles"](sample=sample)
- sample = self.pre_tfs["ResizeImage"](results=sample)
- sample = self.pre_tfs["NormalizeImage"](results=sample)
- sample = self.pre_tfs["PadImage"](results=sample)
- sample = self.pre_tfs["SampleFilterByKey"](sample=sample)
- infer_input, img_metas = self.pre_tfs["GetInferInput"](sample=sample)
- infer_output = self.infer(x=infer_input)
- results = self._format_output(infer_input, infer_output, img_metas)
- return results
- @register("LoadPointsFromFile")
- def build_load_img_from_file(
- self, load_dim=6, use_dim=[0, 1, 2], shift_height=False, use_color=False
- ):
- return "LoadPointsFromFile", LoadPointsFromFile(
- load_dim=load_dim,
- use_dim=use_dim,
- shift_height=shift_height,
- use_color=use_color,
- )
- @register("LoadPointsFromMultiSweeps")
- def build_load_points_from_multi_sweeps(
- self,
- sweeps_num=10,
- load_dim=5,
- use_dim=[0, 1, 2, 4],
- pad_empty_sweeps=False,
- remove_close=False,
- test_mode=False,
- point_cloud_angle_range=None,
- ):
- return "LoadPointsFromMultiSweeps", LoadPointsFromMultiSweeps(
- sweeps_num=sweeps_num,
- load_dim=load_dim,
- use_dim=use_dim,
- pad_empty_sweeps=pad_empty_sweeps,
- remove_close=remove_close,
- test_mode=test_mode,
- point_cloud_angle_range=point_cloud_angle_range,
- )
- @register("LoadMultiViewImageFromFiles")
- def build_load_multi_view_image_from_files(
- self,
- to_float32=False,
- project_pts_to_img_depth=False,
- cam_depth_range=[4.0, 45.0, 1.0],
- constant_std=0.5,
- imread_flag=-1,
- ):
- return "LoadMultiViewImageFromFiles", LoadMultiViewImageFromFiles(
- to_float32=to_float32,
- project_pts_to_img_depth=project_pts_to_img_depth,
- cam_depth_range=cam_depth_range,
- constant_std=constant_std,
- imread_flag=imread_flag,
- )
- @register("ResizeImage")
- def build_resize_image(
- self,
- img_scale=None,
- multiscale_mode="range",
- ratio_range=None,
- keep_ratio=True,
- bbox_clip_border=True,
- backend="cv2",
- override=False,
- ):
- return "ResizeImage", ResizeImage(
- img_scale=img_scale,
- multiscale_mode=multiscale_mode,
- ratio_range=ratio_range,
- keep_ratio=keep_ratio,
- bbox_clip_border=bbox_clip_border,
- backend=backend,
- override=override,
- )
- @register("NormalizeImage")
- def build_normalize_image(self, mean, std, to_rgb=True):
- return "NormalizeImage", NormalizeImage(mean=mean, std=std, to_rgb=to_rgb)
- @register("PadImage")
- def build_pad_image(self, size=None, size_divisor=None, pad_val=0):
- return "PadImage", PadImage(
- size=size, size_divisor=size_divisor, pad_val=pad_val
- )
- @register("SampleFilterByKey")
- def build_sample_filter_by_key(
- self,
- keys,
- meta_keys=(
- "filename",
- "ori_shape",
- "img_shape",
- "lidar2img",
- "depth2img",
- "cam2img",
- "pad_shape",
- "scale_factor",
- "flip",
- "pcd_horizontal_flip",
- "pcd_vertical_flip",
- "box_type_3d",
- "img_norm_cfg",
- "pcd_trans",
- "sample_idx",
- "pcd_scale_factor",
- "pcd_rotation",
- "pts_filename",
- "transformation_3d_flow",
- ),
- ):
- return "SampleFilterByKey", SampleFilterByKey(keys=keys, meta_keys=meta_keys)
- @register("GetInferInput")
- def build_get_infer_input(self):
- return "GetInferInput", GetInferInput()
- def apply(self, input: Any, **kwargs) -> Iterator[Any]:
- """
- Do predicting with the input data and yields predictions.
- Args:
- input (Any): The input data to be predicted.
- Yields:
- Iterator[Any]: An iterator yielding prediction results.
- """
- try:
- for batch_data in self.batch_sampler(input):
- prediction = self.process(batch_data, **kwargs)
- prediction = PredictionWrap(prediction, len(batch_data))
- for idx in range(len(batch_data)):
- yield self.result_class(prediction.get_by_idx(idx))
- except Exception as e:
- raise e
- finally:
- shutil.rmtree(self.temp_dir)
|