|
|
@@ -12,27 +12,10 @@
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
|
|
|
-
|
|
|
-from functools import wraps, partial
|
|
|
-
|
|
|
+from ....utils.func_register import FuncRegister
|
|
|
from ....utils import logging
|
|
|
|
|
|
|
|
|
-def register(register_map, key):
|
|
|
- """register the option setting func"""
|
|
|
-
|
|
|
- def decorator(func):
|
|
|
- register_map[key] = func
|
|
|
-
|
|
|
- @wraps(func)
|
|
|
- def wrapper(self, *args, **kwargs):
|
|
|
- return func(self, *args, **kwargs)
|
|
|
-
|
|
|
- return wrapper
|
|
|
-
|
|
|
- return decorator
|
|
|
-
|
|
|
-
|
|
|
class PaddlePredictorOption(object):
|
|
|
"""Paddle Inference Engine Option"""
|
|
|
|
|
|
@@ -45,9 +28,9 @@ class PaddlePredictorOption(object):
|
|
|
"mkldnn_bf16",
|
|
|
)
|
|
|
SUPPORT_DEVICE = ("gpu", "cpu", "npu", "xpu", "mlu")
|
|
|
- _REGISTER_MAP = {}
|
|
|
|
|
|
- register2self = partial(register, _REGISTER_MAP)
|
|
|
+ _FUNC_MAP = {}
|
|
|
+ register = FuncRegister(_FUNC_MAP)
|
|
|
|
|
|
def __init__(self, **kwargs):
|
|
|
super().__init__()
|
|
|
@@ -80,7 +63,7 @@ class PaddlePredictorOption(object):
|
|
|
"delete_pass": [],
|
|
|
}
|
|
|
|
|
|
- @register2self("run_mode")
|
|
|
+ @register("run_mode")
|
|
|
def set_run_mode(self, run_mode: str):
|
|
|
"""set run mode"""
|
|
|
if run_mode not in self.SUPPORT_RUN_MODE:
|
|
|
@@ -90,14 +73,14 @@ class PaddlePredictorOption(object):
|
|
|
)
|
|
|
self._cfg["run_mode"] = run_mode
|
|
|
|
|
|
- @register2self("batch_size")
|
|
|
+ @register("batch_size")
|
|
|
def set_batch_size(self, batch_size: int):
|
|
|
"""set batch size"""
|
|
|
if not isinstance(batch_size, int) or batch_size < 1:
|
|
|
raise Exception()
|
|
|
self._cfg["batch_size"] = batch_size
|
|
|
|
|
|
- @register2self("device")
|
|
|
+ @register("device")
|
|
|
def set_device(self, device_setting: str):
|
|
|
"""set device"""
|
|
|
if len(device_setting.split(":")) == 1:
|
|
|
@@ -117,36 +100,36 @@ class PaddlePredictorOption(object):
|
|
|
self._cfg["device"] = device.lower()
|
|
|
self._cfg["device_id"] = int(device_id)
|
|
|
|
|
|
- @register2self("min_subgraph_size")
|
|
|
+ @register("min_subgraph_size")
|
|
|
def set_min_subgraph_size(self, min_subgraph_size: int):
|
|
|
"""set min subgraph size"""
|
|
|
if not isinstance(min_subgraph_size, int):
|
|
|
raise Exception()
|
|
|
self._cfg["min_subgraph_size"] = min_subgraph_size
|
|
|
|
|
|
- @register2self("shape_info_filename")
|
|
|
+ @register("shape_info_filename")
|
|
|
def set_shape_info_filename(self, shape_info_filename: str):
|
|
|
"""set shape info filename"""
|
|
|
self._cfg["shape_info_filename"] = shape_info_filename
|
|
|
|
|
|
- @register2self("trt_calib_mode")
|
|
|
+ @register("trt_calib_mode")
|
|
|
def set_trt_calib_mode(self, trt_calib_mode):
|
|
|
"""set trt calib mode"""
|
|
|
self._cfg["trt_calib_mode"] = trt_calib_mode
|
|
|
|
|
|
- @register2self("cpu_threads")
|
|
|
+ @register("cpu_threads")
|
|
|
def set_cpu_threads(self, cpu_threads):
|
|
|
"""set cpu threads"""
|
|
|
if not isinstance(cpu_threads, int) or cpu_threads < 1:
|
|
|
raise Exception()
|
|
|
self._cfg["cpu_threads"] = cpu_threads
|
|
|
|
|
|
- @register2self("trt_use_static")
|
|
|
+ @register("trt_use_static")
|
|
|
def set_trt_use_static(self, trt_use_static):
|
|
|
"""set trt use static"""
|
|
|
self._cfg["trt_use_static"] = trt_use_static
|
|
|
|
|
|
- @register2self("delete_pass")
|
|
|
+ @register("delete_pass")
|
|
|
def set_delete_pass(self, delete_pass):
|
|
|
self._cfg["delete_pass"] = delete_pass
|
|
|
|