Ver Fonte

support to download inference model from huggingface hub (#4121)

Tingquan Gao há 5 meses atrás
pai
commit
3836897674
3 ficheiros alterados com 45 adições e 5 exclusões
  1. 1 0
      .precommit/check_imports.py
  2. 42 5
      paddlex/inference/utils/official_models.py
  3. 2 0
      setup.py

+ 1 - 0
.precommit/check_imports.py

@@ -43,6 +43,7 @@ MOD_TO_DEP = {
     "filetype": "filetype",
     "ftfy": "ftfy",
     "GPUtil": "GPUtil",
+    "huggingface_hub": "huggingface_hub",
     "imagesize": "imagesize",
     "jinja2": "Jinja2",
     "joblib": "joblib",

+ 42 - 5
paddlex/inference/utils/official_models.py

@@ -12,8 +12,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from functools import lru_cache
 from pathlib import Path
 
+import huggingface_hub as hf_hub
+
+hf_hub.logging.set_verbosity_error()
+
+import requests
+
 from ...utils import logging
 from ...utils.cache import CACHE_DIR
 from ...utils.download import download_and_extract
@@ -352,17 +359,47 @@ PP-OCRv5_mobile_rec_infer.tar",
 }
 
 
+@lru_cache(1)
+def is_huggingface_accessible():
+    try:
+        response = requests.get("https://huggingface.co", timeout=1)
+        return response.ok == True
+    except requests.exceptions.RequestException as e:
+        return False
+
+
 class OfficialModelsDict(dict):
     """Official Models Dict"""
 
+    _save_dir = Path(CACHE_DIR) / "official_models"
+
     def __getitem__(self, key):
-        url = super().__getitem__(key)
-        save_dir = Path(CACHE_DIR) / "official_models"
+        def _download_from_bos():
+            url = super(OfficialModelsDict, self).__getitem__(key)
+            download_and_extract(url, self._save_dir, f"{key}", overwrite=False)
+            return self._save_dir / f"{key}"
+
+        def _download_from_hf():
+            local_dir = self._save_dir / f"{key}"
+            try:
+                hf_hub.snapshot_download(
+                    repo_id=f"PaddlePaddle/{key}", local_dir=local_dir
+                )
+            except Exception as e:
+                logging.warning(
+                    f"Encounter exception when download model from huggingface: \n{e}.\nPaddleX would try to download from BOS."
+                )
+                return _download_from_bos()
+            return local_dir
+
         logging.info(
-            f"Using official model ({key}), the model files will be automatically downloaded and saved in {save_dir}."
+            f"Using official model ({key}), the model files will be automatically downloaded and saved in {self._save_dir}."
         )
-        download_and_extract(url, save_dir, f"{key}", overwrite=False)
-        return save_dir / f"{key}"
+
+        if is_huggingface_accessible:
+            return _download_from_hf()
+        else:
+            return _download_from_bos()
 
 
 official_models = OfficialModelsDict(OFFICIAL_MODELS)

+ 2 - 0
setup.py

@@ -34,6 +34,7 @@ DEP_SPECS = {
     "filetype": ">= 1.2",
     "ftfy": "",
     "GPUtil": ">= 1.4",
+    "huggingface_hub": "",
     "imagesize": "",
     "Jinja2": "",
     "joblib": "",
@@ -80,6 +81,7 @@ REQUIRED_DEPS = [
     "colorlog",
     "filelock",
     "GPUtil",
+    "huggingface_hub",
     "numpy",
     "packaging",
     # Currently it is not easy to make `pandas` optional