Переглянути джерело

dbg:

1. use cpu to export model
2. automatically install dependent repo(s)
gaotingquan 1 рік тому
батько
коміт
ff72ae1a39

+ 2 - 2
paddlex/modules/base/trainer/train_deamon.py

@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import os
 import sys
 import time
@@ -203,7 +202,8 @@ class BaseTrainDeamon(ABC):
         if model_name not in self.models:
             config, model = build_model(
                 model_name,
-                device=self.global_config.device,
+                # using CPU to export model
+                device="cpu",
                 config_path=config_path)
             self.models[model_name] = model
         return self.models[model_name]

+ 12 - 3
paddlex/repo_manager/core.py

@@ -12,15 +12,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
-
 import os
 import sys
 from collections import OrderedDict
 
 from ..utils import logging
 from .utils import install_deps_using_pip
-from .meta import get_all_repo_names
+from .meta import get_all_repo_names, get_repo_meta
 from .repo import build_repo_instance, build_repo_group_cloner, build_repo_group_installer
 
 __all__ = [
@@ -29,6 +27,15 @@ __all__ = [
 ]
 
 
+def _parse_repo_deps(repos):
+    ret = []
+    for repo_name in repos:
+        repo_meta = get_repo_meta(repo_name)
+        ret.extend(_parse_repo_deps(repo_meta.get('requires', [])))
+        ret.append(repo_name)
+    return ret
+
+
 class _GlobalContext(object):
     REPO_PARENT_DIR = None
     PDX_COLLECTION_MOD = None
@@ -83,6 +90,8 @@ def setup(repo_names,
           platform=None,
           update_repos=False):
     """ setup """
+    repo_names = list(set(_parse_repo_deps(repo_names)))
+
     repos = []
     for repo_name in repo_names:
         repo = _GlobalContext.build_repo_instance(repo_name)

+ 1 - 10
paddlex/repo_manager/repo.py

@@ -12,8 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
-
 import os
 import os.path as osp
 import importlib
@@ -317,16 +315,9 @@ class RepositoryGroupInstaller(object):
     def _sort_repos(self, repos, check_missing=False):
         # We sort the repos to ensure that the dependencies precede the
         # dependant in the list.
-        def _parse_repo_deps(name, repo_meta):
-            ret = []
-            for n in repo_meta.get('requires', []):
-                ret.extend(_parse_repo_deps(n, get_repo_meta(n)))
-            ret.append((name, repo_meta))
-            return ret
-
         name_meta_pairs = []
         for repo in repos:
-            name_meta_pairs.extend(_parse_repo_deps(repo.name, repo.meta))
+            name_meta_pairs.append((repo.name, repo.meta))
 
         unique_pairs = []
         hashset = set()