Pārlūkot izejas kodu

add tempfile for 3d infer data

warrentdrew 9 mēneši atpakaļ
vecāks
revīzija
3f84237efb

+ 10 - 5
paddlex/inference/common/batch_sampler/det_3d_batch_sampler.py

@@ -29,6 +29,10 @@ from .base_batch_sampler import BaseBatchSampler
 
 class Det3DBatchSampler(BaseBatchSampler):
 
+    def __init__(self, temp_dir) -> None:
+        super().__init__()
+        self.temp_dir = temp_dir
+
     # XXX: auto download for url
     def _download_from_url(self, in_path: str) -> str:
         file_name = Path(in_path).name
@@ -103,10 +107,10 @@ class Det3DBatchSampler(BaseBatchSampler):
                 logging.warning(
                     f"Not supported input data type! Only `str` is supported! So has been ignored: {input}."
                 )
-            # extract tar file
-            tar_root_dir = os.path.dirname(ann_path)
-            self.extract_tar(ann_path, tar_root_dir)
-            data_root_dir, _ = os.path.splitext(ann_path)
+            # extract tar file to tempdir
+            self.extract_tar(ann_path, self.temp_dir)
+            dataset_name = os.path.basename(ann_path).split(".")[0]
+            data_root_dir = os.path.join(self.temp_dir, dataset_name)
             ann_pkl_path = os.path.join(data_root_dir, "nuscenes_infos_val.pkl")
             self.data_infos = self.load_annotations(ann_pkl_path, data_root_dir)
             sample_set.extend(self.data_infos)
@@ -129,7 +133,8 @@ class Det3DBatchSampler(BaseBatchSampler):
     def extract_tar(self, tar_path, extract_path="."):
         try:
             with tarfile.open(tar_path, "r") as tar:
-                tar.extractall(path=extract_path)
+                for member in tar.getmembers():
+                    tar.extract(member, path=extract_path)
                 print(f"file extract to {extract_path}")
         except Exception as e:
             print(f"error occurred while extracting tar file: {e}")

+ 33 - 4
paddlex/inference/models/3d_bev_detection/predictor.py

@@ -12,16 +12,18 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, Union, Dict, List, Tuple
+from typing import Any, Union, Dict, List, Tuple, Iterator
+import shutil
+import tempfile
 from importlib import import_module
 import lazy_paddle
 
+from ....utils import logging
+
 if lazy_paddle.is_compiled_with_cuda() and not lazy_paddle.is_compiled_with_rocm():
     from ....ops.voxelize import hard_voxelize
     from ....ops.iou3d_nms import nms_gpu
 else:
-    from ....utils import logging
-
     logging.error("3D BEVFusion custom ops only support GPU platform!")
 from ....utils.func_register import FuncRegister
 
@@ -32,6 +34,7 @@ from ...common.batch_sampler import Det3DBatchSampler
 from ...common.reader import ReadNuscenesData
 from ..common import StaticInfer
 from ..base import BasicPredictor
+from ..base.predictor.base_predictor import PredictionWrap
 from .processors import (
     LoadPointsFromFile,
     LoadPointsFromMultiSweeps,
@@ -60,6 +63,10 @@ class BEVDet3DPredictor(BasicPredictor):
             *args: Arbitrary positional arguments passed to the superclass.
             **kwargs: Arbitrary keyword arguments passed to the superclass.
         """
+        self.temp_dir = tempfile.mkdtemp()
+        logging.info(
+            f"infer data will be stored in temporary directory {self.temp_dir}"
+        )
         super().__init__(*args, **kwargs)
         self.pre_tfs, self.infer = self._build()
 
@@ -69,7 +76,7 @@ class BEVDet3DPredictor(BasicPredictor):
         Returns:
             Det3DBatchSampler: An instance of Det3DBatchSampler.
         """
-        return Det3DBatchSampler()
+        return Det3DBatchSampler(temp_dir=self.temp_dir)
 
     def _get_result_class(self) -> type:
         """Returns the result class, BEV3DDetResult.
@@ -279,3 +286,25 @@ class BEVDet3DPredictor(BasicPredictor):
     @register("GetInferInput")
     def build_get_infer_input(self):
         return "GetInferInput", GetInferInput()
+
+    def apply(self, input: Any, **kwargs) -> Iterator[Any]:
+        """
+        Do predicting with the input data and yields predictions.
+
+        Args:
+            input (Any): The input data to be predicted.
+
+        Yields:
+            Iterator[Any]: An iterator yielding prediction results.
+        """
+
+        try:
+            for batch_data in self.batch_sampler(input):
+                prediction = self.process(batch_data, **kwargs)
+                prediction = PredictionWrap(prediction, len(batch_data))
+                for idx in range(len(batch_data)):
+                    yield self.result_class(prediction.get_by_idx(idx))
+        except Exception as e:
+            print(f"An error occurred in 3d bev detection inference: {e}")
+        finally:
+            shutil.rmtree(self.temp_dir)