Browse Source

download & git pull (#1880)

Tingquan Gao 1 năm trước cách đây
mục cha
commit
403217f1f8

+ 4 - 2
paddlex/paddlex_cli.py

@@ -28,7 +28,7 @@ def args_cfg():
 
     def parse_str(s):
         """convert str type value
-           to None type if it is "None", 
+           to None type if it is "None",
            to bool type if it means True or False.
         """
         if s in ("None"):
@@ -59,6 +59,7 @@ def args_cfg():
         action='store_true',
         default=False,
         help="Use local repos when installing.")
+    parser.add_argument('--force_clone', action='store_true', default=False)
 
     ################# pipeline predict #################
     parser.add_argument('--predict', action='store_true', default=True, help="")
@@ -89,7 +90,8 @@ def install(args):
         no_deps=args.no_deps,
         platform=args.platform,
         update_repos=args.update_repos,
-        use_local_repos=args.use_local_repos)
+        use_local_repos=args.use_local_repos,
+        force_clone=args.force_clone)
     return
 
 

+ 4 - 3
paddlex/repo_manager/core.py

@@ -1,5 +1,5 @@
 # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-# 
+#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
@@ -89,7 +89,8 @@ def setup(repo_names,
           constraints=None,
           platform=None,
           update_repos=False,
-          use_local_repos=False):
+          use_local_repos=False,
+          force_clone=False):
     """ setup """
     repo_names = list(set(_parse_repo_deps(repo_names)))
 
