|
|
@@ -120,9 +120,14 @@ exporting!"
|
|
|
"""
|
|
|
check_supported_device(self.global_config.device, self.global_config.model)
|
|
|
set_env_for_device(self.global_config.device)
|
|
|
- if using_device_number:
|
|
|
- return update_device_num(self.global_config.device, using_device_number)
|
|
|
- return self.global_config.device
|
|
|
+ device_setting = (
|
|
|
+ update_device_num(self.global_config.device, using_device_number)
|
|
|
+ if using_device_number
|
|
|
+ else self.global_config.device
|
|
|
+ )
|
|
|
+ # replace "dcu" with "gpu"
|
|
|
+ device_setting = device_setting.replace("dcu", "gpu")
|
|
|
+ return device_setting
|
|
|
|
|
|
def update_config(self):
|
|
|
"""update export config"""
|