repo.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import importlib
  17. import tempfile
  18. import shutil
  19. from ..utils import logging
  20. from ..utils.file_interface import custom_open
  21. from ..utils.download import download_and_extract
  22. from .meta import get_repo_meta, REPO_DOWNLOAD_BASE
  23. from .utils import (
  24. install_packages_using_pip,
  25. fetch_repo_using_git,
  26. reset_repo_using_git,
  27. uninstall_package_using_pip,
  28. remove_repo_using_rm,
  29. check_installation_using_pip,
  30. build_wheel_using_pip,
  31. mute,
  32. switch_working_dir,
  33. to_dep_spec_pep508,
  34. env_marker_ast2expr,
  35. install_external_deps,
  36. )
  37. __all__ = ["build_repo_instance", "build_repo_group_installer"]
  38. def build_repo_instance(repo_name, *args, **kwargs):
  39. """build_repo_instance"""
  40. # XXX: Hard-code type
  41. repo_cls = PPRepository
  42. repo_instance = repo_cls(repo_name, *args, **kwargs)
  43. return repo_instance
  44. def build_repo_group_installer(*repos):
  45. """build_repo_group_installer"""
  46. return RepositoryGroupInstaller(list(repos))
  47. def build_repo_group_getter(*repos):
  48. """build_repo_group_getter"""
  49. return RepositoryGroupGetter(list(repos))
  50. class PPRepository(object):
  51. """
  52. Installation, initialization, and PDX module import handler for a
  53. PaddlePaddle repository.
  54. """
  55. def __init__(self, name, repo_parent_dir, pdx_collection_mod):
  56. super().__init__()
  57. self.name = name
  58. self.repo_parent_dir = repo_parent_dir
  59. self.root_dir = osp.join(repo_parent_dir, self.name)
  60. self.meta = get_repo_meta(self.name)
  61. self.git_path = self.meta["git_path"]
  62. self.pkg_name = self.meta["pkg_name"]
  63. self.lib_name = self.meta["lib_name"]
  64. self.pdx_mod_name = (
  65. pdx_collection_mod.__name__ + "." + self.meta["pdx_pkg_name"]
  66. )
  67. self.main_req_file = self.meta.get("main_req_file", "requirements.txt")
  68. def initialize(self):
  69. """initialize"""
  70. if not self.check_installation(quick_check=True):
  71. return False
  72. if "path_env" in self.meta:
  73. # Set env var
  74. os.environ[self.meta["path_env"]] = osp.abspath(self.root_dir)
  75. # NOTE: By calling `self.get_pdx()` we actually loads the repo PDX package
  76. # and do all registration.
  77. self.get_pdx()
  78. return True
  79. def check_installation(self, quick_check=False):
  80. """check_installation"""
  81. if quick_check:
  82. lib = self._get_lib(load=False)
  83. return lib is not None
  84. else:
  85. # TODO: Also check if correct dependencies are installed.
  86. return check_installation_using_pip(self.pkg_name)
  87. def replace_repo_deps(self, deps_to_replace, src_requirements):
  88. """replace_repo_deps"""
  89. with custom_open(src_requirements, "r") as file:
  90. lines = file.readlines()
  91. existing_deps = []
  92. for line in lines:
  93. line = line.strip()
  94. if not line or line.startswith("#"):
  95. continue
  96. dep_to_replace = next((dep for dep in deps_to_replace if dep in line), None)
  97. if dep_to_replace:
  98. if deps_to_replace[dep_to_replace] == "None":
  99. continue
  100. else:
  101. existing_deps.append(
  102. f"{dep_to_replace}=={deps_to_replace[dep_to_replace]}"
  103. )
  104. else:
  105. existing_deps.append(line)
  106. with open(src_requirements, "w") as file:
  107. file.writelines([l + "\n" for l in existing_deps])
  108. def check_repo_exiting(self, quick_check=False):
  109. """check_repo_exiting"""
  110. return os.path.exists(os.path.join(self.root_dir, ".git"))
  111. def install(self, *args, **kwargs):
  112. """install"""
  113. return RepositoryGroupInstaller([self]).install(*args, **kwargs)
  114. def uninstall(self, *args, **kwargs):
  115. """uninstall"""
  116. return RepositoryGroupInstaller([self]).uninstall(*args, **kwargs)
  117. def install_deps(self, *args, **kwargs):
  118. """install_deps"""
  119. return RepositoryGroupInstaller([self]).install_deps(*args, **kwargs)
  120. def install_package(self, no_deps=False, clean=True, install_extra_only=False):
  121. """install_package"""
  122. editable = self.meta.get("editable", True)
  123. extra_editable = self.meta.get("extra_editable", None)
  124. if editable:
  125. logging.warning(f"{self.pkg_name} will be installed in editable mode.")
  126. with switch_working_dir(self.root_dir):
  127. if install_extra_only:
  128. src_requirements = os.path.join(self.root_dir, "requirements.txt")
  129. paddlex_requirements = os.path.join(
  130. self.root_dir, "requirements_paddlex.txt"
  131. )
  132. shutil.copy(paddlex_requirements, src_requirements)
  133. try:
  134. install_packages_using_pip(["."], editable=editable, no_deps=no_deps)
  135. install_external_deps(self.name, self.root_dir)
  136. finally:
  137. if clean:
  138. # Clean build artifacts
  139. tmp_build_dir = os.path.join(self.root_dir, "build")
  140. if os.path.exists(tmp_build_dir):
  141. shutil.rmtree(tmp_build_dir)
  142. if extra_editable:
  143. with switch_working_dir(os.path.join(self.root_dir, extra_editable)):
  144. try:
  145. install_packages_using_pip(["."], editable=True, no_deps=no_deps)
  146. finally:
  147. if clean:
  148. # Clean build artifacts
  149. tmp_build_dir = os.path.join(self.root_dir, "build")
  150. if os.path.exists(tmp_build_dir):
  151. shutil.rmtree(tmp_build_dir)
  152. def uninstall_package(self):
  153. """uninstall_package"""
  154. uninstall_package_using_pip(self.pkg_name)
  155. def download(self):
  156. """download from remote"""
  157. download_url = f"{REPO_DOWNLOAD_BASE}{self.name}.tar"
  158. os.makedirs(self.repo_parent_dir, exist_ok=True)
  159. download_and_extract(download_url, self.repo_parent_dir, self.name)
  160. # reset_repo_using_git('FETCH_HEAD')
  161. def remove(self):
  162. """remove"""
  163. with switch_working_dir(self.repo_parent_dir):
  164. remove_repo_using_rm(self.name)
  165. def update(self, platform=None):
  166. """update"""
  167. branch = self.meta.get("branch", None)
  168. git_url = f"https://{platform}{self.git_path}"
  169. with switch_working_dir(self.root_dir):
  170. try:
  171. fetch_repo_using_git(branch=branch, url=git_url)
  172. reset_repo_using_git("FETCH_HEAD")
  173. except Exception as e:
  174. logging.warning(
  175. f"Update {self.name} from {git_url} failed, check your network connection. Error:\n{e}"
  176. )
  177. def wheel(self, dst_dir):
  178. """wheel"""
  179. with tempfile.TemporaryDirectory() as td:
  180. tmp_repo_dir = osp.join(td, self.name)
  181. tmp_dst_dir = osp.join(td, "dist")
  182. shutil.copytree(self.root_dir, tmp_repo_dir, symlinks=False)
  183. # NOTE: Installation of the repo relies on `self.main_req_file` in root directory
  184. # Thus, we overwrite the content of it.
  185. main_req_file_path = osp.join(tmp_repo_dir, self.main_req_file)
  186. deps_str = self.get_deps()
  187. with open(main_req_file_path, "w", encoding="utf-8") as f:
  188. f.write(deps_str)
  189. install_packages_using_pip([], req_files=[main_req_file_path])
  190. with switch_working_dir(tmp_repo_dir):
  191. build_wheel_using_pip(".", tmp_dst_dir)
  192. shutil.copytree(tmp_dst_dir, dst_dir)
  193. def _get_lib(self, load=True):
  194. """_get_lib"""
  195. import importlib.util
  196. importlib.invalidate_caches()
  197. if load:
  198. try:
  199. with mute():
  200. return importlib.import_module(self.lib_name)
  201. except ImportError:
  202. return None
  203. else:
  204. spec = importlib.util.find_spec(self.lib_name)
  205. if spec is not None and not osp.exists(spec.origin):
  206. return None
  207. else:
  208. return spec
  209. def get_pdx(self):
  210. """get_pdx"""
  211. return importlib.import_module(self.pdx_mod_name)
  212. def get_deps(self, install_extra_only=False, deps_to_replace=None):
  213. """get_deps"""
  214. # Merge requirement files
  215. if install_extra_only:
  216. req_list = []
  217. else:
  218. req_list = [self.main_req_file]
  219. req_list.extend(self.meta.get("extra_req_files", []))
  220. if deps_to_replace is not None:
  221. deps_dict = {}
  222. for dep in deps_to_replace:
  223. part, version = dep.split("=")
  224. repo_name, dep_name = part.split(".")
  225. deps_dict[repo_name] = {dep_name: version}
  226. src_requirements = os.path.join(self.root_dir, "requirements.txt")
  227. if self.name in deps_dict:
  228. self.replace_repo_deps(deps_dict[self.name], src_requirements)
  229. deps = []
  230. for req in req_list:
  231. with open(osp.join(self.root_dir, req), "r", encoding="utf-8") as f:
  232. deps.append(f.read())
  233. for dep in self.meta.get("pdx_pkg_deps", []):
  234. deps.append(dep)
  235. deps = "\n".join(deps)
  236. return deps
  237. def get_version(self):
  238. """get_version"""
  239. version_file = osp.join(self.root_dir, ".pdx_gen.version")
  240. with open(version_file, "r", encoding="utf-8") as f:
  241. lines = f.readlines()
  242. sta_ver = lines[0].rstrip()
  243. commit = lines[1].rstrip()
  244. ret = [sta_ver, commit]
  245. # TODO: Get dynamic version in a subprocess.
  246. ret.append(None)
  247. return ret
  248. def __str__(self):
  249. return f"({self.name}, {id(self)})"
  250. class RepositoryGroupInstaller(object):
  251. """RepositoryGroupInstaller"""
  252. def __init__(self, repos):
  253. super().__init__()
  254. self.repos = repos
  255. def install(
  256. self,
  257. force_reinstall=False,
  258. no_deps=False,
  259. constraints=None,
  260. deps_to_replace=None,
  261. ):
  262. """install"""
  263. # Rollback on failure is not yet supported. A failed installation
  264. # could leave a broken environment.
  265. if force_reinstall:
  266. self.uninstall()
  267. ins_flags = []
  268. repos = self._sort_repos(self.repos, check_missing=True)
  269. for repo in repos:
  270. if force_reinstall or not repo.check_installation():
  271. ins_flags.append(True)
  272. else:
  273. ins_flags.append(False)
  274. if not no_deps:
  275. # We collect the dependencies and install them all at once
  276. # such that we can make use of the pip resolver.
  277. self.install_deps(constraints=constraints, deps_to_replace=deps_to_replace)
  278. # XXX: For historical reasons the repo packages are sequentially
  279. # installed, and we have no failure rollbacks. Meanwhile, installation
  280. # failure of one repo package aborts the entire installation process.
  281. for ins_flag, repo in zip(ins_flags, repos):
  282. if ins_flag:
  283. if repo.name in ["PaddleVideo"]:
  284. repo.install_package(
  285. no_deps=True,
  286. install_extra_only=True,
  287. )
  288. else:
  289. repo.install_package(no_deps=True)
  290. def uninstall(self):
  291. """uninstall"""
  292. repos = self._sort_repos(self.repos, check_missing=False)
  293. repos = repos[::-1]
  294. for repo in repos:
  295. if repo.check_installation():
  296. # NOTE: Dependencies are not uninstalled.
  297. repo.uninstall_package()
  298. def get_deps(self, deps_to_replace=None):
  299. """get_deps"""
  300. deps_list = []
  301. repos = self._sort_repos(self.repos, check_missing=True)
  302. for repo in repos:
  303. if repo.name in ["PaddleVideo"]:
  304. deps = repo.get_deps(
  305. install_extra_only=True, deps_to_replace=deps_to_replace
  306. )
  307. else:
  308. deps = repo.get_deps(deps_to_replace=deps_to_replace)
  309. deps = self._normalize_deps(deps, headline=f"# {repo.name} dependencies")
  310. deps_list.append(deps)
  311. # Add an extra new line to separate dependencies of different repos.
  312. return "\n\n".join(deps_list)
  313. def install_deps(self, constraints, deps_to_replace=None):
  314. """install_deps"""
  315. deps_str = self.get_deps(deps_to_replace=deps_to_replace)
  316. with tempfile.TemporaryDirectory() as td:
  317. req_file = os.path.join(td, "requirements.txt")
  318. with open(req_file, "w", encoding="utf-8") as fr:
  319. fr.write(deps_str)
  320. if constraints is not None:
  321. cons_file = os.path.join(td, "constraints.txt")
  322. with open(cons_file, "w", encoding="utf-8") as fc:
  323. fc.write(constraints)
  324. cons_files = [cons_file]
  325. else:
  326. cons_files = []
  327. install_packages_using_pip([], req_files=[req_file], cons_files=cons_files)
  328. def _sort_repos(self, repos, check_missing=False):
  329. # We sort the repos to ensure that the dependencies precede the
  330. # dependant in the list.
  331. name_meta_pairs = []
  332. for repo in repos:
  333. name_meta_pairs.append((repo.name, repo.meta))
  334. unique_pairs = []
  335. hashset = set()
  336. for name, meta in name_meta_pairs:
  337. if name in hashset:
  338. continue
  339. else:
  340. unique_pairs.append((name, meta))
  341. hashset.add(name)
  342. sorted_repos = []
  343. missing_names = []
  344. name2repo = {repo.name: repo for repo in repos}
  345. for name, meta in unique_pairs:
  346. if name in name2repo:
  347. repo = name2repo[name]
  348. sorted_repos.append(repo)
  349. else:
  350. missing_names.append(name)
  351. if check_missing and len(missing_names) > 0:
  352. be = "is" if len(missing_names) == 1 else "are"
  353. raise RuntimeError(f"{missing_names} {be} required in the installation.")
  354. else:
  355. assert len(sorted_repos) == len(self.repos)
  356. return sorted_repos
  357. def _normalize_deps(self, deps, headline=None):
  358. repo_pkgs = set(repo.pkg_name for repo in self.repos)
  359. normed_lines = []
  360. if headline is not None:
  361. normed_lines.append(headline)
  362. for line in deps.splitlines():
  363. line_s = line.strip()
  364. if len(line_s) == 0 or line_s.startswith("#"):
  365. continue
  366. # If `line` is not a comment, it must be a requirement specifier.
  367. # Other forms may cause a parse error.
  368. n, e, v, m = to_dep_spec_pep508(line_s)
  369. if isinstance(v, str):
  370. raise RuntimeError("Currently, URL based lookup is not supported.")
  371. if n in repo_pkgs:
  372. # Skip repo packages
  373. continue
  374. elif check_installation_using_pip(n):
  375. continue
  376. else:
  377. line_n = [n]
  378. fe = f"[{','.join(e)}]" if e else ""
  379. if fe:
  380. line_n.append(fe)
  381. fv = []
  382. for tup in v:
  383. fv.append(" ".join(tup))
  384. fv = ", ".join(fv) if fv else ""
  385. if fv:
  386. line_n.append(fv)
  387. if m is not None:
  388. fm = f"; {env_marker_ast2expr(m)}"
  389. line_n.append(fm)
  390. line_n = " ".join(line_n)
  391. normed_lines.append(line_n)
  392. return "\n".join(normed_lines)
  393. class RepositoryGroupGetter(object):
  394. """RepositoryGroupGetter"""
  395. def __init__(self, repos):
  396. super().__init__()
  397. self.repos = repos
  398. def get(self, force=False, platform=None):
  399. """clone"""
  400. if force:
  401. self.remove()
  402. for repo in self.repos:
  403. repo.download()
  404. repo.update(platform=platform)
  405. def remove(self):
  406. """remove"""
  407. for repo in self.repos:
  408. repo.remove()