repo.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import importlib
  17. import tempfile
  18. import shutil
  19. from ..utils import logging
  20. from ..utils.download import download_and_extract
  21. from .meta import get_repo_meta, REPO_DOWNLOAD_BASE
  22. from .utils import (install_packages_using_pip, clone_repos_using_git,
  23. update_repos_using_git, uninstall_package_using_pip,
  24. remove_repos_using_rm, check_installation_using_pip,
  25. build_wheel_using_pip, mute, switch_working_dir,
  26. to_dep_spec_pep508, env_marker_ast2expr)
  27. __all__ = ['build_repo_instance', 'build_repo_group_installer']
  28. def build_repo_instance(repo_name, *args, **kwargs):
  29. """ build_repo_instance """
  30. # XXX: Hard-code type
  31. repo_cls = PPRepository
  32. repo_instance = repo_cls(repo_name, *args, **kwargs)
  33. return repo_instance
  34. def build_repo_group_installer(*repos):
  35. """ build_repo_group_installer """
  36. return RepositoryGroupInstaller(list(repos))
  37. def build_repo_group_cloner(*repos):
  38. """ build_repo_group_cloner """
  39. return RepositoryGroupCloner(list(repos))
  40. class PPRepository(object):
  41. """
  42. Installation, initialization, and PDX module import handler for a
  43. PaddlePaddle repository.
  44. """
  45. def __init__(self, name, repo_parent_dir, pdx_collection_mod):
  46. super().__init__()
  47. self.name = name
  48. self.repo_parent_dir = repo_parent_dir
  49. self.root_dir = osp.join(repo_parent_dir, self.name)
  50. self.meta = get_repo_meta(self.name)
  51. self.git_url = self.meta['git_url']
  52. self.pkg_name = self.meta['pkg_name']
  53. self.lib_name = self.meta['lib_name']
  54. self.pdx_mod_name = pdx_collection_mod.__name__ + '.' + self.meta[
  55. 'pdx_pkg_name']
  56. self.main_req_file = self.meta.get('main_req_file', 'requirements.txt')
  57. def initialize(self):
  58. """ initialize """
  59. if not self.check_installation(quick_check=True):
  60. return False
  61. if 'path_env' in self.meta:
  62. # Set env var
  63. os.environ[self.meta['path_env']] = osp.abspath(self.root_dir)
  64. # NOTE: By calling `self.get_pdx()` we actually loads the repo PDX package
  65. # and do all registration.
  66. self.get_pdx()
  67. return True
  68. def check_installation(self, quick_check=False):
  69. """ check_installation """
  70. if quick_check:
  71. lib = self._get_lib(load=False)
  72. return lib is not None
  73. else:
  74. # TODO: Also check if correct dependencies are installed.
  75. return check_installation_using_pip(self.pkg_name)
  76. def check_repo_exiting(self, quick_check=False):
  77. """ check_repo_exiting """
  78. return os.path.exists(os.path.join(self.root_dir, '.git'))
  79. def install(self, *args, **kwargs):
  80. """ install """
  81. return RepositoryGroupInstaller([self]).install(*args, **kwargs)
  82. def uninstall(self, *args, **kwargs):
  83. """ uninstall """
  84. return RepositoryGroupInstaller([self]).uninstall(*args, **kwargs)
  85. def install_deps(self, *args, **kwargs):
  86. """ install_deps """
  87. return RepositoryGroupInstaller([self]).install_deps(*args, **kwargs)
  88. def install_package(self, no_deps=False, clean=True):
  89. """ install_package """
  90. editable = self.meta.get('editable', True)
  91. extra_editable = self.meta.get('extra_editable', None)
  92. if editable:
  93. logging.warning(
  94. f"{self.pkg_name} will be installed in editable mode.")
  95. with switch_working_dir(self.root_dir):
  96. try:
  97. install_packages_using_pip(
  98. ['.'], editable=editable, no_deps=no_deps)
  99. finally:
  100. if clean:
  101. # Clean build artifacts
  102. tmp_build_dir = os.path.join(self.root_dir, 'build')
  103. if os.path.exists(tmp_build_dir):
  104. shutil.rmtree(tmp_build_dir)
  105. if extra_editable:
  106. with switch_working_dir(
  107. os.path.join(self.root_dir, extra_editable)):
  108. try:
  109. install_packages_using_pip(
  110. ['.'], editable=True, no_deps=no_deps)
  111. finally:
  112. if clean:
  113. # Clean build artifacts
  114. tmp_build_dir = os.path.join(self.root_dir, 'build')
  115. if os.path.exists(tmp_build_dir):
  116. shutil.rmtree(tmp_build_dir)
  117. def uninstall_package(self):
  118. """ uninstall_package """
  119. uninstall_package_using_pip(self.pkg_name)
  120. def clone(self, *args, **kwargs):
  121. """ clone """
  122. return RepositoryGroupCloner([self]).clone(*args, **kwargs)
  123. def remove(self, *args, **kwargs):
  124. """ remove """
  125. return RepositoryGroupCloner([self]).remove(*args, **kwargs)
  126. def clone_repos(self, platform=None):
  127. """ clone_repos """
  128. branch = self.meta.get('branch', None)
  129. git_url = f'https://{platform}{self.git_url}'
  130. # uncomment this if you prefer using ssh connection (requires additional setup)
  131. # if platform == 'github.com':
  132. # git_url = f'git@github.com:{self.git_url}'
  133. os.makedirs(self.repo_parent_dir, exist_ok=True)
  134. with switch_working_dir(self.repo_parent_dir):
  135. clone_repos_using_git(git_url, branch=branch)
  136. def download_repos(self):
  137. """ download and pull repos """
  138. download_url = f'{REPO_DOWNLOAD_BASE}{self.name}.tar'
  139. os.makedirs(self.repo_parent_dir, exist_ok=True)
  140. download_and_extract(download_url, self.repo_parent_dir, self.name)
  141. def update_repos(self, platform=None):
  142. """ update_repos """
  143. branch = self.meta.get('branch', None)
  144. git_url = f'https://{platform}{self.git_url}'
  145. with switch_working_dir(self.root_dir):
  146. try:
  147. update_repos_using_git(branch=branch, url=git_url)
  148. except Exception as e:
  149. logging.warning(
  150. f"Pull {self.name} from {self.git_url} failed, check your network connection."
  151. )
  152. def remove_repos(self):
  153. """ remove_repos """
  154. with switch_working_dir(self.repo_parent_dir):
  155. remove_repos_using_rm(self.name)
  156. def wheel(self, dst_dir):
  157. """ wheel """
  158. with tempfile.TemporaryDirectory() as td:
  159. tmp_repo_dir = osp.join(td, self.name)
  160. tmp_dst_dir = osp.join(td, 'dist')
  161. shutil.copytree(self.root_dir, tmp_repo_dir, symlinks=False)
  162. # NOTE: Installation of the repo relies on `self.main_req_file` in root directory
  163. # Thus, we overwrite the content of it.
  164. main_req_file_path = osp.join(tmp_repo_dir, self.main_req_file)
  165. deps_str = self.get_deps()
  166. with open(main_req_file_path, 'w', encoding='utf-8') as f:
  167. f.write(deps_str)
  168. install_packages_using_pip([], req_files=[main_req_file_path])
  169. with switch_working_dir(tmp_repo_dir):
  170. build_wheel_using_pip('.', tmp_dst_dir)
  171. shutil.copytree(tmp_dst_dir, dst_dir)
  172. def _get_lib(self, load=True):
  173. """ _get_lib """
  174. import importlib.util
  175. importlib.invalidate_caches()
  176. if load:
  177. try:
  178. with mute():
  179. return importlib.import_module(self.lib_name)
  180. except ImportError:
  181. return None
  182. else:
  183. spec = importlib.util.find_spec(self.lib_name)
  184. if spec is not None and not osp.exists(spec.origin):
  185. return None
  186. else:
  187. return spec
  188. def get_pdx(self):
  189. """ get_pdx """
  190. return importlib.import_module(self.pdx_mod_name)
  191. def get_deps(self):
  192. """ get_deps """
  193. # Merge requirement files
  194. req_list = [self.main_req_file]
  195. req_list.extend(self.meta.get('extra_req_files', []))
  196. deps = []
  197. for req in req_list:
  198. with open(osp.join(self.root_dir, req), 'r', encoding='utf-8') as f:
  199. deps.append(f.read())
  200. for dep in self.meta.get('pdx_pkg_deps', []):
  201. deps.append(dep)
  202. deps = '\n'.join(deps)
  203. return deps
  204. def get_version(self):
  205. """ get_version """
  206. version_file = osp.join(self.root_dir, '.pdx_gen.version')
  207. with open(version_file, 'r', encoding='utf-8') as f:
  208. lines = f.readlines()
  209. sta_ver = lines[0].rstrip()
  210. commit = lines[1].rstrip()
  211. ret = [sta_ver, commit]
  212. # TODO: Get dynamic version in a subprocess.
  213. ret.append(None)
  214. return ret
  215. def __str__(self):
  216. return f"({self.name}, {id(self)})"
  217. class RepositoryGroupInstaller(object):
  218. """ RepositoryGroupInstaller """
  219. def __init__(self, repos):
  220. super().__init__()
  221. self.repos = repos
  222. def install(self, force_reinstall=False, no_deps=False, constraints=None):
  223. """ install """
  224. # Rollback on failure is not yet supported. A failed installation
  225. # could leave a broken environment.
  226. if force_reinstall:
  227. self.uninstall()
  228. ins_flags = []
  229. repos = self._sort_repos(self.repos, check_missing=True)
  230. for repo in repos:
  231. if force_reinstall or not repo.check_installation():
  232. ins_flags.append(True)
  233. else:
  234. ins_flags.append(False)
  235. if not no_deps:
  236. # We collect the dependencies and install them all at once
  237. # such that we can make use of the pip resolver.
  238. self.install_deps(constraints=constraints)
  239. # XXX: For historical reasons the repo packages are sequentially
  240. # installed, and we have no failure rollbacks. Meanwhile, installation
  241. # failure of one repo package aborts the entire installation process.
  242. for ins_flag, repo in zip(ins_flags, repos):
  243. if ins_flag:
  244. repo.install_package(no_deps=True)
  245. def uninstall(self):
  246. """ uninstall """
  247. repos = self._sort_repos(self.repos, check_missing=False)
  248. repos = repos[::-1]
  249. for repo in repos:
  250. if repo.check_installation():
  251. # NOTE: Dependencies are not uninstalled.
  252. repo.uninstall_package()
  253. def update(self):
  254. """ update """
  255. for repo in self.repos:
  256. repo.update_repos()
  257. def get_deps(self):
  258. """ get_deps """
  259. deps_list = []
  260. repos = self._sort_repos(self.repos, check_missing=True)
  261. for repo in repos:
  262. deps = repo.get_deps()
  263. deps = self._normalize_deps(
  264. deps, headline=f"# {repo.name} dependencies")
  265. deps_list.append(deps)
  266. # Add an extra new line to separate dependencies of different repos.
  267. return '\n\n'.join(deps_list)
  268. def install_deps(self, constraints):
  269. """ install_deps """
  270. deps_str = self.get_deps()
  271. with tempfile.TemporaryDirectory() as td:
  272. req_file = os.path.join(td, 'requirements.txt')
  273. with open(req_file, 'w', encoding='utf-8') as fr:
  274. fr.write(deps_str)
  275. if constraints is not None:
  276. cons_file = os.path.join(td, 'constraints.txt')
  277. with open(cons_file, 'w', encoding='utf-8') as fc:
  278. fc.write(constraints)
  279. cons_files = [cons_file]
  280. else:
  281. cons_files = []
  282. install_packages_using_pip(
  283. [], req_files=[req_file], cons_files=cons_files)
  284. def _sort_repos(self, repos, check_missing=False):
  285. # We sort the repos to ensure that the dependencies precede the
  286. # dependant in the list.
  287. name_meta_pairs = []
  288. for repo in repos:
  289. name_meta_pairs.append((repo.name, repo.meta))
  290. unique_pairs = []
  291. hashset = set()
  292. for name, meta in name_meta_pairs:
  293. if name in hashset:
  294. continue
  295. else:
  296. unique_pairs.append((name, meta))
  297. hashset.add(name)
  298. sorted_repos = []
  299. missing_names = []
  300. name2repo = {repo.name: repo for repo in repos}
  301. for name, meta in unique_pairs:
  302. if name in name2repo:
  303. repo = name2repo[name]
  304. sorted_repos.append(repo)
  305. else:
  306. missing_names.append(name)
  307. if check_missing and len(missing_names) > 0:
  308. be = 'is' if len(missing_names) == 1 else 'are'
  309. raise RuntimeError(
  310. f"{missing_names} {be} required in the installation.")
  311. else:
  312. assert len(sorted_repos) == len(self.repos)
  313. return sorted_repos
  314. def _normalize_deps(self, deps, headline=None):
  315. repo_pkgs = set(repo.pkg_name for repo in self.repos)
  316. normed_lines = []
  317. if headline is not None:
  318. normed_lines.append(headline)
  319. for line in deps.splitlines():
  320. line_s = line.strip()
  321. if len(line_s) == 0 or line_s.startswith('#'):
  322. continue
  323. # If `line` is not a comment, it must be a requirement specifier.
  324. # Other forms may cause a parse error.
  325. n, e, v, m = to_dep_spec_pep508(line_s)
  326. if isinstance(v, str):
  327. raise RuntimeError(
  328. "Currently, URL based lookup is not supported.")
  329. if n in repo_pkgs:
  330. # Skip repo packages
  331. continue
  332. else:
  333. line_n = [n]
  334. fe = f"[{','.join(e)}]" if e else ''
  335. if fe:
  336. line_n.append(fe)
  337. fv = []
  338. for tup in v:
  339. fv.append(' '.join(tup))
  340. fv = ', '.join(fv) if fv else ''
  341. if fv:
  342. line_n.append(fv)
  343. if m is not None:
  344. fm = f"; {env_marker_ast2expr(m)}"
  345. line_n.append(fm)
  346. line_n = ' '.join(line_n)
  347. normed_lines.append(line_n)
  348. return '\n'.join(normed_lines)
  349. class RepositoryGroupCloner(object):
  350. """ RepositoryGroupCloner """
  351. def __init__(self, repos):
  352. super().__init__()
  353. self.repos = repos
  354. def clone(self, force_clone=False, platform=None):
  355. """ clone """
  356. if force_clone:
  357. self.remove()
  358. for repo in self.repos:
  359. repo.clone_repos(platform=platform)
  360. else:
  361. for repo in self.repos:
  362. repo.download_repos()
  363. repo.update_repos(platform=platform)
  364. def remove(self):
  365. """ remove """
  366. for repo in self.repos:
  367. repo.remove_repos()