Browse Source

bugfix: run_mode defult setting is not work when device passed from create_model or PaddlePredictorOption

gaotingquan 5 months ago
parent
commit
81af5ba238

+ 2 - 0
paddlex/inference/models/base/predictor/base_predictor.py

@@ -337,9 +337,11 @@ class BasePredictor(
             pp_option = PaddlePredictorOption(model_name=self.model_name)
         elif pp_option.model_name is None:
             pp_option.model_name = self.model_name
+            pp_option.reset_run_mode_by_default(model_name=self.model_name)
         if device_info:
             pp_option.device_type = device_info[0]
             pp_option.device_id = device_info[1]
+            pp_option.reset_run_mode_by_default(device_type=device_info[0])
         hpi_info = self.get_hpi_info()
         if hpi_info is not None:
             hpi_info = hpi_info.model_dump(exclude_unset=True)

+ 27 - 18
paddlex/inference/utils/pp_option.py

@@ -122,12 +122,16 @@ class PaddlePredictorOption(object):
 
     def _get_default_config(self):
         """get default config"""
-        device_type, device_ids = parse_device(get_default_device())
+        if self.device_type is None:
+            device_type, device_ids = parse_device(get_default_device())
+            device_id = None if device_ids is None else device_ids[0]
+        else:
+            device_type, device_id = self.device_type, self.device_id
 
         default_config = {
             "run_mode": get_default_run_mode(self.model_name, device_type),
             "device_type": device_type,
-            "device_id": None if device_ids is None else device_ids[0],
+            "device_id": device_id,
             "cpu_threads": 8,
             "delete_pass": [],
             "enable_new_ir": True if self.model_name not in NEWIR_BLOCKLIST else False,
@@ -148,9 +152,14 @@ class PaddlePredictorOption(object):
         self._cfg[k] = v
         self.changed = True
 
+    def reset_run_mode_by_default(self, model_name=None, device_type=None):
+        model_name = model_name or self.model_name
+        device_type = device_type or self.device_type
+        self._update("run_mode", get_default_run_mode(model_name, device_type))
+
     @property
     def run_mode(self):
-        return self._cfg["run_mode"]
+        return self._cfg.get("run_mode")
 
     @run_mode.setter
     def run_mode(self, run_mode: str):
@@ -193,7 +202,7 @@ class PaddlePredictorOption(object):
 
     @property
     def device_type(self):
-        return self._cfg["device_type"]
+        return self._cfg.get("device_type")
 
     @device_type.setter
     def device_type(self, device_type):
@@ -211,7 +220,7 @@ class PaddlePredictorOption(object):
 
     @property
     def device_id(self):
-        return self._cfg["device_id"]
+        return self._cfg.get("device_id")
 
     @device_id.setter
     def device_id(self, device_id):
@@ -219,7 +228,7 @@ class PaddlePredictorOption(object):
 
     @property
     def cpu_threads(self):
-        return self._cfg["cpu_threads"]
+        return self._cfg.get("cpu_threads")
 
     @cpu_threads.setter
     def cpu_threads(self, cpu_threads):
@@ -230,7 +239,7 @@ class PaddlePredictorOption(object):
 
     @property
     def delete_pass(self):
-        return self._cfg["delete_pass"]
+        return self._cfg.get("delete_pass")
 
     @delete_pass.setter
     def delete_pass(self, delete_pass):
@@ -238,7 +247,7 @@ class PaddlePredictorOption(object):
 
     @property
     def enable_new_ir(self):
-        return self._cfg["enable_new_ir"]
+        return self._cfg.get("enable_new_ir")
 
     @enable_new_ir.setter
     def enable_new_ir(self, enable_new_ir: bool):
@@ -247,7 +256,7 @@ class PaddlePredictorOption(object):
 
     @property
     def enable_cinn(self):
-        return self._cfg["enable_cinn"]
+        return self._cfg.get("enable_cinn")
 
     @enable_cinn.setter
     def enable_cinn(self, enable_cinn: bool):
@@ -256,7 +265,7 @@ class PaddlePredictorOption(object):
 
     @property
     def trt_cfg_setting(self):
-        return self._cfg["trt_cfg_setting"]
+        return self._cfg.get("trt_cfg_setting")
 
     @trt_cfg_setting.setter
     def trt_cfg_setting(self, config: Dict):
@@ -268,7 +277,7 @@ class PaddlePredictorOption(object):
 
     @property
     def trt_use_dynamic_shapes(self):
-        return self._cfg["trt_use_dynamic_shapes"]
+        return self._cfg.get("trt_use_dynamic_shapes")
 
     @trt_use_dynamic_shapes.setter
     def trt_use_dynamic_shapes(self, trt_use_dynamic_shapes):
@@ -276,7 +285,7 @@ class PaddlePredictorOption(object):
 
     @property
     def trt_collect_shape_range_info(self):
-        return self._cfg["trt_collect_shape_range_info"]
+        return self._cfg.get("trt_collect_shape_range_info")
 
     @trt_collect_shape_range_info.setter
     def trt_collect_shape_range_info(self, trt_collect_shape_range_info):
@@ -284,7 +293,7 @@ class PaddlePredictorOption(object):
 
     @property
     def trt_discard_cached_shape_range_info(self):
-        return self._cfg["trt_discard_cached_shape_range_info"]
+        return self._cfg.get("trt_discard_cached_shape_range_info")
 
     @trt_discard_cached_shape_range_info.setter
     def trt_discard_cached_shape_range_info(self, trt_discard_cached_shape_range_info):
@@ -294,7 +303,7 @@ class PaddlePredictorOption(object):
 
     @property
     def trt_dynamic_shapes(self):
-        return self._cfg["trt_dynamic_shapes"]
+        return self._cfg.get("trt_dynamic_shapes")
 
     @trt_dynamic_shapes.setter
     def trt_dynamic_shapes(self, trt_dynamic_shapes: Dict[str, List[List[int]]]):
@@ -305,7 +314,7 @@ class PaddlePredictorOption(object):
 
     @property
     def trt_dynamic_shape_input_data(self):
-        return self._cfg["trt_dynamic_shape_input_data"]
+        return self._cfg.get("trt_dynamic_shape_input_data")
 
     @trt_dynamic_shape_input_data.setter
     def trt_dynamic_shape_input_data(
@@ -315,7 +324,7 @@ class PaddlePredictorOption(object):
 
     @property
     def trt_shape_range_info_path(self):
-        return self._cfg["trt_shape_range_info_path"]
+        return self._cfg.get("trt_shape_range_info_path")
 
     @trt_shape_range_info_path.setter
     def trt_shape_range_info_path(self, trt_shape_range_info_path: str):
@@ -324,7 +333,7 @@ class PaddlePredictorOption(object):
 
     @property
     def trt_allow_rebuild_at_runtime(self):
-        return self._cfg["trt_allow_rebuild_at_runtime"]
+        return self._cfg.get("trt_allow_rebuild_at_runtime")
 
     @trt_allow_rebuild_at_runtime.setter
     def trt_allow_rebuild_at_runtime(self, trt_allow_rebuild_at_runtime):
@@ -332,7 +341,7 @@ class PaddlePredictorOption(object):
 
     @property
     def mkldnn_cache_capacity(self):
-        return self._cfg["mkldnn_cache_capacity"]
+        return self._cfg.get("mkldnn_cache_capacity")
 
     @mkldnn_cache_capacity.setter
     def mkldnn_cache_capacity(self, capacity: int):