Selaa lähdekoodia

get FLAGS_json_format_model from the flag module

gaotingquan 7 kuukautta sitten
vanhempi
commit
f660ca3b18

+ 4 - 4
paddlex/modules/base/exportor.py

@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import os
 from abc import ABC
 from pathlib import Path
 
@@ -23,6 +22,7 @@ from ...utils.device import (
     set_env_for_device,
     update_device_num,
 )
+from ...utils.flags import FLAGS_json_format_model
 from ...utils.misc import AutoRegisterABCMetaClass
 from .build_model import build_model
 
@@ -134,9 +134,9 @@ exporting!"
 
     def get_export_kwargs(self):
         """get key-value arguments of model export function"""
-        export_with_pir = self.global_config.get("export_with_pir", False) or os.getenv(
-            "FLAGS_json_format_model"
-        ) in ["1", "True"]
+        export_with_pir = (
+            self.global_config.get("export_with_pir", False) or FLAGS_json_format_model
+        )
         return {
             "weight_path": self.export_config.weight_path,
             "save_dir": self.global_config.output,

+ 4 - 4
paddlex/modules/base/trainer.py

@@ -21,7 +21,7 @@ from ...utils.device import (
     set_env_for_device,
     update_device_num,
 )
-from ...utils.flags import DISABLE_CINN_MODEL_WL
+from ...utils.flags import FLAGS_json_format_model, DISABLE_CINN_MODEL_WL
 from ...utils.misc import AutoRegisterABCMetaClass
 from .build_model import build_model
 from .utils.cinn_setting import CINN_WHITELIST, enable_cinn_backend
@@ -75,9 +75,9 @@ class BaseTrainer(ABC, metaclass=AutoRegisterABCMetaClass):
         train_args = self.get_train_kwargs()
         if self.benchmark_config is not None:
             train_args.update({"benchmark": self.benchmark_config})
-        export_with_pir = self.global_config.get("export_with_pir", False) or os.getenv(
-            "FLAGS_json_format_model"
-        ) in ["1", "True"]
+        export_with_pir = (
+            self.global_config.get("export_with_pir", False) or FLAGS_json_format_model
+        )
         train_args.update(
             {
                 "uniform_output_enabled": self.train_config.get(

+ 4 - 3
paddlex/modules/ts_anomaly_detection/trainer.py

@@ -16,6 +16,7 @@ import os
 import tarfile
 from pathlib import Path
 
+from ...utils.flags import FLAGS_json_format_model
 from ..base import BaseTrainer
 from .model_list import MODELS
 
@@ -32,9 +33,9 @@ class TSADTrainer(BaseTrainer):
         self.update_config()
         self.dump_config()
         train_args = self.get_train_kwargs()
-        export_with_pir = self.global_config.get("export_with_pir", False) or os.getenv(
-            "FLAGS_json_format_model"
-        ) in ["1", "True"]
+        export_with_pir = (
+            self.global_config.get("export_with_pir", False) or FLAGS_json_format_model
+        )
         train_args.update(
             {
                 "uniform_output_enabled": self.train_config.get(

+ 4 - 3
paddlex/modules/ts_classification/trainer.py

@@ -16,6 +16,7 @@ import os
 import tarfile
 from pathlib import Path
 
+from ...utils.flags import FLAGS_json_format_model
 from ..base import BaseTrainer
 from .model_list import MODELS
 
@@ -32,9 +33,9 @@ class TSCLSTrainer(BaseTrainer):
         self.update_config()
         self.dump_config()
         train_args = self.get_train_kwargs()
-        export_with_pir = self.global_config.get("export_with_pir", False) or os.getenv(
-            "FLAGS_json_format_model"
-        ) in ["1", "True"]
+        export_with_pir = (
+            self.global_config.get("export_with_pir", False) or FLAGS_json_format_model
+        )
         train_args.update(
             {
                 "uniform_output_enabled": self.train_config.get(

+ 4 - 3
paddlex/modules/ts_forecast/trainer.py

@@ -16,6 +16,7 @@ import os
 import tarfile
 from pathlib import Path
 
+from ...utils.flags import FLAGS_json_format_model
 from ..base import BaseTrainer
 from .model_list import MODELS
 
@@ -32,9 +33,9 @@ class TSFCTrainer(BaseTrainer):
         self.update_config()
         self.dump_config()
         train_args = self.get_train_kwargs()
-        export_with_pir = self.global_config.get("export_with_pir", False) or os.getenv(
-            "FLAGS_json_format_model"
-        ) in ["1", "True"]
+        export_with_pir = (
+            self.global_config.get("export_with_pir", False) or FLAGS_json_format_model
+        )
         train_args.update(
             {
                 "uniform_output_enabled": self.train_config.get(