Procházet zdrojové kódy

Merge pull request #595 from FlyingQianMM/develop_nnn

remove paddlehub because paddlehub2.0.0 cannot download pretrained weights correctly
Jason před 4 roky
rodič
revize
708ff2d2c5

+ 4 - 3
paddlex/__init__.py

@@ -15,6 +15,7 @@
 from __future__ import absolute_import
 
 __version__ = '1.3.5'
+gui_mode = True
 
 import os
 if 'FLAGS_eager_delete_tensor_gb' not in os.environ:
@@ -32,9 +33,9 @@ if version[0] == '1':
         raise Exception(
             'For running paddlex(v{}), Version of paddlepaddle should be greater than 1.8.3'.
             format(__version__))
-    import paddlehub as hub
-    if hub.__version__.strip().split('.')[0] > '1':
-        raise Exception("Try to reinstall Paddlehub by 'pip install paddlehub==1.8.2' while paddlepaddle < 2.0")
+    #import paddlehub as hub
+    #if hub.__version__.strip().split('.')[0] > '1':
+    #    raise Exception("Try to reinstall Paddlehub by 'pip install paddlehub==1.8.2' while paddlepaddle < 2.0")
 
 if hasattr(paddle, 'enable_static'):
     paddle.enable_static()

+ 19 - 19
paddlex/cv/models/slim/prune_config.py

@@ -15,7 +15,6 @@
 import numpy as np
 import os.path as osp
 import paddle.fluid as fluid
-#import paddlehub as hub
 import paddlex
 
 sensitivities_data = {
@@ -130,25 +129,26 @@ def get_sensitivities(flag, model, save_dir):
             model_type)
         url = sensitivities_data[model_type]
         fname = osp.split(url)[-1]
-        paddlex.utils.download(url, path=save_dir)
-        return osp.join(save_dir, fname)
+        if getattr(paddlex, 'gui_mode', False):
+            paddlex.utils.download(url, path=save_dir)
+            return osp.join(save_dir, fname)
 
-#        try:
-#            hub.download(fname, save_path=save_dir)
-#        except Exception as e:
-#            if isinstance(e, hub.ResourceNotFoundError):
-#                raise Exception(
-#                    "Resource for model {}(key='{}') not found".format(
-#                        model_type, fname))
-#            elif isinstance(e, hub.ServerConnectionError):
-#                raise Exception(
-#                    "Cannot get reource for model {}(key='{}'), please check your internet connection"
-#                    .format(model_type, fname))
-#            else:
-#                raise Exception(
-#                    "Unexpected error, please make sure paddlehub >= 1.6.2 {}".
-#                    format(str(e)))
-#        return osp.join(save_dir, fname)
+        import paddlehub as hub
+        try:
+            hub.download(fname, save_path=save_dir)
+        except Exception as e:
+            if isinstance(e, hub.ResourceNotFoundError):
+                raise Exception("Resource for model {}(key='{}') not found".
+                                format(model_type, fname))
+            elif isinstance(e, hub.ServerConnectionError):
+                raise Exception(
+                    "Cannot get reource for model {}(key='{}'), please check your internet connection"
+                    .format(model_type, fname))
+            else:
+                raise Exception(
+                    "Unexpected error, please make sure paddlehub >= 1.6.2 {}".
+                    format(str(e)))
+        return osp.join(save_dir, fname)
     else:
         raise Exception(
             "sensitivities need to be defined as directory path or `DEFAULT`(download sensitivities automatically)."

+ 4 - 4
paddlex/cv/models/utils/pretrain_weights.py

@@ -267,10 +267,10 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir):
                 "Couldn't download pretrain weight, you can download it manualy from {} (decompress the file if it is a compressed file), and set pretrain weights by your self".
                 format(url),
                 exit=False)
-            if isinstance(hub.ResourceNotFoundError):
+            if isinstance(e, hub.ResourceNotFoundError):
                 raise Exception("Resource for backbone {} not found".format(
                     backbone))
-            elif isinstance(hub.ServerConnectionError):
+            elif isinstance(e, hub.ServerConnectionError):
                 raise Exception(
                     "Cannot get reource for backbone {}, please check your internet connection"
                     .format(backbone))
@@ -300,10 +300,10 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir):
                 "Couldn't download pretrain weight, you can download it manualy from {} (decompress the file if it is a compressed file), and set pretrain weights by your self".
                 format(url),
                 exit=False)
-            if isinstance(hub.ResourceNotFoundError):
+            if isinstance(e, hub.ResourceNotFoundError):
                 raise Exception("Resource for backbone {} not found".format(
                     backbone))
-            elif isinstance(hub.ServerConnectionError):
+            elif isinstance(e, hub.ServerConnectionError):
                 raise Exception(
                     "Cannot get reource for backbone {}, please check your internet connection"
                     .format(backbone))

+ 11 - 3
setup.py

@@ -38,9 +38,17 @@ setuptools.setup(
     include_data_files=True,
     setup_requires=['cython', 'numpy'],
     install_requires=[
-        "pycocotools;platform_system!='Windows'", 'pyyaml', 'colorama', 'tqdm',
-        'paddleslim==1.1.1', 'visualdl>=2.0.0', 'paddlehub==1.8.2',
-        'shapely>=1.7.0', 'opencv-python', 'flask_cors', 'sklearn', 'psutil',
+        "pycocotools;platform_system!='Windows'",
+        'pyyaml',
+        'colorama',
+        'tqdm',
+        'paddleslim==1.1.1',
+        'visualdl>=2.0.0',  #'paddlehub==1.8.2',
+        'shapely>=1.7.0',
+        'opencv-python',
+        'flask_cors',
+        'sklearn',
+        'psutil',
         'xlwt'
     ],
     classifiers=[