|
|
@@ -34,8 +34,8 @@ class PaddlePredictorOption(object):
|
|
|
super().__init__()
|
|
|
self.model_name = model_name
|
|
|
self._cfg = {}
|
|
|
- self._init_option(**kwargs)
|
|
|
self._observers = []
|
|
|
+ self._init_option(**kwargs)
|
|
|
|
|
|
def _init_option(self, **kwargs):
|
|
|
for k, v in kwargs.items():
|
|
|
@@ -62,6 +62,7 @@ class PaddlePredictorOption(object):
|
|
|
"trt_use_static": False,
|
|
|
"delete_pass": [],
|
|
|
"enable_new_ir": True if self.model_name not in NEWIR_BLOCKLIST else False,
|
|
|
+ "batch_size": 1, # only for trt
|
|
|
}
|
|
|
|
|
|
def _update(self, k, v):
|
|
|
@@ -179,6 +180,14 @@ class PaddlePredictorOption(object):
|
|
|
"""set run mode"""
|
|
|
self._update("enable_new_ir", enable_new_ir)
|
|
|
|
|
|
+ @property
|
|
|
+ def batch_size(self):
|
|
|
+ return self._cfg["batch_size"]
|
|
|
+
|
|
|
+ @batch_size.setter
|
|
|
+ def batch_size(self, batch_size):
|
|
|
+ self._update("batch_size", batch_size)
|
|
|
+
|
|
|
def get_support_run_mode(self):
|
|
|
"""get supported run mode"""
|
|
|
return self.SUPPORT_RUN_MODE
|