|
|
@@ -110,8 +110,11 @@ class TSModel(BaseModel):
|
|
|
cli_args.append(CLIArgument("--num_workers", num_workers))
|
|
|
# PDX related settings
|
|
|
uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
|
|
|
+ export_with_pir = kwargs.pop("export_with_pir", False)
|
|
|
config.update({"uniform_output_enabled": uniform_output_enabled})
|
|
|
config.update({"pdx_model_name": self.name})
|
|
|
+ if export_with_pir:
|
|
|
+ config.update({"export_with_pir": export_with_pir})
|
|
|
|
|
|
self._assert_empty_kwargs(kwargs)
|
|
|
|
|
|
@@ -234,13 +237,15 @@ class TSModel(BaseModel):
|
|
|
if device is not None:
|
|
|
device_type, _ = parse_device(device)
|
|
|
cli_args.append(CLIArgument("--device", device_type))
|
|
|
-
|
|
|
+ export_with_pir = kwargs.pop("export_with_pir", False)
|
|
|
self._assert_empty_kwargs(kwargs)
|
|
|
with self._create_new_config_file() as config_path:
|
|
|
# Update YAML config file
|
|
|
config = self.config.copy()
|
|
|
config.update_pretrained_weights(weight_path)
|
|
|
config.update({"pdx_model_name": self.name})
|
|
|
+ if export_with_pir:
|
|
|
+ config.update({"export_with_pir": export_with_pir})
|
|
|
config.dump(config_path)
|
|
|
|
|
|
return self.runner.export(config_path, cli_args, device)
|