Forráskód Böngészése

support CINN models white list

gaotingquan 7 hónapja
szülő
commit
761b92d475

+ 11 - 0
paddlex/modules/base/trainer.py

@@ -21,8 +21,10 @@ from ...utils.device import (
     set_env_for_device,
     update_device_num,
 )
+from ...utils.flags import DISABLE_CINN_MODEL_WL
 from ...utils.misc import AutoRegisterABCMetaClass
 from .build_model import build_model
+from .utils.cinn_setting import CINN_WHITELIST, enable_cinn_backend
 
 
 def build_trainer(config: AttrDict) -> "BaseTrainer":
@@ -84,6 +86,15 @@ class BaseTrainer(ABC, metaclass=AutoRegisterABCMetaClass):
                 "export_with_pir": export_with_pir,
             }
         )
+
+        # apply CINN when model is supported
+        if (
+            not DISABLE_CINN_MODEL_WL
+            and self.train_config.get("dy2st", False)
+            and self.global_config.model in CINN_WHITELIST
+        ):
+            enable_cinn_backend()
+
         train_result = self.pdx_model.train(**train_args)
         assert (
             train_result.returncode == 0

+ 42 - 0
paddlex/modules/base/utils/cinn_setting.py

@@ -0,0 +1,42 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+from ....utils import logging
+
+CINN_WHITELIST = [
+    # TODO: update models based testing result
+    "PP-LCNet_x0_25",
+]
+
+
+# TODO(gaotingquan): paddle v3.0.0 don't support enable CINN easily
+def enable_cinn_backend():
+    import paddle
+
+    if not paddle.is_compiled_with_cinn():
+        logging.debug(
+            "Your paddle is not compiled with CINN, can not use CINN backend."
+        )
+        return
+
+    # equivalent to `FLAGS_prim_all=1`
+    paddle.base.core._set_prim_all_enabled(True)
+    # equivalent to `FLAGS_prim_enable_dynamic=1`
+    paddle.base.framework.set_flags({"FLAGS_prim_enable_dynamic": True})
+    os.environ["FLAGS_prim_enable_dynamic"] = "1"
+    # equivalent to `FLAGS_use_cinn=1`
+    paddle.base.framework.set_flags({"FLAGS_use_cinn": True})
+    os.environ["FLAGS_use_cinn"] = "1"

+ 2 - 0
paddlex/utils/flags.py

@@ -28,6 +28,7 @@ __all__ = [
     "FLAGS_json_format_model",
     "USE_PIR_TRT",
     "DISABLE_DEV_MODEL_WL",
+    "DISABLE_CINN_MODEL_WL",
 ]
 
 
@@ -50,6 +51,7 @@ EAGER_INITIALIZATION = get_flag_from_env_var("PADDLE_PDX_EAGER_INIT", True)
 FLAGS_json_format_model = get_flag_from_env_var("FLAGS_json_format_model", None)
 USE_PIR_TRT = get_flag_from_env_var("PADDLE_PDX_USE_PIR_TRT", False)
 DISABLE_DEV_MODEL_WL = get_flag_from_env_var("PADDLE_PDX_DISABLE_DEV_MODEL_WL", False)
+DISABLE_CINN_MODEL_WL = get_flag_from_env_var("PADDLE_PDX_DISABLE_CINN_MODEL_WL", False)
 
 # Inference Benchmark
 INFER_BENCHMARK = get_flag_from_env_var("PADDLE_PDX_INFER_BENCHMARK", False)