Sfoglia il codice sorgente

support to enable CINN to infer (#3835)

* support enable_new_ir when using pdmodel on mlu or npu device

* when new ir is enabled, support enable_cinn to infer on gpu device
Tingquan Gao 7 mesi fa
parent
commit
c5b8ee46a3

+ 8 - 1
paddlex/inference/models/common/static_infer.py

@@ -404,11 +404,15 @@ class PaddleInfer(StaticInfer):
                 config.enable_use_gpu(100, self._option.device_id, precision)
                 if hasattr(config, "enable_new_ir"):
                     config.enable_new_ir(self._option.enable_new_ir)
+                    if self._option.enable_new_ir and self._option.enable_cinn:
+                        config.enable_cinn()
                 if hasattr(config, "enable_new_executor"):
                     config.enable_new_executor()
                 config.set_optimization_level(3)
             elif self._option.device_type == "npu":
                 config.enable_custom_device("npu", self._option.device_id)
+                if hasattr(config, "enable_new_ir"):
+                    config.enable_new_ir(self._option.enable_new_ir)
                 if hasattr(config, "enable_new_executor"):
                     config.enable_new_executor()
             elif self._option.device_type == "xpu":
@@ -418,6 +422,8 @@ class PaddleInfer(StaticInfer):
                     config.enable_new_executor()
             elif self._option.device_type == "mlu":
                 config.enable_custom_device("mlu")
+                if hasattr(config, "enable_new_ir"):
+                    config.enable_new_ir(self._option.enable_new_ir)
                 if hasattr(config, "enable_new_executor"):
                     config.enable_new_executor()
             elif self._option.device_type == "gcu":
@@ -425,8 +431,9 @@ class PaddleInfer(StaticInfer):
 
                 gcu_passes.setUp()
                 config.enable_custom_device("gcu")
-                if hasattr(config, "enable_new_executor"):
+                if hasattr(config, "enable_new_ir"):
                     config.enable_new_ir()
+                if hasattr(config, "enable_new_executor"):
                     config.enable_new_executor()
                 else:
                     pass_builder = config.pass_builder()

+ 10 - 0
paddlex/inference/utils/pp_option.py

@@ -97,6 +97,7 @@ class PaddlePredictorOption(object):
             "cpu_threads": 8,
             "delete_pass": [],
             "enable_new_ir": True if self.model_name not in NEWIR_BLOCKLIST else False,
+            "enable_cinn": False,
             "trt_cfg_setting": {},
             "trt_use_dynamic_shapes": True,  # only for trt
             "trt_collect_shape_range_info": True,  # only for trt
@@ -188,6 +189,15 @@ class PaddlePredictorOption(object):
         self._update("enable_new_ir", enable_new_ir)
 
     @property
+    def enable_cinn(self):
+        return self._cfg["enable_cinn"]
+
+    @enable_cinn.setter
+    def enable_cinn(self, enable_cinn: bool):
+        """set run mode"""
+        self._update("enable_cinn", enable_cinn)
+
+    @property
     def trt_cfg_setting(self):
         return self._cfg["trt_cfg_setting"]