|
@@ -12,8 +12,15 @@
|
|
|
# See the License for the specific language governing permissions and
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
+from functools import lru_cache
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
+import huggingface_hub as hf_hub
|
|
|
|
|
+
|
|
|
|
|
+hf_hub.logging.set_verbosity_error()
|
|
|
|
|
+
|
|
|
|
|
+import requests
|
|
|
|
|
+
|
|
|
from ...utils import logging
|
|
from ...utils import logging
|
|
|
from ...utils.cache import CACHE_DIR
|
|
from ...utils.cache import CACHE_DIR
|
|
|
from ...utils.download import download_and_extract
|
|
from ...utils.download import download_and_extract
|
|
@@ -352,17 +359,47 @@ PP-OCRv5_mobile_rec_infer.tar",
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+@lru_cache(1)
|
|
|
|
|
+def is_huggingface_accessible():
|
|
|
|
|
+ try:
|
|
|
|
|
+ response = requests.get("https://huggingface.co", timeout=1)
|
|
|
|
|
+ return response.ok == True
|
|
|
|
|
+ except requests.exceptions.RequestException as e:
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
class OfficialModelsDict(dict):
|
|
class OfficialModelsDict(dict):
|
|
|
"""Official Models Dict"""
|
|
"""Official Models Dict"""
|
|
|
|
|
|
|
|
|
|
+ _save_dir = Path(CACHE_DIR) / "official_models"
|
|
|
|
|
+
|
|
|
def __getitem__(self, key):
|
|
def __getitem__(self, key):
|
|
|
- url = super().__getitem__(key)
|
|
|
|
|
- save_dir = Path(CACHE_DIR) / "official_models"
|
|
|
|
|
|
|
+ def _download_from_bos():
|
|
|
|
|
+ url = super(OfficialModelsDict, self).__getitem__(key)
|
|
|
|
|
+ download_and_extract(url, self._save_dir, f"{key}", overwrite=False)
|
|
|
|
|
+ return self._save_dir / f"{key}"
|
|
|
|
|
+
|
|
|
|
|
+ def _download_from_hf():
|
|
|
|
|
+ local_dir = self._save_dir / f"{key}"
|
|
|
|
|
+ try:
|
|
|
|
|
+ hf_hub.snapshot_download(
|
|
|
|
|
+ repo_id=f"PaddlePaddle/{key}", local_dir=local_dir
|
|
|
|
|
+ )
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logging.warning(
|
|
|
|
|
+ f"Encounter exception when download model from huggingface: \n{e}.\nPaddleX would try to download from BOS."
|
|
|
|
|
+ )
|
|
|
|
|
+ return _download_from_bos()
|
|
|
|
|
+ return local_dir
|
|
|
|
|
+
|
|
|
logging.info(
|
|
logging.info(
|
|
|
- f"Using official model ({key}), the model files will be automatically downloaded and saved in {save_dir}."
|
|
|
|
|
|
|
+ f"Using official model ({key}), the model files will be automatically downloaded and saved in {self._save_dir}."
|
|
|
)
|
|
)
|
|
|
- download_and_extract(url, save_dir, f"{key}", overwrite=False)
|
|
|
|
|
- return save_dir / f"{key}"
|
|
|
|
|
|
|
+
|
|
|
|
|
+ if is_huggingface_accessible:
|
|
|
|
|
+ return _download_from_hf()
|
|
|
|
|
+ else:
|
|
|
|
|
+ return _download_from_bos()
|
|
|
|
|
|
|
|
|
|
|
|
|
official_models = OfficialModelsDict(OFFICIAL_MODELS)
|
|
official_models = OfficialModelsDict(OFFICIAL_MODELS)
|