Prechádzať zdrojové kódy

lazy import paddle to avoid conflicts with HPI

gaotingquan 1 rok pred
rodič
commit
120ade1531
30 zmenil súbory, kde vykonal 122 pridanie a 57 odobranie
  1. 4 4
      .pre-commit-config.yaml
  2. 8 5
      .precommit/check_custom.py
  3. 0 0
      .precommit/clang_format.hook
  4. 5 0
      paddlex/__init__.py
  5. 3 3
      paddlex/inference/components/paddle_predictor/predictor.py
  6. 1 1
      paddlex/inference/components/task_related/text_rec.py
  7. 3 2
      paddlex/modules/base/predictor/utils/paddle_inference_predictor.py
  8. 1 1
      paddlex/modules/base/trainer/train_deamon.py
  9. 2 3
      paddlex/modules/base/utils/topk_eval.py
  10. 1 1
      paddlex/modules/image_classification/trainer.py
  11. 13 15
      paddlex/modules/multilabel_classification/trainer.py
  12. 1 1
      paddlex/modules/object_detection/trainer.py
  13. 1 1
      paddlex/modules/semantic_segmentation/trainer.py
  14. 1 1
      paddlex/modules/table_recognition/predictor/transforms.py
  15. 5 2
      paddlex/modules/table_recognition/trainer.py
  16. 5 2
      paddlex/modules/text_detection/trainer.py
  17. 1 2
      paddlex/modules/text_recognition/dataset_checker/dataset_src/convert_dataset.py
  18. 1 1
      paddlex/modules/text_recognition/predictor/transforms.py
  19. 5 2
      paddlex/modules/text_recognition/trainer.py
  20. 1 1
      paddlex/modules/ts_anomaly_detection/trainer.py
  21. 1 1
      paddlex/modules/ts_classification/trainer.py
  22. 1 1
      paddlex/modules/ts_forecast/trainer.py
  23. 1 1
      paddlex/pipelines/OCR/utils.py
  24. 2 1
      paddlex/repo_apis/PaddleClas_api/cls/config.py
  25. 6 2
      paddlex/repo_apis/PaddleSeg_api/base_seg_config.py
  26. 0 1
      paddlex/repo_apis/PaddleSeg_api/seg/config.py
  27. 4 1
      paddlex/repo_apis/PaddleTS_api/ts_base/config.py
  28. 1 1
      paddlex/utils/device.py
  29. 43 0
      paddlex/utils/lazy_loader.py
  30. 1 0
      requirements.txt

+ 4 - 4
.pre-commit-config.yaml

@@ -22,7 +22,7 @@ repos:
     -   id: clang-format
         name: clang-format
         description: Format files with ClangFormat
-        entry: bash .clang_format.hook -i
+        entry: bash .precommit/clang_format.hook -i
         language: system
         files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
 # For Python files
@@ -46,8 +46,8 @@ repos:
 # check license
 -   repo: local
     hooks:
-    -   id: check-license
-        name: Check License
-        entry: python .check_license.py
+    -   id: check-custom
+        name: Check Custom
+        entry: python .precommit/check_custom.py
         language: python
         files: \.py$

+ 8 - 5
.check_license.py → .precommit/check_custom.py

@@ -31,12 +31,15 @@ LICENSE_TEXT = """# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
 """
 
 
-def check_license(file_path):
+def check(file_path):
     with open(file_path, "r") as f:
         content = f.read()
-        if not content.startswith(LICENSE_TEXT):
-            print(f"License header missing in {file_path}")
-            return False
+    if not content.startswith(LICENSE_TEXT):
+        print(f"License header missing in {file_path}")
+        return False
+    if "import paddle" in content or "from paddle import " in content:
+        print(f"Please using `lazy_paddle` instead `paddle` when import in {file_path}")
+        return False
     return True
 
 
@@ -44,7 +47,7 @@ def main():
     files = sys.argv[1:]
     all_files_valid = True
     for file in files:
