repo.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. Author: PaddlePaddle Authors
  10. """
  11. import os
  12. import os.path as osp
  13. import importlib
  14. import tempfile
  15. import shutil
  16. from ..utils import logging
  17. from .meta import get_repo_meta
  18. from .utils import (install_packages_using_pip, clone_repos_using_git,
  19. update_repos_using_git, uninstall_package_using_pip,
  20. remove_repos_using_rm, check_installation_using_pip,
  21. build_wheel_using_pip, mute, switch_working_dir,
  22. to_dep_spec_pep508, env_marker_ast2expr)
  23. __all__ = ['build_repo_instance', 'build_repo_group_installer']
  24. def build_repo_instance(repo_name, *args, **kwargs):
  25. """ build_repo_instance """
  26. # XXX: Hard-code type
  27. repo_cls = PPRepository
  28. repo_instance = repo_cls(repo_name, *args, **kwargs)
  29. return repo_instance
  30. def build_repo_group_installer(*repos):
  31. """ build_repo_group_installer """
  32. return RepositoryGroupInstaller(list(repos))
  33. def build_repo_group_cloner(*repos):
  34. """ build_repo_group_cloner """
  35. return RepositoryGroupCloner(list(repos))
  36. class PPRepository(object):
  37. """
  38. Installation, initialization, and PDX module import handler for a
  39. PaddlePaddle repository.
  40. """
  41. def __init__(self, name, repo_parent_dir, pdx_collection_mod):
  42. super().__init__()
  43. self.name = name
  44. self.repo_parent_dir = repo_parent_dir
  45. self.root_dir = osp.join(repo_parent_dir, self.name)
  46. self.meta = get_repo_meta(self.name)
  47. self.repo_url = self.meta['repo_url']
  48. self.pkg_name = self.meta['pkg_name']
  49. self.lib_name = self.meta['lib_name']
  50. self.pdx_mod_name = pdx_collection_mod.__name__ + '.' + self.meta[
  51. 'pdx_pkg_name']
  52. self.main_req_file = self.meta.get('main_req_file', 'requirements.txt')
  53. def initialize(self):
  54. """ initialize """
  55. if not self.check_installation(quick_check=True):
  56. return False
  57. if 'path_env' in self.meta:
  58. # Set env var
  59. os.environ[self.meta['path_env']] = osp.abspath(self.root_dir)
  60. # NOTE: By calling `self.get_pdx()` we actually loads the repo PDX package
  61. # and do all registration.
  62. self.get_pdx()
  63. return True
  64. def check_installation(self, quick_check=False):
  65. """ check_installation """
  66. if quick_check:
  67. lib = self._get_lib(load=False)
  68. return lib is not None
  69. else:
  70. # TODO: Also check if correct dependencies are installed.
  71. return check_installation_using_pip(self.pkg_name)
  72. def check_repo_exiting(self, quick_check=False):
  73. """ check_repo_exiting """
  74. return os.path.exists(os.path.join(self.root_dir, '.git'))
  75. def install(self, *args, **kwargs):
  76. """ install """
  77. return RepositoryGroupInstaller([self]).install(*args, **kwargs)
  78. def uninstall(self, *args, **kwargs):
  79. """ uninstall """
  80. return RepositoryGroupInstaller([self]).uninstall(*args, **kwargs)
  81. def install_deps(self, *args, **kwargs):
  82. """ install_deps """
  83. return RepositoryGroupInstaller([self]).install_deps(*args, **kwargs)
  84. def install_package(self, no_deps=False, clean=True):
  85. """ install_package """
  86. editable = self.meta.get('editable', True)
  87. extra_editable = self.meta.get('extra_editable', None)
  88. if editable:
  89. logging.warning(
  90. f"{self.pkg_name} will be installed in editable mode.")
  91. with switch_working_dir(self.root_dir):
  92. try:
  93. install_packages_using_pip(
  94. ['.'], editable=editable, no_deps=no_deps)
  95. finally:
  96. if clean:
  97. # Clean build artifacts
  98. tmp_build_dir = os.path.join(self.root_dir, 'build')
  99. if os.path.exists(tmp_build_dir):
  100. shutil.rmtree(tmp_build_dir)
  101. if extra_editable:
  102. with switch_working_dir(
  103. os.path.join(self.root_dir, extra_editable)):
  104. try:
  105. install_packages_using_pip(
  106. ['.'], editable=True, no_deps=no_deps)
  107. finally:
  108. if clean:
  109. # Clean build artifacts
  110. tmp_build_dir = os.path.join(self.root_dir, 'build')
  111. if os.path.exists(tmp_build_dir):
  112. shutil.rmtree(tmp_build_dir)
  113. def uninstall_package(self):
  114. """ uninstall_package """
  115. uninstall_package_using_pip(self.pkg_name)
  116. def clone(self, *args, **kwargs):
  117. """ clone """
  118. return RepositoryGroupCloner([self]).clone(*args, **kwargs)
  119. def remove(self, *args, **kwargs):
  120. """ remove """
  121. return RepositoryGroupCloner([self]).remove(*args, **kwargs)
  122. def clone_repos(self, platform=None):
  123. """ clone_repos """
  124. branch = self.meta.get('branch', None)
  125. repo_url = f'https://{platform}{self.repo_url}'
  126. print(self.repo_parent_dir)
  127. with switch_working_dir(self.repo_parent_dir):
  128. clone_repos_using_git(repo_url, branch=branch)
  129. def update_repos(self):
  130. """ update_repos """
  131. with switch_working_dir(self.root_dir):
  132. try:
  133. update_repos_using_git()
  134. except Exception as e:
  135. logging.warning(
  136. f"Pull {self.name} from {self.repo_url} failed, check your network connection."
  137. )
  138. def remove_repos(self):
  139. """ remove_repos """
  140. with switch_working_dir(self.repo_parent_dir):
  141. remove_repos_using_rm(self.name)
  142. def wheel(self, dst_dir):
  143. """ wheel """
  144. with tempfile.TemporaryDirectory() as td:
  145. tmp_repo_dir = osp.join(td, self.name)
  146. tmp_dst_dir = osp.join(td, 'dist')
  147. shutil.copytree(self.root_dir, tmp_repo_dir, symlinks=False)
  148. # NOTE: Installation of the repo relies on `self.main_req_file` in root directory
  149. # Thus, we overwrite the content of it.
  150. main_req_file_path = osp.join(tmp_repo_dir, self.main_req_file)
  151. deps_str = self.get_deps()
  152. with open(main_req_file_path, 'w', encoding='utf-8') as f:
  153. f.write(deps_str)
  154. install_packages_using_pip([], req_files=[main_req_file_path])
  155. with switch_working_dir(tmp_repo_dir):
  156. build_wheel_using_pip('.', tmp_dst_dir)
  157. shutil.copytree(tmp_dst_dir, dst_dir)
  158. def _get_lib(self, load=True):
  159. """ _get_lib """
  160. import importlib.util
  161. importlib.invalidate_caches()
  162. if load:
  163. try:
  164. with mute():
  165. return importlib.import_module(self.lib_name)
  166. except ImportError:
  167. return None
  168. else:
  169. spec = importlib.util.find_spec(self.lib_name)
  170. if spec is not None and not osp.exists(spec.origin):
  171. return None
  172. else:
  173. return spec
  174. def get_pdx(self):
  175. """ get_pdx """
  176. return importlib.import_module(self.pdx_mod_name)
  177. def get_deps(self):
  178. """ get_deps """
  179. # Merge requirement files
  180. req_list = [self.main_req_file]
  181. req_list.extend(self.meta.get('extra_req_files', []))
  182. deps = []
  183. for req in req_list:
  184. with open(osp.join(self.root_dir, req), 'r', encoding='utf-8') as f:
  185. deps.append(f.read())
  186. for dep in self.meta.get('pdx_pkg_deps', []):
  187. deps.append(dep)
  188. deps = '\n'.join(deps)
  189. return deps
  190. def get_version(self):
  191. """ get_version """
  192. version_file = osp.join(self.root_dir, '.pdx_gen.version')
  193. with open(version_file, 'r', encoding='utf-8') as f:
  194. lines = f.readlines()
  195. sta_ver = lines[0].rstrip()
  196. commit = lines[1].rstrip()
  197. ret = [sta_ver, commit]
  198. # TODO: Get dynamic version in a subprocess.
  199. ret.append(None)
  200. return ret
  201. def __str__(self):
  202. return f"({self.name}, {id(self)})"
  203. class RepositoryGroupInstaller(object):
  204. """ RepositoryGroupInstaller """
  205. def __init__(self, repos):
  206. super().__init__()
  207. self.repos = repos
  208. def install(self, force_reinstall=False, no_deps=False, constraints=None):
  209. """ install """
  210. # Rollback on failure is not yet supported. A failed installation
  211. # could leave a broken environment.
  212. if force_reinstall:
  213. self.uninstall()
  214. ins_flags = []
  215. repos = self._sort_repos(self.repos, check_missing=True)
  216. for repo in repos:
  217. if force_reinstall or not repo.check_installation():
  218. ins_flags.append(True)
  219. else:
  220. ins_flags.append(False)
  221. if not no_deps:
  222. # We collect the dependencies and install them all at once
  223. # such that we can make use of the pip resolver.
  224. self.install_deps(constraints=constraints)
  225. # XXX: For historical reasons the repo packages are sequentially
  226. # installed, and we have no failure rollbacks. Meanwhile, installation
  227. # failure of one repo package aborts the entire installation process.
  228. for ins_flag, repo in zip(ins_flags, repos):
  229. if ins_flag:
  230. repo.install_package(no_deps=True)
  231. def uninstall(self):
  232. """ uninstall """
  233. repos = self._sort_repos(self.repos, check_missing=False)
  234. repos = repos[::-1]
  235. for repo in repos:
  236. if repo.check_installation():
  237. # NOTE: Dependencies are not uninstalled.
  238. repo.uninstall_package()
  239. def update(self):
  240. """ update """
  241. for repo in self.repos:
  242. repo.update_repos()
  243. def get_deps(self):
  244. """ get_deps """
  245. deps_list = []
  246. repos = self._sort_repos(self.repos, check_missing=True)
  247. for repo in repos:
  248. deps = repo.get_deps()
  249. deps = self._normalize_deps(
  250. deps, headline=f"# {repo.name} dependencies")
  251. deps_list.append(deps)
  252. # Add an extra new line to separate dependencies of different repos.
  253. return '\n\n'.join(deps_list)
  254. def install_deps(self, constraints):
  255. """ install_deps """
  256. deps_str = self.get_deps()
  257. with tempfile.TemporaryDirectory() as td:
  258. req_file = os.path.join(td, 'requirements.txt')
  259. with open(req_file, 'w', encoding='utf-8') as fr:
  260. fr.write(deps_str)
  261. if constraints is not None:
  262. cons_file = os.path.join(td, 'constraints.txt')
  263. with open(cons_file, 'w', encoding='utf-8') as fc:
  264. fc.write(constraints)
  265. cons_files = [cons_file]
  266. else:
  267. cons_files = []
  268. install_packages_using_pip(
  269. [], req_files=[req_file], cons_files=cons_files)
  270. def _sort_repos(self, repos, check_missing=False):
  271. # We sort the repos to ensure that the dependencies precede the
  272. # dependant in the list.
  273. def _parse_repo_deps(name, repo_meta):
  274. ret = []
  275. for n in repo_meta.get('requires', []):
  276. ret.extend(_parse_repo_deps(n, get_repo_meta(n)))
  277. ret.append((name, repo_meta))
  278. return ret
  279. name_meta_pairs = []
  280. for repo in repos:
  281. name_meta_pairs.extend(_parse_repo_deps(repo.name, repo.meta))
  282. unique_pairs = []
  283. hashset = set()
  284. for name, meta in name_meta_pairs:
  285. if name in hashset:
  286. continue
  287. else:
  288. unique_pairs.append((name, meta))
  289. hashset.add(name)
  290. sorted_repos = []
  291. missing_names = []
  292. name2repo = {repo.name: repo for repo in repos}
  293. for name, meta in unique_pairs:
  294. if name in name2repo:
  295. repo = name2repo[name]
  296. sorted_repos.append(repo)
  297. else:
  298. missing_names.append(name)
  299. if check_missing and len(missing_names) > 0:
  300. be = 'is' if len(missing_names) == 1 else 'are'
  301. raise RuntimeError(
  302. f"{missing_names} {be} required in the installation.")
  303. else:
  304. assert len(sorted_repos) == len(self.repos)
  305. return sorted_repos
  306. def _normalize_deps(self, deps, headline=None):
  307. repo_pkgs = set(repo.pkg_name for repo in self.repos)
  308. normed_lines = []
  309. if headline is not None:
  310. normed_lines.append(headline)
  311. for line in deps.splitlines():
  312. line_s = line.strip()
  313. if len(line_s) == 0 or line_s.startswith('#'):
  314. continue
  315. # If `line` is not a comment, it must be a requirement specifier.
  316. # Other forms may cause a parse error.
  317. n, e, v, m = to_dep_spec_pep508(line_s)
  318. if isinstance(v, str):
  319. raise RuntimeError(
  320. "Currently, URL based lookup is not supported.")
  321. if n in repo_pkgs:
  322. # Skip repo packages
  323. continue
  324. else:
  325. line_n = [n]
  326. fe = f"[{','.join(e)}]" if e else ''
  327. if fe:
  328. line_n.append(fe)
  329. fv = []
  330. for tup in v:
  331. fv.append(' '.join(tup))
  332. fv = ', '.join(fv) if fv else ''
  333. if fv:
  334. line_n.append(fv)
  335. if m is not None:
  336. fm = f"; {env_marker_ast2expr(m)}"
  337. line_n.append(fm)
  338. line_n = ' '.join(line_n)
  339. normed_lines.append(line_n)
  340. return '\n'.join(normed_lines)
  341. class RepositoryGroupCloner(object):
  342. """ RepositoryGroupCloner """
  343. def __init__(self, repos):
  344. super().__init__()
  345. self.repos = repos
  346. def clone(self, force_reclone=False, platform=None):
  347. """ clone """
  348. if force_reclone:
  349. self.remove()
  350. for repo in self.repos:
  351. repo.clone_repos(platform=platform)
  352. def remove(self):
  353. """ remove """
  354. for repo in self.repos:
  355. repo.remove_repos()