|
|
@@ -112,9 +112,13 @@ class ClsModel(BaseModel):
|
|
|
os.environ[env_name] = str(env_value)
|
|
|
else:
|
|
|
config._update_amp(amp)
|
|
|
-
|
|
|
# PDX related settings
|
|
|
- config.update(["Global.uniform_output_enabled=True"])
|
|
|
+ device_type = device.split(":")[0]
|
|
|
+ if device_type in ["npu", "xpu", "mlu"]:
|
|
|
+ uniform_output_enabled = False
|
|
|
+ else:
|
|
|
+ uniform_output_enabled = True
|
|
|
+ config.update([f"Global.uniform_output_enabled={uniform_output_enabled}"])
|
|
|
config.update([f"Global.pdx_model_name={self.name}"])
|
|
|
hpi_config_path = self.model_info.get("hpi_config_path", None)
|
|
|
config.update([f"Global.hpi_config_path={hpi_config_path}"])
|
|
|
@@ -227,7 +231,9 @@ class ClsModel(BaseModel):
|
|
|
config = self.config.copy()
|
|
|
config.update_pretrained_weights(weight_path)
|
|
|
config._update_save_inference_dir(save_dir)
|
|
|
-
|
|
|
+ device = kwargs.pop("device", None)
|
|
|
+ if device:
|
|
|
+ config.update_device(device)
|
|
|
# PDX related settings
|
|
|
config.update([f"Global.pdx_model_name={self.name}"])
|
|
|
hpi_config_path = self.model_info.get("hpi_config_path", None)
|