Pārlūkot izejas kodu

support export with pir (#2716)

zhangyubo0722 10 mēneši atpakaļ
vecāks
revīzija
54a8203cfc

+ 4 - 0
paddlex/modules/base/exportor.py

@@ -122,8 +122,12 @@ exporting!"
 
     def get_export_kwargs(self):
         """get key-value arguments of model export function"""
+        export_with_pir = self.global_config.get("export_with_pir", False) or os.getenv(
+            "FLAGS_json_format_model"
+        ) in ["1", "True"]
         return {
             "weight_path": self.export_config.weight_path,
             "save_dir": self.global_config.output,
             "device": self.get_device(1),
+            "export_with_pir": export_with_pir,
         }

+ 5 - 1
paddlex/modules/base/trainer.py

@@ -71,11 +71,15 @@ class BaseTrainer(ABC, metaclass=AutoRegisterABCMetaClass):
         train_args = self.get_train_kwargs()
         if self.benchmark_config is not None:
             train_args.update({"benchmark": self.benchmark_config})
+        export_with_pir = self.global_config.get("export_with_pir", False) or os.getenv(
+            "FLAGS_json_format_model"
+        ) in ["1", "True"]
         train_args.update(
             {
                 "uniform_output_enabled": self.train_config.get(
                     "uniform_output_enabled", True
-                )
+                ),
+                "export_with_pir": export_with_pir,
             }
         )
         train_result = self.pdx_model.train(**train_args)

+ 11 - 0
paddlex/modules/ts_anomaly_detection/trainer.py

@@ -35,6 +35,17 @@ class TSADTrainer(BaseTrainer):
         self.update_config()
         self.dump_config()
         train_args = self.get_train_kwargs()
+        export_with_pir = self.global_config.get("export_with_pir", False) or os.getenv(
+            "FLAGS_json_format_model"
+        ) in ["1", "True"]
+        train_args.update(
+            {
+                "uniform_output_enabled": self.train_config.get(
+                    "uniform_output_enabled", True
+                ),
+                "export_with_pir": export_with_pir,
+            }
+        )
         if self.benchmark_config is not None:
             train_args.update({"benchmark": self.benchmark_config})
         train_result = self.pdx_model.train(**train_args)

+ 11 - 0
paddlex/modules/ts_classification/trainer.py

@@ -35,6 +35,17 @@ class TSCLSTrainer(BaseTrainer):
         self.update_config()
         self.dump_config()
         train_args = self.get_train_kwargs()
+        export_with_pir = self.global_config.get("export_with_pir", False) or os.getenv(
+            "FLAGS_json_format_model"
+        ) in ["1", "True"]
+        train_args.update(
+            {
+                "uniform_output_enabled": self.train_config.get(
+                    "uniform_output_enabled", True
+                ),
+                "export_with_pir": export_with_pir,
+            }
+        )
         if self.benchmark_config is not None:
             train_args.update({"benchmark": self.benchmark_config})
         train_result = self.pdx_model.train(**train_args)

+ 11 - 0
paddlex/modules/ts_forecast/trainer.py

@@ -35,6 +35,17 @@ class TSFCTrainer(BaseTrainer):
         self.update_config()
         self.dump_config()
         train_args = self.get_train_kwargs()
+        export_with_pir = self.global_config.get("export_with_pir", False) or os.getenv(
+            "FLAGS_json_format_model"
+        ) in ["1", "True"]
+        train_args.update(
+            {
+                "uniform_output_enabled": self.train_config.get(
+                    "uniform_output_enabled", True
+                ),
+                "export_with_pir": export_with_pir,
+            }
+        )
         if self.benchmark_config is not None:
             train_args.update({"benchmark": self.benchmark_config})
         train_result = self.pdx_model.train(**train_args)

+ 6 - 0
paddlex/repo_apis/PaddleClas_api/cls/model.py

@@ -115,8 +115,11 @@ class ClsModel(BaseModel):
             # PDX related settings
             device_type = device.split(":")[0]
             uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
+            export_with_pir = kwargs.pop("export_with_pir", False)
             config.update([f"Global.uniform_output_enabled={uniform_output_enabled}"])
             config.update([f"Global.pdx_model_name={self.name}"])
+            if export_with_pir:
+                config.update([f"Global.export_with_pir={export_with_pir}"])
 
             config.dump(config_path)
             self._assert_empty_kwargs(kwargs)
@@ -231,8 +234,11 @@ class ClsModel(BaseModel):
                 config.update_device(device)
             # PDX related settings
             uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
+            export_with_pir = kwargs.pop("export_with_pir", False)
             config.update([f"Global.uniform_output_enabled={uniform_output_enabled}"])
             config.update([f"Global.pdx_model_name={self.name}"])
+            if export_with_pir:
+                config.update([f"Global.export_with_pir={export_with_pir}"])
 
             config.dump(config_path)
 

+ 6 - 0
paddlex/repo_apis/PaddleDetection_api/instance_seg/model.py

@@ -127,8 +127,11 @@ class InstanceSegModel(BaseModel):
 
         # PDX related settings
         uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
