浏览代码

fix server deploy

warrentdrew 9 月之前
父节点
当前提交
90b7801956

+ 11 - 4
paddlex/inference/common/batch_sampler/det_3d_batch_sampler.py

@@ -108,8 +108,7 @@ class Det3DBatchSampler(BaseBatchSampler):
                     f"Not supported input data type! Only `str` is supported! So has been ignored: {input}."
                 )
             # extract tar file to tempdir
-            self.extract_tar(ann_path, self.temp_dir)
-            dataset_name = os.path.basename(ann_path).split(".")[0]
+            dataset_name = self.extract_tar(ann_path, self.temp_dir)
             data_root_dir = os.path.join(self.temp_dir, dataset_name)
             ann_pkl_path = os.path.join(data_root_dir, "nuscenes_infos_val.pkl")
             self.data_infos = self.load_annotations(ann_pkl_path, data_root_dir)
@@ -132,9 +131,17 @@ class Det3DBatchSampler(BaseBatchSampler):
 
     def extract_tar(self, tar_path, extract_path="."):
         try:
+            memdirs = set()
             with tarfile.open(tar_path, "r") as tar:
                 for member in tar.getmembers():
+                    memdir = member.name.split("/")[0]
+                    memdirs.add(memdir)
                     tar.extract(member, path=extract_path)
-                print(f"file extract to {extract_path}")
+                logging.info(f"file extract to {extract_path}")
+            assert (
+                len(memdirs) == 1
+            ), "Only one base directory is allowed for 3d bev dataset!"
+            return list(memdirs)[0]
         except Exception as e:
-            print(f"error occurred while extracting tar file: {e}")
+            logging.error(f"error occurred while extracting tar file")
+            raise e

+ 3 - 2
paddlex/inference/models/3d_bev_detection/predictor.py

@@ -26,6 +26,7 @@ if lazy_paddle.is_compiled_with_cuda() and not lazy_paddle.is_compiled_with_rocm
 else:
     logging.error("3D BEVFusion custom ops only support GPU platform!")
 from ....utils.func_register import FuncRegister
+from ....utils.cache import TEMP_DIR
 
 module_3d_bev_detection = import_module(".3d_bev_detection", "paddlex.modules")
 module_3d_model_list = getattr(module_3d_bev_detection, "model_list")
@@ -63,7 +64,7 @@ class BEVDet3DPredictor(BasicPredictor):
             *args: Arbitrary positional arguments passed to the superclass.
             **kwargs: Arbitrary keyword arguments passed to the superclass.
         """
-        self.temp_dir = tempfile.mkdtemp()
+        self.temp_dir = tempfile.mkdtemp(dir=TEMP_DIR)
         logging.info(
             f"infer data will be stored in temporary directory {self.temp_dir}"
         )
@@ -305,6 +306,6 @@ class BEVDet3DPredictor(BasicPredictor):
                 for idx in range(len(batch_data)):
                     yield self.result_class(prediction.get_by_idx(idx))
         except Exception as e:
-            print(f"An error occurred in 3d bev detection inference: {e}")
+            raise e
         finally:
             shutil.rmtree(self.temp_dir)