Преглед изворни кода

[Fix] Allow bypassing model white list checks when performing internal tests (#3499)

* Fix device check

* Fix LaTeXOCR bug
Lin Manhui пре 8 месеци
родитељ
комит
bea5ad68f9

+ 1 - 1
docs/other_devices_support/how_to_contribute_device.en.md

@@ -34,7 +34,7 @@ If the relevant device has specific requirements for the PaddlePaddle version, y
 
 ### 2.1.2 Setting Environment Variables (Optional)
 
-If special environment variables need to be set when using the relevant device, you can modify the device environment setup code. The relevant code is located in the `set_env_for_device` function in [PaddleX Environment Variable Settings](https://github.com/PaddlePaddle/PaddleX/blob/develop/paddlex/utils/device.py).
+If special environment variables need to be set when using the relevant device, you can modify the device environment setup code. The relevant code is located in the `set_env_for_device_type` function in [PaddleX Environment Variable Settings](https://github.com/PaddlePaddle/PaddleX/blob/develop/paddlex/utils/device.py).
 
 ### 2.1.3 Creating a Predictor
 

+ 1 - 1
docs/other_devices_support/how_to_contribute_device.md

@@ -34,7 +34,7 @@
 
 ### 2.1.2 设置环境变量(可忽略)
 
-如果相关硬件在使用时,需要设定特殊的环境变量,可以修改设备环境设置代码,相关代码位于 [PaddleX环境变量设置](https://github.com/PaddlePaddle/PaddleX/blob/develop/paddlex/utils/device.py)中的 `set_env_for_device`
+如果相关硬件在使用时,需要设定特殊的环境变量,可以修改设备环境设置代码,相关代码位于 [PaddleX环境变量设置](https://github.com/PaddlePaddle/PaddleX/blob/develop/paddlex/utils/device.py)中的 `set_env_for_device_type`
 
 ### 2.1.3 创建Predictor
 

+ 8 - 1
paddlex/inference/models/common/static_infer.py

@@ -310,7 +310,10 @@ class StaticInfer(object):
             raise RuntimeError("No valid Paddle model found")
         model_file, params_file = model_paths["paddle"]
 
-        if self._option.model_name == "LaTeX_OCR_rec":
+        if (
+            self._option.model_name == "LaTeX_OCR_rec"
+            and self._option.device_type == "cpu"
+        ):
             import cpuinfo
 
             if (
@@ -323,6 +326,10 @@ class StaticInfer(object):
             self._option.run_mode = "mkldnn"
             logging.debug("`run_mode` updated to 'mkldnn'")
 
+        if self._option.device_type == "cpu" and self._option.device_id is not None:
+            self._option.device_id = None
+            logging.debug("`device_id` has been set to None")
+
         if (
             self._option.device_type in ("gpu", "dcu")
             and self._option.device_id is None

+ 0 - 9
paddlex/inference/models/formula_recognition/predictor.py

@@ -66,15 +66,6 @@ class FormulaRecPredictor(BasicPredictor):
                 pre_tfs[name] = op
         pre_tfs["ToBatch"] = ToBatch()
 
-        if self.model_name in ("LaTeX_OCR_rec") and self.pp_option.device in ("cpu"):
-            import cpuinfo
-
-            if "GenuineIntel" in cpuinfo.get_cpu_info().get("vendor_id_raw", ""):
-                self.pp_option.run_mode = "mkldnn"
-                logging.warning(
-                    "Now, the `LaTeX_OCR_rec` model only support `mkldnn` mode when running on Intel CPU devices. So using `mkldnn` instead."
-                )
-
         infer = StaticInfer(
             model_dir=self.model_dir,
             model_prefix=self.MODEL_FILE_PREFIX,

+ 12 - 13
paddlex/inference/utils/pp_option.py

@@ -20,7 +20,7 @@ from ...utils.device import (
     check_supported_device_type,
     get_default_device,
     parse_device,
-    set_env_for_device,
+    set_env_for_device_type,
 )
 from .new_ir_blacklist import NEWIR_BLOCKLIST
 from .trt_blacklist import TRT_BLOCKLIST
@@ -122,8 +122,17 @@ class PaddlePredictorOption(object):
 
     @device_type.setter
     def device_type(self, device_type):
+        if device_type not in self.SUPPORT_DEVICE:
+            support_run_mode_str = ", ".join(self.SUPPORT_DEVICE)
+            raise ValueError(
+                f"The device type must be one of {support_run_mode_str}, but received {repr(device_type)}."
+            )
         check_supported_device_type(device_type, self.model_name)
         self._update("device_type", device_type)
+        set_env_for_device_type(device_type)
+        # XXX(gaotingquan): set flag to accelerate inference in paddle 3.0b2
+        if device_type in ("gpu", "cpu"):
+            os.environ["FLAGS_enable_pir_api"] = "1"
 
     @property
     def device_id(self):
@@ -309,21 +318,11 @@ class PaddlePredictorOption(object):
         if not device:
             return
         device_type, device_ids = parse_device(device)
-        if device_type not in self.SUPPORT_DEVICE:
-            support_run_mode_str = ", ".join(self.SUPPORT_DEVICE)
-            raise ValueError(
-                f"The device type must be one of {support_run_mode_str}, but received {repr(device_type)}."
-            )
         self.device_type = device_type
         device_id = device_ids[0] if device_ids is not None else None
         self.device_id = device_id
-        set_env_for_device(device)
-        if device_type not in ("cpu"):
-            if device_ids is None or len(device_ids) > 1:
-                logging.debug(f"The device ID has been set to {device_id}.")
-        # XXX(gaotingquan): set flag to accelerate inference in paddle 3.0b2
-        if device_type in ("gpu", "cpu"):
-            os.environ["FLAGS_enable_pir_api"] = "1"
+        if device_ids is None or len(device_ids) > 1:
+            logging.debug(f"The device ID has been set to {device_id}.")
 
     def get_support_run_mode(self):
         """get supported run mode"""

+ 15 - 7
paddlex/utils/device.py

@@ -31,6 +31,8 @@ SUPPORTED_DEVICE_TYPE = ["cpu", "gpu", "xpu", "npu", "mlu", "gcu", "dcu"]
 
 
 def constr_device(device_type, device_ids):
+    if device_type == "cpu" and device_ids is not None:
+        raise ValueError("`device_ids` must be None for CPUs")
     if device_ids:
         device_ids = ",".join(map(str, device_ids))
         return f"{device_type}:{device_ids}"
@@ -73,6 +75,8 @@ def parse_device(device):
     device_type = device_type.lower()
     # raise_unsupported_device_error(device_type, SUPPORTED_DEVICE_TYPE)
     assert device_type.lower() in SUPPORTED_DEVICE_TYPE
+    if device_type == "cpu" and device_ids is not None:
+        raise ValueError("No Device ID should be specified for CPUs")
     return device_type, device_ids
 
 
@@ -86,12 +90,16 @@ def update_device_num(device, num):
 
 
 def set_env_for_device(device):
+    device_type, _ = parse_device(device)
+    return set_env_for_device_type(device_type)
+
+
+def set_env_for_device_type(device_type):
     def _set(envs):
         for key, val in envs.items():
             os.environ[key] = val
             logging.debug(f"{key} has been set to {val}.")
 
-    device_type, device_ids = parse_device(device)
     # XXX: is_compiled_with_rocm() must be True on dcu platform ?
     if device_type.lower() == "dcu" and paddle.is_compiled_with_rocm():
         envs = {"FLAGS_conv_workspace_size_limit": "2000"}
@@ -122,17 +130,12 @@ def set_env_for_device(device):
         _set(envs)
 
 
-def check_supported_device(device, model_name):
+def check_supported_device_type(device_type, model_name):
     if DISABLE_DEV_MODEL_WL:
         logging.warning(
             "Skip checking if model is supported on device because the flag `PADDLE_PDX_DISABLE_DEV_MODEL_WL` has been set."
         )
         return
-    device_type, device_ids = parse_device(device)
-    return check_supported_device_type(device_type, model_name)
-
-
-def check_supported_device_type(device_type, model_name):
     if device_type == "dcu":
         assert (
             model_name in DCU_WHITELIST
@@ -153,3 +156,8 @@ def check_supported_device_type(device_type, model_name):
         assert (
             model_name in GCU_WHITELIST
         ), f"The GCU device does not yet support `{model_name}` model!"
+
+
+def check_supported_device(device, model_name):
+    device_type, _ = parse_device(device)
+    return check_supported_device_type(device_type, model_name)