+        export_with_pir = kwargs.pop("export_with_pir", False)
         config.update({"uniform_output_enabled": uniform_output_enabled})
         config.update({"pdx_model_name": self.name})
+        if export_with_pir:
+            config.update({"export_with_pir": export_with_pir})
 
         self._assert_empty_kwargs(kwargs)
 
@@ -278,8 +281,11 @@ class InstanceSegModel(BaseModel):
 
         # PDX related settings
         uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
+        export_with_pir = kwargs.pop("export_with_pir", False)
         config.update({"uniform_output_enabled": uniform_output_enabled})
         config.update({"pdx_model_name": self.name})
+        if export_with_pir:
+            config.update({"export_with_pir": export_with_pir})
 
         self._assert_empty_kwargs(kwargs)
 

+ 6 - 0
paddlex/repo_apis/PaddleDetection_api/object_det/model.py

@@ -131,8 +131,11 @@ class DetModel(BaseModel):
 
         # PDX related settings
         uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
+        export_with_pir = kwargs.pop("export_with_pir", False)
         config.update({"uniform_output_enabled": uniform_output_enabled})
         config.update({"pdx_model_name": self.name})
+        if export_with_pir:
+            config.update({"export_with_pir": export_with_pir})
 
         self._assert_empty_kwargs(kwargs)
 
@@ -282,8 +285,11 @@ class DetModel(BaseModel):
 
         # PDX related settings
         uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
+        export_with_pir = kwargs.pop("export_with_pir", False)
         config.update({"uniform_output_enabled": uniform_output_enabled})
         config.update({"pdx_model_name": self.name})
+        if export_with_pir:
+            config.update({"export_with_pir": export_with_pir})
 
         if self.name in official_categories.keys():
             anno_val_file = abspath(

+ 6 - 0
paddlex/repo_apis/PaddleOCR_api/text_rec/model.py

@@ -135,8 +135,11 @@ class TextRecModel(BaseModel):
         # PDX related settings
         device_type = device.split(":")[0]
         uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
+        export_with_pir = kwargs.pop("export_with_pir", False)
         config.update({"Global.uniform_output_enabled": uniform_output_enabled})
         config.update({"Global.pdx_model_name": self.name})
+        if export_with_pir:
+            config.update({"Global.export_with_pir": export_with_pir})
 
         self._assert_empty_kwargs(kwargs)
 
@@ -269,8 +272,11 @@ class TextRecModel(BaseModel):
 
         # PDX related settings
         uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
+        export_with_pir = kwargs.pop("export_with_pir", False)
         config.update({"Global.uniform_output_enabled": uniform_output_enabled})
         config.update({"Global.pdx_model_name": self.name})
+        if export_with_pir:
+            config.update({"Global.export_with_pir": export_with_pir})
 
         self._assert_empty_kwargs(kwargs)
 

+ 6 - 0
paddlex/repo_apis/PaddleSeg_api/seg/model.py

@@ -167,8 +167,11 @@ class SegModel(BaseModel):
 
         # PDX related settings
         uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
+        export_with_pir = kwargs.pop("export_with_pir", False)
         config.set_val("uniform_output_enabled", uniform_output_enabled)
         config.set_val("pdx_model_name", self.name)
+        if export_with_pir:
+            config.set_val("export_with_pir", export_with_pir)
 
         self._assert_empty_kwargs(kwargs)
 
@@ -352,8 +355,11 @@ class SegModel(BaseModel):
 
         # PDX related settings
         uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
+        export_with_pir = kwargs.pop("export_with_pir", False)
         config.set_val("uniform_output_enabled", uniform_output_enabled)
         config.set_val("pdx_model_name", self.name)
+        if export_with_pir:
+            config.set_val("export_with_pir", export_with_pir)
 
         self._assert_empty_kwargs(kwargs)
 

+ 6 - 1
paddlex/repo_apis/PaddleTS_api/ts_base/model.py

@@ -110,8 +110,11 @@ class TSModel(BaseModel):
                 cli_args.append(CLIArgument("--num_workers", num_workers))
         # PDX related settings
         uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
+        export_with_pir = kwargs.pop("export_with_pir", False)
         config.update({"uniform_output_enabled": uniform_output_enabled})
         config.update({"pdx_model_name": self.name})
+        if export_with_pir:
+            config.update({"export_with_pir": export_with_pir})
 
         self._assert_empty_kwargs(kwargs)
 
@@ -234,13 +237,15 @@ class TSModel(BaseModel):
         if device is not None:
             device_type, _ = parse_device(device)
             cli_args.append(CLIArgument("--device", device_type))
-
+        export_with_pir = kwargs.pop("export_with_pir", False)
         self._assert_empty_kwargs(kwargs)
         with self._create_new_config_file() as config_path:
             # Update YAML config file
             config = self.config.copy()
             config.update_pretrained_weights(weight_path)
             config.update({"pdx_model_name": self.name})
+            if export_with_pir:
+                config.update({"export_with_pir": export_with_pir})
             config.dump(config_path)
 
             return self.runner.export(config_path, cli_args, device)