Forráskód Böngészése

fix_export_device_bug (#2124)

zhangyubo0722 1 éve
szülő
commit
e34420a6bc

+ 1 - 0
paddlex/modules/base/exportor.py

@@ -117,4 +117,5 @@ exporting!"
         return {
             "weight_path": self.export_config.weight_path,
             "save_dir": self.global_config.output,
+            "device": self.get_device(),
         }

+ 9 - 3
paddlex/repo_apis/PaddleClas_api/cls/model.py

@@ -112,9 +112,13 @@ class ClsModel(BaseModel):
                         os.environ[env_name] = str(env_value)
             else:
                 config._update_amp(amp)
-
             # PDX related settings
-            config.update(["Global.uniform_output_enabled=True"])
+            device_type = device.split(":")[0]
+            if device_type in ["npu", "xpu", "mlu"]:
+                uniform_output_enabled = False
+            else:
+                uniform_output_enabled = True
+            config.update([f"Global.uniform_output_enabled={uniform_output_enabled}"])
             config.update([f"Global.pdx_model_name={self.name}"])
             hpi_config_path = self.model_info.get("hpi_config_path", None)
             config.update([f"Global.hpi_config_path={hpi_config_path}"])
@@ -227,7 +231,9 @@ class ClsModel(BaseModel):
             config = self.config.copy()
             config.update_pretrained_weights(weight_path)
             config._update_save_inference_dir(save_dir)
-
+            device = kwargs.pop("device", None)
+            if device:
+                config.update_device(device)
             # PDX related settings
             config.update([f"Global.pdx_model_name={self.name}"])
             hpi_config_path = self.model_info.get("hpi_config_path", None)

+ 9 - 0
paddlex/repo_apis/PaddleDetection_api/instance_seg/config.py

@@ -261,6 +261,15 @@ class InstanceSegConfig(DetConfig):
         """
         if device_type.lower() == "gpu":
             self["use_gpu"] = True
+        elif device_type.lower() == "xpu":
+            self["use_xpu"] = True
+            self["use_gpu"] = False
+        elif device_type.lower() == "npu":
+            self["use_npu"] = True
+            self["use_gpu"] = False
+        elif device_type.lower() == "mlu":
+            self["use_mlu"] = True
+            self["use_gpu"] = False
         else:
             assert device_type.lower() == "cpu"
             self["use_gpu"] = False

+ 10 - 1
paddlex/repo_apis/PaddleDetection_api/instance_seg/model.py

@@ -126,7 +126,11 @@ class InstanceSegModel(BaseModel):
                 cli_args.append(CLIArgument("--enable_ce", enable_ce))
 
         # PDX related settings
-        config.update({"uniform_output_enabled": True})
+        if device_type in ["npu", "xpu", "mlu"]:
+            uniform_output_enabled = False
+        else:
+            uniform_output_enabled = True
+        config.update({"uniform_output_enabled": uniform_output_enabled})
         config.update({"pdx_model_name": self.name})
         hpi_config_path = self.model_info.get("hpi_config_path", None)
         if hpi_config_path:
@@ -255,6 +259,11 @@ class InstanceSegModel(BaseModel):
         config = self.config.copy()
         cli_args = []
 
+        device = kwargs.pop("device", None)
+        if device:
+            device_type, _ = parse_device(device)
+            config.update_device(device_type)
+
         if not weight_path.startswith("http"):
             weight_path = abspath(weight_path)
         config.update_weights(weight_path)

+ 10 - 1
paddlex/repo_apis/PaddleDetection_api/object_det/model.py

@@ -130,7 +130,11 @@ class DetModel(BaseModel):
                 cli_args.append(CLIArgument("--enable_ce", enable_ce))
 
         # PDX related settings
-        config.update({"uniform_output_enabled": True})
+        if device_type in ["npu", "xpu", "mlu"]:
+            uniform_output_enabled = False
+        else:
+            uniform_output_enabled = True
+        config.update({"uniform_output_enabled": uniform_output_enabled})
         config.update({"pdx_model_name": self.name})
         hpi_config_path = self.model_info.get("hpi_config_path", None)
         if hpi_config_path:
@@ -259,6 +263,11 @@ class DetModel(BaseModel):
         config = self.config.copy()
         cli_args = []
 
+        device = kwargs.pop("device", None)
+        if device:
+            device_type, _ = parse_device(device)
+            config.update_device(device_type)
+
         if not weight_path.startswith("http"):
             weight_path = abspath(weight_path)
         config.update_weights(weight_path)

+ 10 - 1
paddlex/repo_apis/PaddleOCR_api/text_rec/model.py

@@ -133,7 +133,12 @@ class TextRecModel(BaseModel):
                     os.environ[env_name] = str(env_value)
 
         # PDX related settings
-        config.update({"Global.uniform_output_enabled": True})
+        device_type = device.split(":")[0]
+        if device_type in ["npu", "xpu", "mlu"]:
+            uniform_output_enabled = False
+        else:
+            uniform_output_enabled = True
+        config.update({"Global.uniform_output_enabled": uniform_output_enabled})
         config.update({"Global.pdx_model_name": self.name})
         hpi_config_path = self.model_info.get("hpi_config_path", None)
         config.update({"Global.hpi_config_path": hpi_config_path})
@@ -252,6 +257,10 @@ class TextRecModel(BaseModel):
         """
         config = self.config.copy()
 
+        device = kwargs.pop("device", None)
+        if device:
+            config.update_device(device)
+
         if not weight_path.startswith("http"):
             weight_path = abspath(weight_path)
         config.update_pretrained_weights(weight_path)

+ 5 - 1
paddlex/repo_apis/PaddleSeg_api/seg/model.py

@@ -166,7 +166,11 @@ class SegModel(BaseModel):
                 cli_args.append(CLIArgument("--seed", seed))
 
         # PDX related settings
-        config.set_val("uniform_output_enabled", True)
+        if device_type in ["npu", "xpu", "mlu"]:
+            uniform_output_enabled = False
+        else:
+            uniform_output_enabled = True
+        config.set_val("uniform_output_enabled", uniform_output_enabled)
         config.set_val("pdx_model_name", self.name)
         hpi_config_path = self.model_info.get("hpi_config_path", None)
         if hpi_config_path:

+ 6 - 1
paddlex/repo_apis/PaddleTS_api/ts_base/model.py

@@ -109,7 +109,12 @@ class TSModel(BaseModel):
         else:
             if num_workers is not None:
                 cli_args.append(CLIArgument("--num_workers", num_workers))
-        config.update({"uniform_output_enabled": True})
+        # PDX related settings
+        if device_type in ["npu", "xpu", "mlu"]:
+            uniform_output_enabled = False
+        else:
+            uniform_output_enabled = True
+        config.update({"uniform_output_enabled": uniform_output_enabled})
         config.update({"pdx_model_name": self.name})
 
         self._assert_empty_kwargs(kwargs)