-        if not check_license(file):
+        if not check(file):
             all_files_valid = False
     if not all_files_valid:
         sys.exit(1)

+ 0 - 0
.clang_format.hook → .precommit/clang_format.hook


+ 5 - 0
paddlex/__init__.py

@@ -12,6 +12,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from .utils.lazy_loader import LazyLoader
+import sys
+
+sys.modules["lazy_paddle"] = LazyLoader("lazy_paddle", globals(), "paddle")
+
 import os
 
 from . import version

+ 3 - 3
paddlex/inference/components/paddle_predictor/predictor.py

@@ -14,9 +14,7 @@
 
 import os
 from abc import abstractmethod
-import numpy as np
-import paddle
-from paddle.inference import Config, create_predictor
+import lazy_paddle as paddle
 import numpy as np
 
 from ..base import BaseComponent
@@ -44,6 +42,8 @@ class BasePaddlePredictor(BaseComponent):
 
     def _create(self, model_dir, model_prefix, option):
         """_create"""
+        from lazy_paddle.inference import Config, create_predictor
+
         use_pir = (
             hasattr(paddle.framework, "use_pir_api") and paddle.framework.use_pir_api()
         )

+ 1 - 1
paddlex/inference/components/task_related/text_rec.py

@@ -21,7 +21,7 @@ import numpy as np
 from PIL import Image
 import cv2
 import math
-import paddle
+import lazy_paddle as paddle
 import json
 import tempfile
 from tokenizers import Tokenizer as TokenizerFast

+ 3 - 2
paddlex/modules/base/predictor/utils/paddle_inference_predictor.py

@@ -13,8 +13,7 @@
 # limitations under the License.
 
 import os
-import paddle
-from paddle.inference import Config, create_predictor
+import lazy_paddle as paddle
 
 from .....utils import logging
 
@@ -34,6 +33,8 @@ class _PaddleInferencePredictor(object):
 
     def _create(self, model_dir, model_prefix, option, delete_pass):
         """_create"""
+        from lazy_paddle.inference import Config, create_predictor
+
         use_pir = (
             hasattr(paddle.framework, "use_pir_api") and paddle.framework.use_pir_api()
         )

+ 1 - 1
paddlex/modules/base/trainer/train_deamon.py

@@ -20,7 +20,7 @@ import traceback
 import threading
 from abc import ABC, abstractmethod
 from pathlib import Path
-import paddle
+import lazy_paddle as paddle
 
 from ..build_model import build_model
 from ....utils.file_interface import write_json_file

+ 2 - 3
paddlex/modules/base/utils/topk_eval.py

@@ -16,8 +16,7 @@
 import os
 import json
 import argparse
-from paddle import nn
-import paddle
+import lazy_paddle as paddle
 
 from ....utils import logging
 
@@ -34,7 +33,7 @@ def parse_args():
     return args
 
 
-class AvgMetrics(nn.Layer):
+class AvgMetrics(paddle.nn.Layer):
     """Average metrics"""
 
     def __init__(self):

+ 1 - 1
paddlex/modules/image_classification/trainer.py

@@ -14,7 +14,7 @@
 
 import json
 import shutil
-import paddle
+import lazy_paddle as paddle
 from pathlib import Path
 
 from ..base import BaseTrainer, BaseTrainDeamon

+ 13 - 15
paddlex/modules/multilabel_classification/trainer.py

@@ -14,7 +14,7 @@
 
 import json
 import shutil
-import paddle
+import lazy_paddle as paddle
 from pathlib import Path
 
 from ..base import BaseTrainer, BaseTrainDeamon
@@ -48,34 +48,32 @@ class MLClsTrainer(BaseTrainer):
         return ClsTrainDeamon(config)
 
     def update_config(self):
-        """update training config
-        """
+        """update training config"""
         if self.train_config.log_interval:
             self.pdx_config.update_log_interval(self.train_config.log_interval)
         if self.train_config.eval_interval:
-            self.pdx_config.update_eval_interval(
-                self.train_config.eval_interval)
+            self.pdx_config.update_eval_interval(self.train_config.eval_interval)
         if self.train_config.save_interval:
