Browse Source

Merge pull request #78 from SunAhong1993/syf0519

fix the vis.py
Jason 5 years ago
parent
commit
311cab5efe
2 changed files with 19 additions and 12 deletions
  1. 18 4
      paddlex/interpret/core/_session_preparation.py
  2. 1 8
      paddlex/interpret/visualize.py

+ 18 - 4
paddlex/interpret/core/_session_preparation.py

@@ -13,15 +13,29 @@
 #limitations under the License.
 
 import os
+import os.path as osp
 import paddle.fluid as fluid
 import numpy as np
 from paddle.fluid.param_attr import ParamAttr
 from ..as_data_reader.readers import preprocess_image
 
-root_path = os.environ['HOME']
-root_path = os.path.join(root_path, '.paddlex')
-h_pre_models = os.path.join(root_path, "pre_models")
-h_pre_models_kmeans = os.path.join(h_pre_models, "kmeans_model.pkl")
+def gen_user_home():
+    if "HOME" in os.environ:
+        home_path = os.environ["HOME"]
+        if os.path.exists(home_path) and os.path.isdir(home_path):
+            return home_path
+    return os.path.expanduser('~')
+
+
+root_path = gen_user_home()
+root_path = osp.join(root_path, '.paddlex')
+h_pre_models = osp.join(root_path, "pre_models")
+if not osp.exists(h_pre_models):
+    if not osp.exists(root_path):
+        os.makedirs(root_path)
+    url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
+    pdx.utils.download_and_decompress(url, path=root_path)
+h_pre_models_kmeans = osp.join(h_pre_models, "kmeans_model.pkl")
 
 
 def paddle_get_fc_weights(var_name="fc_0.w_0"):

+ 1 - 8
paddlex/interpret/visualize.py

@@ -21,14 +21,7 @@ import paddlex as pdx
 from .interpretation_predict import interpretation_predict
 from .core.interpretation import Interpretation
 from .core.normlime_base import precompute_normlime_weights
-
-
-def gen_user_home():
-    if "HOME" in os.environ:
-        home_path = os.environ["HOME"]
-        if os.path.exists(home_path) and os.path.isdir(home_path):
-            return home_path
-    return os.path.expanduser('~')
+from .core._session_preparation import gen_user_home
 
 def visualize(img_file, 
               model,