repo.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import importlib
  15. import os
  16. import os.path as osp
  17. import shutil
  18. import tempfile
  19. from packaging.requirements import Requirement
  20. from ..utils import logging
  21. from ..utils.download import download_and_extract
  22. from ..utils.file_interface import custom_open
  23. from ..utils.install import (
  24. install_packages,
  25. install_packages_from_requirements_file,
  26. uninstall_packages,
  27. )
  28. from .meta import REPO_DIST_NAMES, REPO_DOWNLOAD_BASE, get_repo_meta
  29. from .utils import (
  30. fetch_repo_using_git,
  31. install_external_deps,
  32. remove_repo_using_rm,
  33. reset_repo_using_git,
  34. switch_working_dir,
  35. )
  36. __all__ = ["build_repo_instance", "build_repo_group_installer"]
  37. def build_repo_instance(repo_name, *args, **kwargs):
  38. """build_repo_instance"""
  39. # XXX: Hard-code type
  40. repo_cls = PPRepository
  41. repo_instance = repo_cls(repo_name, *args, **kwargs)
  42. return repo_instance
  43. def build_repo_group_installer(*repos):
  44. """build_repo_group_installer"""
  45. return RepositoryGroupInstaller(list(repos))
  46. def build_repo_group_getter(*repos):
  47. """build_repo_group_getter"""
  48. return RepositoryGroupGetter(list(repos))
  49. class PPRepository(object):
  50. """
  51. Installation, initialization, and PDX module import handler for a
  52. PaddlePaddle repository.
  53. """
  54. def __init__(self, name, repo_parent_dir, pdx_collection_mod):
  55. super().__init__()
  56. self.name = name
  57. self.repo_parent_dir = repo_parent_dir
  58. self.root_dir = osp.join(repo_parent_dir, self.name)
  59. self.meta = get_repo_meta(self.name)
  60. self.git_path = self.meta["git_path"]
  61. self.dist_name = self.meta.get("dist_name", None)
  62. self.import_name = self.meta.get("import_name", None)
  63. self.pdx_mod_name = (
  64. pdx_collection_mod.__name__ + "." + self.meta["pdx_pkg_name"]
  65. )
  66. self.main_req_file = self.meta.get("main_req_file", "requirements.txt")
  67. def initialize(self):
  68. """initialize"""
  69. if not self.check_installation():
  70. return False
  71. if "path_env" in self.meta:
  72. # Set env var
  73. os.environ[self.meta["path_env"]] = osp.abspath(self.root_dir)
  74. # NOTE: By calling `self.get_pdx()` we actually loads the repo PDX package
  75. # and do all registration.
  76. self.get_pdx()
  77. return True
  78. def check_installation(self):
  79. """check_installation"""
  80. return osp.exists(osp.join(self.root_dir, ".installed"))
  81. def replace_repo_deps(self, deps_to_replace, src_requirements):
  82. """replace_repo_deps"""
  83. with custom_open(src_requirements, "r") as file:
  84. lines = file.readlines()
  85. existing_deps = []
  86. for line in lines:
  87. line = line.strip()
  88. if not line or line.startswith("#"):
  89. continue
  90. dep_to_replace = next((dep for dep in deps_to_replace if dep in line), None)
  91. if dep_to_replace:
  92. if deps_to_replace[dep_to_replace] == "None":
  93. continue
  94. else:
  95. existing_deps.append(
  96. f"{dep_to_replace}=={deps_to_replace[dep_to_replace]}"
  97. )
  98. else:
  99. existing_deps.append(line)
  100. with open(src_requirements, "w") as file:
  101. file.writelines([l + "\n" for l in existing_deps])
  102. def check_repo_exiting(self):
  103. """check_repo_exiting"""
  104. return osp.exists(osp.join(self.root_dir, ".git"))
  105. def install_packages(self, clean=True):
  106. """install_packages"""
  107. if self.meta["install_pkg"]:
  108. editable = self.meta.get("editable", True)
  109. if editable:
  110. logging.warning(
  111. f"{self.import_name} will be installed in editable mode."
  112. )
  113. with switch_working_dir(self.root_dir):
  114. try:
  115. pip_install_opts = ["--no-deps"]
  116. if editable:
  117. reqs = ["-e ."]
  118. else:
  119. reqs = ["."]
  120. install_packages(reqs, pip_install_opts=pip_install_opts)
  121. install_external_deps(self.name, self.root_dir)
  122. finally:
  123. if clean:
  124. # Clean build artifacts
  125. tmp_build_dir = "build"
  126. if osp.exists(tmp_build_dir):
  127. shutil.rmtree(tmp_build_dir)
  128. for e in self.meta.get("extra_pkgs", []):
  129. if isinstance(e, tuple):
  130. with switch_working_dir(osp.join(self.root_dir, e[0])):
  131. pip_install_opts = ["--no-deps"]
  132. if e[3]:
  133. reqs = ["-e ."]
  134. else:
  135. reqs = ["."]
  136. try:
  137. install_packages(reqs, pip_install_opts=pip_install_opts)
  138. finally:
  139. if clean:
  140. tmp_build_dir = "build"
  141. if osp.exists(tmp_build_dir):
  142. shutil.rmtree(tmp_build_dir)
  143. def uninstall_packages(self):
  144. """uninstall_packages"""
  145. pkgs = []
  146. if self.meta["install_pkg"]:
  147. pkgs.append(self.dist_name)
  148. for e in self.meta.get("extra_pkgs", []):
  149. if isinstance(e, tuple):
  150. pkgs.append(e[1])
  151. uninstall_packages(pkgs)
  152. def mark_installed(self):
  153. with open(osp.join(self.root_dir, ".installed"), "wb"):
  154. pass
  155. def mark_uninstalled(self):
  156. os.unlink(osp.join(self.root_dir, ".installed"))
  157. def download(self):
  158. """download from remote"""
  159. download_url = f"{REPO_DOWNLOAD_BASE}{self.name}.tar"
  160. os.makedirs(self.repo_parent_dir, exist_ok=True)
  161. download_and_extract(download_url, self.repo_parent_dir, self.name)
  162. # reset_repo_using_git('FETCH_HEAD')
  163. def remove(self):
  164. """remove"""
  165. with switch_working_dir(self.repo_parent_dir):
  166. remove_repo_using_rm(self.name)
  167. def update(self, platform=None):
  168. """update"""
  169. branch = self.meta.get("branch", None)
  170. git_url = f"https://{platform}{self.git_path}"
  171. with switch_working_dir(self.root_dir):
  172. try:
  173. fetch_repo_using_git(branch=branch, url=git_url)
  174. reset_repo_using_git("FETCH_HEAD")
  175. except Exception as e:
  176. logging.warning(
  177. f"Update {self.name} from {git_url} failed, check your network connection. Error:\n{e}"
  178. )
  179. def get_pdx(self):
  180. """get_pdx"""
  181. return importlib.import_module(self.pdx_mod_name)
  182. def get_deps(self, deps_to_replace=None):
  183. """get_deps"""
  184. # Merge requirement files
  185. req_list = [self.main_req_file]
  186. for e in self.meta.get("extra_pkgs", []):
  187. if isinstance(e, tuple):
  188. e = e[2] or osp.join(e[0], "requirements.txt")
  189. req_list.append(e)
  190. if deps_to_replace is not None:
  191. deps_dict = {}
  192. for dep in deps_to_replace:
  193. part, version = dep.split("=")
  194. repo_name, dep_name = part.split(".")
  195. deps_dict[repo_name] = {dep_name: version}
  196. src_requirements = osp.join(self.root_dir, "requirements.txt")
  197. if self.name in deps_dict:
  198. self.replace_repo_deps(deps_dict[self.name], src_requirements)
  199. deps = []
  200. for req in req_list:
  201. with open(osp.join(self.root_dir, req), "r", encoding="utf-8") as f:
  202. deps.append(f.read())
  203. for dep in self.meta.get("pdx_pkg_deps", []):
  204. deps.append(dep)
  205. deps = "\n".join(deps)
  206. return deps
  207. def get_version(self):
  208. """get_version"""
  209. version_file = osp.join(self.root_dir, ".pdx_gen.version")
  210. with open(version_file, "r", encoding="utf-8") as f:
  211. lines = f.readlines()
  212. sta_ver = lines[0].rstrip()
  213. commit = lines[1].rstrip()
  214. ret = [sta_ver, commit]
  215. # TODO: Get dynamic version in a subprocess.
  216. ret.append(None)
  217. return ret
  218. def __str__(self):
  219. return f"({self.name}, {id(self)})"
  220. class RepositoryGroupInstaller(object):
  221. """RepositoryGroupInstaller"""
  222. def __init__(self, repos):
  223. super().__init__()
  224. self.repos = repos
  225. def install(
  226. self,
  227. force_reinstall=False,
  228. no_deps=False,
  229. constraints=None,
  230. deps_to_replace=None,
  231. ):
  232. """install"""
  233. # Rollback on failure is not yet supported. A failed installation
  234. # could leave a broken environment.
  235. if force_reinstall:
  236. self.uninstall()
  237. ins_flags = []
  238. repos = self._sort_repos(self.repos, check_missing=True)
  239. for repo in repos:
  240. if force_reinstall or not repo.check_installation():
  241. ins_flags.append(True)
  242. else:
  243. ins_flags.append(False)
  244. if not no_deps:
  245. # We collect the dependencies and install them all at once
  246. # such that we can make use of the pip resolver.
  247. self.install_deps(constraints=constraints, deps_to_replace=deps_to_replace)
  248. # XXX: For historical reasons the repo packages are sequentially
  249. # installed, and we have no failure rollbacks. Meanwhile, installation
  250. # failure of one repo package aborts the entire installation process.
  251. for ins_flag, repo in zip(ins_flags, repos):
  252. if ins_flag:
  253. repo.install_packages()
  254. repo.mark_installed()
  255. def uninstall(self):
  256. """uninstall"""
  257. repos = self._sort_repos(self.repos, check_missing=False)
  258. repos = repos[::-1]
  259. for repo in repos:
  260. if repo.check_installation():
  261. # NOTE: Dependencies are not uninstalled.
  262. repo.uninstall_packages()
  263. repo.mark_uninstalled()
  264. def get_deps(self, deps_to_replace=None):
  265. """get_deps"""
  266. deps_list = []
  267. repos = self._sort_repos(self.repos, check_missing=True)
  268. for repo in repos:
  269. deps = repo.get_deps(deps_to_replace=deps_to_replace)
  270. deps = self._normalize_deps(deps, headline=f"# {repo.name} dependencies")
  271. deps_list.append(deps)
  272. # Add an extra new line to separate dependencies of different repos.
  273. return "\n\n".join(deps_list)
  274. def install_deps(self, constraints, deps_to_replace=None):
  275. """install_deps"""
  276. deps_str = self.get_deps(deps_to_replace=deps_to_replace)
  277. with tempfile.TemporaryDirectory() as td:
  278. req_file = osp.join(td, "requirements.txt")
  279. with open(req_file, "w", encoding="utf-8") as fr:
  280. fr.write(deps_str)
  281. cons_file = osp.join(td, "constraints.txt")
  282. with open(cons_file, "w", encoding="utf-8") as fc:
  283. if constraints is not None:
  284. fc.write(constraints)
  285. # HACK: Avoid installing OpenCV variants unexpectedly
  286. fc.write("opencv-python == 0.0.0\n")
  287. fc.write("opencv-python-headless == 0.0.0\n")
  288. fc.write("opencv-contrib-python-headless == 0.0.0\n")
  289. pip_install_opts = []
  290. pip_install_opts.append("-c")
  291. pip_install_opts.append(cons_file)
  292. install_packages_from_requirements_file(
  293. req_file, pip_install_opts=pip_install_opts
  294. )
  295. def _sort_repos(self, repos, check_missing=False):
  296. # We sort the repos to ensure that the dependencies precede the
  297. # dependent in the list.
  298. name_meta_pairs = []
  299. for repo in repos:
  300. name_meta_pairs.append((repo.name, repo.meta))
  301. unique_pairs = []
  302. hashset = set()
  303. for name, meta in name_meta_pairs:
  304. if name in hashset:
  305. continue
  306. else:
  307. unique_pairs.append((name, meta))
  308. hashset.add(name)
  309. sorted_repos = []
  310. missing_names = []
  311. name2repo = {repo.name: repo for repo in repos}
  312. for name, meta in unique_pairs:
  313. if name in name2repo:
  314. repo = name2repo[name]
  315. sorted_repos.append(repo)
  316. else:
  317. missing_names.append(name)
  318. if check_missing and len(missing_names) > 0:
  319. be = "is" if len(missing_names) == 1 else "are"
  320. raise RuntimeError(f"{missing_names} {be} required in the installation.")
  321. else:
  322. assert len(sorted_repos) == len(self.repos)
  323. return sorted_repos
  324. def _normalize_deps(self, deps, headline=None):
  325. lines = []
  326. if headline is not None:
  327. lines.append(headline)
  328. for line in deps.splitlines():
  329. line_s = line.strip()
  330. if not line_s:
  331. continue
  332. pos = line_s.find("#")
  333. if pos == 0:
  334. continue
  335. elif pos > 0:
  336. line_s = line_s[:pos]
  337. # If `line` is not an empty line or a comment, it must be a requirement specifier.
  338. # Other forms may cause a parse error.
  339. req = Requirement(line_s)
  340. if req.name in REPO_DIST_NAMES:
  341. # Skip repo packages
  342. continue
  343. elif req.name.replace("_", "-") in (
  344. "opencv-python",
  345. "opencv-contrib-python",
  346. "opencv-python-headless",
  347. "opencv-contrib-python-headless",
  348. ):
  349. # FIXME: The original version specifiers are ignored. It would be better to check them here.
  350. # The resolver will get the version info from the constraints file.
  351. line_s = "opencv-contrib-python"
  352. elif req.name == "albumentations":
  353. # HACK
  354. line_s = "albumentations @ https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/patched_packages/albumentations-1.4.10%2Bpdx-py3-none-any.whl"
  355. line_s += "\nalbucore @ https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/patched_packages/albucore-0.0.13%2Bpdx-py3-none-any.whl"
  356. elif req.name.replace("_", "-") == "nuscenes-devkit":
  357. # HACK
  358. line_s = "nuscenes-devkit @ https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/patched_packages/nuscenes_devkit-1.1.11%2Bpdx-py3-none-any.whl"
  359. elif req.name == "imgaug":
  360. # HACK
  361. line_s = "imgaug @ https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/patched_packages/imgaug-0.4.0%2Bpdx-py2.py3-none-any.whl"
  362. lines.append(line_s)
  363. return "\n".join(lines)
  364. class RepositoryGroupGetter(object):
  365. """RepositoryGroupGetter"""
  366. def __init__(self, repos):
  367. super().__init__()
  368. self.repos = repos
  369. def get(self, force=False, platform=None):
  370. """clone"""
  371. if force:
  372. self.remove()
  373. for repo in self.repos:
  374. repo.download()
  375. repo.update(platform=platform)
  376. def remove(self):
  377. """remove"""
  378. for repo in self.repos:
  379. repo.remove()