-            self.pdx_config.update_save_interval(
-                self.train_config.save_interval)
+            self.pdx_config.update_save_interval(self.train_config.save_interval)
 
-        self.pdx_config.update_dataset(self.global_config.dataset_dir,
-                                       "MLClsDataset")
+        self.pdx_config.update_dataset(self.global_config.dataset_dir, "MLClsDataset")
         if self.train_config.num_classes is not None:
             self.pdx_config.update_num_classes(self.train_config.num_classes)
-        if self.train_config.pretrain_weight_path and self.train_config.pretrain_weight_path != "":
+        if (
+            self.train_config.pretrain_weight_path
+            and self.train_config.pretrain_weight_path != ""
+        ):
             self.pdx_config.update_pretrained_weights(
-                self.train_config.pretrain_weight_path)
+                self.train_config.pretrain_weight_path
+            )
 
-        label_dict_path = Path(self.global_config.dataset_dir).joinpath(
-            "label.txt")
+        label_dict_path = Path(self.global_config.dataset_dir).joinpath("label.txt")
         if label_dict_path.exists():
             self.dump_label_dict(label_dict_path)
         if self.train_config.batch_size is not None:
             self.pdx_config.update_batch_size(self.train_config.batch_size)
         if self.train_config.learning_rate is not None:
-            self.pdx_config.update_learning_rate(
-                self.train_config.learning_rate)
+            self.pdx_config.update_learning_rate(self.train_config.learning_rate)
         if self.train_config.epochs_iters is not None:
             self.pdx_config._update_epochs(self.train_config.epochs_iters)
         if self.train_config.warmup_steps is not None:

+ 1 - 1
paddlex/modules/object_detection/trainer.py

@@ -14,7 +14,7 @@
 
 
 from pathlib import Path
-import paddle
+import lazy_paddle as paddle
 
 from ..base import BaseTrainer, BaseTrainDeamon
 from ...utils.config import AttrDict

+ 1 - 1
paddlex/modules/semantic_segmentation/trainer.py

@@ -16,7 +16,7 @@
 import os
 import glob
 from pathlib import Path
-import paddle
+import lazy_paddle as paddle
 
 from ..base import BaseTrainer, BaseTrainDeamon
 from ...utils.config import AttrDict

+ 1 - 1
paddlex/modules/table_recognition/predictor/transforms.py

@@ -17,7 +17,7 @@ import os
 import os.path as osp
 import numpy as np
 import cv2
-import paddle
+import lazy_paddle as paddle
 
 from .keys import TableRecKeys as K
 from ...base import BaseTransform

+ 5 - 2
paddlex/modules/table_recognition/trainer.py

@@ -15,7 +15,7 @@
 
 import os
 from pathlib import Path
-import paddle
+import lazy_paddle as paddle
 
 from ..base import BaseTrainer, BaseTrainDeamon
 from ...utils.config import AttrDict
@@ -76,7 +76,10 @@ class TableRecTrainer(BaseTrainer):
         Returns:
             dict: the arguments of training function.
         """
-        return {"device": self.get_device(), "dy2st": self.train_config.get("dy2st", False)}
+        return {
+            "device": self.get_device(),
+            "dy2st": self.train_config.get("dy2st", False),
+        }
 
 
 class TableRecTrainDeamon(BaseTrainDeamon):

+ 5 - 2
paddlex/modules/text_detection/trainer.py

@@ -15,7 +15,7 @@
 
 import os
 from pathlib import Path
-import paddle
+import lazy_paddle as paddle
 
 from ..base import BaseTrainer, BaseTrainDeamon
 from ...utils.config import AttrDict
@@ -74,7 +74,10 @@ class TextDetTrainer(BaseTrainer):
         Returns:
             dict: the arguments of training function.
         """
