repo.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import importlib
  17. import tempfile
  18. import shutil
  19. from ..utils import logging
  20. from ..utils.file_interface import custom_open
  21. from ..utils.download import download_and_extract
  22. from .meta import get_repo_meta, REPO_DOWNLOAD_BASE
  23. from .utils import (
  24. install_packages_using_pip,
  25. fetch_repo_using_git,
  26. reset_repo_using_git,
  27. uninstall_package_using_pip,
  28. remove_repo_using_rm,
  29. check_installation_using_pip,
  30. build_wheel_using_pip,
  31. mute,
  32. switch_working_dir,
  33. to_dep_spec_pep508,
  34. env_marker_ast2expr,
  35. install_external_deps,
  36. )
  37. __all__ = ["build_repo_instance", "build_repo_group_installer"]
  38. def build_repo_instance(repo_name, *args, **kwargs):
  39. """build_repo_instance"""
  40. # XXX: Hard-code type
  41. repo_cls = PPRepository
  42. repo_instance = repo_cls(repo_name, *args, **kwargs)
  43. return repo_instance
  44. def build_repo_group_installer(*repos):
  45. """build_repo_group_installer"""
  46. return RepositoryGroupInstaller(list(repos))
  47. def build_repo_group_getter(*repos):
  48. """build_repo_group_getter"""
  49. return RepositoryGroupGetter(list(repos))
  50. class PPRepository(object):
  51. """
  52. Installation, initialization, and PDX module import handler for a
  53. PaddlePaddle repository.
  54. """
  55. def __init__(self, name, repo_parent_dir, pdx_collection_mod):
  56. super().__init__()
  57. self.name = name
  58. self.repo_parent_dir = repo_parent_dir
  59. self.root_dir = osp.join(repo_parent_dir, self.name)
  60. self.meta = get_repo_meta(self.name)
  61. self.git_path = self.meta["git_path"]
  62. self.pkg_name = self.meta["pkg_name"]
  63. self.lib_name = self.meta["lib_name"]
  64. self.pdx_mod_name = (
  65. pdx_collection_mod.__name__ + "." + self.meta["pdx_pkg_name"]
  66. )
  67. self.main_req_file = self.meta.get("main_req_file", "requirements.txt")
  68. def initialize(self):
  69. """initialize"""
  70. if not self.check_installation(quick_check=True):
  71. return False
  72. if "path_env" in self.meta:
  73. # Set env var
  74. os.environ[self.meta["path_env"]] = osp.abspath(self.root_dir)
  75. # NOTE: By calling `self.get_pdx()` we actually loads the repo PDX package
  76. # and do all registration.
  77. self.get_pdx()
  78. return True
  79. def check_installation(self, quick_check=False):
  80. """check_installation"""
  81. if quick_check:
  82. lib = self._get_lib(load=False)
  83. return lib is not None
  84. else:
  85. # TODO: Also check if correct dependencies are installed.
  86. return check_installation_using_pip(self.pkg_name)
  87. def replace_repo_deps(self, deps_to_replace, src_requirements):
  88. """replace_repo_deps"""
  89. with custom_open(src_requirements, "r") as file:
  90. lines = file.readlines()
  91. existing_deps = []
  92. for line in lines:
  93. line = line.strip()
  94. if not line or line.startswith("#"):
  95. continue
  96. dep_to_replace = next((dep for dep in deps_to_replace if dep in line), None)
  97. if dep_to_replace:
  98. existing_deps.append(
  99. f"{dep_to_replace}=={deps_to_replace[dep_to_replace]}"
  100. )
  101. else:
  102. existing_deps.append(line)
  103. with open(src_requirements, "w") as file:
  104. file.writelines([l + "\n" for l in existing_deps])
  105. def check_repo_exiting(self, quick_check=False):
  106. """check_repo_exiting"""
  107. return os.path.exists(os.path.join(self.root_dir, ".git"))
  108. def install(self, *args, **kwargs):
  109. """install"""
  110. return RepositoryGroupInstaller([self]).install(*args, **kwargs)
  111. def uninstall(self, *args, **kwargs):
  112. """uninstall"""
  113. return RepositoryGroupInstaller([self]).uninstall(*args, **kwargs)
  114. def install_deps(self, *args, **kwargs):
  115. """install_deps"""
  116. return RepositoryGroupInstaller([self]).install_deps(*args, **kwargs)
  117. def install_package(self, no_deps=False, clean=True, install_extra_only=False):
  118. """install_package"""
  119. editable = self.meta.get("editable", True)
  120. extra_editable = self.meta.get("extra_editable", None)
  121. if editable:
  122. logging.warning(f"{self.pkg_name} will be installed in editable mode.")
  123. with switch_working_dir(self.root_dir):
  124. if install_extra_only:
  125. src_requirements = os.path.join(self.root_dir, "requirements.txt")
  126. paddlex_requirements = os.path.join(
  127. self.root_dir, "requirements_paddlex.txt"
  128. )
  129. shutil.copy(paddlex_requirements, src_requirements)
  130. try:
  131. install_packages_using_pip(["."], editable=editable, no_deps=no_deps)
  132. install_external_deps(self.name, self.root_dir)
  133. finally:
  134. if clean:
  135. # Clean build artifacts
  136. tmp_build_dir = os.path.join(self.root_dir, "build")
  137. if os.path.exists(tmp_build_dir):
  138. shutil.rmtree(tmp_build_dir)
  139. if extra_editable:
  140. with switch_working_dir(os.path.join(self.root_dir, extra_editable)):
  141. try:
  142. install_packages_using_pip(["."], editable=True, no_deps=no_deps)
  143. finally:
  144. if clean:
  145. # Clean build artifacts
  146. tmp_build_dir = os.path.join(self.root_dir, "build")
  147. if os.path.exists(tmp_build_dir):
  148. shutil.rmtree(tmp_build_dir)
  149. def uninstall_package(self):
  150. """uninstall_package"""
  151. uninstall_package_using_pip(self.pkg_name)
  152. def download(self):
  153. """download from remote"""
  154. download_url = f"{REPO_DOWNLOAD_BASE}{self.name}.tar"
  155. os.makedirs(self.repo_parent_dir, exist_ok=True)
  156. download_and_extract(download_url, self.repo_parent_dir, self.name)
  157. # reset_repo_using_git('FETCH_HEAD')
  158. def remove(self):
  159. """remove"""
  160. with switch_working_dir(self.repo_parent_dir):
  161. remove_repo_using_rm(self.name)
  162. def update(self, platform=None):
  163. """update"""
  164. branch = self.meta.get("branch", None)
  165. git_url = f"https://{platform}{self.git_path}"
  166. with switch_working_dir(self.root_dir):
  167. try:
  168. fetch_repo_using_git(branch=branch, url=git_url)
  169. reset_repo_using_git("FETCH_HEAD")
  170. except Exception as e:
  171. logging.warning(
  172. f"Update {self.name} from {git_url} failed, check your network connection. Error:\n{e}"
  173. )
  174. def wheel(self, dst_dir):
  175. """wheel"""
  176. with tempfile.TemporaryDirectory() as td:
  177. tmp_repo_dir = osp.join(td, self.name)
  178. tmp_dst_dir = osp.join(td, "dist")
  179. shutil.copytree(self.root_dir, tmp_repo_dir, symlinks=False)
  180. # NOTE: Installation of the repo relies on `self.main_req_file` in root directory
  181. # Thus, we overwrite the content of it.
  182. main_req_file_path = osp.join(tmp_repo_dir, self.main_req_file)
  183. deps_str = self.get_deps()
  184. with open(main_req_file_path, "w", encoding="utf-8") as f:
  185. f.write(deps_str)
  186. install_packages_using_pip([], req_files=[main_req_file_path])
  187. with switch_working_dir(tmp_repo_dir):
  188. build_wheel_using_pip(".", tmp_dst_dir)
  189. shutil.copytree(tmp_dst_dir, dst_dir)
  190. def _get_lib(self, load=True):
  191. """_get_lib"""
  192. import importlib.util
  193. importlib.invalidate_caches()
  194. if load:
  195. try:
  196. with mute():
  197. return importlib.import_module(self.lib_name)
  198. except ImportError:
  199. return None
  200. else:
  201. spec = importlib.util.find_spec(self.lib_name)
  202. if spec is not None and not osp.exists(spec.origin):
  203. return None
  204. else:
  205. return spec
  206. def get_pdx(self):
  207. """get_pdx"""
  208. return importlib.import_module(self.pdx_mod_name)
  209. def get_deps(self, install_extra_only=False, deps_to_replace=None):
  210. """get_deps"""
  211. # Merge requirement files
  212. if install_extra_only:
  213. req_list = []
  214. else:
  215. req_list = [self.main_req_file]
  216. req_list.extend(self.meta.get("extra_req_files", []))
  217. if deps_to_replace is not None:
  218. deps_dict = {}
  219. for dep in deps_to_replace:
  220. part, version = dep.split("=")
  221. repo_name, dep_name = part.split(".")
  222. deps_dict[repo_name] = {dep_name: version}
  223. src_requirements = os.path.join(self.root_dir, "requirements.txt")
  224. if self.name in deps_dict:
  225. self.replace_repo_deps(deps_dict[self.name], src_requirements)
  226. deps = []
  227. for req in req_list:
  228. with open(osp.join(self.root_dir, req), "r", encoding="utf-8") as f:
  229. deps.append(f.read())
  230. for dep in self.meta.get("pdx_pkg_deps", []):
  231. deps.append(dep)
  232. deps = "\n".join(deps)
  233. return deps
  234. def get_version(self):
  235. """get_version"""
  236. version_file = osp.join(self.root_dir, ".pdx_gen.version")
  237. with open(version_file, "r", encoding="utf-8") as f:
  238. lines = f.readlines()
  239. sta_ver = lines[0].rstrip()
  240. commit = lines[1].rstrip()
  241. ret = [sta_ver, commit]
  242. # TODO: Get dynamic version in a subprocess.
  243. ret.append(None)
  244. return ret
  245. def __str__(self):
  246. return f"({self.name}, {id(self)})"
  247. class RepositoryGroupInstaller(object):
  248. """RepositoryGroupInstaller"""
  249. def __init__(self, repos):
  250. super().__init__()
  251. self.repos = repos
  252. def install(
  253. self,
  254. force_reinstall=False,
  255. no_deps=False,
  256. constraints=None,
  257. deps_to_replace=None,
  258. ):
  259. """install"""
  260. # Rollback on failure is not yet supported. A failed installation
  261. # could leave a broken environment.
  262. if force_reinstall:
  263. self.uninstall()
  264. ins_flags = []
  265. repos = self._sort_repos(self.repos, check_missing=True)
  266. for repo in repos:
  267. if force_reinstall or not repo.check_installation():
  268. ins_flags.append(True)
  269. else:
  270. ins_flags.append(False)
  271. if not no_deps:
  272. # We collect the dependencies and install them all at once
  273. # such that we can make use of the pip resolver.
  274. self.install_deps(constraints=constraints, deps_to_replace=deps_to_replace)
  275. # XXX: For historical reasons the repo packages are sequentially
  276. # installed, and we have no failure rollbacks. Meanwhile, installation
  277. # failure of one repo package aborts the entire installation process.
  278. for ins_flag, repo in zip(ins_flags, repos):
  279. if ins_flag:
  280. if repo.name in ["PaddleVideo"]:
  281. repo.install_package(
  282. no_deps=True,
  283. install_extra_only=True,
  284. )
  285. else:
  286. repo.install_package(no_deps=True)
  287. def uninstall(self):
  288. """uninstall"""
  289. repos = self._sort_repos(self.repos, check_missing=False)
  290. repos = repos[::-1]
  291. for repo in repos:
  292. if repo.check_installation():
  293. # NOTE: Dependencies are not uninstalled.
  294. repo.uninstall_package()
  295. def get_deps(self, deps_to_replace=None):
  296. """get_deps"""
  297. deps_list = []
  298. repos = self._sort_repos(self.repos, check_missing=True)
  299. for repo in repos:
  300. if repo.name in ["PaddleVideo"]:
  301. deps = repo.get_deps(
  302. install_extra_only=True, deps_to_replace=deps_to_replace
  303. )
  304. else:
  305. deps = repo.get_deps(deps_to_replace=deps_to_replace)
  306. deps = self._normalize_deps(deps, headline=f"# {repo.name} dependencies")
  307. deps_list.append(deps)
  308. # Add an extra new line to separate dependencies of different repos.
  309. return "\n\n".join(deps_list)
  310. def install_deps(self, constraints, deps_to_replace=None):
  311. """install_deps"""
  312. deps_str = self.get_deps(deps_to_replace=deps_to_replace)
  313. with tempfile.TemporaryDirectory() as td:
  314. req_file = os.path.join(td, "requirements.txt")
  315. with open(req_file, "w", encoding="utf-8") as fr:
  316. fr.write(deps_str)
  317. if constraints is not None:
  318. cons_file = os.path.join(td, "constraints.txt")
  319. with open(cons_file, "w", encoding="utf-8") as fc:
  320. fc.write(constraints)
  321. cons_files = [cons_file]
  322. else:
  323. cons_files = []
  324. install_packages_using_pip([], req_files=[req_file], cons_files=cons_files)
  325. def _sort_repos(self, repos, check_missing=False):
  326. # We sort the repos to ensure that the dependencies precede the
  327. # dependant in the list.
  328. name_meta_pairs = []
  329. for repo in repos:
  330. name_meta_pairs.append((repo.name, repo.meta))
  331. unique_pairs = []
  332. hashset = set()
  333. for name, meta in name_meta_pairs:
  334. if name in hashset:
  335. continue
  336. else:
  337. unique_pairs.append((name, meta))
  338. hashset.add(name)
  339. sorted_repos = []
  340. missing_names = []
  341. name2repo = {repo.name: repo for repo in repos}
  342. for name, meta in unique_pairs:
  343. if name in name2repo:
  344. repo = name2repo[name]
  345. sorted_repos.append(repo)
  346. else:
  347. missing_names.append(name)
  348. if check_missing and len(missing_names) > 0:
  349. be = "is" if len(missing_names) == 1 else "are"
  350. raise RuntimeError(f"{missing_names} {be} required in the installation.")
  351. else:
  352. assert len(sorted_repos) == len(self.repos)
  353. return sorted_repos
  354. def _normalize_deps(self, deps, headline=None):
  355. repo_pkgs = set(repo.pkg_name for repo in self.repos)
  356. normed_lines = []
  357. if headline is not None:
  358. normed_lines.append(headline)
  359. for line in deps.splitlines():
  360. line_s = line.strip()
  361. if len(line_s) == 0 or line_s.startswith("#"):
  362. continue
  363. # If `line` is not a comment, it must be a requirement specifier.
  364. # Other forms may cause a parse error.
  365. n, e, v, m = to_dep_spec_pep508(line_s)
  366. if isinstance(v, str):
  367. raise RuntimeError("Currently, URL based lookup is not supported.")
  368. if n in repo_pkgs:
  369. # Skip repo packages
  370. continue
  371. elif check_installation_using_pip(n):
  372. continue
  373. else:
  374. line_n = [n]
  375. fe = f"[{','.join(e)}]" if e else ""
  376. if fe:
  377. line_n.append(fe)
  378. fv = []
  379. for tup in v:
  380. fv.append(" ".join(tup))
  381. fv = ", ".join(fv) if fv else ""
  382. if fv:
  383. line_n.append(fv)
  384. if m is not None:
  385. fm = f"; {env_marker_ast2expr(m)}"
  386. line_n.append(fm)
  387. line_n = " ".join(line_n)
  388. normed_lines.append(line_n)
  389. return "\n".join(normed_lines)
  390. class RepositoryGroupGetter(object):
  391. """RepositoryGroupGetter"""
  392. def __init__(self, repos):
  393. super().__init__()
  394. self.repos = repos
  395. def get(self, force=False, platform=None):
  396. """clone"""
  397. if force:
  398. self.remove()
  399. for repo in self.repos:
  400. repo.download()
  401. repo.update(platform=platform)
  402. def remove(self):
  403. """remove"""
  404. for repo in self.repos:
  405. repo.remove()