utils.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import contextlib
  15. import importlib.metadata
  16. import os
  17. import platform
  18. import subprocess
  19. import sys
  20. import lazy_paddle as paddle
  21. from ..utils import logging
  22. from ..utils.env import get_device_type
  23. PLATFORM = platform.system()
  24. def _check_call(*args, **kwargs):
  25. return subprocess.check_call(*args, **kwargs)
  26. def _compare_version(version1, version2):
  27. import re
  28. def parse_version(version_str):
  29. version_pattern = re.compile(
  30. r"^(?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<pre_release>.*))?(?:\+(?P<build_metadata>.+))?$"
  31. )
  32. match = version_pattern.match(version_str)
  33. if not match:
  34. raise ValueError(f"Unexpected version string: {version_str}")
  35. return (
  36. int(match.group("major")),
  37. int(match.group("minor")),
  38. int(match.group("patch")),
  39. match.group("pre_release"),
  40. )
  41. v1_infos = parse_version(version1)
  42. v2_infos = parse_version(version2)
  43. for v1_info, v2_info in zip(v1_infos, v2_infos):
  44. if v1_info is None and v2_info is None:
  45. continue
  46. if v1_info is None or (v2_info is not None and v1_info < v2_info):
  47. return -1
  48. if v2_info is None or (v1_info is not None and v1_info > v2_info):
  49. return 1
  50. return 0
  51. def check_package_installation(package):
  52. try:
  53. importlib.metadata.distribution(package)
  54. except importlib.metadata.PackageNotFoundError:
  55. return False
  56. return True
  57. def install_external_deps(repo_name, repo_root):
  58. """install paddle repository custom dependencies"""
  59. def get_gcc_version():
  60. return subprocess.check_output(["gcc", "--version"]).decode("utf-8").split()[2]
  61. if repo_name == "PaddleDetection":
  62. if os.path.exists(os.path.join(repo_root, "ppdet", "ext_op")):
  63. """Install custom op for rotated object detection"""
  64. if (
  65. PLATFORM == "Linux"
  66. and _compare_version(get_gcc_version(), "8.2.0") >= 0
  67. and "gpu" in get_device_type()
  68. and (
  69. paddle.is_compiled_with_cuda()
  70. and not paddle.is_compiled_with_rocm()
  71. )
  72. ):
  73. with switch_working_dir(os.path.join(repo_root, "ppdet", "ext_op")):
  74. args = [sys.executable, "setup.py", "install"]
  75. _check_call(args)
  76. else:
  77. logging.warning(
  78. "The custom operators in PaddleDetection for Rotated Object Detection is only supported when using CUDA, GCC>=8.2.0 and Paddle>=2.0.1, "
  79. "your environment does not meet these requirements, so we will skip the installation of custom operators under PaddleDetection/ppdet/ext_ops, "
  80. "which means you can not train the Rotated Object Detection models."
  81. )
  82. def clone_repo_using_git(url, branch=None):
  83. """clone_repo_using_git"""
  84. args = ["git", "clone", "--depth", "1"]
  85. if isinstance(url, str):
  86. url = [url]
  87. args.extend(url)
  88. if branch is not None:
  89. args.extend(["-b", branch])
  90. return _check_call(args)
  91. def fetch_repo_using_git(branch, url, depth=1):
  92. """fetch_repo_using_git"""
  93. args = ["git", "fetch", url, branch, "--depth", str(depth)]
  94. _check_call(args)
  95. def reset_repo_using_git(pointer, hard=True):
  96. """reset_repo_using_git"""
  97. args = ["git", "reset", "--hard", pointer]
  98. return _check_call(args)
  99. def remove_repo_using_rm(name):
  100. """remove_repo_using_rm"""
  101. if os.path.exists(name):
  102. if PLATFORM == "Windows":
  103. return _check_call(["rmdir", "/S", "/Q", name], shell=True)
  104. else:
  105. return _check_call(["rm", "-rf", name])
  106. def build_wheel_using_pip(pkg, dst_dir="./", with_deps=False, pip_flags=None):
  107. """build_wheel_using_pip"""
  108. args = [sys.executable, "-m", "pip", "wheel", "--wheel-dir", dst_dir]
  109. if not with_deps:
  110. args.append("--no-deps")
  111. if pip_flags is not None:
  112. args.extend(pip_flags)
  113. args.append(pkg)
  114. return _check_call(args)
  115. @contextlib.contextmanager
  116. def mute():
  117. """mute"""
  118. with open(os.devnull, "w") as f:
  119. with contextlib.redirect_stdout(f), contextlib.redirect_stderr(f):
  120. yield
  121. @contextlib.contextmanager
  122. def switch_working_dir(new_wd):
  123. """switch_working_dir"""
  124. cwd = os.getcwd()
  125. os.chdir(new_wd)
  126. try:
  127. yield
  128. finally:
  129. os.chdir(cwd)