gaotingquan преди 10 месеца
родител
ревизия
adb12d98b6
променени са 1 файла, в които са добавени 11 реда и са изтрити 4 реда
  1. 11 4
      paddlex/inference/models/common/static_infer.py

+ 11 - 4
paddlex/inference/models/common/static_infer.py

@@ -179,21 +179,28 @@ class StaticInfer:
         if self.option.run_mode.startswith("trt"):
             assert self.option.device == "gpu"
             if not USE_PIR_TRT:
-                if not os.path.exists(self.option.shape_info_filename):
+                if self.option.shape_info_filename is None:
+                    shape_range_info_path = (
+                        self.model_dir / "shape_range_info.pbtxt"
+                    ).as_posix()
+                else:
+                    shape_range_info_path = self.option.shape_info_filename
+                if not os.path.exists(shape_range_info_path):
                     logging.info(
-                        f"Dynamic shape info is collected into: {self.option.shape_info_filename}"
+                        f"Dynamic shape info is collected into: {shape_range_info_path}"
                     )
                     collect_trt_shapes(
                         model_file,
                         params_file,
                         self.option.device_id,
-                        self.option.shape_info_filename,
+                        shape_range_info_path,
                         self.option.trt_dynamic_shapes,
                     )
                 else:
                     logging.info(
-                        f"A dynamic shape info file ( {self.option.shape_info_filename} ) already exists. No need to collect again."
+                        f"A dynamic shape info file ( {shape_range_info_path} ) already exists. No need to collect again."
                     )
+                self.option.shape_info_filename = shape_range_info_path
             else:
                 trt_save_path = (
                     Path(self.model_dir) / "trt" / self.model_prefix