|
|
@@ -12,9 +12,9 @@
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
|
|
|
-from ...utils.device import parse_device
|
|
|
-from ....utils.func_register import FuncRegister
|
|
|
-from ....utils import logging
|
|
|
+from .device import parse_device
|
|
|
+from ...utils.func_register import FuncRegister
|
|
|
+from ...utils import logging
|
|
|
|
|
|
|
|
|
class PaddlePredictorOption(object):
|
|
|
@@ -53,7 +53,6 @@ class PaddlePredictorOption(object):
|
|
|
"""get default config"""
|
|
|
return {
|
|
|
"run_mode": "paddle",
|
|
|
- "batch_size": 1,
|
|
|
"device": "gpu",
|
|
|
"device_id": 0,
|
|
|
"min_subgraph_size": 3,
|
|
|
@@ -62,6 +61,7 @@ class PaddlePredictorOption(object):
|
|
|
"cpu_threads": 1,
|
|
|
"trt_use_static": False,
|
|
|
"delete_pass": [],
|
|
|
+ "enable_new_ir": True,
|
|
|
}
|
|
|
|
|
|
@register("run_mode")
|
|
|
@@ -74,13 +74,6 @@ class PaddlePredictorOption(object):
|
|
|
)
|
|
|
self._cfg["run_mode"] = run_mode
|
|
|
|
|
|
- @register("batch_size")
|
|
|
- def set_batch_size(self, batch_size: int):
|
|
|
- """set batch size"""
|
|
|
- if not isinstance(batch_size, int) or batch_size < 1:
|
|
|
- raise Exception()
|
|
|
- self._cfg["batch_size"] = batch_size
|
|
|
-
|
|
|
@register("device")
|
|
|
def set_device(self, device: str):
|
|
|
"""set device"""
|
|
|
@@ -128,6 +121,11 @@ class PaddlePredictorOption(object):
|
|
|
def set_delete_pass(self, delete_pass):
|
|
|
self._cfg["delete_pass"] = delete_pass
|
|
|
|
|
|
+ @register("enable_new_ir")
|
|
|
+ def set_enable_new_ir(self, enable_new_ir: bool):
|
|
|
+ """set run mode"""
|
|
|
+ self._cfg["enable_new_ir"] = enable_new_ir
|
|
|
+
|
|
|
def get_support_run_mode(self):
|
|
|
"""get supported run mode"""
|
|
|
return self.SUPPORT_RUN_MODE
|