Prechádzať zdrojové kódy

support modify deps for repos (#3489)

zhangyubo0722 8 mesiacov pred
rodič
commit
0c9edf8d46

+ 8 - 0
paddlex/paddlex_cli.py

@@ -92,6 +92,13 @@ def args_cfg():
         default=False,
         help="Use local repositories if they exist.",
     )
+    install_group.add_argument(
+        "--deps_to_replace",
+        type=str,
+        nargs="+",
+        default=None,
+        help="Replace dependency version when installing from repositories.",
+    )
 
     ################# pipeline predict #################
     pipeline_group.add_argument(
@@ -296,6 +303,7 @@ def install(args):
         platform=args.platform,
         update_repos=args.update_repos,
         use_local_repos=args.use_local_repos,
+        deps_to_replace=args.deps_to_replace,
     )
     return
 

+ 7 - 1
paddlex/repo_manager/core.py

@@ -99,6 +99,7 @@ def setup(
     platform=None,
     update_repos=False,
     use_local_repos=False,
+    deps_to_replace=None,
 ):
     """setup"""
     if update_repos and use_local_repos:
@@ -201,7 +202,12 @@ def setup(
         logging.info(installer.get_deps())
 
     logging.info("Now installing the packages...")
-    installer.install(force_reinstall=False, no_deps=no_deps, constraints=constraints)
+    installer.install(
+        force_reinstall=False,
+        no_deps=no_deps,
+        constraints=constraints,
+        deps_to_replace=deps_to_replace,
+    )
     install_deps_using_pip()
     logging.info("All packages are installed.")
 

+ 49 - 9
paddlex/repo_manager/repo.py

@@ -19,6 +19,7 @@ import tempfile
 import shutil
 
 from ..utils import logging
+from ..utils.file_interface import custom_open
 from ..utils.download import download_and_extract
 from .meta import get_repo_meta, REPO_DOWNLOAD_BASE
 from .utils import (
@@ -99,6 +100,25 @@ class PPRepository(object):
             # TODO: Also check if correct dependencies are installed.
             return check_installation_using_pip(self.pkg_name)
 
+    def replace_repo_deps(self, deps_to_replace, src_requirements):
+        """replace_repo_deps"""
+        with custom_open(src_requirements, "r") as file:
+            lines = file.readlines()
+        existing_deps = []
+        for line in lines:
+            line = line.strip()
+            if not line or line.startswith("#"):
+                continue
+            dep_to_replace = next((dep for dep in deps_to_replace if dep in line), None)
+            if dep_to_replace:
+                existing_deps.append(
+                    f"{dep_to_replace}=={deps_to_replace[dep_to_replace]}"
+                )
+            else:
+                existing_deps.append(line)
+        with open(src_requirements, "w") as file:
+            file.writelines([l + "\n" for l in existing_deps])
+
     def check_repo_exiting(self, quick_check=False):
         """check_repo_exiting"""
         return os.path.exists(os.path.join(self.root_dir, ".git"))
@@ -217,7 +237,7 @@ class PPRepository(object):
         """get_pdx"""
         return importlib.import_module(self.pdx_mod_name)
 
-    def get_deps(self, install_extra_only=False):
+    def get_deps(self, install_extra_only=False, deps_to_replace=None):
         """get_deps"""
         # Merge requirement files
         if install_extra_only:
@@ -225,6 +245,15 @@ class PPRepository(object):
         else:
             req_list = [self.main_req_file]
         req_list.extend(self.meta.get("extra_req_files", []))
+        if deps_to_replace is not None:
+            deps_dict = {}
+            for dep in deps_to_replace:
+                part, version = dep.split("=")
+                repo_name, dep_name = part.split(".")
+                deps_dict[repo_name] = {dep_name: version}
+            src_requirements = os.path.join(self.root_dir, "requirements.txt")
+            if self.name in deps_dict:
+                self.replace_repo_deps(deps_dict[self.name], src_requirements)
         deps = []
         for req in req_list:
             with open(osp.join(self.root_dir, req), "r", encoding="utf-8") as f:
@@ -257,7 +286,13 @@ class RepositoryGroupInstaller(object):
         super().__init__()
         self.repos = repos
 
-    def install(self, force_reinstall=False, no_deps=False, constraints=None):
+    def install(
+        self,
+        force_reinstall=False,
+        no_deps=False,
+        constraints=None,
+        deps_to_replace=None,
+    ):
         """install"""
         # Rollback on failure is not yet supported. A failed installation
         # could leave a broken environment.
@@ -273,14 +308,17 @@ class RepositoryGroupInstaller(object):
         if not no_deps:
             # We collect the dependencies and install them all at once
             # such that we can make use of the pip resolver.
-            self.install_deps(constraints=constraints)
+            self.install_deps(constraints=constraints, deps_to_replace=deps_to_replace)
         # XXX: For historical reasons the repo packages are sequentially
         # installed, and we have no failure rollbacks. Meanwhile, installation
         # failure of one repo package aborts the entire installation process.
         for ins_flag, repo in zip(ins_flags, repos):
             if ins_flag:
                 if repo.name in ["PaddleVideo"]:
-                    repo.install_package(no_deps=True, install_extra_only=True)
+                    repo.install_package(
+                        no_deps=True,
+                        install_extra_only=True,
+                    )
                 else:
                     repo.install_package(no_deps=True)
 
@@ -293,23 +331,25 @@ class RepositoryGroupInstaller(object):
                 # NOTE: Dependencies are not uninstalled.
                 repo.uninstall_package()
 
-    def get_deps(self):
+    def get_deps(self, deps_to_replace=None):
         """get_deps"""
         deps_list = []
         repos = self._sort_repos(self.repos, check_missing=True)
         for repo in repos:
             if repo.name in ["PaddleVideo"]:
-                deps = repo.get_deps(install_extra_only=True)
+                deps = repo.get_deps(
+                    install_extra_only=True, deps_to_replace=deps_to_replace
+                )
             else:
-                deps = repo.get_deps()
+                deps = repo.get_deps(deps_to_replace=deps_to_replace)
             deps = self._normalize_deps(deps, headline=f"# {repo.name} dependencies")
             deps_list.append(deps)
         # Add an extra new line to separate dependencies of different repos.
         return "\n\n".join(deps_list)
 
-    def install_deps(self, constraints):
+    def install_deps(self, constraints, deps_to_replace=None):
         """install_deps"""
-        deps_str = self.get_deps()
+        deps_str = self.get_deps(deps_to_replace=deps_to_replace)
         with tempfile.TemporaryDirectory() as td:
             req_file = os.path.join(td, "requirements.txt")
             with open(req_file, "w", encoding="utf-8") as fr: