|
|
@@ -35,6 +35,7 @@ class PaddlePredictorOption(object):
|
|
|
self.model_name = model_name
|
|
|
self._cfg = {}
|
|
|
self._init_option(**kwargs)
|
|
|
+ self._observers = []
|
|
|
|
|
|
def _init_option(self, **kwargs):
|
|
|
for k, v in kwargs.items():
|
|
|
@@ -63,6 +64,10 @@ class PaddlePredictorOption(object):
|
|
|
"enable_new_ir": True if self.model_name not in NEWIR_BLOCKLIST else False,
|
|
|
}
|
|
|
|
|
|
+ def _update(self, k, v):
|
|
|
+ self._cfg[k] = v
|
|
|
+ self.notify()
|
|
|
+
|
|
|
@property
|
|
|
def run_mode(self):
|
|
|
return self._cfg["run_mode"]
|
|
|
@@ -75,7 +80,7 @@ class PaddlePredictorOption(object):
|
|
|
raise ValueError(
|
|
|
f"`run_mode` must be {support_run_mode_str}, but received {repr(run_mode)}."
|
|
|
)
|
|
|
- self._cfg["run_mode"] = run_mode
|
|
|
+ self._update("run_mode", run_mode)
|
|
|
|
|
|
@property
|
|
|
def device_type(self):
|
|
|
@@ -100,9 +105,9 @@ class PaddlePredictorOption(object):
|
|
|
raise ValueError(
|
|
|
f"The device type must be one of {support_run_mode_str}, but received {repr(device_type)}."
|
|
|
)
|
|
|
- self._cfg["device"] = device_type
|
|
|
+ self._update("device", device_type)
|
|
|
device_id = device_ids[0] if device_ids is not None else 0
|
|
|
- self._cfg["device_id"] = device_id
|
|
|
+ self._update("device_id", device_id)
|
|
|
set_env_for_device(device)
|
|
|
if device_type not in ("cpu"):
|
|
|
if device_ids is None or len(device_ids) > 1:
|
|
|
@@ -117,7 +122,7 @@ class PaddlePredictorOption(object):
|
|
|
"""set min subgraph size"""
|
|
|
if not isinstance(min_subgraph_size, int):
|
|
|
raise Exception()
|
|
|
- self._cfg["min_subgraph_size"] = min_subgraph_size
|
|
|
+ self._update("min_subgraph_size", min_subgraph_size)
|
|
|
|
|
|
@property
|
|
|
def shape_info_filename(self):
|
|
|
@@ -126,7 +131,7 @@ class PaddlePredictorOption(object):
|
|
|
@shape_info_filename.setter
|
|
|
def shape_info_filename(self, shape_info_filename: str):
|
|
|
"""set shape info filename"""
|
|
|
- self._cfg["shape_info_filename"] = shape_info_filename
|
|
|
+ self._update("shape_info_filename", shape_info_filename)
|
|
|
|
|
|
@property
|
|
|
def trt_calib_mode(self):
|
|
|
@@ -135,7 +140,7 @@ class PaddlePredictorOption(object):
|
|
|
@trt_calib_mode.setter
|
|
|
def trt_calib_mode(self, trt_calib_mode):
|
|
|
"""set trt calib mode"""
|
|
|
- self._cfg["trt_calib_mode"] = trt_calib_mode
|
|
|
+ self._update("trt_calib_mode", trt_calib_mode)
|
|
|
|
|
|
@property
|
|
|
def cpu_threads(self):
|
|
|
@@ -146,7 +151,7 @@ class PaddlePredictorOption(object):
|
|
|
"""set cpu threads"""
|
|
|
if not isinstance(cpu_threads, int) or cpu_threads < 1:
|
|
|
raise Exception()
|
|
|
- self._cfg["cpu_threads"] = cpu_threads
|
|
|
+ self._update("cpu_threads", cpu_threads)
|
|
|
|
|
|
@property
|
|
|
def trt_use_static(self):
|
|
|
@@ -155,7 +160,7 @@ class PaddlePredictorOption(object):
|
|
|
@trt_use_static.setter
|
|
|
def trt_use_static(self, trt_use_static):
|
|
|
"""set trt use static"""
|
|
|
- self._cfg["trt_use_static"] = trt_use_static
|
|
|
+ self._update("trt_use_static", trt_use_static)
|
|
|
|
|
|
@property
|
|
|
def delete_pass(self):
|
|
|
@@ -163,7 +168,7 @@ class PaddlePredictorOption(object):
|
|
|
|
|
|
@delete_pass.setter
|
|
|
def delete_pass(self, delete_pass):
|
|
|
- self._cfg["delete_pass"] = delete_pass
|
|
|
+ self._update("delete_pass", delete_pass)
|
|
|
|
|
|
@property
|
|
|
def enable_new_ir(self):
|
|
|
@@ -172,7 +177,7 @@ class PaddlePredictorOption(object):
|
|
|
@enable_new_ir.setter
|
|
|
def enable_new_ir(self, enable_new_ir: bool):
|
|
|
"""set run mode"""
|
|
|
- self._cfg["enable_new_ir"] = enable_new_ir
|
|
|
+ self._update("enable_new_ir", enable_new_ir)
|
|
|
|
|
|
def get_support_run_mode(self):
|
|
|
"""get supported run mode"""
|
|
|
@@ -205,3 +210,17 @@ class PaddlePredictorOption(object):
|
|
|
for name, prop in vars(self.__class__).items()
|
|
|
if isinstance(prop, property) and prop.fset is not None
|
|
|
]
|
|
|
+
|
|
|
+ def attach(self, observer):
|
|
|
+ if observer not in self._observers:
|
|
|
+ self._observers.append(observer)
|
|
|
+
|
|
|
+ def detach(self, observer):
|
|
|
+ try:
|
|
|
+ self._observers.remove(observer)
|
|
|
+ except ValueError:
|
|
|
+ pass
|
|
|
+
|
|
|
+ def notify(self):
|
|
|
+ for observer in self._observers:
|
|
|
+ observer.reset()
|