浏览代码

support PADDLE_PDX_MODEL_SOURCE to set the downloading source (#4180)

Tingquan Gao 5 月之前
父节点
当前提交
abbd54b428
共有 2 个文件被更改,包括 11 次插入1 次删除
  1. 10 1
      paddlex/inference/utils/official_models.py
  2. 1 0
      paddlex/utils/flags.py

+ 10 - 1
paddlex/inference/utils/official_models.py

@@ -24,6 +24,7 @@ import requests
 from ...utils import logging
 from ...utils.cache import CACHE_DIR
 from ...utils.download import download_and_extract
+from ...utils.flags import MODEL_SOURCE
 
 OFFICIAL_MODELS = {
     "ResNet18": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0.0/ResNet18_infer.tar",
@@ -459,8 +460,16 @@ class OfficialModelsDict(dict):
             f"Using official model ({key}), the model files will be automatically downloaded and saved in {self._save_dir}."
         )
 
-        if is_huggingface_accessible and key in HUGGINGFACE_MODELS:
+        if (
+            MODEL_SOURCE.lower() == "huggingface"
+            and is_huggingface_accessible
+            and key in HUGGINGFACE_MODELS
+        ):
             return _download_from_hf()
+        elif MODEL_SOURCE.lower() == "modelscope":
+            raise Exception(
+                f"ModelScope is not supported! Please use `HuggingFace` or `BOS`."
+            )
         else:
             return _download_from_bos()
 

+ 1 - 0
paddlex/utils/flags.py

@@ -52,6 +52,7 @@ USE_PIR_TRT = get_flag_from_env_var("PADDLE_PDX_USE_PIR_TRT", True)
 DISABLE_DEV_MODEL_WL = get_flag_from_env_var("PADDLE_PDX_DISABLE_DEV_MODEL_WL", False)
 DISABLE_CINN_MODEL_WL = get_flag_from_env_var("PADDLE_PDX_DISABLE_CINN_MODEL_WL", False)
 LOCAL_FONT_FILE_PATH = get_flag_from_env_var("PADDLE_PDX_LOCAL_FONT_FILE_PATH", None)
+MODEL_SOURCE = os.environ.get("PADDLE_PDX_MODEL_SOURCE", "huggingface")
 
 
 # Inference Benchmark