|
|
@@ -21,7 +21,7 @@ from ...utils.device import (
|
|
|
set_env_for_device,
|
|
|
update_device_num,
|
|
|
)
|
|
|
-from ...utils.flags import DISABLE_CINN_MODEL_WL
|
|
|
+from ...utils.flags import FLAGS_json_format_model, DISABLE_CINN_MODEL_WL
|
|
|
from ...utils.misc import AutoRegisterABCMetaClass
|
|
|
from .build_model import build_model
|
|
|
from .utils.cinn_setting import CINN_WHITELIST, enable_cinn_backend
|
|
|
@@ -75,9 +75,9 @@ class BaseTrainer(ABC, metaclass=AutoRegisterABCMetaClass):
|
|
|
train_args = self.get_train_kwargs()
|
|
|
if self.benchmark_config is not None:
|
|
|
train_args.update({"benchmark": self.benchmark_config})
|
|
|
- export_with_pir = self.global_config.get("export_with_pir", False) or os.getenv(
|
|
|
- "FLAGS_json_format_model"
|
|
|
- ) in ["1", "True"]
|
|
|
+ export_with_pir = (
|
|
|
+ self.global_config.get("export_with_pir", False) or FLAGS_json_format_model
|
|
|
+ )
|
|
|
train_args.update(
|
|
|
{
|
|
|
"uniform_output_enabled": self.train_config.get(
|