@@ -162,7 +163,7 @@ def setup(repo_names,
     installer = build_repo_group_installer(*repos_to_install)
 
     logging.info("Now cloning the repos...")
-    cloner.clone(force_reclone=False, platform=platform)
+    cloner.clone(force_clone=force_clone, platform=platform)
     logging.info("All repos are existing.")
 
     if not no_deps:

+ 11 - 9
paddlex/repo_manager/meta.py

@@ -14,6 +14,8 @@
 
 __all__ = ['get_all_repo_names']
 
+REPO_DOWNLOAD_BASE = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/repos/"
+
 REPO_NAMES = [
     'PaddleClas', 'PaddleOCR', 'PaddleDetection', 'PaddleSeg', 'PaddleNLP',
     'PaddleTS'
@@ -21,7 +23,7 @@ REPO_NAMES = [
 
 REPO_META = {
     'PaddleSeg': {
-        'repo_url': '/PaddlePaddle/PaddleSeg.git',
+        'git_url': '/PaddlePaddle/PaddleSeg.git',
         'platform': 'github',
         'branch': 'develop',
         'pkg_name': 'paddleseg',
@@ -32,7 +34,7 @@ REPO_META = {
         'path_env': 'PADDLE_PDX_PADDLESEG_PATH',
     },
     'PaddleClas': {
-        'repo_url': '/PaddlePaddle/PaddleClas.git',
+        'git_url': '/PaddlePaddle/PaddleClas.git',
         'platform': 'github',
         'branch': 'develop',
         'pkg_name': 'paddleclas',
@@ -44,7 +46,7 @@ REPO_META = {
         'path_env': 'PADDLE_PDX_PADDLECLAS_PATH',
     },
     'PaddleDetection': {
-        'repo_url': '/PaddlePaddle/PaddleDetection.git',
+        'git_url': '/PaddlePaddle/PaddleDetection.git',
         'platform': 'github',
         'branch': 'develop',
         'pkg_name': 'paddledet',
@@ -54,7 +56,7 @@ REPO_META = {
         'path_env': 'PADDLE_PDX_PADDLEDETECTION_PATH',
     },
     'PaddleOCR': {
-        'repo_url': '/PaddlePaddle/PaddleOCR.git',
+        'git_url': '/PaddlePaddle/PaddleOCR.git',
         'platform': 'github',
         'branch': 'main',
         'pkg_name': 'paddleocr',
@@ -66,7 +68,7 @@ REPO_META = {
         'requires': ['PaddleNLP'],
     },
     'PaddleTS': {
-        'repo_url': '/PaddlePaddle/PaddleTS.git',
+        'git_url': '/PaddlePaddle/PaddleTS.git',
         'platform': 'github',
         'branch': 'release_v1.1',
         'pkg_name': 'paddlets',
@@ -77,7 +79,7 @@ REPO_META = {
         'pdx_pkg_deps': ['pandas', 'ruamel.yaml'],
     },
     'PaddleNLP': {
-        'repo_url': '/PaddlePaddle/PaddleNLP.git',
+        'git_url': '/PaddlePaddle/PaddleNLP.git',
         'platform': 'github',
         'branch': 'release/2.9',
         'pkg_name': 'paddlenlp',
@@ -87,7 +89,7 @@ REPO_META = {
         'path_env': 'PADDLE_PDX_PADDLENLP_PATH',
     },
     'PaddleSpeech': {
-        'repo_url': '/PaddlePaddle/PaddleSpeech.git',
+        'git_url': '/PaddlePaddle/PaddleSpeech.git',
         'platform': 'github',
         'branch': 'develop',
         'pkg_name': 'paddlespeech',
@@ -98,7 +100,7 @@ REPO_META = {
         'requires': ['PaddleNLP'],
     },
     'PARL': {
-        'repo_url': '/PaddlePaddle/PARL.git',
+        'git_url': '/PaddlePaddle/PARL.git',
         'platform': 'github',
         'branch': 'develop',
         'pkg_name': 'parl',
@@ -108,7 +110,7 @@ REPO_META = {
         'path_env': 'PADDLE_PDX_PARL_PATH',
     },
     'PaddleMIX': {
-        'repo_url': '/PaddlePaddle/PaddleMIX.git',
+        'git_url': '/PaddlePaddle/PaddleMIX.git',
         'platform': 'github',
         'branch': 'develop',
         'pkg_name': 'paddlemix',

+ 26 - 13
paddlex/repo_manager/repo.py

@@ -1,5 +1,5 @@
 # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-# 
+#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
@@ -19,7 +19,8 @@ import tempfile
 import shutil
 
 from ..utils import logging
-from .meta import get_repo_meta
+from ..utils.download import download_and_extract
+from .meta import get_repo_meta, REPO_DOWNLOAD_BASE
 from .utils import (install_packages_using_pip, clone_repos_using_git,
                     update_repos_using_git, uninstall_package_using_pip,
                     remove_repos_using_rm, check_installation_using_pip,
@@ -60,7 +61,7 @@ class PPRepository(object):
         self.root_dir = osp.join(repo_parent_dir, self.name)
 
         self.meta = get_repo_meta(self.name)
-        self.repo_url = self.meta['repo_url']
+        self.git_url = self.meta['git_url']
         self.pkg_name = self.meta['pkg_name']
         self.lib_name = self.meta['lib_name']
         self.pdx_mod_name = pdx_collection_mod.__name__ + '.' + self.meta[
@@ -149,22 +150,30 @@ class PPRepository(object):
     def clone_repos(self, platform=None):
         """ clone_repos """
         branch = self.meta.get('branch', None)
-        repo_url = f'https://{platform}{self.repo_url}'
+        git_url = f'https://{platform}{self.git_url}'
         # uncomment this if you prefer using ssh connection (requires additional setup)
         # if platform == 'github.com':
-        #    repo_url = f'git@github.com:{self.repo_url}'
+        #    git_url = f'git@github.com:{self.git_url}'
         os.makedirs(self.repo_parent_dir, exist_ok=True)
         with switch_working_dir(self.repo_parent_dir):
-            clone_repos_using_git(repo_url, branch=branch)
+            clone_repos_using_git(git_url, branch=branch)
+
+    def download_repos(self):
+        """ download and pull repos """
+        download_url = f'{REPO_DOWNLOAD_BASE}{self.name}.tar'
+        os.makedirs(self.repo_parent_dir, exist_ok=True)
+        download_and_extract(download_url, self.repo_parent_dir, self.name)
 
-    def update_repos(self):
+    def update_repos(self, platform=None):
         """ update_repos """
+        branch = self.meta.get('branch', None)
+        git_url = f'https://{platform}{self.git_url}'
         with switch_working_dir(self.root_dir):
             try:
-                update_repos_using_git()
+                update_repos_using_git(branch=branch, url=git_url)
             except Exception as e:
                 logging.warning(
-                    f"Pull {self.name} from {self.repo_url} failed, check your network connection."
+                    f"Pull {self.name} from {self.git_url} failed, check your network connection."
                 )
 
     def remove_repos(self):
@@ -393,12 +402,16 @@ class RepositoryGroupCloner(object):
         super().__init__()
         self.repos = repos
 
-    def clone(self, force_reclone=False, platform=None):
+    def clone(self, force_clone=False, platform=None):
         """ clone """
-        if force_reclone:
+        if force_clone:
             self.remove()
-        for repo in self.repos:
-            repo.clone_repos(platform=platform)
+            for repo in self.repos:
+                repo.clone_repos(platform=platform)
+        else:
+            for repo in self.repos:
+                repo.download_repos()
+                repo.update_repos(platform=platform)
 
     def remove(self):
         """ remove """

+ 10 - 6
paddlex/repo_manager/utils.py

@@ -1,5 +1,5 @@
 # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-# 
+#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
@@ -12,8 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
-
 import os
 import sys
 import json
@@ -89,10 +87,16 @@ def clone_repos_using_git(url, branch=None):
     return _check_call(args)
 
 
-def update_repos_using_git():
+def update_repos_using_git(branch=None, url=None):
     """ update_repos_using_git """
-    args = ['git', 'pull']
-    return _check_call(args)
+    if url:
+        args = ['git', 'fetch', url, branch]
+        _check_call(args)
+        args = ['git', 'merge', 'FETCH_HEAD']
+        return _check_call(args)
+    else:
+        args = ['git', 'pull']
+        return _check_call(args)
 
 
 def remove_repos_using_rm(name):

+ 3 - 5
paddlex/utils/download.py

@@ -1,5 +1,5 @@
 # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-# 
+#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
@@ -12,8 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
-
 import os
 import sys
 import time
@@ -149,7 +147,7 @@ def download_and_extract(url,
                          overwrite=False,
                          no_interm_dir=True):
     """ download and extract """
-    # NOTE: `url` MUST come from a trusted source, since we do not provide a solution 
+    # NOTE: `url` MUST come from a trusted source, since we do not provide a solution
     # to secure against CVE-2007-4559.
     os.makedirs(save_dir, exist_ok=True)
     dst_path = os.path.join(save_dir, dst_name)
@@ -174,7 +172,7 @@ def download_and_extract(url,
                     raise FileNotFoundError
                 dp = os.path.join(save_dir, file_name)
                 if os.path.isdir(sp):
-                    shutil.copytree(sp, dp)
+                    shutil.copytree(sp, dp, symlinks=True)
                 else:
                     shutil.copyfile(sp, dp)
                 extd_file = dp