utils.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import contextlib
  15. import importlib.metadata
  16. import os
  17. import platform
  18. import subprocess
  19. import sys
  20. from ..utils import logging
  21. from ..utils.env import get_device_type
  22. PLATFORM = platform.system()
  23. def _check_call(*args, **kwargs):
  24. return subprocess.check_call(*args, **kwargs)
  25. def _compare_version(version1, version2):
  26. import re
  27. def parse_version(version_str):
  28. version_pattern = re.compile(
  29. r"^(?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<pre_release>.*))?(?:\+(?P<build_metadata>.+))?$"
  30. )
  31. match = version_pattern.match(version_str)
  32. if not match:
  33. raise ValueError(f"Unexpected version string: {version_str}")
  34. return (
  35. int(match.group("major")),
  36. int(match.group("minor")),
  37. int(match.group("patch")),
  38. match.group("pre_release"),
  39. )
  40. v1_infos = parse_version(version1)
  41. v2_infos = parse_version(version2)
  42. for v1_info, v2_info in zip(v1_infos, v2_infos):
  43. if v1_info is None and v2_info is None:
  44. continue
  45. if v1_info is None or (v2_info is not None and v1_info < v2_info):
  46. return -1
  47. if v2_info is None or (v1_info is not None and v1_info > v2_info):
  48. return 1
  49. return 0
  50. def check_package_installation(package):
  51. try:
  52. importlib.metadata.distribution(package)
  53. except importlib.metadata.PackageNotFoundError:
  54. return False
  55. return True
  56. def install_external_deps(repo_name, repo_root):
  57. """install paddle repository custom dependencies"""
  58. import paddle
  59. def get_gcc_version():
  60. return subprocess.check_output(["gcc", "--version"]).decode("utf-8").split()[2]
  61. if repo_name == "PaddleDetection":
  62. if os.path.exists(os.path.join(repo_root, "ppdet", "ext_op")):
  63. """Install custom op for rotated object detection"""
  64. if (
  65. PLATFORM == "Linux"
  66. and _compare_version(get_gcc_version(), "8.2.0") >= 0
  67. and "gpu" in get_device_type()
  68. and (
  69. paddle.is_compiled_with_cuda()
  70. and not paddle.is_compiled_with_rocm()
  71. )
  72. ):
  73. with switch_working_dir(os.path.join(repo_root, "ppdet", "ext_op")):
  74. # TODO: Apply constraints here
  75. args = [sys.executable, "setup.py", "install"]
  76. _check_call(args)
  77. else:
  78. logging.warning(
  79. "The custom operators in PaddleDetection for Rotated Object Detection is only supported when using CUDA, GCC>=8.2.0 and Paddle>=2.0.1, "
  80. "your environment does not meet these requirements, so we will skip the installation of custom operators under PaddleDetection/ppdet/ext_ops, "
  81. "which means you can not train the Rotated Object Detection models."
  82. )
  83. def clone_repo_using_git(url, branch=None):
  84. """clone_repo_using_git"""
  85. args = ["git", "clone", "--depth", "1"]
  86. if isinstance(url, str):
  87. url = [url]
  88. args.extend(url)
  89. if branch is not None:
  90. args.extend(["-b", branch])
  91. return _check_call(args)
  92. def fetch_repo_using_git(branch, url, depth=1):
  93. """fetch_repo_using_git"""
  94. args = ["git", "fetch", url, branch, "--depth", str(depth)]
  95. _check_call(args)
  96. def reset_repo_using_git(pointer, hard=True):
  97. """reset_repo_using_git"""
  98. args = ["git", "reset", "--hard", pointer]
  99. return _check_call(args)
  100. def remove_repo_using_rm(name):
  101. """remove_repo_using_rm"""
  102. if os.path.exists(name):
  103. if PLATFORM == "Windows":
  104. return _check_call(["rmdir", "/S", "/Q", name], shell=True)
  105. else:
  106. return _check_call(["rm", "-rf", name])
  107. @contextlib.contextmanager
  108. def mute():
  109. """mute"""
  110. with open(os.devnull, "w") as f:
  111. with contextlib.redirect_stdout(f), contextlib.redirect_stderr(f):
  112. yield
  113. @contextlib.contextmanager
  114. def switch_working_dir(new_wd):
  115. """switch_working_dir"""
  116. cwd = os.getcwd()
  117. os.chdir(new_wd)
  118. try:
  119. yield
  120. finally:
  121. os.chdir(cwd)