-        return {"device": self.get_device(), "dy2st": self.train_config.get("dy2st", False)}
+        return {
+            "device": self.get_device(),
+            "dy2st": self.train_config.get("dy2st", False),
+        }
 
 
 class TextDetTrainDeamon(BaseTrainDeamon):

+ 1 - 2
paddlex/modules/text_recognition/dataset_checker/dataset_src/convert_dataset.py

@@ -21,7 +21,7 @@ import math
 import pickle
 from tqdm import tqdm
 from collections import defaultdict
-from paddle.utils import try_import
+import imagesize
 from .....utils.errors import ConvertFailedError
 from .....utils.logging import info, warning
 
@@ -66,7 +66,6 @@ def convert_pkl_dataset(root_dir):
 
 
 def txt2pickle(images, equations, save_dir):
-    imagesize = try_import("imagesize")
     phase = os.path.basename(equations).replace(".txt", "")
     save_p = os.path.join(save_dir, "latexocr_{}.pkl".format(phase))
     min_dimensions = (32, 32)

+ 1 - 1
paddlex/modules/text_recognition/predictor/transforms.py

@@ -21,7 +21,7 @@ import numpy as np
 from PIL import Image
 import cv2
 import math
-import paddle
+import lazy_paddle as paddle
 import json
 import tempfile
 from tokenizers import Tokenizer as TokenizerFast

+ 5 - 2
paddlex/modules/text_recognition/trainer.py

@@ -16,7 +16,7 @@
 import os
 import shutil
 from pathlib import Path
-import paddle
+import lazy_paddle as paddle
 
 from ..base import BaseTrainer, BaseTrainDeamon
 from ...utils.config import AttrDict
@@ -108,7 +108,10 @@ class TextRecTrainer(BaseTrainer):
         Returns:
             dict: the arguments of training function.
         """
-        return {"device": self.get_device(), "dy2st": self.train_config.get("dy2st", False)}
+        return {
+            "device": self.get_device(),
+            "dy2st": self.train_config.get("dy2st", False),
+        }
 
 
 class TextRecTrainDeamon(BaseTrainDeamon):

+ 1 - 1
paddlex/modules/ts_anomaly_detection/trainer.py

@@ -17,7 +17,7 @@ import json
 import time
 from pathlib import Path
 import tarfile
-import paddle
+import lazy_paddle as paddle
 
 from ..base import BaseTrainer, BaseTrainDeamon
 from ...utils.config import AttrDict

+ 1 - 1
paddlex/modules/ts_classification/trainer.py

@@ -17,7 +17,7 @@ import json
 import time
 import tarfile
 from pathlib import Path
-import paddle
+import lazy_paddle as paddle
 
 from ..base import BaseTrainer, BaseTrainDeamon
 from ...utils.config import AttrDict

+ 1 - 1
paddlex/modules/ts_forecast/trainer.py

@@ -17,7 +17,7 @@ import json
 import time
 import tarfile
 from pathlib import Path
-import paddle
+import lazy_paddle as paddle
 
 from ..base import BaseTrainer, BaseTrainDeamon
 from ...utils.config import AttrDict

+ 1 - 1
paddlex/pipelines/OCR/utils.py

@@ -21,7 +21,7 @@ import random
 import math
 import copy
 
-from paddlex.utils.fonts import PINGFANG_FONT_FILE_PATH
+from ...utils.fonts import PINGFANG_FONT_FILE_PATH
 
 
 def draw_ocr_box_txt(

+ 2 - 1
paddlex/repo_apis/PaddleClas_api/cls/config.py

@@ -14,7 +14,6 @@
 
 import yaml
 from typing import Union
-from paddleclas.ppcls.utils.config import get_config, override_config
 
 from ...base import BaseConfig
 from ....utils.misc import abspath
@@ -33,6 +32,8 @@ class ClsConfig(BaseConfig):
                     'VALID.transforms.1.ResizeImage.resize_short=300'
                 ]
         """
+        from paddleclas.ppcls.utils.config import override_config
+
         dict_ = override_config(self.dict, list_like_obj)
         self.reset_from_dict(dict_)
 

