Răsfoiți Sursa

support paddle_fp32 and paddle_fp16 mode

paddle_fp32 is completely equivalent with paddle
gaotingquan 8 luni în urmă
părinte
comite
774737c06b

+ 55 - 47
paddlex/inference/models/common/static_infer.py

@@ -385,61 +385,69 @@ class StaticInfer(object):
                 params_file,
                 cache_dir,
             )
-        else:
-            config = lazy_paddle.inference.Config(str(model_file), str(params_file))
-
-        if self._option.device_type == "gpu":
             config.exp_disable_mixed_precision_ops({"feed", "fetch"})
             config.enable_use_gpu(100, self._option.device_id)
-            if not self._option.run_mode.startswith("trt"):
+        # for Native Paddle and MKLDNN
+        else:
+            config = lazy_paddle.inference.Config(str(model_file), str(params_file))
+            if self._option.device_type == "gpu":
+                config.exp_disable_mixed_precision_ops({"feed", "fetch"})
+                from lazy_paddle.inference import PrecisionType
+
+                precision = (
+                    PrecisionType.Half
+                    if self._option.run_mode == "paddle_fp16"
+                    else PrecisionType.Float32
+                )
+                config.enable_use_gpu(100, self._option.device_id, precision)
                 if hasattr(config, "enable_new_ir"):
                     config.enable_new_ir(self._option.enable_new_ir)
                 if hasattr(config, "enable_new_executor"):
                     config.enable_new_executor()
                 config.set_optimization_level(3)
-        elif self._option.device_type == "npu":
-            config.enable_custom_device("npu")
-            if hasattr(config, "enable_new_executor"):
-                config.enable_new_executor()
-        elif self._option.device_type == "xpu":
-            if hasattr(config, "enable_new_executor"):
-                config.enable_new_executor()
-        elif self._option.device_type == "mlu":
-            config.enable_custom_device("mlu")
-            if hasattr(config, "enable_new_executor"):
-                config.enable_new_executor()
-        elif self._option.device_type == "dcu":
-            config.enable_use_gpu(100, self._option.device_id)
-            if hasattr(config, "enable_new_executor"):
-                config.enable_new_executor()
-            # XXX: is_compiled_with_rocm() must be True on dcu platform ?
-            if lazy_paddle.is_compiled_with_rocm():
-                # Delete unsupported passes in dcu
-                config.delete_pass("conv2d_add_act_fuse_pass")
-                config.delete_pass("conv2d_add_fuse_pass")
-        else:
-            assert self._option.device_type == "cpu"
-            config.disable_gpu()
-            if "mkldnn" in self._option.run_mode:
-                try:
-                    config.enable_mkldnn()
-                    if "bf16" in self._option.run_mode:
-                        config.enable_mkldnn_bfloat16()
-                except Exception as e:
-                    logging.warning(
-                        "MKL-DNN is not available. We will disable MKL-DNN."
-                    )
-                config.set_mkldnn_cache_capacity(-1)
+            elif self._option.device_type == "npu":
+                config.enable_custom_device("npu")
+                if hasattr(config, "enable_new_executor"):
+                    config.enable_new_executor()
+            elif self._option.device_type == "xpu":
+                if hasattr(config, "enable_new_executor"):
+                    config.enable_new_executor()
+            elif self._option.device_type == "mlu":
+                config.enable_custom_device("mlu")
+                if hasattr(config, "enable_new_executor"):
+                    config.enable_new_executor()
+            elif self._option.device_type == "dcu":
+                config.enable_use_gpu(100, self._option.device_id)
+                if hasattr(config, "enable_new_executor"):
+                    config.enable_new_executor()
+                # XXX: is_compiled_with_rocm() must be True on dcu platform ?
+                if lazy_paddle.is_compiled_with_rocm():
+                    # Delete unsupported passes in dcu
+                    config.delete_pass("conv2d_add_act_fuse_pass")
+                    config.delete_pass("conv2d_add_fuse_pass")
             else:
-                if hasattr(config, "disable_mkldnn"):
-                    config.disable_mkldnn()
-            config.set_cpu_math_library_num_threads(self._option.cpu_threads)
-
-            if hasattr(config, "enable_new_ir"):
-                config.enable_new_ir(self._option.enable_new_ir)
-            if hasattr(config, "enable_new_executor"):
-                config.enable_new_executor()
-            config.set_optimization_level(3)
+                assert self._option.device_type == "cpu"
+                config.disable_gpu()
+                if "mkldnn" in self._option.run_mode:
+                    try:
+                        config.enable_mkldnn()
+                        if "bf16" in self._option.run_mode:
+                            config.enable_mkldnn_bfloat16()
+                    except Exception as e:
+                        logging.warning(
+                            "MKL-DNN is not available. We will disable MKL-DNN."
+                        )
+                    config.set_mkldnn_cache_capacity(-1)
+                else:
+                    if hasattr(config, "disable_mkldnn"):
+                        config.disable_mkldnn()
+                config.set_cpu_math_library_num_threads(self._option.cpu_threads)
+
+                if hasattr(config, "enable_new_ir"):
+                    config.enable_new_ir(self._option.enable_new_ir)
+                if hasattr(config, "enable_new_executor"):
+                    config.enable_new_executor()
+                config.set_optimization_level(3)
 
         config.enable_memory_optim()
         for del_p in self._option.delete_pass:

+ 2 - 0
paddlex/inference/utils/pp_option.py

@@ -32,6 +32,8 @@ class PaddlePredictorOption(object):
     # NOTE: TRT modes start with `trt_`
     SUPPORT_RUN_MODE = (
         "paddle",
+        "paddle_fp32",
+        "paddle_fp16",
         "trt_fp32",
         "trt_fp16",
         "trt_int8",