|
@@ -179,21 +179,28 @@ class StaticInfer:
|
|
|
if self.option.run_mode.startswith("trt"):
|
|
if self.option.run_mode.startswith("trt"):
|
|
|
assert self.option.device == "gpu"
|
|
assert self.option.device == "gpu"
|
|
|
if not USE_PIR_TRT:
|
|
if not USE_PIR_TRT:
|
|
|
- if not os.path.exists(self.option.shape_info_filename):
|
|
|
|
|
|
|
+ if self.option.shape_info_filename is None:
|
|
|
|
|
+ shape_range_info_path = (
|
|
|
|
|
+ self.model_dir / "shape_range_info.pbtxt"
|
|
|
|
|
+ ).as_posix()
|
|
|
|
|
+ else:
|
|
|
|
|
+ shape_range_info_path = self.option.shape_info_filename
|
|
|
|
|
+ if not os.path.exists(shape_range_info_path):
|
|
|
logging.info(
|
|
logging.info(
|
|
|
- f"Dynamic shape info is collected into: {self.option.shape_info_filename}"
|
|
|
|
|
|
|
+ f"Dynamic shape info is collected into: {shape_range_info_path}"
|
|
|
)
|
|
)
|
|
|
collect_trt_shapes(
|
|
collect_trt_shapes(
|
|
|
model_file,
|
|
model_file,
|
|
|
params_file,
|
|
params_file,
|
|
|
self.option.device_id,
|
|
self.option.device_id,
|
|
|
- self.option.shape_info_filename,
|
|
|
|
|
|
|
+ shape_range_info_path,
|
|
|
self.option.trt_dynamic_shapes,
|
|
self.option.trt_dynamic_shapes,
|
|
|
)
|
|
)
|
|
|
else:
|
|
else:
|
|
|
logging.info(
|
|
logging.info(
|
|
|
- f"A dynamic shape info file ( {self.option.shape_info_filename} ) already exists. No need to collect again."
|
|
|
|
|
|
|
+ f"A dynamic shape info file ( {shape_range_info_path} ) already exists. No need to collect again."
|
|
|
)
|
|
)
|
|
|
|
|
+ self.option.shape_info_filename = shape_range_info_path
|
|
|
else:
|
|
else:
|
|
|
trt_save_path = (
|
|
trt_save_path = (
|
|
|
Path(self.model_dir) / "trt" / self.model_prefix
|
|
Path(self.model_dir) / "trt" / self.model_prefix
|