gaotingquan 11 bulan lalu
induk
melakukan
0d127af2ac
2 mengubah file dengan 59 tambahan dan 18 penghapusan
  1. 31 18
      paddlex/repo_manager/utils.py
  2. 28 0
      paddlex/utils/env.py

+ 31 - 18
paddlex/repo_manager/utils.py

@@ -18,8 +18,11 @@ import json
 import platform
 import subprocess
 import contextlib
-
 from parsley import makeGrammar
+import lazy_paddle as paddle
+
+from ..utils.env import get_device_type
+from ..utils import logging
 
 
 PLATFORM = platform.system()
@@ -32,28 +35,32 @@ def _check_call(*args, **kwargs):
 def _check_output(*args, **kwargs):
     return subprocess.check_output(*args, **kwargs)
 
+
 def _compare_version(version1, version2):
     import re
+
     def parse_version(version_str):
         version_pattern = re.compile(
-            r'^(?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<pre_release>.*))?(?:\+(?P<build_metadata>.+))?$'
+            r"^(?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<pre_release>.*))?(?:\+(?P<build_metadata>.+))?$"
         )
         match = version_pattern.match(version_str)
         if not match:
             raise ValueError(f"Unexpected version string: {version_str}")
         return (
-            int(match.group('major')), int(match.group('minor')),
-            int(match.group('patch')), match.group('pre_release')
+            int(match.group("major")),
+            int(match.group("minor")),
+            int(match.group("patch")),
+            match.group("pre_release"),
         )
 
     v1_infos = parse_version(version1)
     v2_infos = parse_version(version2)
     for v1_info, v2_info in zip(v1_infos, v2_infos):
-        if v1_info is None and v2_info is None: 
+        if v1_info is None and v2_info is None:
             continue
-        if v1_info is None or (v2_info is not None and v1_info < v2_info): 
+        if v1_info is None or (v2_info is not None and v1_info < v2_info):
             return -1
-        if v2_info is None or (v1_info is not None and v1_info > v2_info): 
+        if v2_info is None or (v1_info is not None and v1_info > v2_info):
             return 1
     return 0
 
@@ -93,20 +100,26 @@ def install_packages_using_pip(
         args.extend(pip_flags)
     return _check_call(args)
 
+
 def install_external_deps(repo_name, repo_root):
     """install paddle repository custom dependencies"""
-    import paddle
-    from ..utils import logging
-    paddle_version = paddle.__version__
-    paddle_w_cuda = paddle.is_compiled_with_cuda()
-    gcc_version = subprocess.check_output(["gcc", "--version"]).decode('utf-8').split()[2]
-
-    if repo_name == 'PaddleDetection':
-        if os.path.exists(os.path.join(repo_root, 'ppdet', 'ext_op')):
+    gcc_version = (
+        subprocess.check_output(["gcc", "--version"]).decode("utf-8").split()[2]
+    )
+
+    if repo_name == "PaddleDetection":
+        if os.path.exists(os.path.join(repo_root, "ppdet", "ext_op")):
             """Install custom op for rotated object detection"""
-            if _compare_version(paddle_version, '2.0.1') >= 0 and paddle_w_cuda and _compare_version(gcc_version, '8.2.0') >= 0:
-                with switch_working_dir(os.path.join(repo_root, 'ppdet', 'ext_op')):
-                    args = [sys.executable, 'setup.py', 'install']
+            if (
+                _compare_version(gcc_version, "8.2.0") >= 0
+                and "gpu" in get_device_type()
+                and (
+                    paddle.is_compiled_with_cuda()
+                    and not paddle.is_compiled_with_rocm()
+                )
+            ):
+                with switch_working_dir(os.path.join(repo_root, "ppdet", "ext_op")):
+                    args = [sys.executable, "setup.py", "install"]
                     _check_call(args)
             else:
                 logging.warning(

+ 28 - 0
paddlex/utils/env.py

@@ -0,0 +1,28 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import lazy_paddle as paddle
+
+
+def get_device_type():
+    device_str = paddle.get_device()
+    return device_str.split(":")[0]
+
+
+def get_paddle_version():
+    version = paddle.__version__.split(".")
+    # ref: https://github.com/PaddlePaddle/Paddle/blob/release/3.0-beta2/setup.py#L316
+    assert len(version) == 3
+    major_v, minor_v, patch_v = version
+    return major_v, minor_v, patch_v