+ 6 - 2
paddlex/repo_apis/PaddleSeg_api/base_seg_config.py

@@ -15,8 +15,6 @@
 from urllib.parse import urlparse
 
 import yaml
-from paddleseg.utils import NoAliasDumper
-from paddleseg.cvlibs.config import parse_from_yaml, merge_config_dicts
 
 from ..base import BaseConfig
 from ...utils.misc import abspath
@@ -27,11 +25,15 @@ class BaseSegConfig(BaseConfig):
 
     def update(self, dict_like_obj):
         """update"""
+        from paddleseg.cvlibs.config import merge_config_dicts
+
         dict_ = merge_config_dicts(dict_like_obj, self.dict)
         self.reset_from_dict(dict_)
 
     def load(self, config_path):
         """load"""
+        from paddleseg.cvlibs.config import parse_from_yaml
+
         dict_ = parse_from_yaml(config_path)
         if not isinstance(dict_, dict):
             raise TypeError
@@ -39,6 +41,8 @@ class BaseSegConfig(BaseConfig):
 
     def dump(self, config_path):
         """dump"""
+        from paddleseg.utils import NoAliasDumper
+
         with open(config_path, "w", encoding="utf-8") as f:
             yaml.dump(self.dict, f, Dumper=NoAliasDumper)
 

+ 0 - 1
paddlex/repo_apis/PaddleSeg_api/seg/config.py

@@ -18,7 +18,6 @@ from functools import lru_cache
 
 import yaml
 from typing import Union
-from paddleseg.utils import NoAliasDumper
 
 from ..base_seg_config import BaseSegConfig
 from ....utils.misc import abspath

+ 4 - 1
paddlex/repo_apis/PaddleTS_api/ts_base/config.py

@@ -16,7 +16,6 @@ import os
 from urllib.parse import urlparse
 
 import ruamel.yaml
-from paddlets.utils.config import parse_from_yaml, merge_config_dicts
 
 from ...base import BaseConfig
 from ....utils.misc import abspath
@@ -31,6 +30,8 @@ class BaseTSConfig(BaseConfig):
         Args:
             dict_like_obj (dict): dict of pairs(key0.key1.idx.key2=value).
         """
+        from paddlets.utils.config import merge_config_dicts
+
         dict_ = merge_config_dicts(dict_like_obj, self.dict)
         self.reset_from_dict(dict_)
 
@@ -43,6 +44,8 @@ class BaseTSConfig(BaseConfig):
         Raises:
             TypeError: the content of yaml file `config_file_path` error.
         """
+        from paddlets.utils.config import parse_from_yaml
+
         dict_ = parse_from_yaml(config_file_path)
         if not isinstance(dict_, dict):
             raise TypeError

+ 1 - 1
paddlex/utils/device.py

@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import os
-import paddle
+import lazy_paddle as paddle
 from .errors import raise_unsupported_device_error
 
 SUPPORTED_DEVICE_TYPE = ["cpu", "gpu", "xpu", "npu", "mlu"]

+ 43 - 0
paddlex/utils/lazy_loader.py

@@ -0,0 +1,43 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Code copied from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/util/lazy_loader.py
+import importlib
+import types
+
+
+class LazyLoader(types.ModuleType):
+    """Lazily import a module, mainly to avoid pulling in large dependencies."""
+
+    def __init__(self, local_name, parent_module_globals, name):
+        self._local_name = local_name
+        self._parent_module_globals = parent_module_globals
+
+        super(LazyLoader, self).__init__(name)
+
+    def _load(self):
+        module = importlib.import_module(self.__name__)
+        self._parent_module_globals[self._local_name] = module
+
+        self.__dict__.update(module.__dict__)
+
+        return module
+
+    def __getattr__(self, item):
+        module = self._load()
+        return getattr(module, item)
+
+    def __dir__(self):
+        module = self._load()
+        return dir(module)

+ 1 - 0
requirements.txt

@@ -1,3 +1,4 @@
+imagesize
 colorlog
 PyYAML
 filelock