|
|
@@ -12,16 +12,18 @@
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
|
|
|
-from typing import Any, Union, Dict, List, Tuple
|
|
|
+from typing import Any, Union, Dict, List, Tuple, Iterator
|
|
|
+import shutil
|
|
|
+import tempfile
|
|
|
from importlib import import_module
|
|
|
import lazy_paddle
|
|
|
|
|
|
+from ....utils import logging
|
|
|
+
|
|
|
if lazy_paddle.is_compiled_with_cuda() and not lazy_paddle.is_compiled_with_rocm():
|
|
|
from ....ops.voxelize import hard_voxelize
|
|
|
from ....ops.iou3d_nms import nms_gpu
|
|
|
else:
|
|
|
- from ....utils import logging
|
|
|
-
|
|
|
logging.error("3D BEVFusion custom ops only support GPU platform!")
|
|
|
from ....utils.func_register import FuncRegister
|
|
|
|
|
|
@@ -32,6 +34,7 @@ from ...common.batch_sampler import Det3DBatchSampler
|
|
|
from ...common.reader import ReadNuscenesData
|
|
|
from ..common import StaticInfer
|
|
|
from ..base import BasicPredictor
|
|
|
+from ..base.predictor.base_predictor import PredictionWrap
|
|
|
from .processors import (
|
|
|
LoadPointsFromFile,
|
|
|
LoadPointsFromMultiSweeps,
|
|
|
@@ -60,6 +63,10 @@ class BEVDet3DPredictor(BasicPredictor):
|
|
|
*args: Arbitrary positional arguments passed to the superclass.
|
|
|
**kwargs: Arbitrary keyword arguments passed to the superclass.
|
|
|
"""
|
|
|
+ self.temp_dir = tempfile.mkdtemp()
|
|
|
+ logging.info(
|
|
|
+ f"infer data will be stored in temporary directory {self.temp_dir}"
|
|
|
+ )
|
|
|
super().__init__(*args, **kwargs)
|
|
|
self.pre_tfs, self.infer = self._build()
|
|
|
|
|
|
@@ -69,7 +76,7 @@ class BEVDet3DPredictor(BasicPredictor):
|
|
|
Returns:
|
|
|
Det3DBatchSampler: An instance of Det3DBatchSampler.
|
|
|
"""
|
|
|
- return Det3DBatchSampler()
|
|
|
+ return Det3DBatchSampler(temp_dir=self.temp_dir)
|
|
|
|
|
|
def _get_result_class(self) -> type:
|
|
|
"""Returns the result class, BEV3DDetResult.
|
|
|
@@ -279,3 +286,25 @@ class BEVDet3DPredictor(BasicPredictor):
|
|
|
@register("GetInferInput")
|
|
|
def build_get_infer_input(self):
|
|
|
return "GetInferInput", GetInferInput()
|
|
|
+
|
|
|
+ def apply(self, input: Any, **kwargs) -> Iterator[Any]:
|
|
|
+ """
|
|
|
+ Do predicting with the input data and yields predictions.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ input (Any): The input data to be predicted.
|
|
|
+
|
|
|
+ Yields:
|
|
|
+ Iterator[Any]: An iterator yielding prediction results.
|
|
|
+ """
|
|
|
+
|
|
|
+ try:
|
|
|
+ for batch_data in self.batch_sampler(input):
|
|
|
+ prediction = self.process(batch_data, **kwargs)
|
|
|
+ prediction = PredictionWrap(prediction, len(batch_data))
|
|
|
+ for idx in range(len(batch_data)):
|
|
|
+ yield self.result_class(prediction.get_by_idx(idx))
|
|
|
+ except Exception as e:
|
|
|
+ print(f"An error occurred in 3d bev detection inference: {e}")
|
|
|
+ finally:
|
|
|
+ shutil.rmtree(self.temp_dir)
|